Skip to content
Snippets Groups Projects
Commit 324c3211 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

WIP: Add test case print globals

parent 540eaa2c
No related branches found
No related tags found
No related merge requests found
Pipeline #17261 failed
......@@ -7,16 +7,18 @@
"""
"""
import sympy
import pytest
import sympy
import pystencils
from pystencils_autodiff import create_backward_assignments
from pystencils_autodiff.backends.astnodes import PybindModule, TensorflowModule, TorchModule
TARGET_TO_DIALECT = {
'cpu': 'c',
'gpu': 'cuda'
}
try:
from pystencils.interpolation_astnodes import TextureCachedField
HAS_INTERPOLATION = True
except ImportError:
HAS_INTERPOLATION = False
def test_module_printing():
......@@ -76,9 +78,30 @@ def test_module_printing_parameter():
print(module)
@pytest.mark.skipif(not HAS_INTERPOLATION, reason="")
def test_module_printing_globals():
for target in ('gpu',):
z, y, x = pystencils.fields("z, y, x: [20,40]")
forward_assignments = pystencils.AssignmentCollection({
z[0, 0]: x[0, 0] * sympy.log(TextureCachedField(x).at(sympy.Matrix((0.43, 3))) * y[0, 0])
})
backward_assignments = create_backward_assignments(forward_assignments)
forward_ast = pystencils.create_kernel(forward_assignments, target)
forward_ast.function_name = 'forward'
backward_ast = pystencils.create_kernel(backward_assignments, target)
backward_ast.function_name = 'backward'
module = TorchModule([forward_ast, backward_ast])
print(module)
def main():
test_module_printing()
# test_module_printing_parameter()
test_module_printing_parameter()
test_module_printing_globals()
if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment