From 9d01978336436d2415d63df67816e1f6b9147b56 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Sun, 18 Aug 2019 19:52:57 +0200
Subject: [PATCH] Add assumption functions to cast_func

---
 pystencils/data_types.py         | 31 +++++++++++++++++++++++++++++++
 pystencils/math_optimizations.py |  3 +--
 2 files changed, 32 insertions(+), 2 deletions(-)

diff --git a/pystencils/data_types.py b/pystencils/data_types.py
index 3b2fe35d..3880edbc 100644
--- a/pystencils/data_types.py
+++ b/pystencils/data_types.py
@@ -84,6 +84,37 @@ class cast_func(sp.Function):
     def dtype(self):
         return self.args[1]
 
+    @property
+    def is_integer(self):
+        if hasattr(self.dtype, 'numpy_dtype'):
+            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
+        else:
+            return super().is_integer
+
+    @property
+    def is_negative(self):
+        if hasattr(self.dtype, 'numpy_dtype'):
+            if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
+                return False
+
+        return super().is_negative
+
+    @property
+    def is_nonnegative(self):
+        if self.is_negative is False:
+            return True
+        else:
+            return super().is_nonnegative
+
+    @property
+    def is_real(self):
+        if hasattr(self.dtype, 'numpy_dtype'):
+            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
+                np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
+                super().is_real
+        else:
+            return super().is_real
+
 
 # noinspection PyPep8Naming
 class boolean_cast_func(cast_func, Boolean):
diff --git a/pystencils/math_optimizations.py b/pystencils/math_optimizations.py
index aa866570..ad011478 100644
--- a/pystencils/math_optimizations.py
+++ b/pystencils/math_optimizations.py
@@ -9,7 +9,6 @@ import itertools
 
 from pystencils import Assignment
 from pystencils.astnodes import SympyAssignment
-from pystencils.integer_functions import IntegerFunctionTwoArgsMixIn
 
 try:
     from sympy.codegen.rewriting import optims_c99, optimize
@@ -38,7 +37,7 @@ def optimize_assignments(assignments, optimizations):
 
     if HAS_REWRITING:
         assignments = [Assignment(a.lhs, optimize(a.rhs, optimizations))
-                       if hasattr(a, 'lhs') and not a.rhs.atoms(IntegerFunctionTwoArgsMixIn)
+                       if hasattr(a, 'lhs')
                        else a for a in assignments]
         assignments_nodes = [a.atoms(SympyAssignment) for a in assignments]
         for a in itertools.chain.from_iterable(assignments_nodes):
-- 
GitLab