From 9d01978336436d2415d63df67816e1f6b9147b56 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Sun, 18 Aug 2019 19:52:57 +0200 Subject: [PATCH] Add assumption functions to cast_func --- pystencils/data_types.py | 31 +++++++++++++++++++++++++++++++ pystencils/math_optimizations.py | 3 +-- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/pystencils/data_types.py b/pystencils/data_types.py index 3b2fe35d..3880edbc 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -84,6 +84,37 @@ class cast_func(sp.Function): def dtype(self): return self.args[1] + @property + def is_integer(self): + if hasattr(self.dtype, 'numpy_dtype'): + return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer + else: + return super().is_integer + + @property + def is_negative(self): + if hasattr(self.dtype, 'numpy_dtype'): + if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger): + return False + + return super().is_negative + + @property + def is_nonnegative(self): + if self.is_negative is False: + return True + else: + return super().is_nonnegative + + @property + def is_real(self): + if hasattr(self.dtype, 'numpy_dtype'): + return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \ + np.issubdtype(self.dtype.numpy_dtype, np.floating) or \ + super().is_real + else: + return super().is_real + # noinspection PyPep8Naming class boolean_cast_func(cast_func, Boolean): diff --git a/pystencils/math_optimizations.py b/pystencils/math_optimizations.py index aa866570..ad011478 100644 --- a/pystencils/math_optimizations.py +++ b/pystencils/math_optimizations.py @@ -9,7 +9,6 @@ import itertools from pystencils import Assignment from pystencils.astnodes import SympyAssignment -from pystencils.integer_functions import IntegerFunctionTwoArgsMixIn try: from sympy.codegen.rewriting import optims_c99, optimize @@ -38,7 +37,7 @@ def optimize_assignments(assignments, optimizations): if HAS_REWRITING: assignments = [Assignment(a.lhs, optimize(a.rhs, optimizations)) - if hasattr(a, 'lhs') and not a.rhs.atoms(IntegerFunctionTwoArgsMixIn) + if hasattr(a, 'lhs') else a for a in assignments] assignments_nodes = [a.atoms(SympyAssignment) for a in assignments] for a in itertools.chain.from_iterable(assignments_nodes): -- GitLab