diff --git a/src/pystencils_autodiff/_adjoint_field.py b/src/pystencils_autodiff/_adjoint_field.py index c964f65d0ecd9233a978bd34a80762bb46acb546..b76d2d418b97a325caa996f9ea829e83cf3e8e05 100644 --- a/src/pystencils_autodiff/_adjoint_field.py +++ b/src/pystencils_autodiff/_adjoint_field.py @@ -11,7 +11,9 @@ class AdjointField(pystencils.Field): def __init__(self, forward_field, name_prefix='diff'): new_name = name_prefix + forward_field.name - super().__init__(new_name, forward_field.field_type, forward_field._dtype, + super().__init__(new_name, pystencils.FieldType.GENERIC + if forward_field.field_type != pystencils.FieldType.BUFFER + else pystencils.FieldType.BUFFER, forward_field._dtype, forward_field._layout, forward_field.shape, forward_field.strides) self.corresponding_forward_field = forward_field self.name_prefix = name_prefix