Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision

Target

Select target project
  • anirudh.jonnalagadda/pystencils
  • hyteg/pystencils
  • jbadwaik/pystencils
  • jngrad/pystencils
  • itischler/pystencils
  • ob28imeq/pystencils
  • hoenig/pystencils
  • Bindgen/pystencils
  • hammer/pystencils
  • da15siwa/pystencils
  • holzer/pystencils
  • alexander.reinauer/pystencils
  • ec93ujoh/pystencils
  • Harke/pystencils
  • seitz/pystencils
  • pycodegen/pystencils
16 results
Select Git revision
Show changes
Commits on Source (1)
...@@ -506,7 +506,6 @@ class SympyAssignment(Node): ...@@ -506,7 +506,6 @@ class SympyAssignment(Node):
super(SympyAssignment, self).__init__(parent=None) super(SympyAssignment, self).__init__(parent=None)
self._lhs_symbol = lhs_symbol self._lhs_symbol = lhs_symbol
self.rhs = sp.sympify(rhs_expr) self.rhs = sp.sympify(rhs_expr)
self._is_const = is_const
self._is_declaration = self.__is_declaration() self._is_declaration = self.__is_declaration()
def __is_declaration(self): def __is_declaration(self):
...@@ -563,10 +562,6 @@ class SympyAssignment(Node): ...@@ -563,10 +562,6 @@ class SympyAssignment(Node):
def is_declaration(self): def is_declaration(self):
return self._is_declaration return self._is_declaration
@property
def is_const(self):
return self._is_const
def replace(self, child, replacement): def replace(self, child, replacement):
if child == self.lhs: if child == self.lhs:
replacement.parent = self replacement.parent = self
......
...@@ -225,13 +225,9 @@ class CBackend: ...@@ -225,13 +225,9 @@ class CBackend:
def _print_SympyAssignment(self, node): def _print_SympyAssignment(self, node):
if node.is_declaration: if node.is_declaration:
if node.is_const: data_type = self._print(node.lhs.dtype)
prefix = 'const ' return "%s %s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs),
else: self.sympy_printer.doprint(node.rhs))
prefix = ''
data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " "
return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs),
self.sympy_printer.doprint(node.rhs))
else: else:
lhs_type = get_type_of_expression(node.lhs) lhs_type = get_type_of_expression(node.lhs)
if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func): if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func):
......
...@@ -63,11 +63,11 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx', ...@@ -63,11 +63,11 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx',
if assume_inner_stride_one: if assume_inner_stride_one:
replace_inner_stride_with_one(kernel_ast) replace_inner_stride_with_one(kernel_ast)
field_float_dtypes = set(f.dtype for f in all_fields if f.dtype.is_float()) field_float_dtypes = set(f.dtype.numpy_dtype for f in all_fields if f.dtype.is_float())
if len(field_float_dtypes) != 1: if len(field_float_dtypes) != 1:
raise NotImplementedError("Cannot vectorize kernels that contain accesses " raise NotImplementedError("Cannot vectorize kernels that contain accesses "
"to differently typed floating point fields") "to differently typed floating point fields")
float_size = field_float_dtypes.pop().numpy_dtype.itemsize float_size = field_float_dtypes.pop().itemsize
assert float_size in (8, 4) assert float_size in (8, 4)
vector_is = get_vector_instruction_set('double' if float_size == 8 else 'float', vector_is = get_vector_instruction_set('double' if float_size == 8 else 'float',
instruction_set=instruction_set) instruction_set=instruction_set)
...@@ -148,7 +148,7 @@ def insert_vector_casts(ast_node): ...@@ -148,7 +148,7 @@ def insert_vector_casts(ast_node):
return expr return expr
else: else:
target_type = collate_types(arg_types) target_type = collate_types(arg_types)
casted_args = [cast_func(a, target_type) if t != target_type else a casted_args = [cast_func(a, target_type) if not t.equal_ignoring_const(target_type) else a
for a, t in zip(new_args, arg_types)] for a, t in zip(new_args, arg_types)]
return expr.func(*casted_args) return expr.func(*casted_args)
elif expr.func is sp.Pow: elif expr.func is sp.Pow:
...@@ -167,11 +167,11 @@ def insert_vector_casts(ast_node): ...@@ -167,11 +167,11 @@ def insert_vector_casts(ast_node):
if type(condition_target_type) is not VectorType and type(result_target_type) is VectorType: if type(condition_target_type) is not VectorType and type(result_target_type) is VectorType:
condition_target_type = VectorType(condition_target_type, width=result_target_type.width) condition_target_type = VectorType(condition_target_type, width=result_target_type.width)
casted_results = [cast_func(a, result_target_type) if t != result_target_type else a casted_results = [cast_func(a, result_target_type) if not t.equal_ignoring_const(result_target_type) else a
for a, t in zip(new_results, types_of_results)] for a, t in zip(new_results, types_of_results)]
casted_conditions = [cast_func(a, condition_target_type) casted_conditions = [cast_func(a, condition_target_type)
if t != condition_target_type and a is not True else a if not t.equal_ignoring_const(condition_target_type) and a is not True else a
for a, t in zip(new_conditions, types_of_conditions)] for a, t in zip(new_conditions, types_of_conditions)]
return sp.Piecewise(*[(r, c) for r, c in zip(casted_results, casted_conditions)]) return sp.Piecewise(*[(r, c) for r, c in zip(casted_results, casted_conditions)])
......
...@@ -453,7 +453,7 @@ def collate_types(types, forbid_collation_to_float=False): ...@@ -453,7 +453,7 @@ def collate_types(types, forbid_collation_to_float=False):
types = tuple(t for t in types if t.is_float()) types = tuple(t for t in types if t.is_float())
# use numpy collation -> create type from numpy type -> and, put vector type around if necessary # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
result_numpy_type = np.result_type(*(t.numpy_dtype for t in types)) result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
result = BasicType(result_numpy_type) result = BasicType(result_numpy_type, const=any(t.const for t in types))
if vector_type: if vector_type:
result = VectorType(result, vector_type[0].width) result = VectorType(result, vector_type[0].width)
return result return result
...@@ -618,6 +618,12 @@ class BasicType(Type): ...@@ -618,6 +618,12 @@ class BasicType(Type):
else: else:
return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const) return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
def equal_ignoring_const(self, other):
if not isinstance(other, BasicType):
return False
else:
return self.numpy_dtype == other.numpy_dtype
def __hash__(self): def __hash__(self):
return hash(str(self)) return hash(str(self))
...@@ -643,17 +649,23 @@ class VectorType(Type): ...@@ -643,17 +649,23 @@ class VectorType(Type):
else: else:
return (self.base_type, self.width) == (other.base_type, other.width) return (self.base_type, self.width) == (other.base_type, other.width)
def equal_ignoring_const(self, other):
if not isinstance(other, VectorType):
return False
else:
return self.base_type.equal_ignoring_const(other.base_type)
def __str__(self): def __str__(self):
if self.instruction_set is None: if self.instruction_set is None:
return "%s[%d]" % (self.base_type, self.width) return "%s[%d]" % (self.base_type, self.width)
else: else:
if self.base_type == create_type("int64"): if self.base_type.numpy_dtype == np.int64:
return self.instruction_set['int'] return self.instruction_set['int']
elif self.base_type == create_type("float64"): elif self.base_type.numpy_dtype == np.float64:
return self.instruction_set['double'] return self.instruction_set['double']
elif self.base_type == create_type("float32"): elif self.base_type.numpy_dtype == np.float32:
return self.instruction_set['float'] return self.instruction_set['float']
elif self.base_type == create_type("bool"): elif self.base_type.numpy_dtype == np.bool:
return self.instruction_set['bool'] return self.instruction_set['bool']
else: else:
raise NotImplementedError() raise NotImplementedError()
...@@ -692,6 +704,12 @@ class PointerType(Type): ...@@ -692,6 +704,12 @@ class PointerType(Type):
else: else:
return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict) return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
def equal_ignoring_const(self, other):
if not isinstance(other, PointerType):
return False
else:
return self.base_type.equal_ignoring_const(other.base_type)
def __str__(self): def __str__(self):
components = [str(self.base_type), '*'] components = [str(self.base_type), '*']
if self.restrict: if self.restrict:
...@@ -743,6 +761,12 @@ class StructType: ...@@ -743,6 +761,12 @@ class StructType:
else: else:
return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const) return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
def equal_ignoring_const(self, other):
if not isinstance(other, StructType):
return False
else:
return self.numpy_dtype == other.numpy_dtype
def __str__(self): def __str__(self):
# structs are handled byte-wise # structs are handled byte-wise
result = "uint8_t" result = "uint8_t"
......
...@@ -16,7 +16,7 @@ would reference back to the field. ...@@ -16,7 +16,7 @@ would reference back to the field.
from sympy.core.cache import cacheit from sympy.core.cache import cacheit
from pystencils.data_types import ( from pystencils.data_types import (
PointerType, TypedSymbol, create_composite_type_from_string, get_base_type) BasicType, PointerType, TypedSymbol, create_composite_type_from_string, get_base_type)
SHAPE_DTYPE = create_composite_type_from_string("const int64") SHAPE_DTYPE = create_composite_type_from_string("const int64")
STRIDE_DTYPE = create_composite_type_from_string("const int64") STRIDE_DTYPE = create_composite_type_from_string("const int64")
...@@ -78,7 +78,8 @@ class FieldPointerSymbol(TypedSymbol): ...@@ -78,7 +78,8 @@ class FieldPointerSymbol(TypedSymbol):
def __new_stage2__(cls, field_name, field_dtype, const): def __new_stage2__(cls, field_name, field_dtype, const):
name = "_data_{name}".format(name=field_name) name = "_data_{name}".format(name=field_name)
dtype = PointerType(get_base_type(field_dtype), const=const, restrict=True) base_type = BasicType(get_base_type(field_dtype), const=const)
dtype = PointerType(base_type, const=True, restrict=True)
obj = super(FieldPointerSymbol, cls).__xnew__(cls, name, dtype) obj = super(FieldPointerSymbol, cls).__xnew__(cls, name, dtype)
obj.field_name = field_name obj.field_name = field_name
return obj return obj
......
...@@ -878,7 +878,9 @@ class KernelConstraintsCheck: ...@@ -878,7 +878,9 @@ class KernelConstraintsCheck:
assert isinstance(lhs, sp.Symbol) assert isinstance(lhs, sp.Symbol)
self._update_accesses_lhs(lhs) self._update_accesses_lhs(lhs)
if not isinstance(lhs, (AbstractField.AbstractAccess, TypedSymbol)): if not isinstance(lhs, (AbstractField.AbstractAccess, TypedSymbol)):
return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name]) dtype = create_type(self._type_for_symbol[lhs.name])
dtype.const = True
return TypedSymbol(lhs.name, dtype)
else: else:
return lhs return lhs
......