Skip to content
Snippets Groups Projects

Make the RNG node behave more like a regular node

Merged Michael Kuron requested to merge rng into master
All threads resolved!
Viewing commit b1adf2ae
Show latest version
1 file
+ 5
6
Preferences
Compare changes
+ 5
6
@@ -19,7 +19,7 @@ def _get_rng_template(name, data_type, num_vars):
return template
def _get_rng_code(template, dialect, vector_instruction_set, args, dim, result_symbols):
def _get_rng_code(template, dialect, vector_instruction_set, args, result_symbols):
if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None):
return template.format(parameters=', '.join(str(a) for a in args),
result_symbols=result_symbols)
@@ -44,14 +44,14 @@ class RNGBase(CustomCodeNode):
if dim < 3:
coordinates.append(0)
self._args = sp.sympify([time_step, *coordinates, *keys])
self.result_symbols = tuple(TypedSymbol(f'random_{self.id}_{i}', self._data_type)
for i in range(self._num_vars))
symbols_read = set.union(*[sp.sympify(s).atoms(sp.Symbol) for s in list(keys) + coordinates + [time_step]])
symbols_read = set.union(*[s.atoms(sp.Symbol) for s in self.args])
super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
self.headers = [f'"{self._name}_rand.h"']
self._args = sp.sympify([time_step, *coordinates, *keys])
self._dim = dim
RNGBase.id += 1
@property
@@ -60,8 +60,7 @@ class RNGBase(CustomCodeNode):
def get_code(self, dialect, vector_instruction_set):
template = _get_rng_template(self._name, self._data_type, self._num_vars)
return _get_rng_code(template, dialect, vector_instruction_set,
self.args, self._dim, self.result_symbols)
return _get_rng_code(template, dialect, vector_instruction_set, self.args, self.result_symbols)
def __repr__(self):
return (", ".join(['{}'] * self._num_vars) + " \\leftarrow {}RNG").format(*self.result_symbols,