Skip to content
Snippets Groups Projects
Commit 15f1af31 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Prepare codegen API for non-Torch usage

parent 6d7364f4
No related merge requests found
...@@ -154,7 +154,19 @@ except Exception as e: ...@@ -154,7 +154,19 @@ except Exception as e:
warnings.warn(str(e)) warnings.warn(str(e))
def generate_shared_object(output_folder=None, source_files=None, show_code=False, framework_module_class=TorchModule): def get_pyronn_cuda_kernels():
return glob(join(dirname(__file__), 'PYRO-NN-Layers', '*.cu.cc'))
def get_pyronn_include_paths():
return join(dirname(__file__), 'PYRO-NN-Layers')
def generate_shared_object(output_folder=None,
source_files=None,
show_code=False,
framework_module_class=TorchModule,
generate_code_only=False):
object_cache = get_cache_config()['object_cache'] object_cache = get_cache_config()['object_cache']
...@@ -182,6 +194,9 @@ def generate_shared_object(output_folder=None, source_files=None, show_code=Fals ...@@ -182,6 +194,9 @@ def generate_shared_object(output_folder=None, source_files=None, show_code=Fals
if show_code: if show_code:
pystencils.show_code(module, custom_backend=FrameworkIntegrationPrinter()) pystencils.show_code(module, custom_backend=FrameworkIntegrationPrinter())
if generate_code_only:
return module
extension = module.compile(extra_source_files=cuda_sources, extension = module.compile(extra_source_files=cuda_sources,
extra_cuda_flags=['-arch=sm_35'], extra_cuda_flags=['-arch=sm_35'],
with_cuda=True, with_cuda=True,
...@@ -207,7 +222,7 @@ def compile_shared_object(output_folder=None, source_files=None): ...@@ -207,7 +222,7 @@ def compile_shared_object(output_folder=None, source_files=None):
output_folder = dirname(__file__) output_folder = dirname(__file__)
if not source_files: if not source_files:
source_files = glob(join(dirname(__file__), 'PYRO-NN-Layers', '*.cu.cc')) source_files = get_pyronn_cuda_kernels()
generated_file = join(dirname(__file__), 'pyronn_torch.cpp') generated_file = join(dirname(__file__), 'pyronn_torch.cpp')
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment