Skip to content
Snippets Groups Projects
Commit f1caa29c authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Use automatic calculation of adjoint in resampling

parent ea23f11b
No related branches found
No related tags found
No related merge requests found
...@@ -56,7 +56,7 @@ def generic_spatial_matrix_transform(input_field, ...@@ -56,7 +56,7 @@ def generic_spatial_matrix_transform(input_field,
self._autodiff = pystencils.autodiff.AutoDiffOp( self._autodiff = pystencils.autodiff.AutoDiffOp(
assignments, "", backward_assignments=backward_assignments, **kwargs) assignments, "", backward_assignments=backward_assignments, **kwargs)
assignments._create_autodiff = types.MethodType(create_autodiff, assignments) # assignments._create_autodiff = types.MethodType(create_autodiff, assignments)
return assignments return assignments
...@@ -86,6 +86,9 @@ def rotation_transform(input_field, ...@@ -86,6 +86,9 @@ def rotation_transform(input_field,
rotation_angle, rotation_angle,
rotation_axis=None, rotation_axis=None,
interpolation_mode='linear'): interpolation_mode='linear'):
if isinstance(rotation_angle, pystencils.Field):
rotation_angle = rotation_angle.center
if input_field.spatial_dimensions == 3: if input_field.spatial_dimensions == 3:
assert rotation_axis is not None, "You must specify a rotation_axis for 3d rotations!" assert rotation_axis is not None, "You must specify a rotation_axis for 3d rotations!"
transform_matrix = getattr(sympy, 'rot_axis%i' % (rotation_axis + 1))(rotation_angle) transform_matrix = getattr(sympy, 'rot_axis%i' % (rotation_axis + 1))(rotation_angle)
...@@ -124,10 +127,10 @@ def translate(input_field: pystencils.Field, ...@@ -124,10 +127,10 @@ def translate(input_field: pystencils.Field,
interpolation_mode='linear', interpolation_mode='linear',
allow_spatial_derivatives=True): allow_spatial_derivatives=True):
def create_autodiff(self, constant_fields=None, **kwargs): # def create_autodiff(self, constant_fields=None, **kwargs):
backward_assignments = translate(AdjointField(output_field), AdjointField(input_field), -translation) # backward_assignments = translate(AdjointField(output_field), AdjointField(input_field), -translation)
self._autodiff = pystencils.autodiff.AutoDiffOp( # self._autodiff = pystencils.autodiff.AutoDiffOp(
assignments, "", backward_assignments=backward_assignments, **kwargs) # assignments, "", backward_assignments=backward_assignments, **kwargs)
if isinstance(translation, pystencils.Field): if isinstance(translation, pystencils.Field):
translation = translation.center_vector translation = translation.center_vector
...@@ -138,8 +141,8 @@ def translate(input_field: pystencils.Field, ...@@ -138,8 +141,8 @@ def translate(input_field: pystencils.Field,
input_field.physical_to_index(output_field.physical_coordinates - translation), interpolation_mode) input_field.physical_to_index(output_field.physical_coordinates - translation), interpolation_mode)
}) })
if not allow_spatial_derivatives: # if not allow_spatial_derivatives:
assignments._create_autodiff = types.MethodType(create_autodiff, assignments) # assignments._create_autodiff = types.MethodType(create_autodiff, assignments)
return assignments return assignments
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment