diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index b9c02cff7a4cde369a29e1470d89bae1fa8bf78c..2553db45ab4741c432aa928b583613a059d6c8e7 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -1,5 +1,6 @@ from typing import Any, List, Optional, Sequence, Set, Union +import jinja2 import sympy as sp from pystencils.data_types import TypedSymbol, cast_func, create_type @@ -661,7 +662,12 @@ class DestructuringBindingsForFieldClass(Node): FieldShapeSymbol: "shape", FieldStrideSymbol: "stride" } - CLASS_NAME = "Field" + CLASS_NAME_TEMPLATE = jinja2.Template("Field<{{ dtype }}, {{ ndim }}>") + + @property + def fields_accessed(self) -> Set['ResolvedFieldAccess']: + """Set of Field instances: fields which are accessed inside this kernel function""" + return set(o.field for o in self.atoms(ResolvedFieldAccess)) def __init__(self, body): super(DestructuringBindingsForFieldClass, self).__init__() @@ -682,10 +688,12 @@ class DestructuringBindingsForFieldClass(Node): @property def undefined_symbols(self) -> Set[sp.Symbol]: + field_map = {f.name: f for f in self.fields_accessed} undefined_field_symbols = self.symbols_defined corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')} corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')} - return {TypedSymbol(f, self.CLASS_NAME + '&') for f in corresponding_field_names} | \ + return {TypedSymbol(f, self.CLASS_NAME_TEMPLATE.render(dtype=field_map[f].dtype, ndim=field_map[f].ndim) + '&') + for f in corresponding_field_names} | \ (self.body.undefined_symbols - undefined_field_symbols) def subs(self, subs_dict) -> None: diff --git a/pystencils/data_types.py b/pystencils/data_types.py index 7bdc9d340664fa20178ec941cc9e5305d99fd02c..a602b8eb558d93493ee762ad24e22c4e72866edc 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -81,7 +81,7 @@ class TypedSymbol(sp.Symbol): obj = super(TypedSymbol, cls).__xnew__(cls, name) try: obj._dtype = create_type(dtype) - except TypeError: + except (TypeError, ValueError): # on error keep the string obj._dtype = dtype return obj diff --git a/pystencils/field.py b/pystencils/field.py index 82ece20709f1b4b5f0114dbb409331b077541b1f..731755ff6a33ea8b8fa6542bced0308a8b58aa48 100644 --- a/pystencils/field.py +++ b/pystencils/field.py @@ -302,6 +302,10 @@ class Field(AbstractField): def index_dimensions(self) -> int: return len(self.shape) - len(self._layout) + @property + def ndim(self) -> int: + return len(self.shape) + @property def layout(self): return self._layout