From 40c7e207e012457dd6e9201b2e01733f594d12f7 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Sun, 24 Nov 2019 15:02:12 +0100
Subject: [PATCH] Commit stuff

---
 src/pystencils_reco/__init__.py               |   3 +-
 src/pystencils_reco/_assignment_collection.py |  25 ++++-
 src/pystencils_reco/_crazy_decorator.py       |  58 +++++++---
 src/pystencils_reco/numpy_array_handler.py    |  49 +++++++++
 src/pystencils_reco/registration.py           |  70 ++++++++++++
 src/pystencils_reco/resampling.py             |  12 ++
 src/pystencils_reco/vesselness.py             |  49 +++++++++
 tests/test_registration.py                    |  30 +++++
 tests/test_tensorflow.py                      |  45 ++++++++
 tests/test_vesselness.py                      | 103 ++++++++++++++++++
 10 files changed, 428 insertions(+), 16 deletions(-)
 create mode 100644 src/pystencils_reco/numpy_array_handler.py
 create mode 100644 src/pystencils_reco/registration.py
 create mode 100644 src/pystencils_reco/vesselness.py
 create mode 100644 tests/test_registration.py
 create mode 100644 tests/test_tensorflow.py
 create mode 100644 tests/test_vesselness.py

diff --git a/src/pystencils_reco/__init__.py b/src/pystencils_reco/__init__.py
index 4d60143..c0d4bcc 100644
--- a/src/pystencils_reco/__init__.py
+++ b/src/pystencils_reco/__init__.py
@@ -2,7 +2,7 @@
 from pkg_resources import DistributionNotFound, get_distribution
 
 from pystencils_reco._assignment_collection import AssignmentCollection
-from pystencils_reco._crazy_decorator import crazy
+from pystencils_reco._crazy_decorator import crazy, fixed_boundary_handling
 from pystencils_reco._projective_matrix import ProjectiveMatrix
 from pystencils_reco._typed_symbols import matrix_symbols, typed_symbols
 
@@ -18,6 +18,7 @@ finally:
 
 __all__ = ['AssignmentCollection',
            'crazy',
+           'crazy_fixed_boundary_handling'
            'ProjectiveMatrix',
            'matrix_symbols',
            'typed_symbols']
diff --git a/src/pystencils_reco/_assignment_collection.py b/src/pystencils_reco/_assignment_collection.py
index d11bd7b..78d750d 100644
--- a/src/pystencils_reco/_assignment_collection.py
+++ b/src/pystencils_reco/_assignment_collection.py
@@ -13,6 +13,8 @@ from functools import partial
 from itertools import chain
 
 import pystencils
+from pystencils.cache import disk_cache
+from pystencils_autodiff.tensorflow_jit import _hash
 
 
 class NdArrayType(str, Enum):
@@ -51,6 +53,12 @@ def get_type_of_arrays(*args):
     return NdArrayType.UNKNOWN
 
 
+@disk_cache
+def get_module_file(assignments, target):
+    kernel = assignments._create_ml_op('torch_native', target)
+    return kernel.ast.module_name
+
+
 class AssignmentCollection(pystencils.AssignmentCollection):
     """
     A high-level wrapper around pystencils.AssignmentCollection that provides some convenience methods
@@ -68,8 +76,10 @@ class AssignmentCollection(pystencils.AssignmentCollection):
 
         assignments = pystencils.AssignmentCollection(assignments, {})
         if perform_cse:
-            main_assignments = [a for a in assignments if isinstance(a.lhs, pystencils.Field.Access)]
-            subexpressions = [a for a in assignments if not isinstance(a.lhs, pystencils.Field.Access)]
+            main_assignments = [a for a in assignments if not hasattr(
+                a, 'lhs') or isinstance(a.lhs, pystencils.Field.Access)]
+            subexpressions = [a for a in assignments if hasattr(
+                a, 'lhs') and not isinstance(a.lhs, pystencils.Field.Access)]
             assignments = pystencils.AssignmentCollection(main_assignments, subexpressions)
             assignments = pystencils.simp.sympy_cse(assignments)
         super(AssignmentCollection, self).__init__(assignments.all_assignments, {}, *args, **kwargs)
@@ -78,6 +88,17 @@ class AssignmentCollection(pystencils.AssignmentCollection):
         self._autodiff = None
         self.kernel = None
 
+    @property
+    def reproducible_hash(self):
+        fields = sorted(self.free_fields | self.bound_fields, key=lambda f: f.name)
+        hashable_contents = [f.hashable_contents() for f in fields]
+        hash_str = str(self)
+        hash_str += str(hashable_contents)
+        return _hash(hash_str.encode()).hexdigest()
+
+    def __getstate__(self):
+        return self.reproducible_hash
+
     def compile(self, target=None, *args, **kwargs):
         """Convenience wrapper for pystencils.create_kernel(...).compile()
         See :func: ~pystencils.create_kernel
diff --git a/src/pystencils_reco/_crazy_decorator.py b/src/pystencils_reco/_crazy_decorator.py
index 143c6ec..ae2f9a5 100644
--- a/src/pystencils_reco/_crazy_decorator.py
+++ b/src/pystencils_reco/_crazy_decorator.py
@@ -14,8 +14,10 @@ import inspect
 import sympy
 
 import pystencils
+import pystencils_autodiff.transformations
 import pystencils_reco
-from pystencils_autodiff.field_tensor_conversion import ArrayWithIndexDimensions
+from pystencils.cache import disk_cache
+from pystencils_autodiff.field_tensor_conversion import ArrayWrapper
 
 
 def crazy(function):
@@ -41,11 +43,11 @@ def crazy(function):
 
         kwargs.update({arg_names[i]: a for i, a in enumerate(args)})
 
-        if (isinstance(assignments, (pystencils.AssignmentCollection, list)) and
+        if (isinstance(assignments, (pystencils.AssignmentCollection, list, dict)) and
                 not isinstance(assignments, pystencils_reco.AssignmentCollection)):
             assignments = pystencils_reco.AssignmentCollection(assignments)
 
-        kwargs = {k: v if not isinstance(v, ArrayWithIndexDimensions) else v.array for k, v in kwargs.items()}
+        kwargs = {k: v if not isinstance(v, ArrayWrapper) else v.array for k, v in kwargs.items()}
 
         try:
             assignments.kwargs = kwargs
@@ -57,17 +59,47 @@ def crazy(function):
     return wrapper
 
 
+def fixed_boundary_handling(function):
+
+    @functools.wraps(function)
+    def wrapper(*args, **kwargs):
+
+        assignments = function(*args, **kwargs)
+        kwargs = assignments.__dict__.get('kwargs', {})
+        args = assignments.__dict__.get('args', {})
+
+        if (isinstance(assignments, (pystencils.AssignmentCollection, list, dict)) and
+                not isinstance(assignments, pystencils_reco.AssignmentCollection)):
+            assignments = pystencils_reco.AssignmentCollection(assignments)
+
+        assignments = pystencils_autodiff.transformations.add_fixed_constant_boundary_handling(assignments)
+
+        assignments = pystencils_reco.AssignmentCollection(assignments)
+        assignments.args = args
+        assignments.kwargs = kwargs
+
+        return assignments
+
+    return wrapper
+
+
+@disk_cache
+def crazy_compile(crazy_function, *args, **kwargs):
+
+    return crazy_function(*args, **kwargs).compile()
+
+
 # class requires(object):
-    # """ Decorator for registering requirements on print methods. """
+# """ Decorator for registering requirements on print methods. """
 
-    # def __init__(self, **kwargs):
-    # self._decorator_kwargs = kwargs
+# def __init__(self, **kwargs):
+# self._decorator_kwargs = kwargs
 
-    # def __call__(self, function):
-    # def _method_wrapper(self, *args, **kwargs):
-    # for k, v in self._decorator_kwargs.items():
-    # obj, member = k.split('__')
-    # setattr(kwargs[obj], member, v)
+# def __call__(self, function):
+# def _method_wrapper(self, *args, **kwargs):
+# for k, v in self._decorator_kwargs.items():
+# obj, member = k.split('__')
+# setattr(kwargs[obj], member, v)
 
-    # return function(*args, **kwargs)
-    # return functools.wraps(function)(_method_wrapper)
+# return function(*args, **kwargs)
+# return functools.wraps(function)(_method_wrapper)
diff --git a/src/pystencils_reco/numpy_array_handler.py b/src/pystencils_reco/numpy_array_handler.py
new file mode 100644
index 0000000..3d01f78
--- /dev/null
+++ b/src/pystencils_reco/numpy_array_handler.py
@@ -0,0 +1,49 @@
+# -*- coding: utf-8 -*-
+#
+# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
+#
+# Distributed under terms of the GPLv3 license.
+
+"""
+
+"""
+import numpy as np
+
+import pystencils
+
+
+class NumpyArrayHandler:
+
+    def zeros(self, shape, dtype=np.float64, order='C'):
+        return np.zeros(shape, dtype, order)
+
+    def ones(self, shape, dtype, order='C'):
+        return np.ones(shape, dtype, order)
+
+    def empty(self, shape, dtype=np.float64, layout=None):
+        if layout:
+            cpu_array = pystencils.field.create_numpy_array_with_layout(shape, dtype, layout)
+            return self.from_numpy(cpu_array)
+        else:
+            return np.empty(shape, dtype)
+
+    def empty_like(self, array):
+        return self.empty(array.shape, array.dtype)
+
+    def ones_like(self, array):
+        return self.ones(array.shape, array.dtype)
+
+    def zeros_like(self, array):
+        return self.ones(array.shape, array.dtype)
+
+    def to_gpu(self, array):
+        return array
+
+    def upload(self, gpuarray, numpy_array):
+        gpuarray[...] = numpy_array
+
+    def download(self, gpuarray, numpy_array):
+        numpy_array[...] = gpuarray
+
+    def randn(self, shape, dtype=np.float64):
+        return np.random.randn(*shape).astype(dtype)
diff --git a/src/pystencils_reco/registration.py b/src/pystencils_reco/registration.py
new file mode 100644
index 0000000..9118450
--- /dev/null
+++ b/src/pystencils_reco/registration.py
@@ -0,0 +1,70 @@
+# -*- coding: utf-8 -*-
+#
+# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
+#
+# Distributed under terms of the GPLv3 license.
+
+"""
+
+"""
+
+import sympy as sp
+
+import pystencils
+from pystencils.rng import PhiloxFourFloats
+from pystencils_autodiff.transformations import get_random_sampling
+from pystencils_reco import crazy, fixed_boundary_handling
+
+
+@fixed_boundary_handling
+@crazy
+def autocorrelation(x: {'field_type': pystencils.FieldType.CUSTOM},
+                    y,
+                    r_xx,
+                    r_yy,
+                    r_xy,
+                    stencil):
+    print(locals())
+
+    sum_xx = []
+    sum_yy = []
+    sum_xy = []
+
+    for s in stencil:
+        sum_xx.append(x[s] * x.center)
+        sum_yy.append(y[s] * y.center)
+        sum_xy.append(y[s] * x.center)
+
+    cross = sp.Symbol('cross')
+    return {
+        cross: sp.Add(*sum_xy),
+        r_xy.center: cross,
+        r_xx.center: sp.Piecewise((sp.Add(*sum_xx), sp.Abs(cross) > 1e-3), (1., True)),
+        r_yy.center: sp.Piecewise((sp.Add(*sum_yy), sp.Abs(cross) > 1e-3), (1., True)),
+    }
+
+
+@crazy
+def autocorrelation_random_sampling(x,
+                                    y,
+                                    auto_correlation,
+                                    time_step=pystencils.data_types.TypedSymbol(
+                                        'time_step', pystencils.data_types.create_type('int32')),
+                                    eps=1e-3):
+
+    assert auto_correlation.spatial_dimensions == 1
+
+    random_floats = PhiloxFourFloats(1, time_step)
+    random_point = get_random_sampling(random_floats.result_symbols[:x.spatial_dimensions], y.aabb_min, y.aabb_max)
+
+    x_point = x.interpolated_access(x.physical_to_index(random_point))
+    y_point = y.interpolated_access(y.physical_to_index(random_point))
+
+    xx = x_point * x_point
+    yy = y_point * y_point
+    xy = y_point * x_point
+
+    return [
+        random_floats,
+        pystencils.Assignment(auto_correlation.center, xy ** 2 / (xx * yy + eps))
+    ]
diff --git a/src/pystencils_reco/resampling.py b/src/pystencils_reco/resampling.py
index 849f3b0..855ff79 100644
--- a/src/pystencils_reco/resampling.py
+++ b/src/pystencils_reco/resampling.py
@@ -113,3 +113,15 @@ def resample(input_field, output_field, interpolation_mode='linear'):
                                             output_field,
                                             sympy.Matrix(sympy.Identity(input_field.spatial_dimensions)),
                                             interpolation_mode=interpolation_mode)
+
+
+@crazy
+def translate(input_field: pystencils.Field,
+              output_field: pystencils.Field,
+              translation,
+              interpolation_mode='linear'):
+
+    return {
+        output_field.center: input_field.interpolated_access(
+            input_field.physical_to_index(output_field.physical_coordinates - translation), interpolation_mode)
+    }
diff --git a/src/pystencils_reco/vesselness.py b/src/pystencils_reco/vesselness.py
new file mode 100644
index 0000000..48161ce
--- /dev/null
+++ b/src/pystencils_reco/vesselness.py
@@ -0,0 +1,49 @@
+import numpy as np
+import sympy
+
+import pystencils
+from pystencils.math_optimizations import evaluate_constant_terms, optimize
+from pystencils.simp import sympy_cse
+from pystencils_reco import crazy
+
+
+@crazy
+def eigenvalues_3d(H_field: {'index_dimensions': 1}, xx, xy, xz, yy, yz, zz):
+
+    H = sympy.Matrix([[xx.center, xy.center, xz.center],
+                      [xy.center, yy.center, yz.center],
+                      [xz.center, yz.center, zz.center]]
+                     )
+
+    eigenvalues = list(H.eigenvals())
+
+    assignments = pystencils.AssignmentCollection({
+        H_field.center(i): sympy.re(eigenvalues[i]) for i in range(3)
+    })
+
+    # class complex_symbol_generator():
+    # def __iter__(self):
+    # counter = 0
+    # while True:
+    # yield TypedSymbol('xi_%i' % counter, create_type(np.complex64))
+    # counter += 1
+
+    # assignments.subexpression_symbol_generator = complex_symbol_generator()
+
+    assignments = sympy_cse(assignments)
+
+    assignments = pystencils.AssignmentCollection(
+        [pystencils.Assignment(a.lhs, sympy.re(optimize(a.rhs, [evaluate_constant_terms]))) for a in assignments]
+    )
+
+    # complex_symbols = [(a.lhs, a.rhs) for a in assignments if any(atom.is_real is False for atom in a.rhs.atoms())]
+    # assignments = assignments.subs({a.lhs, a.rhs})
+    # print(complex_symbols)
+
+    return assignments
+
+
+@crazy
+def vesselness(eigenvalues: {'index_dimensions': 1}, vesselness):
+
+    lamda1 = sympy.max()
diff --git a/tests/test_registration.py b/tests/test_registration.py
new file mode 100644
index 0000000..794dad8
--- /dev/null
+++ b/tests/test_registration.py
@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+#
+# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
+#
+# Distributed under terms of the GPLv3 license.
+
+"""
+
+"""
+import pytest
+
+from pystencils_reco.numpy_array_handler import NumpyArrayHandler
+from pystencils_reco.registration import autocorrelation
+from pystencils_reco.stencils import BallStencil
+
+
+@pytest.mark.parametrize('array_handler', (NumpyArrayHandler(),))
+def test_registration(array_handler):
+    x = array_handler.randn((20, 30, 40))
+    y = array_handler.randn((22, 32, 42))
+    r_xx = array_handler.empty_like(y)
+    r_yy = array_handler.empty_like(y)
+    r_xy = array_handler.empty_like(y)
+
+    assignments = autocorrelation(x, y, r_xx, r_yy, r_xy, BallStencil(3))
+
+    print(assignments)
+
+    kernel = assignments.compile()
+    kernel()
diff --git a/tests/test_tensorflow.py b/tests/test_tensorflow.py
new file mode 100644
index 0000000..4555fa0
--- /dev/null
+++ b/tests/test_tensorflow.py
@@ -0,0 +1,45 @@
+# -*- coding: utf-8 -*-
+#
+# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
+#
+# Distributed under terms of the GPLv3 license.
+
+"""
+
+"""
+import pytest
+import tensorflow as tf
+
+import pystencils_reco.resampling
+from pystencils_reco import crazy
+
+pytest.importorskip('tensorflow')
+
+
+def test_field_generation():
+
+    x = tf.random.normal((20, 30, 100), name='x')
+    y = tf.zeros((20, 30, 100), dtype=tf.float64)
+
+    @crazy
+    def kernel(x, y):
+        print(x.dtype)
+        print(y.dtype)
+        return {x.center: y.center + 2}
+
+    kernel(x, y)
+
+
+@pytest.mark.parametrize('with_texture', (False,))
+def test_texture(with_texture):
+
+    x = tf.random.normal((20, 30, 100), name='x')
+    y = tf.zeros_like(x, name='y')
+    assignments = pystencils_reco.resampling.scale_transform(x, y, 2)
+
+    kernel = assignments.compile()
+    rtn = kernel()
+    rtn = rtn[0].cpu()
+    print(rtn)
+    import pyconrad.autoinit
+    pyconrad.show_everything()
diff --git a/tests/test_vesselness.py b/tests/test_vesselness.py
new file mode 100644
index 0000000..0b32db5
--- /dev/null
+++ b/tests/test_vesselness.py
@@ -0,0 +1,103 @@
+# -*- coding: utf-8 -*-
+#
+# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
+#
+# Distributed under terms of the GPLv3 license.
+
+"""
+
+"""
+
+
+import sympy
+import torch
+
+import pystencils
+from pystencils.backends.json import print_json, write_json
+from pystencils_reco._assignment_collection import get_module_file
+from pystencils_reco.vesselness import eigenvalues_3d
+
+
+def test_vesselness():
+
+    # import numpy as np
+
+    # image = np.random.rand(30, 40, 50).astype(np.float32)
+    # result = np.random.rand(30, 40, 50, 3).astype(np.float32)
+
+    # assignments = eigenvalues_3d(result, image, image, image, image, image, image)
+    # print(assignments)
+    # print(assignments.reproducible_hash)
+    # print(get_module_file(assignments, 'cuda'))
+    # kernel = assignments.compile()
+    # kernel()
+
+    import torch
+    torch.set_default_dtype(torch.double)
+
+    result = torch.randn((2, 3, 4, 3))
+    image0 = torch.zeros((2, 3, 4), requires_grad=True)
+    image1 = torch.zeros((2, 3, 4), requires_grad=True)
+    image2 = torch.zeros((2, 3, 4), requires_grad=True)
+    image3 = torch.zeros((2, 3, 4), requires_grad=True)
+    image4 = torch.zeros((2, 3, 4), requires_grad=True)
+    image5 = torch.zeros((2, 3, 4), requires_grad=True)
+
+    assignments = eigenvalues_3d(result, image0, image1, image2, image3, image4, image5)
+    assignments.compile()
+    # print(assignments)
+    # # assignments.lambdify(sympy.
+    # # kernel=assignments.compile()
+
+    # main_assignments = [a for a in assignments if isinstance(a.lhs, pystencils.Field.Access)]
+    # subexpressions = [a for a in assignments if not isinstance(a.lhs, pystencils.Field.Access)]
+    # assignments = pystencils.AssignmentCollection(main_assignments, subexpressions)
+    # lam = assignments.lambdify(sympy.symbols('xx_C xy_C xz_C yy_C yz_C zz_C'), module='tensorflow')
+
+    # import tensorflow as tf
+
+    # image0 = tf.random.uniform((20, 30, 40))
+    # image1 = tf.random.uniform((20, 30, 40))
+    # image2 = tf.random.uniform((20, 30, 40))
+    # image3 = tf.random.uniform((20, 30, 40))
+    # image4 = tf.random.uniform((20, 30, 40))
+    # image5 = tf.random.uniform((20, 30, 40))
+
+    # a = lam(image0, image1, image2, image3, image4, image5)
+    import tensorflow as tf
+
+    image0 = tf.random.uniform((20, 30, 40, 3, 3))
+    eigenvalues, _ = tf.linalg.eigh(image0)
+
+    print(eigenvalues)
+    print(eigenvalues.shape)
+
+    sorted_eigenvalues = tf.sort(eigenvalues, axis=-1)
+    print(sorted_eigenvalues)
+
+    l1 = sorted_eigenvalues[..., -1]
+    l2 = sorted_eigenvalues[..., -2]
+    l3 = sorted_eigenvalues[..., -3]
+    print(l1)
+    # assignments = assignments.new_without_subexpressions().main_assignments
+    # lambdas = {assignment.lhs: sympy.lambdify(sympy.symbols('H_field xx xy xz yy yz zz'),
+    # assignment.rhs, 'tensorflow') for assignment in assignments}
+    # print(lambdas)
+
+    # torch.autograd.gradcheck(kernel.apply, tuple(
+    # [image0,
+    # image1,
+    # image2,
+    # image3,
+    # image4,
+    # image5]),
+    # atol=1e-4,
+    # raise_exception=True)
+
+    # image = tf.random.uniform((30, 40, 50))
+    # result = tf.random.normal((30, 40, 50, 3))
+
+    # #
+    # kernel = eigenvalues_3d(result, image, image, image, image, image, image).compile()
+
+    # kernel()
-- 
GitLab