From 490adaae5fc1af045d8f7ed933d756c4203bede4 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 29 May 2024 21:47:50 +0200 Subject: [PATCH] refine printing of integer literals --- src/pystencils/types/types.py | 22 +++++++++++++++++++--- tests/nbackend/test_extensions.py | 6 +++--- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/src/pystencils/types/types.py b/src/pystencils/types/types.py index e7f1ce208..2f0f2ff46 100644 --- a/src/pystencils/types/types.py +++ b/src/pystencils/types/types.py @@ -483,7 +483,20 @@ class PsIntegerType(PsScalarType, ABC): if not isinstance(value, np_dtype): raise PsTypeError(f"Given value {value} is not of required type {np_dtype}") unsigned_suffix = "" if self.signed else "u" - return f"{value}{unsigned_suffix}" + + match self.width: + case w if w < 32: + # Plain integer literals get at least type `int`, which is 32 bit in all relevant cases + # So we need to explicitly cast to smaller types + return f"(({self._c_type_without_const()}) {value}{unsigned_suffix})" + case 32: + # No suffix here - becomes `int`, which is 32 bit + return f"{value}{unsigned_suffix}" + case 64: + # LL suffix: `long long` is the only type guaranteed to be 64 bit wide + return f"{value}{unsigned_suffix}LL" + case _: + assert False, "unreachable code" def create_constant(self, value: Any) -> Any: np_type = self.NUMPY_TYPES[self._width] @@ -498,9 +511,12 @@ class PsIntegerType(PsScalarType, ABC): raise PsTypeError(f"Could not interpret {value} as {repr(self)}") - def c_string(self) -> str: + def _c_type_without_const(self) -> str: prefix = "" if self._signed else "u" - return f"{self._const_string()}{prefix}int{self._width}_t" + return f"{prefix}int{self._width}_t" + + def c_string(self) -> str: + return f"{self._const_string()}{self._c_type_without_const()}" def __repr__(self) -> str: return f"PsIntegerType( width={self.width}, signed={self.signed}, const={self.const} )" diff --git a/tests/nbackend/test_extensions.py b/tests/nbackend/test_extensions.py index 75726a351..16e610a55 100644 --- a/tests/nbackend/test_extensions.py +++ b/tests/nbackend/test_extensions.py @@ -54,6 +54,6 @@ def test_literals(): print(code) assert "const double x = C;" in code - assert "CELLS[0]" in code - assert "CELLS[1]" in code - assert "CELLS[2]" in code + assert "CELLS[0LL]" in code + assert "CELLS[1LL]" in code + assert "CELLS[2LL]" in code -- GitLab