From 7c3a2e6cbb4de2dfe8185ba6bc11a3dbe3a3a26b Mon Sep 17 00:00:00 2001
From: Daniel Bauer <daniel.j.bauer@fau.de>
Date: Mon, 25 Mar 2024 16:08:30 +0100
Subject: [PATCH] add integer division

---
 src/pystencils/backend/ast/expressions.py           |  6 ++++++
 src/pystencils/backend/emission.py                  |  3 ++-
 src/pystencils/backend/kernelcreation/freeze.py     | 13 ++++++++-----
 .../backend/kernelcreation/typification.py          |  4 +++-
 4 files changed, 19 insertions(+), 7 deletions(-)

diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py
index b3a83295f..31d415fbd 100644
--- a/src/pystencils/backend/ast/expressions.py
+++ b/src/pystencils/backend/ast/expressions.py
@@ -485,6 +485,12 @@ class PsDiv(PsBinOp):
     pass
 
 
+class PsIntDiv(PsBinOp):
+    @property
+    def python_operator(self) -> Callable[[Any, Any], Any] | None:
+        return operator.floordiv
+
+
 class PsLeftShift(PsBinOp):
     @property
     def python_operator(self) -> Callable[[Any, Any], Any] | None:
diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py
index b1eec1f71..0a582b10d 100644
--- a/src/pystencils/backend/emission.py
+++ b/src/pystencils/backend/emission.py
@@ -24,6 +24,7 @@ from .ast.expressions import (
     PsSub,
     PsMul,
     PsDiv,
+    PsIntDiv,
     PsNeg,
     PsDeref,
     PsAddressOf,
@@ -321,7 +322,7 @@ class CAstPrinter:
                 return ("-", Ops.Sub)
             case PsMul():
                 return ("*", Ops.Mul)
-            case PsDiv():
+            case PsDiv() | PsIntDiv():
                 return ("/", Ops.Div)
             case PsLeftShift():
                 return ("<<", Ops.LeftShift)
diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py
index bedb51743..f5775affd 100644
--- a/src/pystencils/backend/kernelcreation/freeze.py
+++ b/src/pystencils/backend/kernelcreation/freeze.py
@@ -27,6 +27,7 @@ from ..ast.expressions import (
     PsCall,
     PsCast,
     PsConstantExpr,
+    PsIntDiv,
     PsLeftShift,
     PsLookup,
     PsRightShift,
@@ -308,15 +309,17 @@ class FreezeExpressions:
 
         match func:
             case sp.Abs():
-                PsCall(PsMathFunction(MathFunctions.Abs), args)
+                return PsCall(PsMathFunction(MathFunctions.Abs), args)
             case sp.exp():
-                PsCall(PsMathFunction(MathFunctions.Exp), args)
+                return PsCall(PsMathFunction(MathFunctions.Exp), args)
             case sp.sin():
-                PsCall(PsMathFunction(MathFunctions.Sin), args)
+                return PsCall(PsMathFunction(MathFunctions.Sin), args)
             case sp.cos():
-                PsCall(PsMathFunction(MathFunctions.Cos), args)
+                return PsCall(PsMathFunction(MathFunctions.Cos), args)
             case sp.tan():
-                PsCall(PsMathFunction(MathFunctions.Tan), args)
+                return PsCall(PsMathFunction(MathFunctions.Tan), args)
+            case integer_functions.int_div():
+                return PsIntDiv(*args)
             case integer_functions.bit_shift_left():
                 return PsLeftShift(*args)
             case integer_functions.bit_shift_right():
diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py
index 3908a2c99..3fbb9c1a8 100644
--- a/src/pystencils/backend/kernelcreation/typification.py
+++ b/src/pystencils/backend/kernelcreation/typification.py
@@ -31,6 +31,7 @@ from ..ast.expressions import (
     PsCall,
     PsCast,
     PsConstantExpr,
+    PsIntDiv,
     PsLeftShift,
     PsLookup,
     PsRightShift,
@@ -265,7 +266,8 @@ class Typifier:
 
             # integer operations
             case (
-                PsLeftShift(op1, op2)
+                PsIntDiv(op1, op2)
+                | PsLeftShift(op1, op2)
                 | PsRightShift(op1, op2)
                 | PsBitwiseAnd(op1, op2)
                 | PsBitwiseXor(op1, op2)
-- 
GitLab