From 490adaae5fc1af045d8f7ed933d756c4203bede4 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 29 May 2024 21:47:50 +0200
Subject: [PATCH] refine printing of integer literals

---
 src/pystencils/types/types.py     | 22 +++++++++++++++++++---
 tests/nbackend/test_extensions.py |  6 +++---
 2 files changed, 22 insertions(+), 6 deletions(-)

diff --git a/src/pystencils/types/types.py b/src/pystencils/types/types.py
index e7f1ce208..2f0f2ff46 100644
--- a/src/pystencils/types/types.py
+++ b/src/pystencils/types/types.py
@@ -483,7 +483,20 @@ class PsIntegerType(PsScalarType, ABC):
         if not isinstance(value, np_dtype):
             raise PsTypeError(f"Given value {value} is not of required type {np_dtype}")
         unsigned_suffix = "" if self.signed else "u"
-        return f"{value}{unsigned_suffix}"
+
+        match self.width:
+            case w if w < 32:
+                #   Plain integer literals get at least type `int`, which is 32 bit in all relevant cases
+                #   So we need to explicitly cast to smaller types
+                return f"(({self._c_type_without_const()}) {value}{unsigned_suffix})"
+            case 32:
+                #   No suffix here - becomes `int`, which is 32 bit
+                return f"{value}{unsigned_suffix}"
+            case 64:
+                #   LL suffix: `long long` is the only type guaranteed to be 64 bit wide
+                return f"{value}{unsigned_suffix}LL"
+            case _:
+                assert False, "unreachable code"
 
     def create_constant(self, value: Any) -> Any:
         np_type = self.NUMPY_TYPES[self._width]
@@ -498,9 +511,12 @@ class PsIntegerType(PsScalarType, ABC):
 
         raise PsTypeError(f"Could not interpret {value} as {repr(self)}")
 
-    def c_string(self) -> str:
+    def _c_type_without_const(self) -> str:
         prefix = "" if self._signed else "u"
-        return f"{self._const_string()}{prefix}int{self._width}_t"
+        return f"{prefix}int{self._width}_t"
+
+    def c_string(self) -> str:
+        return f"{self._const_string()}{self._c_type_without_const()}"
 
     def __repr__(self) -> str:
         return f"PsIntegerType( width={self.width}, signed={self.signed}, const={self.const} )"
diff --git a/tests/nbackend/test_extensions.py b/tests/nbackend/test_extensions.py
index 75726a351..16e610a55 100644
--- a/tests/nbackend/test_extensions.py
+++ b/tests/nbackend/test_extensions.py
@@ -54,6 +54,6 @@ def test_literals():
     print(code)
 
     assert "const double x = C;" in code
-    assert "CELLS[0]" in code
-    assert "CELLS[1]" in code
-    assert "CELLS[2]" in code
+    assert "CELLS[0LL]" in code
+    assert "CELLS[1LL]" in code
+    assert "CELLS[2LL]" in code
-- 
GitLab