diff --git a/pystencils/rng.py b/pystencils/rng.py index b5aeb90a24ce4da576cba748a491c0e11d722e8f..f567e0c1b77b6e07d9107825364b9227cf17b26a 100644 --- a/pystencils/rng.py +++ b/pystencils/rng.py @@ -19,7 +19,7 @@ def _get_rng_template(name, data_type, num_vars): return template -def _get_rng_code(template, dialect, vector_instruction_set, args, dim, result_symbols): +def _get_rng_code(template, dialect, vector_instruction_set, args, result_symbols): if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None): return template.format(parameters=', '.join(str(a) for a in args), result_symbols=result_symbols) @@ -44,14 +44,14 @@ class RNGBase(CustomCodeNode): if dim < 3: coordinates.append(0) + self._args = sp.sympify([time_step, *coordinates, *keys]) self.result_symbols = tuple(TypedSymbol(f'random_{self.id}_{i}', self._data_type) for i in range(self._num_vars)) - symbols_read = set.union(*[sp.sympify(s).atoms(sp.Symbol) for s in list(keys) + coordinates + [time_step]]) + symbols_read = set.union(*[s.atoms(sp.Symbol) for s in self.args]) super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols) self.headers = [f'"{self._name}_rand.h"'] - self._args = sp.sympify([time_step, *coordinates, *keys]) - self._dim = dim + RNGBase.id += 1 @property @@ -60,8 +60,7 @@ class RNGBase(CustomCodeNode): def get_code(self, dialect, vector_instruction_set): template = _get_rng_template(self._name, self._data_type, self._num_vars) - return _get_rng_code(template, dialect, vector_instruction_set, - self.args, self._dim, self.result_symbols) + return _get_rng_code(template, dialect, vector_instruction_set, self.args, self.result_symbols) def __repr__(self): return (", ".join(['{}'] * self._num_vars) + " \\leftarrow {}RNG").format(*self.result_symbols,