Skip to content
Snippets Groups Projects

Add TypedMatrixSymbol (for usage of `MatrixSymbol` in kernels)

Closed Stephan Seitz requested to merge seitz/pystencils:matrix-symbols into master
Viewing commit e8db1ac3
Prev
Show latest version
1 file
+ 55
0
Preferences
Compare changes
+ 55
0
import sympy as sp
import pystencils
from pystencils.data_types import TypedMatrixSymbol, TypedSymbol, create_type
class DynamicFunction(sp.Function):
"""
Function that is passed as an argument to a kernel.
Can be printed for example as `std::function` or as a functor template.
"""
def __new__(cls, typed_function_symbol, return_dtype, *args):
obj = sp.Function.__new__(cls, typed_function_symbol, return_dtype, *args)
if hasattr(return_dtype, 'shape'):
obj.shape = return_dtype.shape
return obj
@property
def function_dtype(self):
return self.args[0].dtype
@property
def dtype(self):
return self.args[1].dtype
@property
def name(self):
return self.args[0].name
def __str__(self):
return f'{self.name}({", ".join(str(a) for a in self.args[2:])})'
def __repr__(self):
return self.__str__()
def test_dynamic_matrix_location_dependent():
x, y = pystencils.fields('x, y: float32[3d]')
A = TypedMatrixSymbol('A', 3, 1, create_type('double'), 'Vector3<double>')
my_fun_call = DynamicFunction(TypedSymbol('my_fun',
'std::function<Vector3<double>(int, int, int)>'),
A.dtype,
*pystencils.x_vector(3))
assignments = pystencils.AssignmentCollection({
A: my_fun_call,
y.center: A[0] + A[1] + A[2]
})
ast = pystencils.create_kernel(assignments)
pystencils.show_code(ast)