From 36c557937b4a7174cf586876779149995b97b48c Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Fri, 22 Dec 2023 18:24:56 +0100
Subject: [PATCH] add define_once and custom generators; made define variadic

---
 src/pystencilssfg/composer/basic_composer.py | 18 +++++++++++++++---
 src/pystencilssfg/composer/custom.py         | 11 +++++++++++
 2 files changed, 26 insertions(+), 3 deletions(-)
 create mode 100644 src/pystencilssfg/composer/custom.py

diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py
index 56583f0..f5962b8 100644
--- a/src/pystencilssfg/composer/basic_composer.py
+++ b/src/pystencilssfg/composer/basic_composer.py
@@ -6,6 +6,7 @@ import numpy as np
 from pystencils import Field, TypedSymbol
 from pystencils.astnodes import KernelFunction
 
+from .custom import CustomGenerator
 from ..tree import (
     SfgCallTreeNode,
     SfgKernelCallNode,
@@ -53,14 +54,25 @@ class SfgBasicComposer:
         """
         self._ctx.append_to_prelude(content)
 
-    def define(self, definition: str):
-        """Add a custom definition to the generated header file."""
-        self._ctx.add_definition(definition)
+    def define(self, *definitions: str):
+        """Add custom definitions to the generated header file."""
+        for d in definitions:
+            self._ctx.add_definition(d)
+
+    def define_once(self, *definitions: str):
+        """Same as `define`, but only adds definitions only if the same code string was not already added."""
+        for definition in definitions:
+            if all(d != definition for d in self._ctx.definitions()):
+                self._ctx.add_definition(definition)
 
     def namespace(self, namespace: str):
         """Set the inner code namespace. Throws an exception if a namespace was already set."""
         self._ctx.set_namespace(namespace)
 
+    def generate(self, generator: CustomGenerator):
+        """Invokes a custom code generator with the underlying context."""
+        generator.generate(self._ctx)
+
     @property
     def kernels(self) -> SfgKernelNamespace:
         """The default kernel namespace. Add kernels like:
diff --git a/src/pystencilssfg/composer/custom.py b/src/pystencilssfg/composer/custom.py
new file mode 100644
index 0000000..1b43dd3
--- /dev/null
+++ b/src/pystencilssfg/composer/custom.py
@@ -0,0 +1,11 @@
+from abc import ABC, abstractmethod
+from ..context import SfgContext
+
+
+class CustomGenerator(ABC):
+    """Abstract base class for custom code generators that may be passed to
+    [SfgComposer.generate][pystencilssfg.SfgComposer.generate]."""
+
+    @abstractmethod
+    def generate(self, ctx: SfgContext) -> None:
+        ...
-- 
GitLab