Skip to content
Snippets Groups Projects
Select Git revision
  • 2d925329a8a53ae7a77a402a379dc7261a60a793
  • master default protected
  • v2.0-dev protected
  • zikeliml/Task-96-dotExporterForAST
  • zikeliml/124-rework-tutorials
  • fma
  • fhennig/v2.0-deprecations
  • holzer-master-patch-46757
  • 66-absolute-access-is-probably-not-copied-correctly-after-_eval_subs
  • gpu_bufferfield_fix
  • hyteg
  • vectorization_sqrt_fix
  • target_dh_refactoring
  • const_fix
  • improved_comm
  • gpu_liveness_opts
  • release/1.3.7 protected
  • release/1.3.6 protected
  • release/2.0.dev0 protected
  • release/1.3.5 protected
  • release/1.3.4 protected
  • release/1.3.3 protected
  • release/1.3.2 protected
  • release/1.3.1 protected
  • release/1.3 protected
  • release/1.2 protected
  • release/1.1.1 protected
  • release/1.1 protected
  • release/1.0.1 protected
  • release/1.0 protected
  • release/0.4.4 protected
  • last/Kerncraft
  • last/OpenCL
  • last/LLVM
  • release/0.4.3 protected
  • release/0.4.2 protected
36 results

random.py

Blame
  • random.py 4.55 KiB
    import sympy as sp
    import numpy as np
    from pystencils import TypedSymbol
    from pystencils.astnodes import LoopOverCoordinate
    from pystencils.backends.cbackend import CustomCodeNode
    
    philox_two_doubles_call = """
    {result_symbols[0].dtype} {result_symbols[0].name};
    {result_symbols[1].dtype} {result_symbols[1].name};
    philox_double2({parameters}, {result_symbols[0].name}, {result_symbols[1].name});
    """
    
    philox_four_floats_call = """
    {result_symbols[0].dtype} {result_symbols[0].name};
    {result_symbols[1].dtype} {result_symbols[1].name};
    {result_symbols[2].dtype} {result_symbols[2].name};
    {result_symbols[3].dtype} {result_symbols[3].name};
    philox_float4({parameters}, 
                  {result_symbols[0].name}, {result_symbols[1].name}, {result_symbols[2].name}, {result_symbols[3].name});
    
    """
    
    
    class PhiloxTwoDoubles(CustomCodeNode):
    
        def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), keys=(0, 0)):
            self.result_symbols = tuple(TypedSymbol(sp.Dummy().name, np.float64) for _ in range(2))
    
            symbols_read = [s for s in keys if isinstance(s, sp.Symbol)]
            super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
            self._time_step = time_step
            self.headers = ['"philox_rand.h"']
            self.keys = list(keys)
            self._args = (time_step, *sp.sympify(keys))
            self._dim = dim
    
        @property
        def args(self):
            return self._args
    
        @property
        def undefined_symbols(self):
            result = {a for a in self.args if isinstance(a, sp.Symbol)}
            loop_counters = [LoopOverCoordinate.get_loop_counter_symbol(i)
                             for i in range(self._dim)]
            result.update(loop_counters)
            return result
    
        def get_code(self, dialect, vector_instruction_set):
            parameters = [self._time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i)
                                              for i in range(self._dim)] + self.keys
    
            while len(parameters) < 6:
                parameters.append(0)
            parameters = parameters[:6]
    
            assert len(parameters) == 6
    
            if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None):
                return philox_two_doubles_call.format(parameters=', '.join(str(p) for p in parameters),
                                                      result_symbols=self.result_symbols)
            else:
                raise NotImplementedError("Not yet implemented for this backend")
    
        def __repr__(self):
            return "{}, {} <- PhiloxRNG".format(*self.result_symbols)
    
    
    class PhiloxFourFloats(CustomCodeNode):
    
        def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), keys=(0, 0)):
            self.result_symbols = tuple(TypedSymbol(sp.Dummy().name, np.float32) for _ in range(4))
            symbols_read = [s for s in keys if isinstance(s, sp.Symbol)]
    
            super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
            self._time_step = time_step
            self.headers = ['"philox_rand.h"']
            self.keys = list(keys)
            self._args = (time_step, *sp.sympify(keys))
            self._dim = dim
    
        @property
        def args(self):
            return self._args
    
        @property
        def undefined_symbols(self):
            result = {a for a in self.args if isinstance(a, sp.Symbol)}
            loop_counters = [LoopOverCoordinate.get_loop_counter_symbol(i)
                             for i in range(self._dim)]
            result.update(loop_counters)
            return result
    
        def get_code(self, dialect, vector_instruction_set):
            parameters = [self._time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i)
                                              for i in range(self._dim)] + self.keys
    
            while len(parameters) < 6:
                parameters.append(0)
            parameters = parameters[:6]
    
            assert len(parameters) == 6
    
            if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None):
                return philox_four_floats_call.format(parameters=', '.join(str(p) for p in parameters),
                                                      result_symbols=self.result_symbols)
            else:
                raise NotImplementedError("Not yet implemented for this backend")
    
        def __repr__(self):
            return "{}, {}, {}, {} <- PhiloxRNG".format(*self.result_symbols)
    
    
    def random_symbol(assignment_list, rng_node=PhiloxTwoDoubles, *args, **kwargs):
        while True:
            node = rng_node(*args, **kwargs)
            inserted = False
            for symbol in node.result_symbols:
                if not inserted:
                    assignment_list.insert(0, node)
                    inserted = True
                yield symbol