diff --git a/codegen/generate_wrappers.py b/codegen/generate_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..350d5927e03854461b9031dfcafa133b5a15083b --- /dev/null +++ b/codegen/generate_wrappers.py @@ -0,0 +1,116 @@ +# +# Copyright © 2020 Stephan Seitz <stephan.seitz@fau.de> +# +# Distributed under terms of the GPLv3 license. + +""" + +""" + + +import argparse +from glob import glob +from os import makedirs +from os.path import basename, dirname, join +from shutil import copyfile, copytree, rmtree + +import pystencils +from pystencils.astnodes import Block +from pystencils.cpu.cpujit import get_cache_config +from pystencils.data_types import TypedSymbol, create_type +from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol +from pystencils_autodiff.backends.astnodes import TorchModule +from pystencils_autodiff.framework_integration.astnodes import ( + CustomFunctionCall, WrapperFunction) +from pystencils_autodiff.framework_integration.printer import \ + FrameworkIntegrationPrinter + +volume = pystencils.fields('volume: float32[3d]') +projection = pystencils.fields('projection: float32[2d]') +projection_matrices = pystencils.fields('matrices: float32[3d]') +inv_matrices = pystencils.fields('inv_matrices: float32[3d]') +source_points = pystencils.fields('source_points: float32[1d]') + +FUNCTIONS = { + 'Cone_Backprojection3D_Kernel_Launcher': CustomFunctionCall('Cone_Backprojection3D_Kernel_Launcher', + FieldPointerSymbol(projection.name, projection.dtype, const=True), + FieldPointerSymbol(volume.name, volume.dtype, const=False), + FieldPointerSymbol(projection_matrices.name, + projection_matrices.dtype, const=True), + FieldShapeSymbol(['matrices'], 0), + *[FieldShapeSymbol(['volume'], i) for i in range(2, -1, -1)], + TypedSymbol('volume_spacing_x', create_type('float32'), const=True), + TypedSymbol('volume_spacing_y', create_type('float32'), const=True), + TypedSymbol('volume_spacing_z', create_type('float32'), const=True), + TypedSymbol('volume_origin_x', create_type('float32'), const=True), + TypedSymbol('volume_origin_y', create_type('float32'), const=True), + TypedSymbol('volume_origin_z', create_type('float32'), const=True), + *[FieldShapeSymbol(['projection'], i) for i in range(1, -1, -1)], + TypedSymbol('projection_multiplier', create_type('float32'), const=True), + fields_accessed=[volume, projection, projection_matrices], custom_signature=""" +void Cone_Backprojection3D_Kernel_Launcher(const float *sinogram_ptr, float *out, const float *projection_matrices, const int number_of_projections, + const int volume_width, const int volume_height, const int volume_depth, + const float volume_spacing_x, const float volume_spacing_y, const float volume_spacing_z, + const float volume_origin_x, const float volume_origin_y, const float volume_origin_z, + const int detector_width, const int detector_height, const float projection_multiplier); +"""), # noqa + +'Cone_Projection_Kernel_Launcher': CustomFunctionCall('Cone_Projection_Kernel_Launcher', + FieldPointerSymbol(volume.name, volume.dtype, const=True), + FieldPointerSymbol(projection.name, projection.dtype, const=False), + FieldPointerSymbol(inv_matrices.name, + inv_matrices.dtype, const=True), + FieldPointerSymbol(source_points.name, + source_points.dtype, const=True), + FieldShapeSymbol([source_points.name], 0), + *[FieldShapeSymbol(['volume'], i) for i in range(2, -1, -1)], + TypedSymbol('volume_spacing_x', create_type('float32'), const=True), + TypedSymbol('volume_spacing_y', create_type('float32'), const=True), + TypedSymbol('volume_spacing_z', create_type('float32'), const=True), + *[FieldShapeSymbol(['projection'], i) for i in range(1, -1, -1)], + TypedSymbol('step_size', create_type('float32'), const=True), + fields_accessed=[volume, projection, inv_matrices, source_points], custom_signature=""" +void Cone_Projection_Kernel_Launcher(const float* volume_ptr, float *out, const float *inv_AR_matrix, const float *src_points, + const int number_of_projections, const int volume_width, const int volume_height, const int volume_depth, + const float volume_spacing_x, const float volume_spacing_y, const float volume_spacing_z, + const int detector_width, const int detector_height, const float step_size); +""") # noqa + } + + +def main(): + + parser = argparse.ArgumentParser() + parser.add_argument('--output-folder', default=join(dirname(__file__), '..', 'src', 'pyronn_torch')) + parser.add_argument('--source-files', default=glob(join(dirname(__file__), + '..', 'src', 'pyronn_torch', 'PYRO-NN-Layers', '*.cu.cc'))) + args = parser.parse_args() + object_cache = get_cache_config()['object_cache'] + + module_name = 'PYRO_NN' + + cuda_sources = [] + makedirs(join(object_cache, module_name), exist_ok=True) + rmtree(join(object_cache, module_name, 'helper_headers')) + copytree(join(dirname(__file__), '..', 'src', 'pyronn_torch', + 'PYRO-NN-Layers', 'helper_headers'), join(object_cache, module_name, 'helper_headers')) + for s in args.source_files: + + dst = join(object_cache, module_name, basename(s).replace('.cu.cc', '.cu')) + copyfile(s, dst) # Torch only accepts *.cu as CUDA + cuda_sources.append(dst) + + functions = [WrapperFunction(Block([v]), function_name=k) for k, v in FUNCTIONS.items()] + module = TorchModule(module_name, functions, wrap_wrapper_functions=True) + + pystencils.show_code(module, custom_backend=FrameworkIntegrationPrinter()) + + extension = module.compile(extra_source_files=cuda_sources, extra_cuda_flags=['-arch=sm_35'], with_cuda=True) + + for v in extension.__dict__.values(): + if hasattr(v, '__doc__'): + print(v.__doc__) + + +if __name__ == '__main__': + main() diff --git a/src/pyronn_torch/PYRO-NN-Layers b/src/pyronn_torch/PYRO-NN-Layers index f74139b7e3211f54f57e8264d5e62850409e7c81..32dd8e7e0f43fc80a5ef905535c5212dea0c0773 160000 --- a/src/pyronn_torch/PYRO-NN-Layers +++ b/src/pyronn_torch/PYRO-NN-Layers @@ -1 +1 @@ -Subproject commit f74139b7e3211f54f57e8264d5e62850409e7c81 +Subproject commit 32dd8e7e0f43fc80a5ef905535c5212dea0c0773 diff --git a/src/pyronn_torch/__init__.py b/src/pyronn_torch/__init__.py index c121499895045854f83f8af4a39405f9d96d5137..7452884d6cd15d8045be2d8bca61ca6f2ea182ba 100644 --- a/src/pyronn_torch/__init__.py +++ b/src/pyronn_torch/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- -from pkg_resources import get_distribution, DistributionNotFound + +from pkg_resources import DistributionNotFound, get_distribution try: # Change here if project is renamed and does not equal the package name @@ -9,3 +10,4 @@ except DistributionNotFound: __version__ = 'unknown' finally: del get_distribution, DistributionNotFound +