Skip to content
Snippets Groups Projects

Reimplement create_staggered_kernel

Merged Michael Kuron requested to merge staggered_kernel into master
Compare and
5 files
+ 172
15
Compare changes
  • Side-by-side
  • Inline
Files
5
+ 27
15
@@ -15,7 +15,7 @@ import pystencils
@@ -15,7 +15,7 @@ import pystencils
from pystencils.alignedarray import aligned_empty
from pystencils.alignedarray import aligned_empty
from pystencils.data_types import StructType, TypedSymbol, create_type
from pystencils.data_types import StructType, TypedSymbol, create_type
from pystencils.kernelparameters import FieldShapeSymbol, FieldStrideSymbol
from pystencils.kernelparameters import FieldShapeSymbol, FieldStrideSymbol
from pystencils.stencil import direction_string_to_offset, offset_to_direction_string
from pystencils.stencil import direction_string_to_offset, offset_to_direction_string, inverse_direction
from pystencils.sympyextensions import is_integer_sequence
from pystencils.sympyextensions import is_integer_sequence
__all__ = ['Field', 'fields', 'FieldType', 'AbstractField']
__all__ = ['Field', 'fields', 'FieldType', 'AbstractField']
@@ -34,6 +34,8 @@ class FieldType(Enum):
@@ -34,6 +34,8 @@ class FieldType(Enum):
CUSTOM = 3
CUSTOM = 3
# staggered field
# staggered field
STAGGERED = 4
STAGGERED = 4
 
# staggered field that reverses sign when accessed via opposite direction
 
STAGGERED_FLUX = 5
@staticmethod
@staticmethod
def is_generic(field):
def is_generic(field):
@@ -58,7 +60,12 @@ class FieldType(Enum):
@@ -58,7 +60,12 @@ class FieldType(Enum):
@staticmethod
@staticmethod
def is_staggered(field):
def is_staggered(field):
assert isinstance(field, Field)
assert isinstance(field, Field)
return field.field_type == FieldType.STAGGERED
return field.field_type == FieldType.STAGGERED or field.field_type == FieldType.STAGGERED_FLUX
 
 
@staticmethod
 
def is_staggered_flux(field):
 
assert isinstance(field, Field)
 
return field.field_type == FieldType.STAGGERED_FLUX
def fields(description=None, index_dimensions=0, layout=None, field_type=FieldType.GENERIC, **kwargs):
def fields(description=None, index_dimensions=0, layout=None, field_type=FieldType.GENERIC, **kwargs):
@@ -490,24 +497,29 @@ class Field(AbstractField):
@@ -490,24 +497,29 @@ class Field(AbstractField):
raise ValueError("Wrong number of spatial indices: "
raise ValueError("Wrong number of spatial indices: "
"Got %d, expected %d" % (len(offset), self.spatial_dimensions))
"Got %d, expected %d" % (len(offset), self.spatial_dimensions))
offset = list(offset)
prefactor = 1
neighbor = [0] * len(offset)
neighbor_vec = [0] * len(offset)
for i, o in enumerate(offset):
for i in range(self.spatial_dimensions):
if (o + sp.Rational(1, 2)).is_Integer:
if (offset[i] + sp.Rational(1, 2)).is_Integer:
offset[i] += sp.Rational(1, 2)
neighbor_vec[i] = sp.sign(offset[i])
neighbor[i] = -1
neighbor = offset_to_direction_string(neighbor_vec)
neighbor = offset_to_direction_string(neighbor)
if neighbor not in self.staggered_stencil:
try:
neighbor_vec = inverse_direction(neighbor_vec)
idx = self.staggered_stencil.index(neighbor)
neighbor = offset_to_direction_string(neighbor_vec)
except ValueError:
if FieldType.is_staggered_flux(self):
 
prefactor = -1
 
if neighbor not in self.staggered_stencil:
raise ValueError("{} is not a valid neighbor for the {} stencil".format(offset_orig,
raise ValueError("{} is not a valid neighbor for the {} stencil".format(offset_orig,
self.staggered_stencil_name))
self.staggered_stencil_name))
offset = tuple(offset)
 
offset = tuple(sp.Matrix(offset) - sp.Rational(1, 2) * sp.Matrix(neighbor_vec))
 
 
idx = self.staggered_stencil.index(neighbor)
if self.index_dimensions == 1: # this field stores a scalar value at each staggered position
if self.index_dimensions == 1: # this field stores a scalar value at each staggered position
if index is not None:
if index is not None:
raise ValueError("Cannot specify an index for a scalar staggered field")
raise ValueError("Cannot specify an index for a scalar staggered field")
return Field.Access(self, offset, (idx,))
return prefactor * Field.Access(self, offset, (idx,))
else: # this field stores a vector or tensor at each staggered position
else: # this field stores a vector or tensor at each staggered position
if index is None:
if index is None:
raise ValueError("Wrong number of indices: "
raise ValueError("Wrong number of indices: "
@@ -520,7 +532,7 @@ class Field(AbstractField):
@@ -520,7 +532,7 @@ class Field(AbstractField):
raise ValueError("Wrong number of indices: "
raise ValueError("Wrong number of indices: "
"Got %d, expected %d" % (len(index), self.index_dimensions - 1))
"Got %d, expected %d" % (len(index), self.index_dimensions - 1))
return Field.Access(self, offset, (idx, *index))
return prefactor * Field.Access(self, offset, (idx, *index))
@property
@property
def staggered_stencil(self):
def staggered_stencil(self):
Loading