Skip to content
Snippets Groups Projects
Commit 15f1af31 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Prepare codegen API for non-Torch usage

parent 6d7364f4
No related branches found
No related tags found
No related merge requests found
......@@ -154,7 +154,19 @@ except Exception as e:
warnings.warn(str(e))
def generate_shared_object(output_folder=None, source_files=None, show_code=False, framework_module_class=TorchModule):
def get_pyronn_cuda_kernels():
return glob(join(dirname(__file__), 'PYRO-NN-Layers', '*.cu.cc'))
def get_pyronn_include_paths():
return join(dirname(__file__), 'PYRO-NN-Layers')
def generate_shared_object(output_folder=None,
source_files=None,
show_code=False,
framework_module_class=TorchModule,
generate_code_only=False):
object_cache = get_cache_config()['object_cache']
......@@ -182,6 +194,9 @@ def generate_shared_object(output_folder=None, source_files=None, show_code=Fals
if show_code:
pystencils.show_code(module, custom_backend=FrameworkIntegrationPrinter())
if generate_code_only:
return module
extension = module.compile(extra_source_files=cuda_sources,
extra_cuda_flags=['-arch=sm_35'],
with_cuda=True,
......@@ -207,7 +222,7 @@ def compile_shared_object(output_folder=None, source_files=None):
output_folder = dirname(__file__)
if not source_files:
source_files = glob(join(dirname(__file__), 'PYRO-NN-Layers', '*.cu.cc'))
source_files = get_pyronn_cuda_kernels()
generated_file = join(dirname(__file__), 'pyronn_torch.cpp')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment