Skip to content
Snippets Groups Projects
Commit 9b57cf87 authored by Martin Bauer's avatar Martin Bauer
Browse files

Merge branch 'functional_coordinate_transform' into 'master'

Allow functions for Field.coordinate_transform

See merge request pycodegen/pystencils!118
parents 5d3e1948 5257b4ae
No related branches found
No related tags found
1 merge request!118Allow functions for Field.coordinate_transform
Pipeline #20910 passed
...@@ -15,7 +15,8 @@ import pystencils ...@@ -15,7 +15,8 @@ import pystencils
from pystencils.alignedarray import aligned_empty from pystencils.alignedarray import aligned_empty
from pystencils.data_types import StructType, TypedSymbol, create_type from pystencils.data_types import StructType, TypedSymbol, create_type
from pystencils.kernelparameters import FieldShapeSymbol, FieldStrideSymbol from pystencils.kernelparameters import FieldShapeSymbol, FieldStrideSymbol
from pystencils.stencil import direction_string_to_offset, offset_to_direction_string, inverse_direction from pystencils.stencil import (
direction_string_to_offset, inverse_direction, offset_to_direction_string)
from pystencils.sympyextensions import is_integer_sequence from pystencils.sympyextensions import is_integer_sequence
__all__ = ['Field', 'fields', 'FieldType', 'AbstractField'] __all__ = ['Field', 'fields', 'FieldType', 'AbstractField']
...@@ -328,10 +329,10 @@ class Field(AbstractField): ...@@ -328,10 +329,10 @@ class Field(AbstractField):
self._layout = normalize_layout(layout) self._layout = normalize_layout(layout)
self.shape = shape self.shape = shape
self.strides = strides self.strides = strides
self.latex_name = None # type: Optional[str] self.latex_name: Optional[str] = None
self.coordinate_origin = sp.Matrix(tuple( self.coordinate_origin: tuple[float, sp.Symbol] = sp.Matrix(tuple(
0 for _ in range(self.spatial_dimensions) 0 for _ in range(self.spatial_dimensions)
)) # type: tuple[float,sp.Symbol] )) # type
self.coordinate_transform = sp.eye(self.spatial_dimensions) self.coordinate_transform = sp.eye(self.spatial_dimensions)
if field_type == FieldType.STAGGERED: if field_type == FieldType.STAGGERED:
assert self.staggered_stencil assert self.staggered_stencil
...@@ -432,7 +433,7 @@ class Field(AbstractField): ...@@ -432,7 +433,7 @@ class Field(AbstractField):
return sp.Matrix([[self(i, j) for j in range(index_shape[1])] for i in range(index_shape[0])]) return sp.Matrix([[self(i, j) for j in range(index_shape[1])] for i in range(index_shape[0])])
elif len(index_shape) == 3: elif len(index_shape) == 3:
return sp.Matrix([[[self(i, j, k) for k in range(index_shape[2])] return sp.Matrix([[[self(i, j, k) for k in range(index_shape[2])]
for j in range(index_shape[1])] for i in range(index_shape[0])]) for j in range(index_shape[1])] for i in range(index_shape[0])])
else: else:
raise NotImplementedError("center_vector is not implemented for more than 3 index dimensions") raise NotImplementedError("center_vector is not implemented for more than 3 index dimensions")
...@@ -453,7 +454,7 @@ class Field(AbstractField): ...@@ -453,7 +454,7 @@ class Field(AbstractField):
return sp.Matrix([self.__getitem__(offset)(i) for i in range(self.index_shape[0])]) return sp.Matrix([self.__getitem__(offset)(i) for i in range(self.index_shape[0])])
elif self.index_dimensions == 2: elif self.index_dimensions == 2:
return sp.Matrix([[self.__getitem__(offset)(i, k) for k in range(self.index_shape[1])] return sp.Matrix([[self.__getitem__(offset)(i, k) for k in range(self.index_shape[1])]
for i in range(self.index_shape[0])]) for i in range(self.index_shape[0])])
else: else:
raise NotImplementedError("neighbor_vector is not implemented for more than 2 index dimensions") raise NotImplementedError("neighbor_vector is not implemented for more than 2 index dimensions")
...@@ -528,7 +529,7 @@ class Field(AbstractField): ...@@ -528,7 +529,7 @@ class Field(AbstractField):
prefactor = -1 prefactor = -1
if neighbor not in self.staggered_stencil: if neighbor not in self.staggered_stencil:
raise ValueError("{} is not a valid neighbor for the {} stencil".format(offset_orig, raise ValueError("{} is not a valid neighbor for the {} stencil".format(offset_orig,
self.staggered_stencil_name)) self.staggered_stencil_name))
offset = tuple(sp.Matrix(offset) - sp.Rational(1, 2) * sp.Matrix(neighbor_vec)) offset = tuple(sp.Matrix(offset) - sp.Rational(1, 2) * sp.Matrix(neighbor_vec))
...@@ -562,7 +563,7 @@ class Field(AbstractField): ...@@ -562,7 +563,7 @@ class Field(AbstractField):
return sp.Matrix([self.staggered_access(offset, i) for i in range(self.index_shape[1])]) return sp.Matrix([self.staggered_access(offset, i) for i in range(self.index_shape[1])])
elif self.index_dimensions == 3: elif self.index_dimensions == 3:
return sp.Matrix([[self.staggered_access(offset, (i, k)) for k in range(self.index_shape[2])] return sp.Matrix([[self.staggered_access(offset, (i, k)) for k in range(self.index_shape[2])]
for i in range(self.index_shape[1])]) for i in range(self.index_shape[1])])
else: else:
raise NotImplementedError("staggered_vector_access is not implemented for more than 3 index dimensions") raise NotImplementedError("staggered_vector_access is not implemented for more than 3 index dimensions")
...@@ -613,7 +614,10 @@ class Field(AbstractField): ...@@ -613,7 +614,10 @@ class Field(AbstractField):
@property @property
def physical_coordinates(self): def physical_coordinates(self):
return self.coordinate_transform @ (self.coordinate_origin + pystencils.x_vector(self.spatial_dimensions)) if hasattr(self.coordinate_transform, '__call__'):
return self.coordinate_transform(self.coordinate_origin + pystencils.x_vector(self.spatial_dimensions))
else:
return self.coordinate_transform @ (self.coordinate_origin + pystencils.x_vector(self.spatial_dimensions))
@property @property
def physical_coordinates_staggered(self): def physical_coordinates_staggered(self):
...@@ -623,10 +627,23 @@ class Field(AbstractField): ...@@ -623,10 +627,23 @@ class Field(AbstractField):
def index_to_physical(self, index_coordinates, staggered=False): def index_to_physical(self, index_coordinates, staggered=False):
if staggered: if staggered:
index_coordinates = sp.Matrix([i + 0.5 for i in index_coordinates]) index_coordinates = sp.Matrix([i + 0.5 for i in index_coordinates])
return self.coordinate_transform @ (self.coordinate_origin + index_coordinates) if hasattr(self.coordinate_transform, '__call__'):
return self.coordinate_transform(self.coordinate_origin + index_coordinates)
else:
return self.coordinate_transform @ (self.coordinate_origin + index_coordinates)
def physical_to_index(self, physical_coordinates, staggered=False): def physical_to_index(self, physical_coordinates, staggered=False):
rtn = self.coordinate_transform.inv() @ physical_coordinates - self.coordinate_origin if hasattr(self.coordinate_transform, '__call__'):
if hasattr(self.coordinate_transform, 'inv'):
return self.coordinate_transform.inv()(physical_coordinates) - self.coordinate_origin
else:
idx = sp.Matrix(sp.symbols(f'index_coordinates:{self.ndim}', real=True))
rtn = sp.solve(self.index_to_physical(idx) - physical_coordinates, idx)
assert rtn, f'Could not find inverese of coordinate_transform: {self.index_to_physical(idx)}'
return rtn
else:
rtn = self.coordinate_transform.inv() @ physical_coordinates - self.coordinate_origin
if staggered: if staggered:
rtn = sp.Matrix([i - 0.5 for i in rtn]) rtn = sp.Matrix([i - 0.5 for i in rtn])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment