Skip to content
Snippets Groups Projects
Commit 65b2fd79 authored by Markus Holzer's avatar Markus Holzer
Browse files

Merge branch 'vectorization' into 'master'

fix some constant types for vectorization

See merge request !65
parents a1076133 f9327edc
No related branches found
No related tags found
1 merge request!65fix some constant types for vectorization
Pipeline #30420 passed
...@@ -142,7 +142,7 @@ class PdfsToCentralMomentsByMatrix(AbstractMomentTransform): ...@@ -142,7 +142,7 @@ class PdfsToCentralMomentsByMatrix(AbstractMomentTransform):
central_moments = self.forward_matrix * f_vec central_moments = self.forward_matrix * f_vec
main_assignments = [Assignment(sq_sym(moment_symbol_base, e), eq) main_assignments = [Assignment(sq_sym(moment_symbol_base, e), eq)
for e, eq in zip(self.moment_exponents, central_moments)] for e, eq in zip(self.moment_exponents, central_moments)]
symbol_gen = SymbolGen(subexpression_base) symbol_gen = SymbolGen(subexpression_base, dtype=float)
ac = AssignmentCollection(main_assignments, subexpression_symbol_generator=symbol_gen) ac = AssignmentCollection(main_assignments, subexpression_symbol_generator=symbol_gen)
if simplification: if simplification:
...@@ -157,7 +157,7 @@ class PdfsToCentralMomentsByMatrix(AbstractMomentTransform): ...@@ -157,7 +157,7 @@ class PdfsToCentralMomentsByMatrix(AbstractMomentTransform):
moment_vec = sp.Matrix(moments) moment_vec = sp.Matrix(moments)
pdfs_from_moments = self.backward_matrix * moment_vec pdfs_from_moments = self.backward_matrix * moment_vec
main_assignments = [Assignment(f, eq) for f, eq in zip(pdf_symbols, pdfs_from_moments)] main_assignments = [Assignment(f, eq) for f, eq in zip(pdf_symbols, pdfs_from_moments)]
symbol_gen = SymbolGen(subexpression_base) symbol_gen = SymbolGen(subexpression_base, dtype=float)
ac = AssignmentCollection(main_assignments, subexpression_symbol_generator=symbol_gen) ac = AssignmentCollection(main_assignments, subexpression_symbol_generator=symbol_gen)
if simplification: if simplification:
...@@ -228,7 +228,7 @@ class FastCentralMomentTransform(AbstractMomentTransform): ...@@ -228,7 +228,7 @@ class FastCentralMomentTransform(AbstractMomentTransform):
collect_partial_sums(e) collect_partial_sums(e)
subexpressions = [Assignment(lhs, rhs) for lhs, rhs in subexpressions_dict.items()] subexpressions = [Assignment(lhs, rhs) for lhs, rhs in subexpressions_dict.items()]
symbol_gen = SymbolGen(subexpression_base) symbol_gen = SymbolGen(subexpression_base, dtype=float)
ac = AssignmentCollection(main_assignments, subexpressions=subexpressions, ac = AssignmentCollection(main_assignments, subexpressions=subexpressions,
subexpression_symbol_generator=symbol_gen) subexpression_symbol_generator=symbol_gen)
if simplification: if simplification:
...@@ -244,7 +244,7 @@ class FastCentralMomentTransform(AbstractMomentTransform): ...@@ -244,7 +244,7 @@ class FastCentralMomentTransform(AbstractMomentTransform):
pdf_symbols, moment_symbol_base=POST_COLLISION_CENTRAL_MOMENT, simplification=False) pdf_symbols, moment_symbol_base=POST_COLLISION_CENTRAL_MOMENT, simplification=False)
raw_equations = raw_equations.new_without_subexpressions() raw_equations = raw_equations.new_without_subexpressions()
symbol_gen = SymbolGen(subexpression_base) symbol_gen = SymbolGen(subexpression_base, dtype=float)
ac = self._split_backward_equations(raw_equations, symbol_gen) ac = self._split_backward_equations(raw_equations, symbol_gen)
if simplification: if simplification:
...@@ -441,7 +441,7 @@ class PdfsToRawMomentsTransform(AbstractMomentTransform): ...@@ -441,7 +441,7 @@ class PdfsToRawMomentsTransform(AbstractMomentTransform):
collect_partial_sums(e) collect_partial_sums(e)
subexpressions += [Assignment(lhs, rhs) for lhs, rhs in partial_sums_dict.items()] subexpressions += [Assignment(lhs, rhs) for lhs, rhs in partial_sums_dict.items()]
symbol_gen = SymbolGen(subexpression_base) symbol_gen = SymbolGen(subexpression_base, dtype=float)
ac = AssignmentCollection(main_assignments, subexpressions=subexpressions, ac = AssignmentCollection(main_assignments, subexpressions=subexpressions,
subexpression_symbol_generator=symbol_gen) subexpression_symbol_generator=symbol_gen)
ac.add_simplification_hint('cq_symbols_to_moments', self.get_cq_to_moment_symbols_dict(moment_symbol_base)) ac.add_simplification_hint('cq_symbols_to_moments', self.get_cq_to_moment_symbols_dict(moment_symbol_base))
...@@ -457,7 +457,7 @@ class PdfsToRawMomentsTransform(AbstractMomentTransform): ...@@ -457,7 +457,7 @@ class PdfsToRawMomentsTransform(AbstractMomentTransform):
post_collision_moments = [sq_sym(moment_symbol_base, e) for e in self.moment_exponents] post_collision_moments = [sq_sym(moment_symbol_base, e) for e in self.moment_exponents]
rm_to_f_vec = self.inv_moment_matrix * sp.Matrix(post_collision_moments) rm_to_f_vec = self.inv_moment_matrix * sp.Matrix(post_collision_moments)
main_assignments = [Assignment(f, eq) for f, eq in zip(pdf_symbols, rm_to_f_vec)] main_assignments = [Assignment(f, eq) for f, eq in zip(pdf_symbols, rm_to_f_vec)]
symbol_gen = SymbolGen(subexpression_base) symbol_gen = SymbolGen(subexpression_base, dtype=float)
ac = AssignmentCollection(main_assignments, subexpression_symbol_generator=symbol_gen) ac = AssignmentCollection(main_assignments, subexpression_symbol_generator=symbol_gen)
ac.add_simplification_hint('stencil', self.stencil) ac.add_simplification_hint('stencil', self.stencil)
...@@ -547,7 +547,7 @@ class PdfsToCentralMomentsByShiftMatrix(AbstractMomentTransform): ...@@ -547,7 +547,7 @@ class PdfsToCentralMomentsByShiftMatrix(AbstractMomentTransform):
rm_to_cm_dict = self._undo_remaining_cq_subexpressions(rm_to_cm_dict, cq_subs) rm_to_cm_dict = self._undo_remaining_cq_subexpressions(rm_to_cm_dict, cq_subs)
subexpressions = rm_ac.all_assignments subexpressions = rm_ac.all_assignments
symbol_gen = SymbolGen(subexpression_base) symbol_gen = SymbolGen(subexpression_base, dtype=float)
ac = AssignmentCollection(rm_to_cm_dict, subexpressions=subexpressions, ac = AssignmentCollection(rm_to_cm_dict, subexpressions=subexpressions,
subexpression_symbol_generator=symbol_gen) subexpression_symbol_generator=symbol_gen)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment