diff --git a/src/pyronn_torch/codegen.py b/src/pyronn_torch/codegen.py index 0baa4223634c030246b9ef4c74d454d0bdb53d89..47e00fa8b80cafa9daf013ffedf6659e7f6e4c43 100644 --- a/src/pyronn_torch/codegen.py +++ b/src/pyronn_torch/codegen.py @@ -154,7 +154,19 @@ except Exception as e: warnings.warn(str(e)) -def generate_shared_object(output_folder=None, source_files=None, show_code=False, framework_module_class=TorchModule): +def get_pyronn_cuda_kernels(): + return glob(join(dirname(__file__), 'PYRO-NN-Layers', '*.cu.cc')) + + +def get_pyronn_include_paths(): + return join(dirname(__file__), 'PYRO-NN-Layers') + + +def generate_shared_object(output_folder=None, + source_files=None, + show_code=False, + framework_module_class=TorchModule, + generate_code_only=False): object_cache = get_cache_config()['object_cache'] @@ -182,6 +194,9 @@ def generate_shared_object(output_folder=None, source_files=None, show_code=Fals if show_code: pystencils.show_code(module, custom_backend=FrameworkIntegrationPrinter()) + if generate_code_only: + return module + extension = module.compile(extra_source_files=cuda_sources, extra_cuda_flags=['-arch=sm_35'], with_cuda=True, @@ -207,7 +222,7 @@ def compile_shared_object(output_folder=None, source_files=None): output_folder = dirname(__file__) if not source_files: - source_files = glob(join(dirname(__file__), 'PYRO-NN-Layers', '*.cu.cc')) + source_files = get_pyronn_cuda_kernels() generated_file = join(dirname(__file__), 'pyronn_torch.cpp')