Skip to content
Snippets Groups Projects

WIP: Revamp the type system

Compare and
47 files
+ 1383
546
Compare changes
  • Side-by-side
  • Inline

Files

+ 18
18
@@ -11,9 +11,9 @@ from sympy.logic.boolalg import BooleanFalse, BooleanTrue
from pystencils.astnodes import KernelFunction, LoopOverCoordinate, Node
from pystencils.cpu.vectorization import vec_all, vec_any, CachelineSize
from pystencils.data_types import (
PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression,
reinterpret_cast_func, vector_memory_access, BasicType, TypedSymbol)
from pystencils.typing import (
PointerType, VectorType, address_of, CastFunc, create_type, get_type_of_expression,
ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol)
from pystencils.enums import Backend
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.integer_functions import (
@@ -276,7 +276,7 @@ class CBackend:
else:
lhs_type = get_type_of_expression(node.lhs)
printed_mask = ""
if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func):
if type(lhs_type) is VectorType and isinstance(node.lhs, CastFunc):
arg, data_type, aligned, nontemporal, mask, stride = node.lhs.args
instr = 'storeU'
if aligned:
@@ -289,12 +289,12 @@ class CBackend:
self._vector_instruction_set['load' + instr[-1]].format('{0}', **self._kwargs),
'{1}', '{2}', **self._kwargs), **self._kwargs)
printed_mask = self.sympy_printer.doprint(mask)
if data_type.base_type.base_name == 'double':
if data_type.base_type.c_name == 'double':
if self._vector_instruction_set['double'] == '__m256d':
printed_mask = f"_mm256_castpd_si256({printed_mask})"
elif self._vector_instruction_set['double'] == '__m128d':
printed_mask = f"_mm_castpd_si128({printed_mask})"
elif data_type.base_type.base_name == 'float':
elif data_type.base_type.c_name == 'float':
if self._vector_instruction_set['float'] == '__m256':
printed_mask = f"_mm256_castps_si256({printed_mask})"
elif self._vector_instruction_set['float'] == '__m128':
@@ -302,7 +302,7 @@ class CBackend:
rhs_type = get_type_of_expression(node.rhs)
if type(rhs_type) is not VectorType:
rhs = cast_func(node.rhs, VectorType(rhs_type))
rhs = CastFunc(node.rhs, VectorType(rhs_type))
else:
rhs = node.rhs
@@ -322,7 +322,7 @@ class CBackend:
if stride == 1:
offset = offset.subs({node.lhs.args[0].field.spatial_strides[0]: 1})
size = sp.Mul(*node.lhs.args[0].field.spatial_shape)
element_size = 8 if data_type.base_type.base_name == 'double' else 4
element_size = 8 if data_type.base_type.c_name == 'double' else 4
size_cond = f"({offset} + {CachelineSize.symbol/element_size}) < {size}"
pre_code = f"if ({first_cond} && {size_cond}) " + "{\n\t" + \
self._vector_instruction_set['cachelineZero'].format(ptr, **self._kwargs) + ';\n}\n'
@@ -483,13 +483,13 @@ class CustomSympyPrinter(CCodePrinter):
}
if hasattr(expr, 'to_c'):
return expr.to_c(self._print)
if isinstance(expr, reinterpret_cast_func):
if isinstance(expr, ReinterpretCastFunc):
arg, data_type = expr.args
return f"*(({self._print(PointerType(data_type, restrict=False))})(& {self._print(arg)}))"
elif isinstance(expr, address_of):
assert len(expr.args) == 1, "address_of must only have one argument"
return f"&({self._print(expr.args[0])})"
elif isinstance(expr, cast_func):
elif isinstance(expr, CastFunc):
arg, data_type = expr.args
if isinstance(arg, sp.Number) and arg.is_finite:
return self._typed_number(arg, data_type)
@@ -648,22 +648,22 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
return None
def _print_Abs(self, expr):
if 'abs' in self.instruction_set and isinstance(expr.args[0], vector_memory_access):
if 'abs' in self.instruction_set and isinstance(expr.args[0], VectorMemoryAccess):
return self.instruction_set['abs'].format(self._print(expr.args[0]), **self._kwargs)
return super()._print_Abs(expr)
def _print_Function(self, expr):
if isinstance(expr, vector_memory_access):
if isinstance(expr, VectorMemoryAccess):
arg, data_type, aligned, _, mask, stride = expr.args
if stride != 1:
return self.instruction_set['loadS'].format(f"& {self._print(arg)}", stride, **self._kwargs)
instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
return instruction.format(f"& {self._print(arg)}", **self._kwargs)
elif isinstance(expr, cast_func):
elif isinstance(expr, CastFunc):
arg, data_type = expr.args
if type(data_type) is VectorType:
# vector_memory_access is a cast_func itself so it should't be directly inside a cast_func
assert not isinstance(arg, vector_memory_access)
assert not isinstance(arg, VectorMemoryAccess)
if isinstance(arg, sp.Tuple):
is_boolean = get_type_of_expression(arg[0]) == create_type("bool")
is_integer = get_type_of_expression(arg[0]) == create_type("int")
@@ -747,12 +747,12 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
# special treatment for all-integer args, for loop index arithmetic until we have proper int vectorization
suffix = ""
if all([(type(e) is cast_func and str(e.dtype) == self.instruction_set['int']) or isinstance(e, sp.Integer)
if all([(type(e) is CastFunc and str(e.dtype) == self.instruction_set['int']) or isinstance(e, sp.Integer)
or (type(e) is TypedSymbol and isinstance(e.dtype, BasicType) and e.dtype.is_int()) for e in args]):
dtype = set([e.dtype for e in args if type(e) is cast_func])
dtype = set([e.dtype for e in args if type(e) is CastFunc])
assert len(dtype) == 1
dtype = dtype.pop()
args = [cast_func(e, dtype) if (isinstance(e, sp.Integer) or isinstance(e, TypedSymbol)) else e
args = [CastFunc(e, dtype) if (isinstance(e, sp.Integer) or isinstance(e, TypedSymbol)) else e
for e in args]
suffix = "int"
@@ -880,7 +880,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result = self._print(expr.args[-1][0])
for true_expr, condition in reversed(expr.args[:-1]):
if isinstance(condition, cast_func) and get_type_of_expression(condition.args[0]) == create_type("bool"):
if isinstance(condition, CastFunc) and get_type_of_expression(condition.args[0]) == create_type("bool"):
if not KERNCRAFT_NO_TERNARY_MODE:
result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr),
result, **self._kwargs)
Loading