diff --git a/src/pystencils_autodiff/backends/_torch_native.py b/src/pystencils_autodiff/backends/_torch_native.py index 9fae28c5f8ebf28bca5f5f95e4c10d04b20edd9b..48417af13d375c9c624d7b783bb207b437a07d5d 100644 --- a/src/pystencils_autodiff/backends/_torch_native.py +++ b/src/pystencils_autodiff/backends/_torch_native.py @@ -19,15 +19,16 @@ def create_autograd_function(autodiff_obj, use_cuda): if use_cuda: forward_ast = autodiff_obj.forward_ast_gpu - backward_ast = autodiff_obj.backward_ast_gpu + backward_ast = autodiff_obj.backward_ast_gpu if autodiff_obj.backward_output_fields else None else: forward_ast = autodiff_obj.forward_ast_cpu - backward_ast = autodiff_obj.backward_ast_cpu + backward_ast = autodiff_obj.backward_ast_cpu if autodiff_obj.backward_output_fields else None - op_name = f'{autodiff_obj.op_name}_{_hash((str(pystencils.show_code(forward_ast))+ str(autodiff_obj)).encode()).hexdigest()}' # noqa + op_name = f'{autodiff_obj.op_name}_{_hash((str(pystencils.show_code(forward_ast)) + str(autodiff_obj)+str(autodiff_obj.constant_fields)).encode()).hexdigest()}' # noqa forward_ast.function_name = f'{op_name}_{forward_ast.function_name}' - backward_ast.function_name = f'{op_name}_{backward_ast.function_name}' - module = TorchModule(op_name, [forward_ast, backward_ast]) + if backward_ast: + backward_ast.function_name = f'{op_name}_{backward_ast.function_name}' + module = TorchModule(op_name, [forward_ast, backward_ast] if backward_ast else [forward_ast]) compiled_op = module.compile() # print(TorchModule(op_name, [forward_ast, backward_ast])) @@ -103,16 +104,26 @@ def create_autograd_function(autodiff_obj, use_cuda): return tuple(backward_output_tensors.values()) + def call(self, **kwargs): + rtn = self.apply(*[kwargs[p.symbol.name] for p in self.forward_parameters]) + if len(rtn) == 1: + rtn = rtn[0] + return rtn + cls = type(op_name, (torch.autograd.Function,), {}) cls.class_kwargs = class_kwargs cls.forward = forward cls.backward = backward cls.kernel = forward cls.ast = module - cls.parameters = forward_ast.get_parameters() + cls.parameters = [f for f in module.kernel_wrappers + if f.function_name == "call_" + forward_ast.function_name][0].get_parameters() + cls.forward_parameters = [p for p in cls.parameters if p.symbol.name in [ + f.name for f in autodiff_obj.forward_input_fields]] cls.forward_ast = forward_ast cls.backward_ast = backward_ast cls.num_regs = None + cls.call = call cls.code = str(module) return cls