From d05067ad75d37d9b169c309e80385c2de3e01645 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Wed, 26 Feb 2020 13:33:02 +0100 Subject: [PATCH] Make generate_kernel_call regard CustomFunctionCall --- src/pystencils_autodiff/backends/astnodes.py | 17 ++++++++++----- .../framework_integration/astnodes.py | 21 +++++++++++++++---- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index e35d145..10228c8 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -108,12 +108,14 @@ class TorchModule(JinjaCppFile): :param forward_kernel_ast: one or more kernel ASTs (can have any C dialect) :param backward_kernel_ast: """ + from pystencils_autodiff.framework_integration.astnodes import CustomFunctionCall + if not isinstance(kernel_asts, Iterable): kernel_asts = [kernel_asts] wrapper_functions = [self.generate_wrapper_function(k) - if not isinstance(k, WrapperFunction) or wrap_wrapper_functions + if not isinstance(k, WrapperFunction) else k for k in kernel_asts] - kernel_asts = [k for k in kernel_asts if not isinstance(k, WrapperFunction)] + kernel_asts = [k for k in kernel_asts if not isinstance(k, (WrapperFunction, CustomFunctionCall))] self.module_name = module_name self.compiled_file = None @@ -137,7 +139,12 @@ class TorchModule(JinjaCppFile): return WrapperFunction(cls.DESTRUCTURING_CLASS(generate_kernel_call(kernel_ast)), function_name='call_' + kernel_ast.function_name) - def compile(self, extra_source_files=[], extra_cuda_flags=[], with_cuda=None, build_dir=None): + def compile(self, + extra_source_files=[], + extra_cuda_flags=[], + with_cuda=None, + build_dir=None, + compile_module_name=None): from torch.utils.cpp_extension import load file_extension = '.cu' if self.is_cuda else '.cpp' source_code = str(self) @@ -147,7 +154,7 @@ class TorchModule(JinjaCppFile): os.makedirs(build_dir, exist_ok=True) file_name = join(build_dir, f'{hash}{file_extension}') - self.compiled_file = file_name + self.compiled_file = (join(build_dir, compile_module_name) or file_name).replace('.cpp', '') + '.so' if not exists(file_name): write_file(file_name, source_code) @@ -155,7 +162,7 @@ class TorchModule(JinjaCppFile): # Torch regards CXX os.environ['CXX'] = get_compiler_config()['command'] - torch_extension = load(hash, + torch_extension = load(compile_module_name or hash, [file_name] + extra_source_files, with_cuda=self.is_cuda or with_cuda, extra_cflags=['--std=c++14', get_compiler_config() diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py index 75dd5a1..b7c78dd 100644 --- a/src/pystencils_autodiff/framework_integration/astnodes.py +++ b/src/pystencils_autodiff/framework_integration/astnodes.py @@ -43,10 +43,11 @@ class DestructuringBindingsForFieldClass(Node): """Set of Field instances: fields which are accessed inside this kernel function""" from pystencils.interpolation_astnodes import InterpolatorAccess - return (set(o.field for o in self.atoms(ResolvedFieldAccess) - | self.atoms(InterpolatorAccess)) - | set(itertools.chain.from_iterable((k.kernel_function.fields_accessed - for k in self.atoms(FunctionCall))))) + return set(itertools.chain.from_iterable(((a.field for a in self.atoms(pystencils.Field.Access)), + (a.field for a in self.atoms(InterpolatorAccess)), + (a.field for a in self.atoms(ResolvedFieldAccess))),)) \ + | set(itertools.chain.from_iterable((k.kernel_function.fields_accessed + for k in self.atoms(FunctionCall)))) def __init__(self, body): super(DestructuringBindingsForFieldClass, self).__init__() @@ -133,6 +134,9 @@ class WrapperFunction(pystencils.astnodes.KernelFunction): def generate_kernel_call(kernel_function): + if isinstance(kernel_function, CustomFunctionCall): + return kernel_function + from pystencils.interpolation_astnodes import InterpolatorAccess from pystencils.kernelparameters import FieldPointerSymbol @@ -152,6 +156,7 @@ def generate_kernel_call(kernel_function): FunctionCall(kernel_function), CudaErrorCheck(), ]) + elif kernel_function.backend == 'gpucuda': return pystencils.astnodes.Block([CudaErrorCheck(), FunctionCall(kernel_function), @@ -385,3 +390,11 @@ class CustomFunctionCall(JinjaCppFile): @property def symbols_defined(self): return set(self.ast_dict.fields_accessed) + + @property + def function_name(self): + return self.ast_dict.function_name + + @property + def undefined_symbols(self): + return set(self.ast_dict.args) -- GitLab