diff --git a/src/pystencils_reco/resampling.py b/src/pystencils_reco/resampling.py index 5ac543f367f44bba403997d620f950be62cde486..be067291a2973c5fe53411c386172f9d7eb308f2 100644 --- a/src/pystencils_reco/resampling.py +++ b/src/pystencils_reco/resampling.py @@ -13,7 +13,9 @@ from collections.abc import Iterable import sympy import pystencils +from pystencils.autodiff import AdjointField from pystencils_reco import AssignmentCollection, crazy +import types @crazy @@ -23,13 +25,28 @@ def generic_spatial_matrix_transform(input_field, output_field, transform_matrix if inverse_matrix is None: inverse_matrix = transform_matrix.inv() + # output_coordinate = input_field.coordinate_transform.inv() @ ( + # inverse_matrix @ output_field.physical_coordinates_staggered) - input_field.coordinate_origin + output_coordinate = input_field.physical_to_index( + inverse_matrix @ output_field.physical_coordinates_staggered, staggered=False) + assignments = AssignmentCollection({ output_field.center(): - texture.at(input_field.coordinate_transform.inv() @ - (inverse_matrix @ output_field.physical_coordinates_staggered) - input_field.coordinate_origin) + texture.at(output_coordinate) }) - assignments.transform_matrix = transform_matrix + def create_autodiff(self, constant_fields=None): + assignments.transform_matrix = transform_matrix + + texture = pystencils.astnodes.TextureCachedField(AdjointField(output_field)) + output_coordinate = output_field.physical_to_index( + transform_matrix @ input_field.physical_coordinates_staggered, staggered=True) + backward_assignments = AssignmentCollection({ + AdjointField(input_field).center(): texture.at(output_coordinate) + }) + self._autodiff = pystencils.autodiff.AutoDiffOp(assignments, "", backward_assignments=backward_assignments) + + assignments._create_autodiff = types.MethodType(create_autodiff, assignments) return assignments @@ -54,7 +71,7 @@ def scale_transform(input_field, output_field, scaling_factor): def rotation_transform(input_field, output_field, rotation_angle, rotation_axis=None): if input_field.spatial_dimensions == 3: assert rotation_axis is not None, "You must specify a rotation_axis for 3d rotations!" - transform_matrix = getattr(sympy, 'rot_axis%i' % (rotation_axis+1))(rotation_angle) + transform_matrix = getattr(sympy, 'rot_axis%i' % (rotation_axis + 1))(rotation_angle) elif input_field.spatial_dimensions == 2: # 2d rotation is 3d rotation around 3rd axis transform_matrix = sympy.rot_axis3(rotation_angle)[:2, :2] diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 983468173d57d920788fa74fb63942e45632cb73..8c1e376125c97e824cf52588948fabc1187b771b 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -8,6 +8,7 @@ """ import pystencils +import pystencils_reco.resampling from pystencils.autodiff import torch_tensor_from_field from pystencils_reco.filters import mean_filter from pystencils_reco.stencils import BallStencil @@ -54,9 +55,22 @@ def test_pytorch_from_tensors(): print(torch_op) +def test_texture(): + + x, y = pystencils.fields('x,y: float32[100,100]') + assignments = pystencils_reco.resampling.scale_transform(x, y, 2) + + x_tensor = torch_tensor_from_field(x, requires_grad=True, cuda=True) + y_tensor = torch_tensor_from_field(y, cuda=True) + kernel = assignments.create_pytorch_op(x=x_tensor, y=y_tensor) + print(assignments) + print(kernel) + + def main(): # test_pytorch() - test_pytorch_from_tensors() + # test_pytorch_from_tensors() + test_texture() if __name__ == '__main__': diff --git a/tests/test_resampling.py b/tests/test_resampling.py index e1fc0f9c5196ef61fb0bab374da4c64a9812feb5..28cb0f656aab48d3d57fcb32588fea3e5e1f3116 100644 --- a/tests/test_resampling.py +++ b/tests/test_resampling.py @@ -21,7 +21,7 @@ from pystencils_reco.resampling import rotation_transform, scale_transform def test_scaling(): for ndim in range(1, 5): - for scale in (0.5, [(s+1)*0.1 for s in range(ndim)]): + for scale in (0.5, [(s + 1) * 0.1 for s in range(ndim)]): x, y = pystencils.fields('x,y: float32[%id]' % ndim) transform = scale_transform(x, y, scale) print(transform) @@ -40,7 +40,7 @@ def test_rotation(): def test_scaling_compilation(): for ndim in range(1, 4): - for scale in (0.5, [(s+1)*0.1 for s in range(ndim)]): + for scale in (0.5, [(s + 1) * 0.1 for s in range(ndim)]): x, y = pystencils.fields('x,y: float32[%id]' % ndim) scale_transform(x, y, scale).compile('gpu') @@ -63,7 +63,7 @@ def test_scaling_visualize(): s = pystencils.data_types.TypedSymbol('s', 'float32') transform = scale_transform(x, y, s).compile('gpu') - test_image = 1-skimage.io.imread(join(dirname(__file__), "test_data", "test_vessel2d_mask.png"), as_gray=True) + test_image = 1 - skimage.io.imread(join(dirname(__file__), "test_data", "test_vessel2d_mask.png"), as_gray=True) test_image = np.ascontiguousarray(test_image, np.float32) test_image = to_gpu(test_image) tmp = zeros_like(test_image) @@ -82,11 +82,10 @@ def test_rotation_visualize(): s = pystencils.data_types.TypedSymbol('s', 'float32') transform = rotation_transform(x, y, s).compile('gpu') - test_image = 1 - skimage.io.imread(join(dirname(__file__), "test_data", "test_vessel2d_mask.png"), as_gray=True) + test_image = 1 - skimage.io.imread(join(dirname(__file__), "test_data", "test_vessel2d_mask.png"), as_gray=True) test_image = np.ascontiguousarray(test_image, np.float32) test_image = to_gpu(test_image) tmp = zeros_like(test_image) - print(transform.code) for s in (0.2, 0.5, 0.7, 1, 2): transform(x=test_image, y=tmp, s=s) @@ -98,7 +97,7 @@ def test_rotation_around_center_visualize(): import pyconrad.autoinit from pycuda.gpuarray import to_gpu, zeros_like - test_image = 1 - skimage.io.imread(join(dirname(__file__), "test_data", "test_vessel2d_mask.png"), as_gray=True) + test_image = 1 - skimage.io.imread(join(dirname(__file__), "test_data", "test_vessel2d_mask.png"), as_gray=True) test_image = np.ascontiguousarray(test_image, np.float32) test_image = to_gpu(test_image) @@ -109,7 +108,6 @@ def test_rotation_around_center_visualize(): print(x.coordinate_origin) s = pystencils.data_types.TypedSymbol('s', 'float32') transform = rotation_transform(x, y, s).compile('gpu') - print(transform.code) for s in (0, 0.2, 0.5, 0.7, 1, 2): transform(x=test_image, y=tmp, s=s)