Skip to content
Snippets Groups Projects
Commit 5955f84b authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Apply dirty hack when adjoint calculation depends on size of forward

fields but not on the forward fields them selves (just take size of
corresponding adjoint fields)
parent 3e5b4877
No related branches found
No related tags found
No related merge requests found
Pipeline #28021 failed
......@@ -151,12 +151,11 @@ class DestructuringBindingsForFieldClass(Node):
undefined_field_symbols = self.symbols_defined
corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')}
corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')}
return {TypedSymbol(f,
self.CLASS_NAME_TEMPLATE.format(dtype=field_map[f].dtype,
ndim=field_map[f].ndim) + ('&'
if self.ARGS_AS_REFERENCE
else ''))
for f in corresponding_field_names} | (self.body.undefined_symbols - undefined_field_symbols)
return {TypedSymbol(f, self.CLASS_NAME_TEMPLATE.format(dtype=(field_map.get(f) or field_map.get('diff' + f)).dtype,
ndim=(field_map.get(f) or field_map.get('diff' + f)).ndim) + ('&'
if self.ARGS_AS_REFERENCE
else ''))
for f in corresponding_field_names} | (self.body.undefined_symbols - undefined_field_symbols)
def subs(self, subs_dict) -> None:
"""Inplace! substitute, similar to sympy's but modifies the AST inplace."""
......
......@@ -123,9 +123,12 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend):
+ (node.field_suffix if hasattr(node, 'field_suffix') else ''),
node.CLASS_TO_MEMBER_DICT[u.__class__].format(
dtype=(u.dtype.base_type if type(u) == FieldPointerSymbol
else fields_dtype[u.field_name
if hasattr(u, 'field_name')
else u.field_names[0]]),
else ((fields_dtype.get(u.field_name
if hasattr(u, 'field_name')
else u.field_names[0]))
or (fields_dtype.get('diff' + u.field_name
if hasattr(u, 'field_name')
else 'diff' + u.field_names[0])))),
field_name=(u.field_name if hasattr(u, "field_name") else ""),
dim=("" if type(u) == FieldPointerSymbol else u.coordinate),
dim_letter=("" if type(u) == FieldPointerSymbol else 'xyz'[u.coordinate])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment