Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Commits on Source (1)
......@@ -506,7 +506,6 @@ class SympyAssignment(Node):
super(SympyAssignment, self).__init__(parent=None)
self._lhs_symbol = lhs_symbol
self.rhs = sp.sympify(rhs_expr)
self._is_const = is_const
self._is_declaration = self.__is_declaration()
def __is_declaration(self):
......@@ -563,10 +562,6 @@ class SympyAssignment(Node):
def is_declaration(self):
return self._is_declaration
@property
def is_const(self):
return self._is_const
def replace(self, child, replacement):
if child == self.lhs:
replacement.parent = self
......
......@@ -225,13 +225,9 @@ class CBackend:
def _print_SympyAssignment(self, node):
if node.is_declaration:
if node.is_const:
prefix = 'const '
else:
prefix = ''
data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " "
return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs),
self.sympy_printer.doprint(node.rhs))
data_type = self._print(node.lhs.dtype)
return "%s %s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs),
self.sympy_printer.doprint(node.rhs))
else:
lhs_type = get_type_of_expression(node.lhs)
if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func):
......
......@@ -63,11 +63,11 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx',
if assume_inner_stride_one:
replace_inner_stride_with_one(kernel_ast)
field_float_dtypes = set(f.dtype for f in all_fields if f.dtype.is_float())
field_float_dtypes = set(f.dtype.numpy_dtype for f in all_fields if f.dtype.is_float())
if len(field_float_dtypes) != 1:
raise NotImplementedError("Cannot vectorize kernels that contain accesses "
"to differently typed floating point fields")
float_size = field_float_dtypes.pop().numpy_dtype.itemsize
float_size = field_float_dtypes.pop().itemsize
assert float_size in (8, 4)
vector_is = get_vector_instruction_set('double' if float_size == 8 else 'float',
instruction_set=instruction_set)
......@@ -148,7 +148,7 @@ def insert_vector_casts(ast_node):
return expr
else:
target_type = collate_types(arg_types)
casted_args = [cast_func(a, target_type) if t != target_type else a
casted_args = [cast_func(a, target_type) if not t.equal_ignoring_const(target_type) else a
for a, t in zip(new_args, arg_types)]
return expr.func(*casted_args)
elif expr.func is sp.Pow:
......@@ -167,11 +167,11 @@ def insert_vector_casts(ast_node):
if type(condition_target_type) is not VectorType and type(result_target_type) is VectorType:
condition_target_type = VectorType(condition_target_type, width=result_target_type.width)
casted_results = [cast_func(a, result_target_type) if t != result_target_type else a
casted_results = [cast_func(a, result_target_type) if not t.equal_ignoring_const(result_target_type) else a
for a, t in zip(new_results, types_of_results)]
casted_conditions = [cast_func(a, condition_target_type)
if t != condition_target_type and a is not True else a
if not t.equal_ignoring_const(condition_target_type) and a is not True else a
for a, t in zip(new_conditions, types_of_conditions)]
return sp.Piecewise(*[(r, c) for r, c in zip(casted_results, casted_conditions)])
......
......@@ -453,7 +453,7 @@ def collate_types(types, forbid_collation_to_float=False):
types = tuple(t for t in types if t.is_float())
# use numpy collation -> create type from numpy type -> and, put vector type around if necessary
result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
result = BasicType(result_numpy_type)
result = BasicType(result_numpy_type, const=any(t.const for t in types))
if vector_type:
result = VectorType(result, vector_type[0].width)
return result
......@@ -618,6 +618,12 @@ class BasicType(Type):
else:
return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
def equal_ignoring_const(self, other):
if not isinstance(other, BasicType):
return False
else:
return self.numpy_dtype == other.numpy_dtype
def __hash__(self):
return hash(str(self))
......@@ -643,17 +649,23 @@ class VectorType(Type):
else:
return (self.base_type, self.width) == (other.base_type, other.width)
def equal_ignoring_const(self, other):
if not isinstance(other, VectorType):
return False
else:
return self.base_type.equal_ignoring_const(other.base_type)
def __str__(self):
if self.instruction_set is None:
return "%s[%d]" % (self.base_type, self.width)
else:
if self.base_type == create_type("int64"):
if self.base_type.numpy_dtype == np.int64:
return self.instruction_set['int']
elif self.base_type == create_type("float64"):
elif self.base_type.numpy_dtype == np.float64:
return self.instruction_set['double']
elif self.base_type == create_type("float32"):
elif self.base_type.numpy_dtype == np.float32:
return self.instruction_set['float']
elif self.base_type == create_type("bool"):
elif self.base_type.numpy_dtype == np.bool:
return self.instruction_set['bool']
else:
raise NotImplementedError()
......@@ -692,6 +704,12 @@ class PointerType(Type):
else:
return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
def equal_ignoring_const(self, other):
if not isinstance(other, PointerType):
return False
else:
return self.base_type.equal_ignoring_const(other.base_type)
def __str__(self):
components = [str(self.base_type), '*']
if self.restrict:
......@@ -743,6 +761,12 @@ class StructType:
else:
return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
def equal_ignoring_const(self, other):
if not isinstance(other, StructType):
return False
else:
return self.numpy_dtype == other.numpy_dtype
def __str__(self):
# structs are handled byte-wise
result = "uint8_t"
......
......@@ -16,7 +16,7 @@ would reference back to the field.
from sympy.core.cache import cacheit
from pystencils.data_types import (
PointerType, TypedSymbol, create_composite_type_from_string, get_base_type)
BasicType, PointerType, TypedSymbol, create_composite_type_from_string, get_base_type)
SHAPE_DTYPE = create_composite_type_from_string("const int64")
STRIDE_DTYPE = create_composite_type_from_string("const int64")
......@@ -78,7 +78,8 @@ class FieldPointerSymbol(TypedSymbol):
def __new_stage2__(cls, field_name, field_dtype, const):
name = "_data_{name}".format(name=field_name)
dtype = PointerType(get_base_type(field_dtype), const=const, restrict=True)
base_type = BasicType(get_base_type(field_dtype), const=const)
dtype = PointerType(base_type, const=True, restrict=True)
obj = super(FieldPointerSymbol, cls).__xnew__(cls, name, dtype)
obj.field_name = field_name
return obj
......
......@@ -878,7 +878,9 @@ class KernelConstraintsCheck:
assert isinstance(lhs, sp.Symbol)
self._update_accesses_lhs(lhs)
if not isinstance(lhs, (AbstractField.AbstractAccess, TypedSymbol)):
return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name])
dtype = create_type(self._type_for_symbol[lhs.name])
dtype.const = True
return TypedSymbol(lhs.name, dtype)
else:
return lhs
......