Skip to content
Snippets Groups Projects
Commit 61800b73 authored by Martin Bauer's avatar Martin Bauer
Browse files

pystencils type system: distinction between static and reinterpret cast

parent 2d925329
Branches
Tags
No related merge requests found
...@@ -12,7 +12,7 @@ from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift ...@@ -12,7 +12,7 @@ from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift
bitwise_or, modulo_ceil bitwise_or, modulo_ceil
from pystencils.astnodes import Node, KernelFunction from pystencils.astnodes import Node, KernelFunction
from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func, \ from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func, \
vector_memory_access vector_memory_access, reinterpret_cast_func
__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter'] __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
...@@ -251,7 +251,10 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -251,7 +251,10 @@ class CustomSympyPrinter(CCodePrinter):
} }
if hasattr(expr, 'to_c'): if hasattr(expr, 'to_c'):
return expr.to_c(self._print) return expr.to_c(self._print)
if isinstance(expr, cast_func): if isinstance(expr, reinterpret_cast_func):
arg, data_type = expr.args
return "*((%s)(& %s))" % (PointerType(data_type, restrict=False), self._print(arg))
elif isinstance(expr, cast_func):
arg, data_type = expr.args arg, data_type = expr.args
if isinstance(arg, sp.Number): if isinstance(arg, sp.Number):
return self._typed_number(arg, data_type) return self._typed_number(arg, data_type)
......
...@@ -56,6 +56,11 @@ class vector_memory_access(cast_func): ...@@ -56,6 +56,11 @@ class vector_memory_access(cast_func):
nargs = (4,) nargs = (4,)
# noinspection PyPep8Naming
class reinterpret_cast_func(cast_func):
pass
# noinspection PyPep8Naming # noinspection PyPep8Naming
class pointer_arithmetic_func(sp.Function, Boolean): class pointer_arithmetic_func(sp.Function, Boolean):
@property @property
......
...@@ -10,8 +10,8 @@ from sympy.tensor import IndexedBase ...@@ -10,8 +10,8 @@ from sympy.tensor import IndexedBase
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.assignment_collection.nestedscopes import NestedScopes from pystencils.assignment_collection.nestedscopes import NestedScopes
from pystencils.field import Field, FieldType from pystencils.field import Field, FieldType
from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, cast_func, \ from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, reinterpret_cast_func, \
pointer_arithmetic_func, get_type_of_expression, collate_types, create_type cast_func, pointer_arithmetic_func, get_type_of_expression, collate_types, create_type
from pystencils.kernelparameters import FieldPointerSymbol from pystencils.kernelparameters import FieldPointerSymbol
from pystencils.slicing import normalize_slice from pystencils.slicing import normalize_slice
import pystencils.astnodes as ast import pystencils.astnodes as ast
...@@ -427,7 +427,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), ...@@ -427,7 +427,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
if isinstance(get_base_type(field_access.field.dtype), StructType): if isinstance(get_base_type(field_access.field.dtype), StructType):
new_type = field_access.field.dtype.get_element_type(field_access.index[0]) new_type = field_access.field.dtype.get_element_type(field_access.index[0])
result = cast_func(result, new_type) result = reinterpret_cast_func(result, new_type)
return visit_sympy_expr(result, enclosing_block, sympy_assignment) return visit_sympy_expr(result, enclosing_block, sympy_assignment)
else: else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment