diff --git a/src/pystencils_reco/resampling.py b/src/pystencils_reco/resampling.py index 74712acd45b766f03df07382873edf04b57297fc..0021c990ae8737b0924572da15e7dd23b3745b66 100644 --- a/src/pystencils_reco/resampling.py +++ b/src/pystencils_reco/resampling.py @@ -8,7 +8,6 @@ Implements common resampling operations like rotations and scalings """ -import itertools import types from collections.abc import Iterable @@ -17,6 +16,7 @@ import sympy import pystencils import pystencils.autodiff from pystencils.autodiff import AdjointField +from pystencils.data_types import cast_func, create_type from pystencils_reco import AssignmentCollection, crazy @@ -122,10 +122,46 @@ def translate(input_field: pystencils.Field, translation, interpolation_mode='linear'): - return { - output_field.center: input_field.interpolated_access( - input_field.physical_to_index(output_field.physical_coordinates - translation), interpolation_mode) - } + def create_autodiff(self, constant_fields=None, **kwargs): + backward_assignments = translate(AdjointField(output_field), AdjointField(input_field), -translation) + self._autodiff = pystencils.autodiff.AutoDiffOp( + assignments, "", backward_assignments=backward_assignments, **kwargs) + + if isinstance(translation, pystencils.Field): + translation = translation.center_vector + + assignments = AssignmentCollection( + { + output_field.center: input_field.interpolated_access( + input_field.physical_to_index(output_field.physical_coordinates - translation), interpolation_mode) + }) + assignments._create_autodiff = types.MethodType(create_autodiff, assignments) + return assignments + + +@crazy +def upsample(input: {'field_type': pystencils.field.FieldType.CUSTOM}, + result, + factor): + + ndim = input.spatial_dimensions + here = pystencils.x_vector(ndim) + + assignments = AssignmentCollection( + {result.center: + pystencils.astnodes.ConditionalFieldAccess( + input.absolute_access(tuple(cast_func(sympy.S(1) / factor * h, + create_type('int64')) for h in here), ()), + sympy.Or(*[s % cast_func(factor, 'int64') > 0 for s in here])) + }) + + def create_autodiff(self, constant_fields=None, **kwargs): + backward_assignments = downsample(AdjointField(result), AdjointField(input), factor) + self._autodiff = pystencils.autodiff.AutoDiffOp( + assignments, "", backward_assignments=backward_assignments, **kwargs) + + assignments._create_autodiff = types.MethodType(create_autodiff, assignments) + return assignments @crazy @@ -137,5 +173,14 @@ def downsample(input: {'field_type': pystencils.field.FieldType.CUSTOM}, ndim = input.spatial_dimensions - return {result.center, - input.absolute_access(factor * pystencils.x_vector(ndim), ())} + assignments = AssignmentCollection({result.center: + input.absolute_access(factor * pystencils.x_vector(ndim), ())}) + + def create_autodiff(self, constant_fields=None, **kwargs): + backward_assignments = upsample(AdjointField(result), AdjointField(input), factor) + self._autodiff = pystencils.autodiff.AutoDiffOp( + assignments, "", backward_assignments=backward_assignments, **kwargs) + + assignments._create_autodiff = types.MethodType(create_autodiff, assignments) + + return assignments diff --git a/tests/test_superresolution.py b/tests/test_superresolution.py index 19463ffc3d93b8d1835225fb6fda8b840c4eeaf6..00becc3b0a70c3378a0d2db5e4e827bb4d1617dc 100644 --- a/tests/test_superresolution.py +++ b/tests/test_superresolution.py @@ -7,11 +7,13 @@ """ """ +from os.path import dirname, join + import numpy as np +import skimage.io import pystencils -from pystencils_reco.resampling import downsample, scale_transform -from pystencils_reco.unet import max_pooling +from pystencils_reco.resampling import downsample, scale_transform, translate, upsample try: import pyconrad.autoinit @@ -22,7 +24,7 @@ except Exception: def test_superresolution(): - x, y = np.random.rand(20, 10), np.zeros((20, 10)) + x, y = np.random.rand(20, 10), np.zeros((20, 10)) kernel = scale_transform(x, y, 0.5).compile() print(pystencils.show_code(kernel)) @@ -34,10 +36,32 @@ def test_superresolution(): def test_downsample(): shape = (20, 10) - x, y = np.random.rand(*shape), np.zeros(shape) + x, y = np.random.rand(*shape), np.zeros(tuple(s // 2 for s in shape)) kernel = downsample(x, y, 2).compile() print(pystencils.show_code(kernel)) kernel() pyconrad.show_everything() + + +def test_warp(): + import torch + NUM_LENNAS = 5 + perturbation = 0.1 + + lenna_file = join(dirname(__file__), "test_data", "lenna.png") + lenna = skimage.io.imread(lenna_file, as_gray=True).astype(np.float32) + + warp_vectors = list(perturbation * torch.randn(lenna.shape + (2,)) for _ in range(NUM_LENNAS)) + + warped = [torch.zeros(lenna.shape) for _ in range(NUM_LENNAS)] + + warp_kernel = translate(lenna, warped[0], pystencils.autodiff.ArrayWrapper( + warp_vectors[0], index_dimensions=1), interpolation_mode='linear').compile() + + for i in range(len(warped)): + warp_kernel(lenna[i], warped[i], warp_vectors[i]) + + +test_warp()