Skip to content
Snippets Groups Projects
Commit 9d019783 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add assumption functions to cast_func

parent e07600cf
Branches
No related tags found
No related merge requests found
...@@ -84,6 +84,37 @@ class cast_func(sp.Function): ...@@ -84,6 +84,37 @@ class cast_func(sp.Function):
def dtype(self): def dtype(self):
return self.args[1] return self.args[1]
@property
def is_integer(self):
if hasattr(self.dtype, 'numpy_dtype'):
return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
else:
return super().is_integer
@property
def is_negative(self):
if hasattr(self.dtype, 'numpy_dtype'):
if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
return False
return super().is_negative
@property
def is_nonnegative(self):
if self.is_negative is False:
return True
else:
return super().is_nonnegative
@property
def is_real(self):
if hasattr(self.dtype, 'numpy_dtype'):
return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
super().is_real
else:
return super().is_real
# noinspection PyPep8Naming # noinspection PyPep8Naming
class boolean_cast_func(cast_func, Boolean): class boolean_cast_func(cast_func, Boolean):
......
...@@ -9,7 +9,6 @@ import itertools ...@@ -9,7 +9,6 @@ import itertools
from pystencils import Assignment from pystencils import Assignment
from pystencils.astnodes import SympyAssignment from pystencils.astnodes import SympyAssignment
from pystencils.integer_functions import IntegerFunctionTwoArgsMixIn
try: try:
from sympy.codegen.rewriting import optims_c99, optimize from sympy.codegen.rewriting import optims_c99, optimize
...@@ -38,7 +37,7 @@ def optimize_assignments(assignments, optimizations): ...@@ -38,7 +37,7 @@ def optimize_assignments(assignments, optimizations):
if HAS_REWRITING: if HAS_REWRITING:
assignments = [Assignment(a.lhs, optimize(a.rhs, optimizations)) assignments = [Assignment(a.lhs, optimize(a.rhs, optimizations))
if hasattr(a, 'lhs') and not a.rhs.atoms(IntegerFunctionTwoArgsMixIn) if hasattr(a, 'lhs')
else a for a in assignments] else a for a in assignments]
assignments_nodes = [a.atoms(SympyAssignment) for a in assignments] assignments_nodes = [a.atoms(SympyAssignment) for a in assignments]
for a in itertools.chain.from_iterable(assignments_nodes): for a in itertools.chain.from_iterable(assignments_nodes):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment