diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index 56583f063f3aba0623fa4b8b5f84836a93795e08..f5962b8cfaab7da9789021f162d7284855c6919f 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -6,6 +6,7 @@ import numpy as np from pystencils import Field, TypedSymbol from pystencils.astnodes import KernelFunction +from .custom import CustomGenerator from ..tree import ( SfgCallTreeNode, SfgKernelCallNode, @@ -53,14 +54,25 @@ class SfgBasicComposer: """ self._ctx.append_to_prelude(content) - def define(self, definition: str): - """Add a custom definition to the generated header file.""" - self._ctx.add_definition(definition) + def define(self, *definitions: str): + """Add custom definitions to the generated header file.""" + for d in definitions: + self._ctx.add_definition(d) + + def define_once(self, *definitions: str): + """Same as `define`, but only adds definitions only if the same code string was not already added.""" + for definition in definitions: + if all(d != definition for d in self._ctx.definitions()): + self._ctx.add_definition(definition) def namespace(self, namespace: str): """Set the inner code namespace. Throws an exception if a namespace was already set.""" self._ctx.set_namespace(namespace) + def generate(self, generator: CustomGenerator): + """Invokes a custom code generator with the underlying context.""" + generator.generate(self._ctx) + @property def kernels(self) -> SfgKernelNamespace: """The default kernel namespace. Add kernels like: diff --git a/src/pystencilssfg/composer/custom.py b/src/pystencilssfg/composer/custom.py new file mode 100644 index 0000000000000000000000000000000000000000..1b43dd3d5a0e46bbbd88cf3e0bd1758d3835a4a4 --- /dev/null +++ b/src/pystencilssfg/composer/custom.py @@ -0,0 +1,11 @@ +from abc import ABC, abstractmethod +from ..context import SfgContext + + +class CustomGenerator(ABC): + """Abstract base class for custom code generators that may be passed to + [SfgComposer.generate][pystencilssfg.SfgComposer.generate].""" + + @abstractmethod + def generate(self, ctx: SfgContext) -> None: + ...