diff --git a/src/pystencils_autodiff/_autodiff.py b/src/pystencils_autodiff/_autodiff.py index b182280ae67dc82a2b7aeebd5991d1ad1804631f..5032b5f8d3a1cdf60b7ac5b451a3d97b77eabee6 100644 --- a/src/pystencils_autodiff/_autodiff.py +++ b/src/pystencils_autodiff/_autodiff.py @@ -15,6 +15,7 @@ from pystencils_autodiff.backends import AVAILABLE_BACKENDS from pystencils_autodiff.transformations import add_fixed_constant_boundary_handling REMOVE_CASTS = ReplaceOptim(lambda x: isinstance(x, pystencils.data_types.cast_func), lambda x: x.args[0]) +DEFAULT_OP_NAME = "autodiffop" @pystencils.cache.disk_cache_no_fallback @@ -220,7 +221,7 @@ Backward: def __init__(self, forward_assignments: List[ps.Assignment], - op_name: str = "autodiffop", + op_name: str = DEFAULT_OP_NAME, boundary_handling: AutoDiffBoundaryHandling = None, time_constant_fields: List[ps.Field] = None, constant_fields: List[ps.Field] = [], @@ -604,8 +605,8 @@ Backward: def time_constant_fields(self): return self._time_constant_fields - def create_torch_op(self, *args, **kwags): - return self.create_tensorflow_op(*args, backend='torch_native', **kwags) + def create_torch_op(self, *args, **kwargs): + return self.create_tensorflow_op(*args, backend='torch_native', **kwargs) def create_tensorflow_op(self, inputfield_tensor_dict={}, @@ -685,7 +686,8 @@ Backward: self, inputfield_tensor_dict, forward_loop, backward_loop) elif backend == 'torch_native': import pystencils_autodiff.backends._torch_native - op = pystencils_autodiff.backends._torch_native.create_autograd_function(self, use_cuda) + op = pystencils_autodiff.backends._torch_native.create_autograd_function( + self, use_cuda, op_name=self.op_name if self.op_name != DEFAULT_OP_NAME else None) elif backend == 'tensorflow': import pystencils_autodiff.backends._tensorflow op = pystencils_autodiff.backends._tensorflow.tensorflowop_from_autodiffop( diff --git a/src/pystencils_autodiff/backends/_torch_native.py b/src/pystencils_autodiff/backends/_torch_native.py index c9c78ddc143fdf1b57bbfb495d4bbeac677d0e93..64be3e07e4f32aa195d21af09d29ffcdcd96bed4 100644 --- a/src/pystencils_autodiff/backends/_torch_native.py +++ b/src/pystencils_autodiff/backends/_torch_native.py @@ -7,7 +7,7 @@ from pystencils_autodiff.backends.astnodes import TorchModule from pystencils_autodiff.tensorflow_jit import _hash -def create_autograd_function(autodiff_obj, use_cuda): +def create_autograd_function(autodiff_obj, use_cuda, op_name=None): import torch field_to_tensor_dict = dict() # Allocate output tensor for forward and backward pass @@ -24,10 +24,11 @@ def create_autograd_function(autodiff_obj, use_cuda): forward_ast = autodiff_obj.forward_ast_cpu backward_ast = autodiff_obj.backward_ast_cpu if autodiff_obj.backward_output_fields else None - op_name = f'{autodiff_obj.op_name}_{_hash((str(pystencils.show_code(forward_ast)) + str(autodiff_obj)+str(autodiff_obj.constant_fields)).encode()).hexdigest()}' # noqa - forward_ast.function_name = f'{op_name}_{forward_ast.function_name}' - if backward_ast: - backward_ast.function_name = f'{op_name}_{backward_ast.function_name}' + if not op_name: + op_name = f'{autodiff_obj.op_name}_{_hash((str(pystencils.get_code_str(forward_ast)) + str(autodiff_obj)+str(autodiff_obj.constant_fields)).encode()).hexdigest()}' # noqa + forward_ast.function_name = f'{op_name}_{forward_ast.function_name}' + if backward_ast: + backward_ast.function_name = f'{op_name}_{backward_ast.function_name}' module = TorchModule(op_name, [forward_ast, backward_ast] if backward_ast else [forward_ast]) compiled_op = module.compile()