diff --git a/codegen/generate_wrappers.py b/codegen/generate_wrappers.py index 350d5927e03854461b9031dfcafa133b5a15083b..d82b487667d592d0647d13008851f8b02dc2e7fe 100644 --- a/codegen/generate_wrappers.py +++ b/codegen/generate_wrappers.py @@ -54,7 +54,6 @@ void Cone_Backprojection3D_Kernel_Launcher(const float *sinogram_ptr, float *out const float volume_origin_x, const float volume_origin_y, const float volume_origin_z, const int detector_width, const int detector_height, const float projection_multiplier); """), # noqa - 'Cone_Projection_Kernel_Launcher': CustomFunctionCall('Cone_Projection_Kernel_Launcher', FieldPointerSymbol(volume.name, volume.dtype, const=True), FieldPointerSymbol(projection.name, projection.dtype, const=False), @@ -74,7 +73,30 @@ void Cone_Projection_Kernel_Launcher(const float* volume_ptr, float *out, const const int number_of_projections, const int volume_width, const int volume_height, const int volume_depth, const float volume_spacing_x, const float volume_spacing_y, const float volume_spacing_z, const int detector_width, const int detector_height, const float step_size); -""") # noqa +"""), # noqa +'Cone_Projection_Kernel_Tex_Interp_Launcher': CustomFunctionCall('Cone_Projection_Kernel_Tex_Interp_Launcher', + FieldPointerSymbol(volume.name, volume.dtype, const=True), + FieldPointerSymbol(projection.name, projection.dtype, const=False), + FieldPointerSymbol(inv_matrices.name, + inv_matrices.dtype, const=True), + FieldPointerSymbol(source_points.name, + source_points.dtype, const=True), + FieldShapeSymbol([source_points.name], 0), + *[FieldShapeSymbol(['volume'], i) for i in range(2, -1, -1)], + TypedSymbol('volume_spacing_x', create_type('float32'), const=True), + TypedSymbol('volume_spacing_y', create_type('float32'), const=True), + TypedSymbol('volume_spacing_z', create_type('float32'), const=True), + *[FieldShapeSymbol(['projection'], i) for i in range(1, -1, -1)], + TypedSymbol('step_size', create_type('float32'), const=True), + fields_accessed=[volume, projection, inv_matrices, source_points], custom_signature=""" +void Cone_Projection_Kernel_Tex_Interp_Launcher( + const float *__restrict__ volume_ptr, float *out, + const float *inv_AR_matrix, const float *src_points, + const int number_of_projections, const int volume_width, + const int volume_height, const int volume_depth, + const float volume_spacing_x, const float volume_spacing_y, + const float volume_spacing_z, const int detector_width, + const int detector_height, const float step_size);"""), # noqa } @@ -94,8 +116,8 @@ def main(): rmtree(join(object_cache, module_name, 'helper_headers')) copytree(join(dirname(__file__), '..', 'src', 'pyronn_torch', 'PYRO-NN-Layers', 'helper_headers'), join(object_cache, module_name, 'helper_headers')) - for s in args.source_files: + for s in args.source_files: dst = join(object_cache, module_name, basename(s).replace('.cu.cc', '.cu')) copyfile(s, dst) # Torch only accepts *.cu as CUDA cuda_sources.append(dst)