diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index e35d1454d537eeb129be7a57f2ab02c952d95cfc..10228c8b776ed9b157c3d0e2712b9235ffe2253b 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -108,12 +108,14 @@ class TorchModule(JinjaCppFile): :param forward_kernel_ast: one or more kernel ASTs (can have any C dialect) :param backward_kernel_ast: """ + from pystencils_autodiff.framework_integration.astnodes import CustomFunctionCall + if not isinstance(kernel_asts, Iterable): kernel_asts = [kernel_asts] wrapper_functions = [self.generate_wrapper_function(k) - if not isinstance(k, WrapperFunction) or wrap_wrapper_functions + if not isinstance(k, WrapperFunction) else k for k in kernel_asts] - kernel_asts = [k for k in kernel_asts if not isinstance(k, WrapperFunction)] + kernel_asts = [k for k in kernel_asts if not isinstance(k, (WrapperFunction, CustomFunctionCall))] self.module_name = module_name self.compiled_file = None @@ -137,7 +139,12 @@ class TorchModule(JinjaCppFile): return WrapperFunction(cls.DESTRUCTURING_CLASS(generate_kernel_call(kernel_ast)), function_name='call_' + kernel_ast.function_name) - def compile(self, extra_source_files=[], extra_cuda_flags=[], with_cuda=None, build_dir=None): + def compile(self, + extra_source_files=[], + extra_cuda_flags=[], + with_cuda=None, + build_dir=None, + compile_module_name=None): from torch.utils.cpp_extension import load file_extension = '.cu' if self.is_cuda else '.cpp' source_code = str(self) @@ -147,7 +154,7 @@ class TorchModule(JinjaCppFile): os.makedirs(build_dir, exist_ok=True) file_name = join(build_dir, f'{hash}{file_extension}') - self.compiled_file = file_name + self.compiled_file = (join(build_dir, compile_module_name) or file_name).replace('.cpp', '') + '.so' if not exists(file_name): write_file(file_name, source_code) @@ -155,7 +162,7 @@ class TorchModule(JinjaCppFile): # Torch regards CXX os.environ['CXX'] = get_compiler_config()['command'] - torch_extension = load(hash, + torch_extension = load(compile_module_name or hash, [file_name] + extra_source_files, with_cuda=self.is_cuda or with_cuda, extra_cflags=['--std=c++14', get_compiler_config() diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py index 75dd5a1ca6b1d0484533d3264a11b94f2928081d..b7c78dd7d0ee2167021381c0ab53130a0fefde8e 100644 --- a/src/pystencils_autodiff/framework_integration/astnodes.py +++ b/src/pystencils_autodiff/framework_integration/astnodes.py @@ -43,10 +43,11 @@ class DestructuringBindingsForFieldClass(Node): """Set of Field instances: fields which are accessed inside this kernel function""" from pystencils.interpolation_astnodes import InterpolatorAccess - return (set(o.field for o in self.atoms(ResolvedFieldAccess) - | self.atoms(InterpolatorAccess)) - | set(itertools.chain.from_iterable((k.kernel_function.fields_accessed - for k in self.atoms(FunctionCall))))) + return set(itertools.chain.from_iterable(((a.field for a in self.atoms(pystencils.Field.Access)), + (a.field for a in self.atoms(InterpolatorAccess)), + (a.field for a in self.atoms(ResolvedFieldAccess))),)) \ + | set(itertools.chain.from_iterable((k.kernel_function.fields_accessed + for k in self.atoms(FunctionCall)))) def __init__(self, body): super(DestructuringBindingsForFieldClass, self).__init__() @@ -133,6 +134,9 @@ class WrapperFunction(pystencils.astnodes.KernelFunction): def generate_kernel_call(kernel_function): + if isinstance(kernel_function, CustomFunctionCall): + return kernel_function + from pystencils.interpolation_astnodes import InterpolatorAccess from pystencils.kernelparameters import FieldPointerSymbol @@ -152,6 +156,7 @@ def generate_kernel_call(kernel_function): FunctionCall(kernel_function), CudaErrorCheck(), ]) + elif kernel_function.backend == 'gpucuda': return pystencils.astnodes.Block([CudaErrorCheck(), FunctionCall(kernel_function), @@ -385,3 +390,11 @@ class CustomFunctionCall(JinjaCppFile): @property def symbols_defined(self): return set(self.ast_dict.fields_accessed) + + @property + def function_name(self): + return self.ast_dict.function_name + + @property + def undefined_symbols(self): + return set(self.ast_dict.args)