Skip to content
Snippets Groups Projects
Commit bb88ec51 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Instantiate function class

parent 374d0fd5
No related branches found
No related tags found
No related merge requests found
......@@ -139,8 +139,7 @@ def create_autograd_function(autodiff_obj, inputfield_to_tensor_dict, forward_lo
os.mkdir(cache_dir)
# TODO: create function and stuff
compiled_operation = generate_torch(cache_dir, autodiff_obj, is_cuda,
dtype)
compiled_operation = generate_torch(cache_dir, autodiff_obj, is_cuda, dtype)
field_to_tensor_dict = inputfield_to_tensor_dict
# Allocate output tensor for forward and backward pass
for field in chain(autodiff_obj.forward_output_fields, autodiff_obj.backward_output_fields):
......@@ -164,7 +163,7 @@ def create_autograd_function(autodiff_obj, inputfield_to_tensor_dict, forward_lo
cls.saved = None
cls.forward = forward
cls.backward = backward
return cls
return cls()
else:
op = pystencils_autodiff.backends._pytorch.create_autograd_function(autodiff_obj,
inputfield_to_tensor_dict,
......
......@@ -168,7 +168,7 @@ def test_execute_torch():
y_tensor = pystencils_autodiff.torch_tensor_from_field(y, 1, cuda=False)
op_cpp = create_autograd_function(autodiff, {x: x_tensor, y: y_tensor})
foo = op_cpp.forward(x_tensor)
foo = op_cpp.forward()
print(foo)
assert op_cpp is not None
......@@ -189,7 +189,7 @@ def test_execute_torch_gpu():
op_cuda = create_autograd_function(autodiff, {x: x_tensor, y: y_tensor})
assert op_cuda is not None
rtn = op_cuda.forward(y_tensor, x_tensor)
rtn = op_cuda.forward()
print(y_tensor)
print(rtn)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment