Skip to content
Snippets Groups Projects
Commit 4f43b51a authored by Nils Kohl's avatar Nils Kohl :full_moon_with_face: Committed by Martin Bauer
Browse files

Improved support for arbitrary field classes.

- introduced AbstractField and AbstractAccess

Fixes #28
parent eec4dc4b
No related merge requests found
......@@ -13,7 +13,7 @@ from pystencils.sympyextensions import is_integer_sequence
import pickle
import hashlib
__all__ = ['Field', 'fields', 'FieldType']
__all__ = ['Field', 'fields', 'FieldType', 'AbstractField']
def fields(description=None, index_dimensions=0, layout=None, **kwargs):
......@@ -116,7 +116,13 @@ class FieldType(Enum):
return field.field_type == FieldType.CUSTOM
class Field:
class AbstractField:
class AbstractAccess:
pass
class Field(AbstractField):
"""
With fields one can formulate stencil-like update rules on structured grids.
This Field class knows about the dimension, memory layout (strides) and optionally about the size of an array.
......@@ -394,7 +400,7 @@ class Field:
return self.hashable_contents() == other.hashable_contents()
# noinspection PyAttributeOutsideInit,PyUnresolvedReferences
class Access(sp.Symbol):
class Access(sp.Symbol, AbstractField.AbstractAccess):
"""Class representing a relative access into a `Field`.
This class behaves like a normal sympy Symbol, it is actually derived from it. One can built up
......
......@@ -9,7 +9,7 @@ from sympy.logic.boolalg import Boolean
from sympy.tensor import IndexedBase
from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.assignment import Assignment
from pystencils.field import Field, FieldType
from pystencils.field import AbstractField, FieldType, Field
from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, reinterpret_cast_func, \
cast_func, pointer_arithmetic_func, get_type_of_expression, collate_types, create_type
from pystencils.kernelparameters import FieldPointerSymbol
......@@ -160,7 +160,7 @@ def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layer
:class:`LoopOverCoordinate` instance with nested loops, ordered according to field layouts
"""
# find correct ordering by inspecting participating FieldAccesses
field_accesses = body.atoms(Field.Access)
field_accesses = body.atoms(AbstractField.AbstractAccess)
field_accesses = {e for e in field_accesses if not e.is_absolute_access}
# exclude accesses to buffers from field_list, because buffers are treated separately
......@@ -353,7 +353,7 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
loop_iterations = [(l.stop - l.start) / l.step for l in loops]
loop_counters = [l.loop_counter_symbol for l in loops]
field_accesses = ast_node.atoms(Field.Access)
field_accesses = ast_node.atoms(AbstractField.AbstractAccess)
buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)}
loop_counters = [v * len(buffer_accesses) for v in loop_counters]
......@@ -369,7 +369,7 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()):
def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
if isinstance(expr, Field.Access):
if isinstance(expr, AbstractField.AbstractAccess):
field_access = expr
# Do not apply transformation if field is not a buffer
......@@ -433,7 +433,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0]))
def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
if isinstance(expr, Field.Access):
if isinstance(expr, AbstractField.AbstractAccess):
field_access = expr
field = field_access.field
......@@ -654,12 +654,13 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
if s in assignment_map: # if there is no assignment inside the loop body it is independent already
for new_symbol in assignment_map[s].rhs.atoms(sp.Symbol):
if type(new_symbol) is not Field.Access and new_symbol not in symbols_with_temporary_array:
if not isinstance(new_symbol, AbstractField.AbstractAccess) and \
new_symbol not in symbols_with_temporary_array:
symbols_to_process.append(new_symbol)
symbols_resolved.add(s)
for symbol in symbol_group:
if type(symbol) is not Field.Access:
if not isinstance(symbol, AbstractField.AbstractAccess):
assert type(symbol) is TypedSymbol
new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype))
symbols_with_temporary_array[symbol] = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol]
......@@ -668,7 +669,7 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
for assignment in inner_loop.body.args:
if assignment.lhs in symbols_resolved:
new_rhs = assignment.rhs.subs(symbols_with_temporary_array.items())
if type(assignment.lhs) is not Field.Access and assignment.lhs in symbol_group:
if not isinstance(assignment.lhs, AbstractField.AbstractAccess) and assignment.lhs in symbol_group:
assert type(assignment.lhs) is TypedSymbol
new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype))
new_lhs = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol]
......@@ -792,7 +793,7 @@ class KernelConstraintsCheck:
def process_expression(self, rhs, type_constants=True):
self._update_accesses_rhs(rhs)
if isinstance(rhs, Field.Access):
if isinstance(rhs, AbstractField.AbstractAccess):
self.fields_read.add(rhs.field)
self.fields_read.update(rhs.indirect_addressing_fields)
return rhs
......@@ -822,13 +823,13 @@ class KernelConstraintsCheck:
def _process_lhs(self, lhs):
assert isinstance(lhs, sp.Symbol)
self._update_accesses_lhs(lhs)
if not isinstance(lhs, Field.Access) and not isinstance(lhs, TypedSymbol):
if not isinstance(lhs, AbstractField.AbstractAccess) and not isinstance(lhs, TypedSymbol):
return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name])
else:
return lhs
def _update_accesses_lhs(self, lhs):
if isinstance(lhs, Field.Access):
if isinstance(lhs, AbstractField.AbstractAccess):
fai = self.FieldAndIndex(lhs.field, lhs.index)
self._field_writes[fai].add(lhs.offsets)
if len(self._field_writes[fai]) > 1:
......@@ -841,7 +842,7 @@ class KernelConstraintsCheck:
self.scopes.define_symbol(lhs)
def _update_accesses_rhs(self, rhs):
if isinstance(rhs, Field.Access) and self.check_independence_condition:
if isinstance(rhs, AbstractField.AbstractAccess) and self.check_independence_condition:
writes = self._field_writes[self.FieldAndIndex(rhs.field, rhs.index)]
for write_offset in writes:
assert len(writes) == 1
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment