Coverage for src/pystencilssfg/composer/basic_composer.py: 87%
270 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-04 07:16 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-04 07:16 +0000
1from __future__ import annotations
3from typing import Sequence, TypeAlias
4from abc import ABC, abstractmethod
5import sympy as sp
6from functools import reduce
7from warnings import warn
9from pystencils import (
10 Field,
11 CreateKernelConfig,
12 create_kernel,
13 Assignment,
14 AssignmentCollection,
15)
16from pystencils.codegen import Kernel, Lambda
17from pystencils.types import create_type, UserTypeSpec, PsType
19from ..context import SfgContext, SfgCursor
20from .custom import CustomGenerator
21from ..ir import (
22 SfgCallTreeNode,
23 SfgKernelCallNode,
24 SfgStatements,
25 SfgFunctionParams,
26 SfgRequireIncludes,
27 SfgSequence,
28 SfgBlock,
29 SfgBranch,
30 SfgSwitch,
31)
32from ..ir.postprocessing import (
33 SfgDeferredParamSetter,
34 SfgDeferredFieldMapping,
35 SfgDeferredVectorMapping,
36)
37from ..ir import (
38 SfgFunction,
39 SfgKernelNamespace,
40 SfgKernelHandle,
41 SfgEntityDecl,
42 SfgEntityDef,
43 SfgNamespaceBlock,
44)
45from ..lang import (
46 VarLike,
47 ExprLike,
48 _VarLike,
49 _ExprLike,
50 asvar,
51 depends,
52 HeaderFile,
53 includes,
54 SfgVar,
55 SfgKernelParamVar,
56 AugExpr,
57 SupportsFieldExtraction,
58 SupportsVectorExtraction,
59 void,
60)
61from ..exceptions import SfgException
64class SfgIComposer(ABC):
65 def __init__(self, ctx: SfgContext):
66 self._ctx = ctx
67 self._cursor = ctx.cursor
69 @property
70 def context(self):
71 return self._ctx
74class SfgNodeBuilder(ABC):
75 """Base class for node builders used by the composer"""
77 @abstractmethod
78 def resolve(self) -> SfgCallTreeNode:
79 pass
82_SequencerArg = (tuple, ExprLike, SfgCallTreeNode, SfgNodeBuilder)
83SequencerArg: TypeAlias = tuple | ExprLike | SfgCallTreeNode | SfgNodeBuilder
84"""Valid arguments to `make_sequence` and any sequencer that uses it."""
87class KernelsAdder:
88 """Handle on a kernel namespace that permits registering kernels."""
90 def __init__(self, cursor: SfgCursor, knamespace: SfgKernelNamespace):
91 self._cursor = cursor
92 self._kernel_namespace = knamespace
93 self._inline: bool = False
94 self._loc: SfgNamespaceBlock | None = None
96 def inline(self) -> KernelsAdder:
97 """Generate kernel definitions ``inline`` in the header file."""
98 self._inline = True
99 return self
101 def add(self, kernel: Kernel, name: str | None = None):
102 """Adds an existing pystencils AST to this namespace.
103 If a name is specified, the AST's function name is changed."""
104 if name is None:
105 kernel_name = kernel.name
106 else:
107 kernel_name = name
109 if self._kernel_namespace.find_kernel(kernel_name) is not None:
110 raise ValueError(
111 f"Duplicate kernels: A kernel called {kernel_name} already exists "
112 f"in namespace {self._kernel_namespace.fqname}"
113 )
115 if name is not None:
116 kernel.name = kernel_name
118 khandle = SfgKernelHandle(
119 kernel_name, self._kernel_namespace, kernel, inline=self._inline
120 )
121 self._kernel_namespace.add_kernel(khandle)
123 loc = self._get_loc()
124 loc.elements.append(SfgEntityDef(khandle))
126 for header in kernel.required_headers:
127 hfile = HeaderFile.parse(header)
128 if self._inline:
129 self._cursor.context.header_file.includes.append(hfile)
130 else:
131 impl_file = self._cursor.context.impl_file
132 assert impl_file is not None
133 impl_file.includes.append(hfile)
135 return khandle
137 def create(
138 self,
139 assignments: Assignment | Sequence[Assignment] | AssignmentCollection,
140 name: str | None = None,
141 config: CreateKernelConfig | None = None,
142 ):
143 """Creates a new pystencils kernel from a list of assignments and a configuration.
144 This is a wrapper around `create_kernel <pystencils.codegen.create_kernel>`
145 with a subsequent call to `add`.
146 """
147 if config is None:
148 config = CreateKernelConfig()
150 if name is not None:
151 if self._kernel_namespace.find_kernel(name) is not None:
152 raise ValueError(
153 f"Duplicate kernels: A kernel called {name} already exists "
154 f"in namespace {self._kernel_namespace.fqname}"
155 )
157 config.function_name = name
159 kernel = create_kernel(assignments, config=config)
160 return self.add(kernel)
162 def _get_loc(self) -> SfgNamespaceBlock:
163 if self._loc is None:
164 kns_block = SfgNamespaceBlock(self._kernel_namespace)
166 if self._inline:
167 self._cursor.write_header(kns_block)
168 else:
169 self._cursor.write_impl(kns_block)
171 self._loc = kns_block
172 return self._loc
175class SfgBasicComposer(SfgIComposer):
176 """Composer for basic source components, and base class for all composer mix-ins."""
178 def __init__(self, sfg: SfgContext | SfgIComposer):
179 ctx: SfgContext = sfg if isinstance(sfg, SfgContext) else sfg.context
180 super().__init__(ctx)
182 def prelude(self, content: str, end: str = "\n"):
183 """Append a string to the prelude comment, to be printed at the top of both generated files.
185 The string should not contain C/C++ comment delimiters, since these will be added automatically
186 during code generation.
188 :Example:
189 >>> sfg.prelude("This file was generated using pystencils-sfg; do not modify it directly!")
191 will appear in the generated files as
193 .. code-block:: C++
195 /*
196 * This file was generated using pystencils-sfg; do not modify it directly!
197 */
199 """
200 for f in self._ctx.files:
201 if f.prelude is None:
202 f.prelude = content + end
203 else:
204 f.prelude += content + end
206 def code(self, *code: str, impl: bool = False):
207 """Add arbitrary lines of code to the generated header file.
209 :Example:
211 >>> sfg.code(
212 ... "#define PI 3.14 // more than enough for engineers",
213 ... "using namespace std;"
214 ... )
216 will appear as
218 .. code-block:: C++
220 #define PI 3.14 // more than enough for engineers
221 using namespace std;
223 Args:
224 code: Sequence of code strings to be written to the output file
225 impl: If `True`, write the code to the implementation file; otherwise, to the header file.
226 """
227 for c in code:
228 if impl:
229 self._cursor.write_impl(c)
230 else:
231 self._cursor.write_header(c)
233 def define(self, *definitions: str):
234 from warnings import warn
236 warn(
237 "The `define` method of `SfgBasicComposer` is deprecated and will be removed in a future version."
238 "Use `sfg.code()` instead.",
239 FutureWarning,
240 )
242 self.code(*definitions)
244 def namespace(self, namespace: str):
245 """Enter a new namespace block.
247 Calling `namespace` as a regular function will open a new namespace as a child of the
248 currently active namespace; this new namespace will then become active instead.
249 Using `namespace` as a context manager will instead activate the given namespace
250 only for the length of the ``with`` block.
252 Args:
253 namespace: Qualified name of the namespace
255 :Example:
257 The following calls will set the current namespace to ``outer::inner``
258 for the remaining code generation run:
260 .. code-block::
262 sfg.namespace("outer")
263 sfg.namespace("inner")
265 Subsequent calls to `namespace` can only create further nested namespaces.
267 To step back out of a namespace, `namespace` can also be used as a context manager:
269 .. code-block::
271 with sfg.namespace("detail"):
272 ...
274 This way, code generated inside the ``with`` region is placed in the ``detail`` namespace,
275 and code after this block will again live in the enclosing namespace.
277 """
278 return self._cursor.enter_namespace(namespace)
280 def generate(self, generator: CustomGenerator):
281 """Invoke a custom code generator with the underlying context."""
282 from .composer import SfgComposer
284 generator.generate(SfgComposer(self))
286 @property
287 def kernels(self) -> KernelsAdder:
288 """The default kernel namespace.
290 Add kernels like::
292 sfg.kernels.add(ast, "kernel_name")
293 sfg.kernels.create(assignments, "kernel_name", config)
294 """
295 return self.kernel_namespace("kernels")
297 def kernel_namespace(self, name: str) -> KernelsAdder:
298 """Return a view on a kernel namespace in order to add kernels to it."""
299 kns = self._cursor.get_entity(name)
300 if kns is None:
301 kns = SfgKernelNamespace(name, self._cursor.current_namespace)
302 self._cursor.add_entity(kns)
303 elif not isinstance(kns, SfgKernelNamespace):
304 raise ValueError(
305 f"The existing entity {kns.fqname} is not a kernel namespace"
306 )
308 kadder = KernelsAdder(self._cursor, kns)
309 if self._ctx.impl_file is None:
310 kadder.inline()
311 return kadder
313 def include(self, header: str | HeaderFile, private: bool = False):
314 """Include a header file.
316 Args:
317 header_file: Path to the header file. Enclose in ``<>`` for a system header.
318 private: If ``True``, in header-implementation code generation, the header file is
319 only included in the implementation file.
321 :Example:
323 >>> sfg.include("<vector>")
324 >>> sfg.include("custom.h")
326 will be printed as
328 .. code-block:: C++
330 #include <vector>
331 #include "custom.h"
332 """
333 header_file = HeaderFile.parse(header)
335 if private:
336 if self._ctx.impl_file is None:
337 raise ValueError(
338 "Cannot emit a private include since no implementation file is being generated"
339 )
340 self._ctx.impl_file.includes.append(header_file)
341 else:
342 self._ctx.header_file.includes.append(header_file)
344 def kernel_function(self, name: str, kernel: Kernel | SfgKernelHandle):
345 """Create a function comprising just a single kernel call.
347 Args:
348 ast_or_kernel_handle: Either a pystencils AST, or a kernel handle for an already registered AST.
349 """
350 if isinstance(kernel, Kernel):
351 khandle = self.kernels.add(kernel, name)
352 else:
353 khandle = kernel
355 self.function(name)(self.call(khandle))
357 def function(
358 self,
359 name: str,
360 return_type: UserTypeSpec | None = None,
361 ) -> SfgFunctionSequencer:
362 """Add a function.
364 The syntax of this function adder uses a chain of two calls to mimic C++ syntax:
366 .. code-block:: Python
368 sfg.function("FunctionName")(
369 # Function Body
370 )
372 The function body is constructed via sequencing (see `make_sequence`).
373 """
374 seq = SfgFunctionSequencer(self._cursor, name)
376 if return_type is not None:
377 warn(
378 "The parameter `return_type` to `function()` is deprecated and will be removed by version 0.1. "
379 "Use `.returns()` instead.",
380 FutureWarning,
381 )
382 seq.returns(return_type)
384 if self._ctx.impl_file is None:
385 seq.inline()
387 return seq
389 def call(self, kernel_handle: SfgKernelHandle) -> SfgCallTreeNode:
390 """Use inside a function body to directly call a kernel.
392 When using `call`, the given kernel will simply be called as a function.
393 To invoke a GPU kernel on a specified launch grid,
394 use `gpu_invoke <SfgGpuComposer.gpu_invoke>` instead.
396 Args:
397 kernel_handle: Handle to a kernel previously added to some kernel namespace.
398 """
399 return SfgKernelCallNode(kernel_handle)
401 def seq(self, *args: tuple | str | SfgCallTreeNode | SfgNodeBuilder) -> SfgSequence:
402 """Syntax sequencing. For details, see `make_sequence`"""
403 return make_sequence(*args)
405 def params(self, *args: AugExpr) -> SfgFunctionParams:
406 """Use inside a function body to add parameters to the function."""
407 return SfgFunctionParams([x.as_variable() for x in args])
409 def require(self, *incls: str | HeaderFile) -> SfgRequireIncludes:
410 """Use inside a function body to require the inclusion of headers."""
411 return SfgRequireIncludes((HeaderFile.parse(incl) for incl in incls))
413 def var(self, name: str, dtype: UserTypeSpec) -> AugExpr:
414 """Create a variable with given name and data type."""
415 return AugExpr(create_type(dtype)).var(name)
417 def vars(self, names: str, dtype: UserTypeSpec) -> tuple[AugExpr, ...]:
418 """Create multiple variables with given names and the same data type.
420 Example:
422 >>> sfg.vars("x, y, z", "float32")
423 (x, y, z)
425 """
426 varnames = names.split(",")
427 return tuple(self.var(n.strip(), dtype) for n in varnames)
429 def init(self, lhs: VarLike):
430 """Create a C++ in-place initialization.
432 Usage:
434 .. code-block:: Python
436 obj = sfg.var("obj", "SomeClass")
437 sfg.init(obj)(arg1, arg2, arg3)
439 becomes
441 .. code-block:: C++
443 SomeClass obj { arg1, arg2, arg3 };
444 """
445 lhs_var = asvar(lhs)
447 def parse_args(*args: ExprLike):
448 args_str = ", ".join(str(arg) for arg in args)
449 deps: set[SfgVar] = reduce(set.union, (depends(arg) for arg in args), set())
450 incls: set[HeaderFile] = reduce(set.union, (includes(arg) for arg in args))
451 return SfgStatements(
452 f"{lhs_var.dtype.c_string()} {lhs_var.name} { {args_str} } ;",
453 (lhs_var,),
454 deps,
455 incls,
456 )
458 return parse_args
460 def expr(self, fmt: str, *deps, **kwdeps) -> AugExpr:
461 """Create an expression while keeping track of variables it depends on.
463 This method is meant to be used similarly to `str.format`; in fact,
464 it calls `str.format` internally and therefore supports all of its
465 formatting features.
466 In addition, however, the format arguments are scanned for *variables*
467 (e.g. created using `var`), which are attached to the expression.
468 This way, *pystencils-sfg* keeps track of any variables an expression depends on.
470 :Example:
472 >>> x, y, z, w = sfg.vars("x, y, z, w", "float32")
473 >>> expr = sfg.expr("{} + {} * {}", x, y, z)
474 >>> expr
475 x + y * z
477 You can look at the expression's dependencies:
479 >>> sorted(expr.depends, key=lambda v: v.name)
480 [x: float32, y: float32, z: float32]
482 If you use an existing expression to create a larger one, the new expression
483 inherits all variables from its parts:
485 >>> expr2 = sfg.expr("{} + {}", expr, w)
486 >>> expr2
487 x + y * z + w
488 >>> sorted(expr2.depends, key=lambda v: v.name)
489 [w: float32, x: float32, y: float32, z: float32]
491 """
492 return AugExpr.format(fmt, *deps, **kwdeps)
494 def expr_from_lambda(self, lamb: Lambda) -> AugExpr:
495 depends = set(SfgKernelParamVar(p) for p in lamb.parameters)
496 code = lamb.c_code()
497 return AugExpr.make(code, depends, dtype=lamb.return_type)
499 @property
500 def branch(self) -> SfgBranchBuilder:
501 """Use inside a function body to create an if/else conditonal branch.
503 The syntax is:
505 .. code-block:: Python
507 sfg.branch("condition")(
508 # then-body
509 )(
510 # else-body (may be omitted)
511 )
512 """
513 return SfgBranchBuilder()
515 def switch(self, switch_arg: ExprLike, autobreak: bool = True) -> SfgSwitchBuilder:
516 """Use inside a function to construct a switch-case statement.
518 Args:
519 switch_arg: Argument to the `switch()` statement
520 autobreak: Whether to automatically print a ``break;`` at the end of each case block
521 """
522 return SfgSwitchBuilder(switch_arg, autobreak=autobreak)
524 def map_field(
525 self,
526 field: Field,
527 index_provider: SupportsFieldExtraction,
528 cast_indexing_symbols: bool = True,
529 ) -> SfgDeferredFieldMapping:
530 """Map a pystencils field to a field data structure, from which pointers, sizes
531 and strides should be extracted.
533 Args:
534 field: The pystencils field to be mapped
535 index_provider: An object that provides the field indexing information
536 cast_indexing_symbols: Whether to always introduce explicit casts for indexing symbols
537 """
538 return SfgDeferredFieldMapping(
539 field, index_provider, cast_indexing_symbols=cast_indexing_symbols
540 )
542 def set_param(self, param: VarLike | sp.Symbol, expr: ExprLike):
543 """Set a kernel parameter to an expression.
545 Code setting the parameter will only be generated if the parameter
546 is actually alive (i.e. required by some kernel, and not yet set) at
547 the point this method is called.
548 """
549 var: SfgVar | sp.Symbol = asvar(param) if isinstance(param, _VarLike) else param
550 return SfgDeferredParamSetter(var, expr)
552 def map_vector(
553 self,
554 lhs_components: Sequence[VarLike | sp.Symbol],
555 rhs: SupportsVectorExtraction,
556 ):
557 """Extracts scalar numerical values from a vector data type.
559 Args:
560 lhs_components: Vector components as a list of symbols.
561 rhs: An object providing access to vector components
562 """
563 components: list[SfgVar | sp.Symbol] = [
564 (asvar(c) if isinstance(c, _VarLike) else c) for c in lhs_components
565 ]
566 return SfgDeferredVectorMapping(components, rhs)
569def make_statements(arg: ExprLike) -> SfgStatements:
570 return SfgStatements(str(arg), (), depends(arg), includes(arg))
573def make_sequence(*args: SequencerArg) -> SfgSequence:
574 """Construct a sequence of C++ code from various kinds of arguments.
576 `make_sequence` is ubiquitous throughout the function building front-end;
577 among others, it powers the syntax of `SfgBasicComposer.function`
578 and `SfgBasicComposer.branch`.
580 `make_sequence` constructs an abstract syntax tree for code within a function body, accepting various
581 types of arguments which then get turned into C++ code. These are
583 - Strings (`str`) are printed as-is
584 - Tuples (`tuple`) signify *blocks*, i.e. C++ code regions enclosed in ``{ }``
585 - Sub-ASTs and AST builders, which are often produced by the syntactic sugar and
586 factory methods of `SfgComposer`.
588 :Example:
590 .. code-block:: Python
592 tree = make_sequence(
593 "int a = 0;",
594 "int b = 1;",
595 (
596 "int tmp = b;",
597 "b = a;",
598 "a = tmp;"
599 ),
600 SfgKernelCall(kernel_handle)
601 )
603 sfg.context.add_function("myFunction", tree)
605 will translate to
607 .. code-block:: C++
609 void myFunction() {
610 int a = 0;
611 int b = 0;
612 {
613 int tmp = b;
614 b = a;
615 a = tmp;
616 }
617 kernels::kernel( ... );
618 }
619 """
620 children = []
621 for i, arg in enumerate(args):
622 if isinstance(arg, SfgNodeBuilder):
623 children.append(arg.resolve())
624 elif isinstance(arg, SfgCallTreeNode):
625 children.append(arg)
626 elif isinstance(arg, _ExprLike):
627 children.append(make_statements(arg))
628 elif isinstance(arg, tuple):
629 # Tuples are treated as blocks
630 subseq = make_sequence(*arg)
631 children.append(SfgBlock(subseq))
632 else:
633 raise TypeError(f"Sequence argument {i} has invalid type.")
635 return SfgSequence(children)
638class SfgFunctionSequencerBase:
639 """Common base class for function and method sequencers.
641 This builder uses call sequencing to specify the function or method's properties.
643 Example:
645 >>> sfg.function(
646 ... "myFunction"
647 ... ).returns(
648 ... "float32"
649 ... ).attr(
650 ... "nodiscard", "maybe_unused"
651 ... ).inline().constexpr()(
652 ... "return 31.2;"
653 ... )
654 """
656 def __init__(self, cursor: SfgCursor, name: str) -> None:
657 self._cursor = cursor
658 self._name = name
659 self._return_type: PsType = void
660 self._params: list[SfgVar] | None = None
662 # Qualifiers
663 self._inline: bool = False
664 self._constexpr: bool = False
666 # Attributes
667 self._attributes: list[str] = []
669 def returns(self, rtype: UserTypeSpec):
670 """Set the return type of the function"""
671 self._return_type = create_type(rtype)
672 return self
674 def params(self, *args: VarLike):
675 """Specify the parameters for this function.
677 Use this to manually specify the function's parameter list.
679 If any free variables collected from the function body are not contained
680 in the parameter list, an error will be raised.
681 """
682 self._params = [asvar(v) for v in args]
683 return self
685 def inline(self):
686 """Mark this function as ``inline``."""
687 self._inline = True
688 return self
690 def constexpr(self):
691 """Mark this function as ``constexpr``."""
692 self._constexpr = True
693 return self
695 def attr(self, *attrs: str):
696 """Add attributes to this function"""
697 self._attributes += attrs
698 return self
701class SfgFunctionSequencer(SfgFunctionSequencerBase):
702 """Sequencer for constructing functions."""
704 def __call__(self, *args: SequencerArg) -> None:
705 """Populate the function body"""
706 tree = make_sequence(*args)
707 func = SfgFunction(
708 self._name,
709 self._cursor.current_namespace,
710 tree,
711 return_type=self._return_type,
712 inline=self._inline,
713 constexpr=self._constexpr,
714 attributes=self._attributes,
715 required_params=self._params,
716 )
717 self._cursor.add_entity(func)
719 if self._inline:
720 self._cursor.write_header(SfgEntityDef(func))
721 else:
722 self._cursor.write_header(SfgEntityDecl(func))
723 self._cursor.write_impl(SfgEntityDef(func))
726class SfgBranchBuilder(SfgNodeBuilder):
727 """Multi-call builder for C++ ``if/else`` statements."""
729 def __init__(self) -> None:
730 self._phase = 0
732 self._cond: ExprLike | None = None
733 self._branch_true = SfgSequence(())
734 self._branch_false: SfgSequence | None = None
736 def __call__(self, *args) -> SfgBranchBuilder:
737 match self._phase:
738 case 0: # Condition
739 if len(args) != 1:
740 raise ValueError(
741 "Must specify exactly one argument as branch condition!"
742 )
744 self._cond = args[0]
746 case 1: # Then-branch
747 self._branch_true = make_sequence(*args)
748 case 2: # Else-branch
749 self._branch_false = make_sequence(*args)
750 case _: # There's no third branch!
751 raise TypeError("Branch construct already complete.")
753 self._phase += 1
755 return self
757 def resolve(self) -> SfgCallTreeNode:
758 assert self._cond is not None
759 return SfgBranch(
760 make_statements(self._cond), self._branch_true, self._branch_false
761 )
764class SfgSwitchBuilder(SfgNodeBuilder):
765 """Builder for C++ switches."""
767 def __init__(self, switch_arg: ExprLike, autobreak: bool = True):
768 self._switch_arg = switch_arg
769 self._cases: dict[str, SfgSequence] = dict()
770 self._default: SfgSequence | None = None
771 self._autobreak = autobreak
773 def case(self, label: str):
774 if label in self._cases:
775 raise SfgException(f"Duplicate case: {label}")
777 def sequencer(*args: SequencerArg):
778 if self._autobreak:
779 args += ("break;",)
780 tree = make_sequence(*args)
781 self._cases[label] = tree
782 return self
784 return sequencer
786 def cases(self, cases_dict: dict[str, SequencerArg]):
787 for key, value in cases_dict.items():
788 self.case(key)(value)
789 return self
791 def default(self, *args):
792 if self._default is not None:
793 raise SfgException("Duplicate default case")
795 tree = make_sequence(*args)
796 self._default = tree
798 return self
800 def resolve(self) -> SfgCallTreeNode:
801 return SfgSwitch(make_statements(self._switch_arg), self._cases, self._default)