diff --git a/tests/test_alternative_wrappers.py b/tests/test_alternative_wrappers.py index afbf7af184d2d9165a3516c13bcf8da83f238563..3e215f8250235204aebfc9a7ebc50b2d2ae8b2ba 100644 --- a/tests/test_alternative_wrappers.py +++ b/tests/test_alternative_wrappers.py @@ -31,3 +31,13 @@ def test_wrap_tensorflow(): generate_shared_object(tempfile.TemporaryDirectory, None, show_code=True, framework_module_class=TensorflowModule, generate_code_only=True) + + +def test_wrap_torch(): + import pytest + pytest.importorskip("pystencils_autodiff") + + from pystencils_autodiff.backends.astnodes import TorchModule + + generate_shared_object(tempfile.TemporaryDirectory, None, show_code=True, + framework_module_class=TorchModule, generate_code_only=True)