From c143c626d48c1ee30d9a65aa79c8e762c8e98771 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Mon, 24 Feb 2020 13:19:49 +0100
Subject: [PATCH] Use StrictUndefined and add CustomFunctionCall

---
 src/pystencils_autodiff/_file_io.py           |  2 +-
 src/pystencils_autodiff/backends/astnodes.py  |  4 +--
 .../backends/module.tmpl.hpp                  |  1 -
 .../field_tensor_conversion.py                |  5 ++-
 .../framework_integration/astnodes.py         | 35 +++++++++++++++++--
 .../framework_integration/printer.py          | 11 ++++++
 .../framework_integration/texture_astnodes.py |  2 +-
 .../framework_integration/types.py            |  9 +++++
 src/pystencils_autodiff/graph_datahandling.py | 26 +++++++++++++-
 src/pystencils_autodiff/tensorflow_jit.py     |  2 +-
 tests/test_dynamic_function.py                | 20 +++++++++--
 11 files changed, 104 insertions(+), 13 deletions(-)

diff --git a/src/pystencils_autodiff/_file_io.py b/src/pystencils_autodiff/_file_io.py
index 6807151..8bf0425 100644
--- a/src/pystencils_autodiff/_file_io.py
+++ b/src/pystencils_autodiff/_file_io.py
@@ -18,7 +18,7 @@ _hash = hashlib.md5
 
 
 def read_template_from_file(file):
-    return jinja2.Template(read_file(file))
+    return jinja2.Template(read_file(file), undefined=jinja2.StrictUndefined)
 
 
 def read_file(file):
diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py
index dd65c4e..e853a4e 100644
--- a/src/pystencils_autodiff/backends/astnodes.py
+++ b/src/pystencils_autodiff/backends/astnodes.py
@@ -102,7 +102,7 @@ class TorchModule(JinjaCppFile):
     def backend(self):
         return 'gpucuda' if self.is_cuda else 'c'
 
-    def __init__(self, module_name, kernel_asts, with_python_bindings=True):
+    def __init__(self, module_name, kernel_asts, with_python_bindings=True, wrap_wrapper_functions=False):
         """Create a C++ module with forward and optional backward_kernels
 
         :param forward_kernel_ast: one or more kernel ASTs (can have any C dialect)
@@ -111,7 +111,7 @@ class TorchModule(JinjaCppFile):
         if not isinstance(kernel_asts, Iterable):
             kernel_asts = [kernel_asts]
         wrapper_functions = [self.generate_wrapper_function(k)
-                             if not isinstance(k, WrapperFunction)
+                             if not isinstance(k, WrapperFunction) or wrap_wrapper_functions
                              else k for k in kernel_asts]
         kernel_asts = [k for k in kernel_asts if not isinstance(k, WrapperFunction)]
         self.module_name = module_name
diff --git a/src/pystencils_autodiff/backends/module.tmpl.hpp b/src/pystencils_autodiff/backends/module.tmpl.hpp
index bf7b99c..3803d53 100644
--- a/src/pystencils_autodiff/backends/module.tmpl.hpp
+++ b/src/pystencils_autodiff/backends/module.tmpl.hpp
@@ -12,4 +12,3 @@
 {% endfor %}
 
 {{ declarations | join('\n\n') }}
-
diff --git a/src/pystencils_autodiff/field_tensor_conversion.py b/src/pystencils_autodiff/field_tensor_conversion.py
index 1d34bda..46e00d1 100644
--- a/src/pystencils_autodiff/field_tensor_conversion.py
+++ b/src/pystencils_autodiff/field_tensor_conversion.py
@@ -99,7 +99,10 @@ def is_array_like(a):
     return (hasattr(a, '__array__')
             or isinstance(a, pycuda.gpuarray.GPUArray)
             or ('tensorflow' in str(type(a)) and 'Tensor' in str(type(a)))
-            or 'torch.Tensor' in str(type(a))) and not isinstance(a, (sympy.Matrix, sympy.MutableDenseMatrix))
+            or 'torch.Tensor' in str(type(a))) and not isinstance(a,
+                                                                  (sympy.Matrix,
+                                                                   sympy.MutableDenseMatrix,
+                                                                   sympy.MatrixSymbol))
 
 
 def tf_constant_from_field(field, init_val=0):
diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py
index 35620f0..8e9cbba 100644
--- a/src/pystencils_autodiff/framework_integration/astnodes.py
+++ b/src/pystencils_autodiff/framework_integration/astnodes.py
@@ -166,7 +166,7 @@ class JinjaCppFile(Node):
     TEMPLATE: jinja2.Template = None
     NOT_PRINT_TYPES = (pystencils.Field, pystencils.TypedSymbol, bool)
 
-    def __init__(self, ast_dict):
+    def __init__(self, ast_dict={}):
         self.ast_dict = pystencils.utils.DotDict(ast_dict)
         self.printer = FrameworkIntegrationPrinter()
         Node.__init__(self)
@@ -349,4 +349,35 @@ class MeshNormalFunctor(DynamicFunction):
     @property
     def name(self):
         return self.mesh_name
-    
+
+
+class CustomFunctionDeclaration(JinjaCppFile):
+    TEMPLATE = jinja2.Template("""{{function_name}}({{ args | join(', ') }});""", undefined=jinja2.StrictUndefined)
+
+    def __init__(self, function_name, args):
+        super().__init__({})
+        self.ast_dict.update({
+            'function_name': function_name,
+            'args': [f'{self._print(a.dtype)} {self._print(a)}' for a in args]
+        })
+
+    @property
+    def function_name(self):
+        return self.ast_dict.function_name
+
+
+class CustomFunctionCall(JinjaCppFile):
+    TEMPLATE = jinja2.Template("""{{function_name}}({{ args | join(', ') }});""", undefined=jinja2.StrictUndefined)
+
+    def __init__(self, function_name, *args, fields_accessed=[]):
+        ast_dict = {
+            'function_name': function_name,
+            'args': args,
+            'fields_accessed': [f.center for f in fields_accessed]
+        }
+        super().__init__(ast_dict)
+        self.required_global_declarations = [CustomFunctionDeclaration(self.ast_dict.function_name, self.ast_dict.args)]
+
+    @property
+    def symbols_defined(self):
+        return set(self.ast_dict.fields_accessed)
diff --git a/src/pystencils_autodiff/framework_integration/printer.py b/src/pystencils_autodiff/framework_integration/printer.py
index f038f85..36477c9 100644
--- a/src/pystencils_autodiff/framework_integration/printer.py
+++ b/src/pystencils_autodiff/framework_integration/printer.py
@@ -19,6 +19,8 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend):
         super().__init__(dialect='c')
         self.sympy_printer.__class__._print_DynamicFunction = self._print_DynamicFunction
         self.sympy_printer.__class__._print_MeshNormalFunctor = self._print_DynamicFunction
+        self.sympy_printer.__class__._print_MatrixElement = self._print_MatrixElement
+        self.sympy_printer.__class__._print_TypedMatrixElement = self._print_MatrixElement
 
     def _print(self, node):
         from pystencils_autodiff.framework_integration.astnodes import JinjaCppFile
@@ -126,6 +128,15 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend):
         arg_str = ', '.join(self._print(a) for a in expr.args[2:])
         return f'{name}({arg_str})'
 
+    def _print_MatrixElement(self, expr):
+        name = expr.name
+        if expr.args[0].args[1] == 1 or expr.args[0].args[2] == 1 or (hasattr(expr.args[0], 'linear_indexing')
+                                                                      and expr.args[0].linear_indexing):
+            return f'{name}[{self._print(expr.args[1])}]'
+        else:
+            arg_str = ', '.join(self._print(a) for a in expr.args[1:])
+            return f'{name}({arg_str})'
+
     def _print_CustomCodeNode(self, node):
         super_code = super()._print_CustomCodeNode(node)
         if super_code:
diff --git a/src/pystencils_autodiff/framework_integration/texture_astnodes.py b/src/pystencils_autodiff/framework_integration/texture_astnodes.py
index 21e43f2..8598426 100644
--- a/src/pystencils_autodiff/framework_integration/texture_astnodes.py
+++ b/src/pystencils_autodiff/framework_integration/texture_astnodes.py
@@ -168,7 +168,7 @@ std::shared_ptr<void> {{array}}Destroyer(nullptr, [&](...){
             texture_object='tex_' + texture_name,
             array='array_' + texture_name,
             texture_name=texture_name,
-            texture_namespace=self._texture_namespace + '::',
+            texture_namespace=self._texture_namespace + '::' if self._texture_namespace else '',
             ndim=self._ndim,
             device_ptr=self._device_ptr,
             create_array=self._get_create_array_call(),
diff --git a/src/pystencils_autodiff/framework_integration/types.py b/src/pystencils_autodiff/framework_integration/types.py
index f392934..2f96ef9 100644
--- a/src/pystencils_autodiff/framework_integration/types.py
+++ b/src/pystencils_autodiff/framework_integration/types.py
@@ -16,3 +16,12 @@ class TemplateType(Type):
 
     def _sympystr(self, *args, **kwargs):
         return str(self._name)
+
+
+class CustomCppType(Type):
+
+    def __init__(self, name):
+        self._name = name
+
+    def _sympystr(self, *args, **kwargs):
+        return str(self._name)
diff --git a/src/pystencils_autodiff/graph_datahandling.py b/src/pystencils_autodiff/graph_datahandling.py
index bc04f5e..27fc13b 100644
--- a/src/pystencils_autodiff/graph_datahandling.py
+++ b/src/pystencils_autodiff/graph_datahandling.py
@@ -15,7 +15,6 @@ import numpy as np
 import pystencils.datahandling
 import pystencils.kernel_wrapper
 import pystencils.timeloop
-from pystencils.data_types import create_type
 from pystencils.field import FieldType
 
 
@@ -62,6 +61,19 @@ class DataTransfer:
         return f'DataTransferKind: {self.kind} with {self.field}'
 
 
+class GhostTensorExtraction:
+    def __init__(self, field: pystencils.Field, on_gpu: bool, with_ghost_layers=False):
+        self.field = field
+        self.on_gpu = on_gpu
+        self.with_ghost_layers = with_ghost_layers
+
+    def __str__(self):
+        return f'GhostTensorExtraction: {self.field}, on_gpu: {self.on_gpu}'
+
+    def __repr__(self):
+        return self.__str__(self)
+
+
 class Swap(DataTransfer):
     def __init__(self, source, destination, gpu):
         self.kind = DataTransferKind.DEVICE_SWAP if gpu else DataTransferKind.HOST_SWAP
@@ -88,6 +100,13 @@ class KernelCall:
     def __str__(self):
         return "Call " + str(self.kernel.ast.function_name)
 
+    @property
+    def ast(self):
+        if isinstance(self.kernel, pystencils.astnodes.Node):
+            return self.kernel
+        else:
+            return self.kernel.ast
+
 
 class FieldOutput:
     def __init__(self, fields, output_path, flag_field):
@@ -321,6 +340,11 @@ class GraphDataHandling(pystencils.datahandling.SerialDataHandling):
         fields = [self.fields[f] if isinstance(f, str) else f for f in fields]
         self.call_queue.append(FieldOutput(fields, output_path, flag_field))
 
+    def extract_tensor(self, field, on_gpu, with_ghost_layers=False):
+        if isinstance(field, str):
+            field = self.fields[field]
+        self.call_queue.append(GhostTensorExtraction(field, on_gpu, with_ghost_layers))
+
     # TODO
     # def reduce_float_sequence(self, sequence, operation, all_reduce=False) -> np.array:
         # return np.array(sequence)
diff --git a/src/pystencils_autodiff/tensorflow_jit.py b/src/pystencils_autodiff/tensorflow_jit.py
index caac7bc..c8d2a17 100644
--- a/src/pystencils_autodiff/tensorflow_jit.py
+++ b/src/pystencils_autodiff/tensorflow_jit.py
@@ -151,7 +151,7 @@ def compile_file(file,
         command_prefix = [nvcc or NVCC_BINARY,
                           '--expt-relaxed-constexpr',
                           '-ccbin',
-                          get_compiler_config()['tensorflow_host_compiler'],
+                          get_compiler_config()['command'],
                           '-Xcompiler',
                           get_compiler_config()['flags'].replace('c++11', 'c++14'),
                           *_nvcc_flags,
diff --git a/tests/test_dynamic_function.py b/tests/test_dynamic_function.py
index 0c0a765..9fb7e8d 100644
--- a/tests/test_dynamic_function.py
+++ b/tests/test_dynamic_function.py
@@ -5,7 +5,7 @@ from pystencils.data_types import TypedSymbol, create_type
 from pystencils_autodiff.framework_integration.astnodes import DynamicFunction
 from pystencils_autodiff.framework_integration.printer import (
     DebugFrameworkPrinter, FrameworkIntegrationPrinter)
-from pystencils_autodiff.framework_integration.types import TemplateType
+from pystencils_autodiff.framework_integration.types import CustomCppType, TemplateType
 
 
 def test_dynamic_function():
@@ -55,14 +55,28 @@ def test_dynamic_matrix():
     pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())
 
 
+def test_typed_matrix():
+    x, y = pystencils.fields('x, y:  float32[3d]')
+    from pystencils.data_types import TypedMatrixSymbol
+
+    A = TypedMatrixSymbol('A', 3, 1, create_type('double'), CustomCppType('Vector3<real_t>'))
+
+    assignments = pystencils.AssignmentCollection({
+        y.center: A[0] + A[1] + A[2]
+    })
+
+    ast = pystencils.create_kernel(assignments)
+    pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())
+
+
 def test_dynamic_matrix_location_dependent():
     x, y = pystencils.fields('x, y:  float32[3d]')
     from pystencils.data_types import TypedMatrixSymbol
 
-    A = TypedMatrixSymbol('A', 3, 1, create_type('double'), 'Vector3<double>')
+    A = TypedMatrixSymbol('A', 3, 1, create_type('double'), CustomCppType('Vector3<double>'))
 
     my_fun_call = DynamicFunction(TypedSymbol('my_fun',
-                                              'std::function< Vector3 < double >(int, int, int) >'),
+                                              'std::function<Vector3<double>(int, int, int)>'),
                                   A.dtype,
                                   *pystencils.x_vector(3))
 
-- 
GitLab