Skip to content
Snippets Groups Projects
Commit bdd0e1e6 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Allow sole WrapperFunctions in TorchModule

parent ef057629
Branches
Tags
No related merge requests found
Pipeline #19193 failed
......@@ -55,7 +55,7 @@ class PybindArrayDestructuring(DestructuringBindingsForFieldClass):
CLASS_TO_MEMBER_DICT = {
FieldPointerSymbol: "mutable_data()",
FieldShapeSymbol: "shape({dim})",
FieldStrideSymbol: "strides({dim})"
FieldStrideSymbol: "strides({dim}) / sizeof({dtype})"
}
CLASS_NAME_TEMPLATE = "pybind11::array_t<{dtype}>"
......@@ -77,7 +77,10 @@ class TorchModule(JinjaCppFile):
"""
if not isinstance(kernel_asts, Iterable):
kernel_asts = [kernel_asts]
wrapper_functions = [self.generate_wrapper_function(k) for k in kernel_asts]
wrapper_functions = [self.generate_wrapper_function(k)
if not isinstance(k, WrapperFunction)
else k for k in kernel_asts]
kernel_asts = [k for k in kernel_asts if not isinstance(k, WrapperFunction)]
self.module_name = module_name
ast_dict = {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment