Skip to content
Snippets Groups Projects

WIP: Revamp the type system

Closed Markus Holzer requested to merge holzer/pystencils:TypeSystem into master
Compare and
60 files
+ 1728
1562
Compare changes
  • Side-by-side
  • Inline
Files
60
@@ -11,9 +11,9 @@ from sympy.logic.boolalg import BooleanFalse, BooleanTrue
@@ -11,9 +11,9 @@ from sympy.logic.boolalg import BooleanFalse, BooleanTrue
from pystencils.astnodes import KernelFunction, LoopOverCoordinate, Node
from pystencils.astnodes import KernelFunction, LoopOverCoordinate, Node
from pystencils.cpu.vectorization import vec_all, vec_any, CachelineSize
from pystencils.cpu.vectorization import vec_all, vec_any, CachelineSize
from pystencils.data_types import (
from pystencils.typing import (
PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression,
PointerType, VectorType, address_of, CastFunc, create_type, get_type_of_expression,
reinterpret_cast_func, vector_memory_access, BasicType, TypedSymbol)
ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol)
from pystencils.enums import Backend
from pystencils.enums import Backend
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.integer_functions import (
from pystencils.integer_functions import (
@@ -219,7 +219,7 @@ class CBackend:
@@ -219,7 +219,7 @@ class CBackend:
return getattr(self, method_name)(node)
return getattr(self, method_name)(node)
raise NotImplementedError(f"{self.__class__.__name__} does not support node of type {node.__class__.__name__}")
raise NotImplementedError(f"{self.__class__.__name__} does not support node of type {node.__class__.__name__}")
def _print_Type(self, node):
def _print_AbstractType(self, node):
return str(node)
return str(node)
def _print_KernelFunction(self, node):
def _print_KernelFunction(self, node):
@@ -276,7 +276,7 @@ class CBackend:
@@ -276,7 +276,7 @@ class CBackend:
else:
else:
lhs_type = get_type_of_expression(node.lhs)
lhs_type = get_type_of_expression(node.lhs)
printed_mask = ""
printed_mask = ""
if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func):
if type(lhs_type) is VectorType and isinstance(node.lhs, CastFunc):
arg, data_type, aligned, nontemporal, mask, stride = node.lhs.args
arg, data_type, aligned, nontemporal, mask, stride = node.lhs.args
instr = 'storeU'
instr = 'storeU'
if aligned:
if aligned:
@@ -289,12 +289,12 @@ class CBackend:
@@ -289,12 +289,12 @@ class CBackend:
self._vector_instruction_set['load' + instr[-1]].format('{0}', **self._kwargs),
self._vector_instruction_set['load' + instr[-1]].format('{0}', **self._kwargs),
'{1}', '{2}', **self._kwargs), **self._kwargs)
'{1}', '{2}', **self._kwargs), **self._kwargs)
printed_mask = self.sympy_printer.doprint(mask)
printed_mask = self.sympy_printer.doprint(mask)
if data_type.base_type.base_name == 'double':
if data_type.base_type.c_name == 'double':
if self._vector_instruction_set['double'] == '__m256d':
if self._vector_instruction_set['double'] == '__m256d':
printed_mask = f"_mm256_castpd_si256({printed_mask})"
printed_mask = f"_mm256_castpd_si256({printed_mask})"
elif self._vector_instruction_set['double'] == '__m128d':
elif self._vector_instruction_set['double'] == '__m128d':
printed_mask = f"_mm_castpd_si128({printed_mask})"
printed_mask = f"_mm_castpd_si128({printed_mask})"
elif data_type.base_type.base_name == 'float':
elif data_type.base_type.c_name == 'float':
if self._vector_instruction_set['float'] == '__m256':
if self._vector_instruction_set['float'] == '__m256':
printed_mask = f"_mm256_castps_si256({printed_mask})"
printed_mask = f"_mm256_castps_si256({printed_mask})"
elif self._vector_instruction_set['float'] == '__m128':
elif self._vector_instruction_set['float'] == '__m128':
@@ -302,7 +302,7 @@ class CBackend:
@@ -302,7 +302,7 @@ class CBackend:
rhs_type = get_type_of_expression(node.rhs)
rhs_type = get_type_of_expression(node.rhs)
if type(rhs_type) is not VectorType:
if type(rhs_type) is not VectorType:
rhs = cast_func(node.rhs, VectorType(rhs_type))
rhs = CastFunc(node.rhs, VectorType(rhs_type))
else:
else:
rhs = node.rhs
rhs = node.rhs
@@ -322,7 +322,7 @@ class CBackend:
@@ -322,7 +322,7 @@ class CBackend:
if stride == 1:
if stride == 1:
offset = offset.subs({node.lhs.args[0].field.spatial_strides[0]: 1})
offset = offset.subs({node.lhs.args[0].field.spatial_strides[0]: 1})
size = sp.Mul(*node.lhs.args[0].field.spatial_shape)
size = sp.Mul(*node.lhs.args[0].field.spatial_shape)
element_size = 8 if data_type.base_type.base_name == 'double' else 4
element_size = 8 if data_type.base_type.c_name == 'double' else 4
size_cond = f"({offset} + {CachelineSize.symbol/element_size}) < {size}"
size_cond = f"({offset} + {CachelineSize.symbol/element_size}) < {size}"
pre_code = f"if ({first_cond} && {size_cond}) " + "{\n\t" + \
pre_code = f"if ({first_cond} && {size_cond}) " + "{\n\t" + \
self._vector_instruction_set['cachelineZero'].format(ptr, **self._kwargs) + ';\n}\n'
self._vector_instruction_set['cachelineZero'].format(ptr, **self._kwargs) + ';\n}\n'
@@ -483,13 +483,13 @@ class CustomSympyPrinter(CCodePrinter):
@@ -483,13 +483,13 @@ class CustomSympyPrinter(CCodePrinter):
}
}
if hasattr(expr, 'to_c'):
if hasattr(expr, 'to_c'):
return expr.to_c(self._print)
return expr.to_c(self._print)
if isinstance(expr, reinterpret_cast_func):
if isinstance(expr, ReinterpretCastFunc):
arg, data_type = expr.args
arg, data_type = expr.args
return f"*(({self._print(PointerType(data_type, restrict=False))})(& {self._print(arg)}))"
return f"*(({self._print(PointerType(data_type, restrict=False))})(& {self._print(arg)}))"
elif isinstance(expr, address_of):
elif isinstance(expr, address_of):
assert len(expr.args) == 1, "address_of must only have one argument"
assert len(expr.args) == 1, "address_of must only have one argument"
return f"&({self._print(expr.args[0])})"
return f"&({self._print(expr.args[0])})"
elif isinstance(expr, cast_func):
elif isinstance(expr, CastFunc):
arg, data_type = expr.args
arg, data_type = expr.args
if isinstance(arg, sp.Number) and arg.is_finite:
if isinstance(arg, sp.Number) and arg.is_finite:
return self._typed_number(arg, data_type)
return self._typed_number(arg, data_type)
@@ -648,22 +648,22 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -648,22 +648,22 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
return None
return None
def _print_Abs(self, expr):
def _print_Abs(self, expr):
if 'abs' in self.instruction_set and isinstance(expr.args[0], vector_memory_access):
if 'abs' in self.instruction_set and isinstance(expr.args[0], VectorMemoryAccess):
return self.instruction_set['abs'].format(self._print(expr.args[0]), **self._kwargs)
return self.instruction_set['abs'].format(self._print(expr.args[0]), **self._kwargs)
return super()._print_Abs(expr)
return super()._print_Abs(expr)
def _print_Function(self, expr):
def _print_Function(self, expr):
if isinstance(expr, vector_memory_access):
if isinstance(expr, VectorMemoryAccess):
arg, data_type, aligned, _, mask, stride = expr.args
arg, data_type, aligned, _, mask, stride = expr.args
if stride != 1:
if stride != 1:
return self.instruction_set['loadS'].format(f"& {self._print(arg)}", stride, **self._kwargs)
return self.instruction_set['loadS'].format(f"& {self._print(arg)}", stride, **self._kwargs)
instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
return instruction.format(f"& {self._print(arg)}", **self._kwargs)
return instruction.format(f"& {self._print(arg)}", **self._kwargs)
elif isinstance(expr, cast_func):
elif isinstance(expr, CastFunc):
arg, data_type = expr.args
arg, data_type = expr.args
if type(data_type) is VectorType:
if type(data_type) is VectorType:
# vector_memory_access is a cast_func itself so it should't be directly inside a cast_func
# vector_memory_access is a cast_func itself so it should't be directly inside a cast_func
assert not isinstance(arg, vector_memory_access)
assert not isinstance(arg, VectorMemoryAccess)
if isinstance(arg, sp.Tuple):
if isinstance(arg, sp.Tuple):
is_boolean = get_type_of_expression(arg[0]) == create_type("bool")
is_boolean = get_type_of_expression(arg[0]) == create_type("bool")
is_integer = get_type_of_expression(arg[0]) == create_type("int")
is_integer = get_type_of_expression(arg[0]) == create_type("int")
@@ -747,12 +747,12 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -747,12 +747,12 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
# special treatment for all-integer args, for loop index arithmetic until we have proper int vectorization
# special treatment for all-integer args, for loop index arithmetic until we have proper int vectorization
suffix = ""
suffix = ""
if all([(type(e) is cast_func and str(e.dtype) == self.instruction_set['int']) or isinstance(e, sp.Integer)
if all([(type(e) is CastFunc and str(e.dtype) == self.instruction_set['int']) or isinstance(e, sp.Integer)
or (type(e) is TypedSymbol and isinstance(e.dtype, BasicType) and e.dtype.is_int()) for e in args]):
or (type(e) is TypedSymbol and isinstance(e.dtype, BasicType) and e.dtype.is_int()) for e in args]):
dtype = set([e.dtype for e in args if type(e) is cast_func])
dtype = set([e.dtype for e in args if type(e) is CastFunc])
assert len(dtype) == 1
assert len(dtype) == 1
dtype = dtype.pop()
dtype = dtype.pop()
args = [cast_func(e, dtype) if (isinstance(e, sp.Integer) or isinstance(e, TypedSymbol)) else e
args = [CastFunc(e, dtype) if (isinstance(e, sp.Integer) or isinstance(e, TypedSymbol)) else e
for e in args]
for e in args]
suffix = "int"
suffix = "int"
@@ -880,7 +880,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -880,7 +880,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result = self._print(expr.args[-1][0])
result = self._print(expr.args[-1][0])
for true_expr, condition in reversed(expr.args[:-1]):
for true_expr, condition in reversed(expr.args[:-1]):
if isinstance(condition, cast_func) and get_type_of_expression(condition.args[0]) == create_type("bool"):
if isinstance(condition, CastFunc) and get_type_of_expression(condition.args[0]) == create_type("bool"):
if not KERNCRAFT_NO_TERNARY_MODE:
if not KERNCRAFT_NO_TERNARY_MODE:
result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr),
result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr),
result, **self._kwargs)
result, **self._kwargs)
Loading