From 23c8768b46dbdc15ec8f2de232daa64e5b6e5d8e Mon Sep 17 00:00:00 2001 From: Behzad Safaei <iwia103h@a0226.nhr.fau.de> Date: Thu, 13 Feb 2025 23:02:50 +0100 Subject: [PATCH] Add support for modules with non-void return types --- src/pairs/analysis/__init__.py | 5 ++++- src/pairs/analysis/modules.py | 12 ++++++++++++ src/pairs/code_gen/cgen.py | 8 +++++++- src/pairs/ir/functions.py | 2 +- src/pairs/ir/module.py | 15 ++++++++++++++- src/pairs/ir/mutator.py | 6 +++++- src/pairs/ir/ret.py | 13 +++++++++++++ src/pairs/ir/types.py | 5 ++++- src/pairs/transformations/__init__.py | 7 ++++++- 9 files changed, 66 insertions(+), 7 deletions(-) create mode 100644 src/pairs/ir/ret.py diff --git a/src/pairs/analysis/__init__.py b/src/pairs/analysis/__init__.py index ba2204c..7b200b2 100644 --- a/src/pairs/analysis/__init__.py +++ b/src/pairs/analysis/__init__.py @@ -2,7 +2,7 @@ import time from pairs.analysis.expressions import DetermineExpressionsTerminals, ResetInPlaceOperations, DetermineInPlaceOperations, ListDeclaredExpressions from pairs.analysis.blocks import DiscoverBlockVariants, DetermineExpressionsOwnership, DetermineParentBlocks from pairs.analysis.devices import FetchKernelReferences, MarkCandidateLoops -from pairs.analysis.modules import FetchModulesReferences +from pairs.analysis.modules import FetchModulesReferences, InferModulesReturnTypes class Analysis: @@ -51,3 +51,6 @@ class Analysis: def mark_candidate_loops(self): self.apply(MarkCandidateLoops()) + + def infer_modules_return_types(self): + self.apply(InferModulesReturnTypes()) \ No newline at end of file diff --git a/src/pairs/analysis/modules.py b/src/pairs/analysis/modules.py index d5311bc..fd7bd11 100644 --- a/src/pairs/analysis/modules.py +++ b/src/pairs/analysis/modules.py @@ -1,5 +1,17 @@ from pairs.ir.visitor import Visitor +class InferModulesReturnTypes(Visitor): + def __init__(self, ast=None): + super().__init__(ast) + + def visit_Module(self, ast_node): + self.current_module = ast_node + self.visit_children(ast_node) + + def visit_Return(self, ast_node): + self.current_module._return_type = ast_node.expr.type() + self.visit_children(ast_node) + class FetchModulesReferences(Visitor): def __init__(self, ast=None): diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py index 9cfce8b..313101d 100644 --- a/src/pairs/code_gen/cgen.py +++ b/src/pairs/code_gen/cgen.py @@ -30,6 +30,7 @@ from pairs.ir.print import Print, PrintCode from pairs.ir.variables import Var, DeclareVariable, Deref from pairs.ir.parameters import Parameter from pairs.ir.vectors import Vector, VectorAccess, VectorOp, ZeroVector +from pairs.ir.ret import Return from pairs.sim.domain_partitioners import DomainPartitioners from pairs.sim.timestep import Timestep from pairs.code_gen.printer import Printer @@ -185,7 +186,8 @@ class CGen: print_params = ", ".join(module_params) ending = "{" if definition else ";" - self.print(f"void {module.name}({print_params}){ending}") + tkw = Types.c_keyword(self.sim, module.return_type) + self.print(f"{tkw} {module.name}({print_params}){ending}") def generate_module_decls(self): self.print("") @@ -1140,6 +1142,10 @@ class CGen: self.generate_statement(ast_node.block) self.print("}") + if isinstance(ast_node, Return): + expr = self.generate_expression(ast_node.expr) + self.print(f"return {expr};") + def generate_expression(self, ast_node, mem=False, index=None): if isinstance(ast_node, Array): return self.generate_object_reference(ast_node) diff --git a/src/pairs/ir/functions.py b/src/pairs/ir/functions.py index efaf3e6..ffb49d5 100644 --- a/src/pairs/ir/functions.py +++ b/src/pairs/ir/functions.py @@ -36,7 +36,7 @@ class Call_Int(Call): class Call_Void(Call): def __init__(self, sim, func_name, parameters): - super().__init__(sim, func_name, parameters, Types.Invalid) + super().__init__(sim, func_name, parameters, Types.Void) sim.add_statement(self) def __str__(self): diff --git a/src/pairs/ir/module.py b/src/pairs/ir/module.py index 04b4f85..44539db 100644 --- a/src/pairs/ir/module.py +++ b/src/pairs/ir/module.py @@ -5,12 +5,20 @@ from pairs.ir.features import FeatureProperty from pairs.ir.properties import Property, ContactProperty from pairs.ir.variables import Var from pairs.ir.parameters import Parameter +from pairs.ir.types import Types class Module(ASTNode): last_module = 0 - def __init__(self, sim, name=None, block=None, resizes_to_check={}, check_properties_resize=False, run_on_device=False, user_defined=False, interface=False): + def __init__(self, sim, + name=None, + block=None, + resizes_to_check={}, + check_properties_resize=False, + run_on_device=False, + user_defined=False, + interface=False): super().__init__(sim) self._id = Module.last_module self._name = name if name is not None else "module" + str(Module.last_module) @@ -27,6 +35,7 @@ class Module(ASTNode): self._run_on_device = run_on_device self._user_defined = user_defined self._interface = interface + self._return_type = Types.Void self._profile = False sim.add_module(self) Module.last_module += 1 @@ -58,6 +67,10 @@ class Module(ASTNode): def interface(self): return self._interface + @property + def return_type(self): + return self._return_type + def profile(self): self._profile = True self.sim.enable_profiler() diff --git a/src/pairs/ir/mutator.py b/src/pairs/ir/mutator.py index 669dbc4..3fb017f 100644 --- a/src/pairs/ir/mutator.py +++ b/src/pairs/ir/mutator.py @@ -54,7 +54,11 @@ class Mutator: ast_node._reduction_variable = self.mutate(ast_node._reduction_variable) return ast_node - + + def mutate_Return(self, ast_node): + ast_node.expr = self.mutate(ast_node.expr) + return ast_node + def mutate_Print(self, ast_node): ast_node.args = [self.mutate(arg) for arg in ast_node.args] return ast_node diff --git a/src/pairs/ir/ret.py b/src/pairs/ir/ret.py new file mode 100644 index 0000000..bb23504 --- /dev/null +++ b/src/pairs/ir/ret.py @@ -0,0 +1,13 @@ +from pairs.ir.ast_node import ASTNode + +class Return(ASTNode): + def __init__(self, sim, expr): + super().__init__(sim) + self.expr = expr + self.sim.add_statement(self) + + def __str__(self): + return f"Return<{self.expr}>" + + def children(self): + return [self.expr] \ No newline at end of file diff --git a/src/pairs/ir/types.py b/src/pairs/ir/types.py index f4ab048..ab27939 100644 --- a/src/pairs/ir/types.py +++ b/src/pairs/ir/types.py @@ -1,5 +1,6 @@ class Types: - Invalid = -1 + Invalid = -2 + Void = -1 Int32 = 0 Int64 = 1 UInt64 = 2 @@ -26,6 +27,7 @@ class Types: else 'long long int' if t == Types.Int64 else 'unsigned long long int' if t == Types.UInt64 else 'bool' if t == Types.Boolean + else 'void' if t == Types.Void else '<invalid type>' ) @@ -39,6 +41,7 @@ class Types: else 'long long int' if t == Types.Int64 else 'unsigned long long int' if t == Types.UInt64 else 'bool' if t == Types.Boolean + else 'void' if t == Types.Void else '<invalid type>' ) diff --git a/src/pairs/transformations/__init__.py b/src/pairs/transformations/__init__.py index 3a902ef..733d5c1 100644 --- a/src/pairs/transformations/__init__.py +++ b/src/pairs/transformations/__init__.py @@ -67,6 +67,7 @@ class Transformations: self._module_resizes = add_resize_logic.module_resizes self.analysis().fetch_modules_references() self.apply(DereferenceWriteVariables()) + self.analysis().infer_modules_return_types() self.apply(ReplaceModulesByCalls(), [self._module_resizes]) self.apply(MergeAdjacentBlocks()) @@ -108,5 +109,9 @@ class Transformations: self.add_expression_declarations() self.add_host_references_to_modules() self.add_device_references_to_modules() - self.add_instrumentation() + + # TODO: Place stop timers before the function returns + # or simply don't instrument modules that have a non-void return type + # to avoid having to deal with returns within conditional blocks + # self.add_instrumentation() -- GitLab