From c143c626d48c1ee30d9a65aa79c8e762c8e98771 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Mon, 24 Feb 2020 13:19:49 +0100 Subject: [PATCH] Use StrictUndefined and add CustomFunctionCall --- src/pystencils_autodiff/_file_io.py | 2 +- src/pystencils_autodiff/backends/astnodes.py | 4 +-- .../backends/module.tmpl.hpp | 1 - .../field_tensor_conversion.py | 5 ++- .../framework_integration/astnodes.py | 35 +++++++++++++++++-- .../framework_integration/printer.py | 11 ++++++ .../framework_integration/texture_astnodes.py | 2 +- .../framework_integration/types.py | 9 +++++ src/pystencils_autodiff/graph_datahandling.py | 26 +++++++++++++- src/pystencils_autodiff/tensorflow_jit.py | 2 +- tests/test_dynamic_function.py | 20 +++++++++-- 11 files changed, 104 insertions(+), 13 deletions(-) diff --git a/src/pystencils_autodiff/_file_io.py b/src/pystencils_autodiff/_file_io.py index 6807151..8bf0425 100644 --- a/src/pystencils_autodiff/_file_io.py +++ b/src/pystencils_autodiff/_file_io.py @@ -18,7 +18,7 @@ _hash = hashlib.md5 def read_template_from_file(file): - return jinja2.Template(read_file(file)) + return jinja2.Template(read_file(file), undefined=jinja2.StrictUndefined) def read_file(file): diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index dd65c4e..e853a4e 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -102,7 +102,7 @@ class TorchModule(JinjaCppFile): def backend(self): return 'gpucuda' if self.is_cuda else 'c' - def __init__(self, module_name, kernel_asts, with_python_bindings=True): + def __init__(self, module_name, kernel_asts, with_python_bindings=True, wrap_wrapper_functions=False): """Create a C++ module with forward and optional backward_kernels :param forward_kernel_ast: one or more kernel ASTs (can have any C dialect) @@ -111,7 +111,7 @@ class TorchModule(JinjaCppFile): if not isinstance(kernel_asts, Iterable): kernel_asts = [kernel_asts] wrapper_functions = [self.generate_wrapper_function(k) - if not isinstance(k, WrapperFunction) + if not isinstance(k, WrapperFunction) or wrap_wrapper_functions else k for k in kernel_asts] kernel_asts = [k for k in kernel_asts if not isinstance(k, WrapperFunction)] self.module_name = module_name diff --git a/src/pystencils_autodiff/backends/module.tmpl.hpp b/src/pystencils_autodiff/backends/module.tmpl.hpp index bf7b99c..3803d53 100644 --- a/src/pystencils_autodiff/backends/module.tmpl.hpp +++ b/src/pystencils_autodiff/backends/module.tmpl.hpp @@ -12,4 +12,3 @@ {% endfor %} {{ declarations | join('\n\n') }} - diff --git a/src/pystencils_autodiff/field_tensor_conversion.py b/src/pystencils_autodiff/field_tensor_conversion.py index 1d34bda..46e00d1 100644 --- a/src/pystencils_autodiff/field_tensor_conversion.py +++ b/src/pystencils_autodiff/field_tensor_conversion.py @@ -99,7 +99,10 @@ def is_array_like(a): return (hasattr(a, '__array__') or isinstance(a, pycuda.gpuarray.GPUArray) or ('tensorflow' in str(type(a)) and 'Tensor' in str(type(a))) - or 'torch.Tensor' in str(type(a))) and not isinstance(a, (sympy.Matrix, sympy.MutableDenseMatrix)) + or 'torch.Tensor' in str(type(a))) and not isinstance(a, + (sympy.Matrix, + sympy.MutableDenseMatrix, + sympy.MatrixSymbol)) def tf_constant_from_field(field, init_val=0): diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py index 35620f0..8e9cbba 100644 --- a/src/pystencils_autodiff/framework_integration/astnodes.py +++ b/src/pystencils_autodiff/framework_integration/astnodes.py @@ -166,7 +166,7 @@ class JinjaCppFile(Node): TEMPLATE: jinja2.Template = None NOT_PRINT_TYPES = (pystencils.Field, pystencils.TypedSymbol, bool) - def __init__(self, ast_dict): + def __init__(self, ast_dict={}): self.ast_dict = pystencils.utils.DotDict(ast_dict) self.printer = FrameworkIntegrationPrinter() Node.__init__(self) @@ -349,4 +349,35 @@ class MeshNormalFunctor(DynamicFunction): @property def name(self): return self.mesh_name - + + +class CustomFunctionDeclaration(JinjaCppFile): + TEMPLATE = jinja2.Template("""{{function_name}}({{ args | join(', ') }});""", undefined=jinja2.StrictUndefined) + + def __init__(self, function_name, args): + super().__init__({}) + self.ast_dict.update({ + 'function_name': function_name, + 'args': [f'{self._print(a.dtype)} {self._print(a)}' for a in args] + }) + + @property + def function_name(self): + return self.ast_dict.function_name + + +class CustomFunctionCall(JinjaCppFile): + TEMPLATE = jinja2.Template("""{{function_name}}({{ args | join(', ') }});""", undefined=jinja2.StrictUndefined) + + def __init__(self, function_name, *args, fields_accessed=[]): + ast_dict = { + 'function_name': function_name, + 'args': args, + 'fields_accessed': [f.center for f in fields_accessed] + } + super().__init__(ast_dict) + self.required_global_declarations = [CustomFunctionDeclaration(self.ast_dict.function_name, self.ast_dict.args)] + + @property + def symbols_defined(self): + return set(self.ast_dict.fields_accessed) diff --git a/src/pystencils_autodiff/framework_integration/printer.py b/src/pystencils_autodiff/framework_integration/printer.py index f038f85..36477c9 100644 --- a/src/pystencils_autodiff/framework_integration/printer.py +++ b/src/pystencils_autodiff/framework_integration/printer.py @@ -19,6 +19,8 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend): super().__init__(dialect='c') self.sympy_printer.__class__._print_DynamicFunction = self._print_DynamicFunction self.sympy_printer.__class__._print_MeshNormalFunctor = self._print_DynamicFunction + self.sympy_printer.__class__._print_MatrixElement = self._print_MatrixElement + self.sympy_printer.__class__._print_TypedMatrixElement = self._print_MatrixElement def _print(self, node): from pystencils_autodiff.framework_integration.astnodes import JinjaCppFile @@ -126,6 +128,15 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend): arg_str = ', '.join(self._print(a) for a in expr.args[2:]) return f'{name}({arg_str})' + def _print_MatrixElement(self, expr): + name = expr.name + if expr.args[0].args[1] == 1 or expr.args[0].args[2] == 1 or (hasattr(expr.args[0], 'linear_indexing') + and expr.args[0].linear_indexing): + return f'{name}[{self._print(expr.args[1])}]' + else: + arg_str = ', '.join(self._print(a) for a in expr.args[1:]) + return f'{name}({arg_str})' + def _print_CustomCodeNode(self, node): super_code = super()._print_CustomCodeNode(node) if super_code: diff --git a/src/pystencils_autodiff/framework_integration/texture_astnodes.py b/src/pystencils_autodiff/framework_integration/texture_astnodes.py index 21e43f2..8598426 100644 --- a/src/pystencils_autodiff/framework_integration/texture_astnodes.py +++ b/src/pystencils_autodiff/framework_integration/texture_astnodes.py @@ -168,7 +168,7 @@ std::shared_ptr<void> {{array}}Destroyer(nullptr, [&](...){ texture_object='tex_' + texture_name, array='array_' + texture_name, texture_name=texture_name, - texture_namespace=self._texture_namespace + '::', + texture_namespace=self._texture_namespace + '::' if self._texture_namespace else '', ndim=self._ndim, device_ptr=self._device_ptr, create_array=self._get_create_array_call(), diff --git a/src/pystencils_autodiff/framework_integration/types.py b/src/pystencils_autodiff/framework_integration/types.py index f392934..2f96ef9 100644 --- a/src/pystencils_autodiff/framework_integration/types.py +++ b/src/pystencils_autodiff/framework_integration/types.py @@ -16,3 +16,12 @@ class TemplateType(Type): def _sympystr(self, *args, **kwargs): return str(self._name) + + +class CustomCppType(Type): + + def __init__(self, name): + self._name = name + + def _sympystr(self, *args, **kwargs): + return str(self._name) diff --git a/src/pystencils_autodiff/graph_datahandling.py b/src/pystencils_autodiff/graph_datahandling.py index bc04f5e..27fc13b 100644 --- a/src/pystencils_autodiff/graph_datahandling.py +++ b/src/pystencils_autodiff/graph_datahandling.py @@ -15,7 +15,6 @@ import numpy as np import pystencils.datahandling import pystencils.kernel_wrapper import pystencils.timeloop -from pystencils.data_types import create_type from pystencils.field import FieldType @@ -62,6 +61,19 @@ class DataTransfer: return f'DataTransferKind: {self.kind} with {self.field}' +class GhostTensorExtraction: + def __init__(self, field: pystencils.Field, on_gpu: bool, with_ghost_layers=False): + self.field = field + self.on_gpu = on_gpu + self.with_ghost_layers = with_ghost_layers + + def __str__(self): + return f'GhostTensorExtraction: {self.field}, on_gpu: {self.on_gpu}' + + def __repr__(self): + return self.__str__(self) + + class Swap(DataTransfer): def __init__(self, source, destination, gpu): self.kind = DataTransferKind.DEVICE_SWAP if gpu else DataTransferKind.HOST_SWAP @@ -88,6 +100,13 @@ class KernelCall: def __str__(self): return "Call " + str(self.kernel.ast.function_name) + @property + def ast(self): + if isinstance(self.kernel, pystencils.astnodes.Node): + return self.kernel + else: + return self.kernel.ast + class FieldOutput: def __init__(self, fields, output_path, flag_field): @@ -321,6 +340,11 @@ class GraphDataHandling(pystencils.datahandling.SerialDataHandling): fields = [self.fields[f] if isinstance(f, str) else f for f in fields] self.call_queue.append(FieldOutput(fields, output_path, flag_field)) + def extract_tensor(self, field, on_gpu, with_ghost_layers=False): + if isinstance(field, str): + field = self.fields[field] + self.call_queue.append(GhostTensorExtraction(field, on_gpu, with_ghost_layers)) + # TODO # def reduce_float_sequence(self, sequence, operation, all_reduce=False) -> np.array: # return np.array(sequence) diff --git a/src/pystencils_autodiff/tensorflow_jit.py b/src/pystencils_autodiff/tensorflow_jit.py index caac7bc..c8d2a17 100644 --- a/src/pystencils_autodiff/tensorflow_jit.py +++ b/src/pystencils_autodiff/tensorflow_jit.py @@ -151,7 +151,7 @@ def compile_file(file, command_prefix = [nvcc or NVCC_BINARY, '--expt-relaxed-constexpr', '-ccbin', - get_compiler_config()['tensorflow_host_compiler'], + get_compiler_config()['command'], '-Xcompiler', get_compiler_config()['flags'].replace('c++11', 'c++14'), *_nvcc_flags, diff --git a/tests/test_dynamic_function.py b/tests/test_dynamic_function.py index 0c0a765..9fb7e8d 100644 --- a/tests/test_dynamic_function.py +++ b/tests/test_dynamic_function.py @@ -5,7 +5,7 @@ from pystencils.data_types import TypedSymbol, create_type from pystencils_autodiff.framework_integration.astnodes import DynamicFunction from pystencils_autodiff.framework_integration.printer import ( DebugFrameworkPrinter, FrameworkIntegrationPrinter) -from pystencils_autodiff.framework_integration.types import TemplateType +from pystencils_autodiff.framework_integration.types import CustomCppType, TemplateType def test_dynamic_function(): @@ -55,14 +55,28 @@ def test_dynamic_matrix(): pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter()) +def test_typed_matrix(): + x, y = pystencils.fields('x, y: float32[3d]') + from pystencils.data_types import TypedMatrixSymbol + + A = TypedMatrixSymbol('A', 3, 1, create_type('double'), CustomCppType('Vector3<real_t>')) + + assignments = pystencils.AssignmentCollection({ + y.center: A[0] + A[1] + A[2] + }) + + ast = pystencils.create_kernel(assignments) + pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter()) + + def test_dynamic_matrix_location_dependent(): x, y = pystencils.fields('x, y: float32[3d]') from pystencils.data_types import TypedMatrixSymbol - A = TypedMatrixSymbol('A', 3, 1, create_type('double'), 'Vector3<double>') + A = TypedMatrixSymbol('A', 3, 1, create_type('double'), CustomCppType('Vector3<double>')) my_fun_call = DynamicFunction(TypedSymbol('my_fun', - 'std::function< Vector3 < double >(int, int, int) >'), + 'std::function<Vector3<double>(int, int, int)>'), A.dtype, *pystencils.x_vector(3)) -- GitLab