Skip to content
Snippets Groups Projects
Commit 1c955a6a authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Regard constant_fields in ops hash in _torch_native

parent 943505f7
No related branches found
No related tags found
No related merge requests found
Pipeline #21053 failed
......@@ -19,15 +19,16 @@ def create_autograd_function(autodiff_obj, use_cuda):
if use_cuda:
forward_ast = autodiff_obj.forward_ast_gpu
backward_ast = autodiff_obj.backward_ast_gpu
backward_ast = autodiff_obj.backward_ast_gpu if autodiff_obj.backward_output_fields else None
else:
forward_ast = autodiff_obj.forward_ast_cpu
backward_ast = autodiff_obj.backward_ast_cpu
backward_ast = autodiff_obj.backward_ast_cpu if autodiff_obj.backward_output_fields else None
op_name = f'{autodiff_obj.op_name}_{_hash((str(pystencils.show_code(forward_ast))+ str(autodiff_obj)).encode()).hexdigest()}' # noqa
op_name = f'{autodiff_obj.op_name}_{_hash((str(pystencils.show_code(forward_ast)) + str(autodiff_obj)+str(autodiff_obj.constant_fields)).encode()).hexdigest()}' # noqa
forward_ast.function_name = f'{op_name}_{forward_ast.function_name}'
backward_ast.function_name = f'{op_name}_{backward_ast.function_name}'
module = TorchModule(op_name, [forward_ast, backward_ast])
if backward_ast:
backward_ast.function_name = f'{op_name}_{backward_ast.function_name}'
module = TorchModule(op_name, [forward_ast, backward_ast] if backward_ast else [forward_ast])
compiled_op = module.compile()
# print(TorchModule(op_name, [forward_ast, backward_ast]))
......@@ -103,16 +104,26 @@ def create_autograd_function(autodiff_obj, use_cuda):
return tuple(backward_output_tensors.values())
def call(self, **kwargs):
rtn = self.apply(*[kwargs[p.symbol.name] for p in self.forward_parameters])
if len(rtn) == 1:
rtn = rtn[0]
return rtn
cls = type(op_name, (torch.autograd.Function,), {})
cls.class_kwargs = class_kwargs
cls.forward = forward
cls.backward = backward
cls.kernel = forward
cls.ast = module
cls.parameters = forward_ast.get_parameters()
cls.parameters = [f for f in module.kernel_wrappers
if f.function_name == "call_" + forward_ast.function_name][0].get_parameters()
cls.forward_parameters = [p for p in cls.parameters if p.symbol.name in [
f.name for f in autodiff_obj.forward_input_fields]]
cls.forward_ast = forward_ast
cls.backward_ast = backward_ast
cls.num_regs = None
cls.call = call
cls.code = str(module)
return cls
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment