Skip to content
Snippets Groups Projects
Commit 59b6c5c4 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Form backward_assignments with spatial derivatives

parent db815251
Branches
Tags
No related merge requests found
Pipeline #20944 failed
...@@ -121,7 +121,8 @@ def resample(input_field, output_field, interpolation_mode='linear'): ...@@ -121,7 +121,8 @@ def resample(input_field, output_field, interpolation_mode='linear'):
def translate(input_field: pystencils.Field, def translate(input_field: pystencils.Field,
output_field: pystencils.Field, output_field: pystencils.Field,
translation, translation,
interpolation_mode='linear'): interpolation_mode='linear',
allow_spatial_derivatives=True):
def create_autodiff(self, constant_fields=None, **kwargs): def create_autodiff(self, constant_fields=None, **kwargs):
backward_assignments = translate(AdjointField(output_field), AdjointField(input_field), -translation) backward_assignments = translate(AdjointField(output_field), AdjointField(input_field), -translation)
...@@ -136,7 +137,9 @@ def translate(input_field: pystencils.Field, ...@@ -136,7 +137,9 @@ def translate(input_field: pystencils.Field,
output_field.center: input_field.interpolated_access( output_field.center: input_field.interpolated_access(
input_field.physical_to_index(output_field.physical_coordinates - translation), interpolation_mode) input_field.physical_to_index(output_field.physical_coordinates - translation), interpolation_mode)
}) })
assignments._create_autodiff = types.MethodType(create_autodiff, assignments)
if not allow_spatial_derivatives:
assignments._create_autodiff = types.MethodType(create_autodiff, assignments)
return assignments return assignments
......
...@@ -93,7 +93,7 @@ def test_polar_transform2(): ...@@ -93,7 +93,7 @@ def test_polar_transform2():
class PolarTransform(sympy.Function): class PolarTransform(sympy.Function):
def eval(args): def eval(args):
return sympy.Matrix( return sympy.Matrix(
(args.norm(), sympy.atan2(args[1]-x.shape[1]/2, args[0]-x.shape[0]/2) / sympy.pi * x.shape[1]/2)) (args.norm(), sympy.atan2(args[1] - x.shape[1] / 2, args[0] - x.shape[0] / 2) / sympy.pi * x.shape[1] / 2))
x.set_coordinate_origin_to_field_center() x.set_coordinate_origin_to_field_center()
y.coordinate_transform = PolarTransform y.coordinate_transform = PolarTransform
...@@ -117,11 +117,11 @@ def test_polar_inverted_transform(): ...@@ -117,11 +117,11 @@ def test_polar_inverted_transform():
class PolarTransform(sympy.Function): class PolarTransform(sympy.Function):
def eval(args): def eval(args):
return sympy.Matrix( return sympy.Matrix(
(args.norm(), sympy.atan2(args[1]-x.shape[1]/2, args[0]-x.shape[0]/2) / sympy.pi * x.shape[1]/2)) (args.norm(), sympy.atan2(args[1] - x.shape[1] / 2, args[0] - x.shape[0] / 2) / sympy.pi * x.shape[1] / 2))
def inv(): def inv():
return lambda l: (sympy.Matrix((sympy.cos(l[1] * sympy.pi / x.shape[1]*2) * l[0], return lambda l: (sympy.Matrix((sympy.cos(l[1] * sympy.pi / x.shape[1] * 2) * l[0],
sympy.sin(l[1] * sympy.pi / x.shape[1]*2) * l[0])) sympy.sin(l[1] * sympy.pi / x.shape[1] * 2) * l[0]))
+ sympy.Matrix(x.shape) * 0.5) + sympy.Matrix(x.shape) * 0.5)
lenna_file = join(dirname(__file__), "test_data", "lenna.png") lenna_file = join(dirname(__file__), "test_data", "lenna.png")
...@@ -251,3 +251,16 @@ def test_motion_model2(): ...@@ -251,3 +251,16 @@ def test_motion_model2():
# while True: # while True:
# sleep(100) # sleep(100)
def test_spatial_derivative():
x, y = pystencils.fields('x, y: float32[2d]')
tx, ty = pystencils.fields('t_x, t_y: float32[2d]')
assignments = pystencils.AssignmentCollection({
y.center: x.interpolated_access((tx.center, 2 * ty.center))
})
backward_assignments = pystencils.autodiff.create_backward_assignments(assignments)
print("backward_assignments: " + str(backward_assignments))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment