Skip to content
Snippets Groups Projects
Commit af798d17 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Enable use of custom printer for classes with printer variable

parent d3c15dd7
Branches
Tags
No related merge requests found
Pipeline #19313 failed
...@@ -15,6 +15,12 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend): ...@@ -15,6 +15,12 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend):
super().__init__(sympy_printer=None, super().__init__(sympy_printer=None,
dialect='c') dialect='c')
def _print(self, node):
from pystencils_autodiff.framework_integration.astnodes import JinjaCppFile
if isinstance(node, JinjaCppFile):
node.printer = self
return super()._print(node)
def _print_WrapperFunction(self, node): def _print_WrapperFunction(self, node):
super_result = super()._print_KernelFunction(node) super_result = super()._print_KernelFunction(node)
return super_result.replace('FUNC_PREFIX ', '') return super_result.replace('FUNC_PREFIX ', '')
......
...@@ -8,11 +8,12 @@ ...@@ -8,11 +8,12 @@
""" """
import pytest import pytest
import sympy
import pystencils import pystencils
import sympy
from pystencils_autodiff import create_backward_assignments from pystencils_autodiff import create_backward_assignments
from pystencils_autodiff.backends.astnodes import PybindModule, TensorflowModule, TorchModule from pystencils_autodiff.backends.astnodes import PybindModule, TensorflowModule, TorchModule
from pystencils_autodiff.framework_integration.printer import FrameworkIntegrationPrinter
try: try:
from pystencils.interpolation_astnodes import TextureCachedField from pystencils.interpolation_astnodes import TextureCachedField
...@@ -101,5 +102,23 @@ def test_module_printing_globals(): ...@@ -101,5 +102,23 @@ def test_module_printing_globals():
print(module) print(module)
if __name__ == "__main__": def test_custom_printer():
test_module_printing_globals()
class DoesNotLikeTorchPrinter(FrameworkIntegrationPrinter):
def _print_TorchModule(self, node):
return 'Error: I don\'t like Torch'
z, y, x = pystencils.fields("z, y, x: [20,40]")
forward_assignments = pystencils.AssignmentCollection({
z[0, 0]: x[0, 0] * sympy.log(TextureCachedField(x).at(sympy.Matrix((0.43, 3))) * y[0, 0])
})
backward_assignments = create_backward_assignments(forward_assignments)
forward_ast = pystencils.create_kernel(forward_assignments)
forward_ast.function_name = 'forward'
backward_ast = pystencils.create_kernel(backward_assignments)
backward_ast.function_name = 'backward'
module = TorchModule("hallo", [forward_ast, backward_ast])
print(DoesNotLikeTorchPrinter()(module))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment