diff --git a/ast/arrays.py b/ast/arrays.py index 56f2f03a1b1081fc0f594e87bc689d5c7011d617..198269318437a734be400b81841632f31962aeeb 100644 --- a/ast/arrays.py +++ b/ast/arrays.py @@ -199,6 +199,10 @@ class ArrayAccess: def transform(self, fn): self.array = self.array.transform(fn) self.indexes = [i.transform(fn) for i in self.indexes] + + if self.index is not None: + self.index = self.index.transform(fn) + return fn(self) diff --git a/ast/expr.py b/ast/expr.py index 371ca0e909305fdd7200783cd4bec36ab0905380..93db1df3775f001ccd9315f8d5fd6e02aeaa0730 100644 --- a/ast/expr.py +++ b/ast/expr.py @@ -227,5 +227,6 @@ class BinOp: def transform(self, fn): self.lhs = self.lhs.transform(fn) self.rhs = self.rhs.transform(fn) + self.bin_op_vector_index_mapping = {i: e.transform(fn) for i, e in self.bin_op_vector_index_mapping.items()} return fn(self) diff --git a/sim/particle_simulation.py b/sim/particle_simulation.py index ddabe68d080cfcc07b88b6eb4046e12f11465953..766b7f2ba06725e573b4d56f554197343b9322db 100644 --- a/sim/particle_simulation.py +++ b/sim/particle_simulation.py @@ -194,7 +194,7 @@ class ParticleSimulation: self.global_scope = program Block.set_block_levels(program) Transform.apply(program, Transform.flatten) - #Transform.apply(program, Transform.simplify) + Transform.apply(program, Transform.simplify) #Transform.apply(program, Transform.reuse_index_expressions) #Transform.apply(program, Transform.reuse_expr_expressions) #Transform.apply(program, Transform.reuse_array_access_expressions)