diff --git a/src/pairs/analysis/__init__.py b/src/pairs/analysis/__init__.py index ba2204c5c87f81f6f101a26d72ff12ec5067b294..7b200b201ef6b1126275c6656c98419b36e2d89a 100644 --- a/src/pairs/analysis/__init__.py +++ b/src/pairs/analysis/__init__.py @@ -2,7 +2,7 @@ import time from pairs.analysis.expressions import DetermineExpressionsTerminals, ResetInPlaceOperations, DetermineInPlaceOperations, ListDeclaredExpressions from pairs.analysis.blocks import DiscoverBlockVariants, DetermineExpressionsOwnership, DetermineParentBlocks from pairs.analysis.devices import FetchKernelReferences, MarkCandidateLoops -from pairs.analysis.modules import FetchModulesReferences +from pairs.analysis.modules import FetchModulesReferences, InferModulesReturnTypes class Analysis: @@ -51,3 +51,6 @@ class Analysis: def mark_candidate_loops(self): self.apply(MarkCandidateLoops()) + + def infer_modules_return_types(self): + self.apply(InferModulesReturnTypes()) \ No newline at end of file diff --git a/src/pairs/analysis/modules.py b/src/pairs/analysis/modules.py index d5311bc89780774ec5b647cb6110bb018b7b4743..fd7bd11393525e91f96cd0e3a8baaa451209a7d2 100644 --- a/src/pairs/analysis/modules.py +++ b/src/pairs/analysis/modules.py @@ -1,5 +1,17 @@ from pairs.ir.visitor import Visitor +class InferModulesReturnTypes(Visitor): + def __init__(self, ast=None): + super().__init__(ast) + + def visit_Module(self, ast_node): + self.current_module = ast_node + self.visit_children(ast_node) + + def visit_Return(self, ast_node): + self.current_module._return_type = ast_node.expr.type() + self.visit_children(ast_node) + class FetchModulesReferences(Visitor): def __init__(self, ast=None): diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py index 9cfce8be8674082432f14cc03ae21bf35929ed87..313101da28f8fae613fdab6a7448355fc2381997 100644 --- a/src/pairs/code_gen/cgen.py +++ b/src/pairs/code_gen/cgen.py @@ -30,6 +30,7 @@ from pairs.ir.print import Print, PrintCode from pairs.ir.variables import Var, DeclareVariable, Deref from pairs.ir.parameters import Parameter from pairs.ir.vectors import Vector, VectorAccess, VectorOp, ZeroVector +from pairs.ir.ret import Return from pairs.sim.domain_partitioners import DomainPartitioners from pairs.sim.timestep import Timestep from pairs.code_gen.printer import Printer @@ -185,7 +186,8 @@ class CGen: print_params = ", ".join(module_params) ending = "{" if definition else ";" - self.print(f"void {module.name}({print_params}){ending}") + tkw = Types.c_keyword(self.sim, module.return_type) + self.print(f"{tkw} {module.name}({print_params}){ending}") def generate_module_decls(self): self.print("") @@ -1140,6 +1142,10 @@ class CGen: self.generate_statement(ast_node.block) self.print("}") + if isinstance(ast_node, Return): + expr = self.generate_expression(ast_node.expr) + self.print(f"return {expr};") + def generate_expression(self, ast_node, mem=False, index=None): if isinstance(ast_node, Array): return self.generate_object_reference(ast_node) diff --git a/src/pairs/ir/functions.py b/src/pairs/ir/functions.py index efaf3e6044c6f1c649c6b56a3745f5a0944ce938..ffb49d53717b9a15940c6c9039ddd1ee6c7baf69 100644 --- a/src/pairs/ir/functions.py +++ b/src/pairs/ir/functions.py @@ -36,7 +36,7 @@ class Call_Int(Call): class Call_Void(Call): def __init__(self, sim, func_name, parameters): - super().__init__(sim, func_name, parameters, Types.Invalid) + super().__init__(sim, func_name, parameters, Types.Void) sim.add_statement(self) def __str__(self): diff --git a/src/pairs/ir/module.py b/src/pairs/ir/module.py index 04b4f850d57ed8c0e1ad4b28a4a9f30cce8384da..44539dba0eab376716cfd4920818889b81e85690 100644 --- a/src/pairs/ir/module.py +++ b/src/pairs/ir/module.py @@ -5,12 +5,20 @@ from pairs.ir.features import FeatureProperty from pairs.ir.properties import Property, ContactProperty from pairs.ir.variables import Var from pairs.ir.parameters import Parameter +from pairs.ir.types import Types class Module(ASTNode): last_module = 0 - def __init__(self, sim, name=None, block=None, resizes_to_check={}, check_properties_resize=False, run_on_device=False, user_defined=False, interface=False): + def __init__(self, sim, + name=None, + block=None, + resizes_to_check={}, + check_properties_resize=False, + run_on_device=False, + user_defined=False, + interface=False): super().__init__(sim) self._id = Module.last_module self._name = name if name is not None else "module" + str(Module.last_module) @@ -27,6 +35,7 @@ class Module(ASTNode): self._run_on_device = run_on_device self._user_defined = user_defined self._interface = interface + self._return_type = Types.Void self._profile = False sim.add_module(self) Module.last_module += 1 @@ -58,6 +67,10 @@ class Module(ASTNode): def interface(self): return self._interface + @property + def return_type(self): + return self._return_type + def profile(self): self._profile = True self.sim.enable_profiler() diff --git a/src/pairs/ir/mutator.py b/src/pairs/ir/mutator.py index 669dbc44792e36f5af75c6be2c35b2347bd48427..3fb017f80cc517c309075658225a1e6a4d435e03 100644 --- a/src/pairs/ir/mutator.py +++ b/src/pairs/ir/mutator.py @@ -54,7 +54,11 @@ class Mutator: ast_node._reduction_variable = self.mutate(ast_node._reduction_variable) return ast_node - + + def mutate_Return(self, ast_node): + ast_node.expr = self.mutate(ast_node.expr) + return ast_node + def mutate_Print(self, ast_node): ast_node.args = [self.mutate(arg) for arg in ast_node.args] return ast_node diff --git a/src/pairs/ir/ret.py b/src/pairs/ir/ret.py new file mode 100644 index 0000000000000000000000000000000000000000..bb235044e87e5ca08f2587b2e406399e796865f8 --- /dev/null +++ b/src/pairs/ir/ret.py @@ -0,0 +1,13 @@ +from pairs.ir.ast_node import ASTNode + +class Return(ASTNode): + def __init__(self, sim, expr): + super().__init__(sim) + self.expr = expr + self.sim.add_statement(self) + + def __str__(self): + return f"Return<{self.expr}>" + + def children(self): + return [self.expr] \ No newline at end of file diff --git a/src/pairs/ir/types.py b/src/pairs/ir/types.py index f4ab048507e3ee8dcf38ded75d2170c34bfc1673..ab27939a83c2da5b43da4ccf46a672cbb1e7943b 100644 --- a/src/pairs/ir/types.py +++ b/src/pairs/ir/types.py @@ -1,5 +1,6 @@ class Types: - Invalid = -1 + Invalid = -2 + Void = -1 Int32 = 0 Int64 = 1 UInt64 = 2 @@ -26,6 +27,7 @@ class Types: else 'long long int' if t == Types.Int64 else 'unsigned long long int' if t == Types.UInt64 else 'bool' if t == Types.Boolean + else 'void' if t == Types.Void else '<invalid type>' ) @@ -39,6 +41,7 @@ class Types: else 'long long int' if t == Types.Int64 else 'unsigned long long int' if t == Types.UInt64 else 'bool' if t == Types.Boolean + else 'void' if t == Types.Void else '<invalid type>' ) diff --git a/src/pairs/transformations/__init__.py b/src/pairs/transformations/__init__.py index 3a902ef2348c041c3c1734bec3bcbb83e21d5fda..733d5c10fbaec621d37db6a7009e64493f752719 100644 --- a/src/pairs/transformations/__init__.py +++ b/src/pairs/transformations/__init__.py @@ -67,6 +67,7 @@ class Transformations: self._module_resizes = add_resize_logic.module_resizes self.analysis().fetch_modules_references() self.apply(DereferenceWriteVariables()) + self.analysis().infer_modules_return_types() self.apply(ReplaceModulesByCalls(), [self._module_resizes]) self.apply(MergeAdjacentBlocks()) @@ -108,5 +109,9 @@ class Transformations: self.add_expression_declarations() self.add_host_references_to_modules() self.add_device_references_to_modules() - self.add_instrumentation() + + # TODO: Place stop timers before the function returns + # or simply don't instrument modules that have a non-void return type + # to avoid having to deal with returns within conditional blocks + # self.add_instrumentation()