Skip to content
Snippets Groups Projects
Commit 42c29c46 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Automatically generate adjoint for matrix transforms

parent df040780
Branches
Tags
No related merge requests found
......@@ -13,7 +13,9 @@ from collections.abc import Iterable
import sympy
import pystencils
from pystencils.autodiff import AdjointField
from pystencils_reco import AssignmentCollection, crazy
import types
@crazy
......@@ -23,13 +25,28 @@ def generic_spatial_matrix_transform(input_field, output_field, transform_matrix
if inverse_matrix is None:
inverse_matrix = transform_matrix.inv()
# output_coordinate = input_field.coordinate_transform.inv() @ (
# inverse_matrix @ output_field.physical_coordinates_staggered) - input_field.coordinate_origin
output_coordinate = input_field.physical_to_index(
inverse_matrix @ output_field.physical_coordinates_staggered, staggered=False)
assignments = AssignmentCollection({
output_field.center():
texture.at(input_field.coordinate_transform.inv() @
(inverse_matrix @ output_field.physical_coordinates_staggered) - input_field.coordinate_origin)
texture.at(output_coordinate)
})
assignments.transform_matrix = transform_matrix
def create_autodiff(self, constant_fields=None):
assignments.transform_matrix = transform_matrix
texture = pystencils.astnodes.TextureCachedField(AdjointField(output_field))
output_coordinate = output_field.physical_to_index(
transform_matrix @ input_field.physical_coordinates_staggered, staggered=True)
backward_assignments = AssignmentCollection({
AdjointField(input_field).center(): texture.at(output_coordinate)
})
self._autodiff = pystencils.autodiff.AutoDiffOp(assignments, "", backward_assignments=backward_assignments)
assignments._create_autodiff = types.MethodType(create_autodiff, assignments)
return assignments
......@@ -54,7 +71,7 @@ def scale_transform(input_field, output_field, scaling_factor):
def rotation_transform(input_field, output_field, rotation_angle, rotation_axis=None):
if input_field.spatial_dimensions == 3:
assert rotation_axis is not None, "You must specify a rotation_axis for 3d rotations!"
transform_matrix = getattr(sympy, 'rot_axis%i' % (rotation_axis+1))(rotation_angle)
transform_matrix = getattr(sympy, 'rot_axis%i' % (rotation_axis + 1))(rotation_angle)
elif input_field.spatial_dimensions == 2:
# 2d rotation is 3d rotation around 3rd axis
transform_matrix = sympy.rot_axis3(rotation_angle)[:2, :2]
......
......@@ -8,6 +8,7 @@
"""
import pystencils
import pystencils_reco.resampling
from pystencils.autodiff import torch_tensor_from_field
from pystencils_reco.filters import mean_filter
from pystencils_reco.stencils import BallStencil
......@@ -54,9 +55,22 @@ def test_pytorch_from_tensors():
print(torch_op)
def test_texture():
x, y = pystencils.fields('x,y: float32[100,100]')
assignments = pystencils_reco.resampling.scale_transform(x, y, 2)
x_tensor = torch_tensor_from_field(x, requires_grad=True, cuda=True)
y_tensor = torch_tensor_from_field(y, cuda=True)
kernel = assignments.create_pytorch_op(x=x_tensor, y=y_tensor)
print(assignments)
print(kernel)
def main():
# test_pytorch()
test_pytorch_from_tensors()
# test_pytorch_from_tensors()
test_texture()
if __name__ == '__main__':
......
......@@ -21,7 +21,7 @@ from pystencils_reco.resampling import rotation_transform, scale_transform
def test_scaling():
for ndim in range(1, 5):
for scale in (0.5, [(s+1)*0.1 for s in range(ndim)]):
for scale in (0.5, [(s + 1) * 0.1 for s in range(ndim)]):
x, y = pystencils.fields('x,y: float32[%id]' % ndim)
transform = scale_transform(x, y, scale)
print(transform)
......@@ -40,7 +40,7 @@ def test_rotation():
def test_scaling_compilation():
for ndim in range(1, 4):
for scale in (0.5, [(s+1)*0.1 for s in range(ndim)]):
for scale in (0.5, [(s + 1) * 0.1 for s in range(ndim)]):
x, y = pystencils.fields('x,y: float32[%id]' % ndim)
scale_transform(x, y, scale).compile('gpu')
......@@ -63,7 +63,7 @@ def test_scaling_visualize():
s = pystencils.data_types.TypedSymbol('s', 'float32')
transform = scale_transform(x, y, s).compile('gpu')
test_image = 1-skimage.io.imread(join(dirname(__file__), "test_data", "test_vessel2d_mask.png"), as_gray=True)
test_image = 1 - skimage.io.imread(join(dirname(__file__), "test_data", "test_vessel2d_mask.png"), as_gray=True)
test_image = np.ascontiguousarray(test_image, np.float32)
test_image = to_gpu(test_image)
tmp = zeros_like(test_image)
......@@ -82,11 +82,10 @@ def test_rotation_visualize():
s = pystencils.data_types.TypedSymbol('s', 'float32')
transform = rotation_transform(x, y, s).compile('gpu')
test_image = 1 - skimage.io.imread(join(dirname(__file__), "test_data", "test_vessel2d_mask.png"), as_gray=True)
test_image = 1 - skimage.io.imread(join(dirname(__file__), "test_data", "test_vessel2d_mask.png"), as_gray=True)
test_image = np.ascontiguousarray(test_image, np.float32)
test_image = to_gpu(test_image)
tmp = zeros_like(test_image)
print(transform.code)
for s in (0.2, 0.5, 0.7, 1, 2):
transform(x=test_image, y=tmp, s=s)
......@@ -98,7 +97,7 @@ def test_rotation_around_center_visualize():
import pyconrad.autoinit
from pycuda.gpuarray import to_gpu, zeros_like
test_image = 1 - skimage.io.imread(join(dirname(__file__), "test_data", "test_vessel2d_mask.png"), as_gray=True)
test_image = 1 - skimage.io.imread(join(dirname(__file__), "test_data", "test_vessel2d_mask.png"), as_gray=True)
test_image = np.ascontiguousarray(test_image, np.float32)
test_image = to_gpu(test_image)
......@@ -109,7 +108,6 @@ def test_rotation_around_center_visualize():
print(x.coordinate_origin)
s = pystencils.data_types.TypedSymbol('s', 'float32')
transform = rotation_transform(x, y, s).compile('gpu')
print(transform.code)
for s in (0, 0.2, 0.5, 0.7, 1, 2):
transform(x=test_image, y=tmp, s=s)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment