Skip to content
Snippets Groups Projects
Commit f1caa29c authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Use automatic calculation of adjoint in resampling

parent ea23f11b
Branches
Tags
No related merge requests found
......@@ -56,7 +56,7 @@ def generic_spatial_matrix_transform(input_field,
self._autodiff = pystencils.autodiff.AutoDiffOp(
assignments, "", backward_assignments=backward_assignments, **kwargs)
assignments._create_autodiff = types.MethodType(create_autodiff, assignments)
# assignments._create_autodiff = types.MethodType(create_autodiff, assignments)
return assignments
......@@ -86,6 +86,9 @@ def rotation_transform(input_field,
rotation_angle,
rotation_axis=None,
interpolation_mode='linear'):
if isinstance(rotation_angle, pystencils.Field):
rotation_angle = rotation_angle.center
if input_field.spatial_dimensions == 3:
assert rotation_axis is not None, "You must specify a rotation_axis for 3d rotations!"
transform_matrix = getattr(sympy, 'rot_axis%i' % (rotation_axis + 1))(rotation_angle)
......@@ -124,10 +127,10 @@ def translate(input_field: pystencils.Field,
interpolation_mode='linear',
allow_spatial_derivatives=True):
def create_autodiff(self, constant_fields=None, **kwargs):
backward_assignments = translate(AdjointField(output_field), AdjointField(input_field), -translation)
self._autodiff = pystencils.autodiff.AutoDiffOp(
assignments, "", backward_assignments=backward_assignments, **kwargs)
# def create_autodiff(self, constant_fields=None, **kwargs):
# backward_assignments = translate(AdjointField(output_field), AdjointField(input_field), -translation)
# self._autodiff = pystencils.autodiff.AutoDiffOp(
# assignments, "", backward_assignments=backward_assignments, **kwargs)
if isinstance(translation, pystencils.Field):
translation = translation.center_vector
......@@ -138,8 +141,8 @@ def translate(input_field: pystencils.Field,
input_field.physical_to_index(output_field.physical_coordinates - translation), interpolation_mode)
})
if not allow_spatial_derivatives:
assignments._create_autodiff = types.MethodType(create_autodiff, assignments)
# if not allow_spatial_derivatives:
# assignments._create_autodiff = types.MethodType(create_autodiff, assignments)
return assignments
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment