From e5294b75bd34a6a8cc6db5b82cc08ab39ee4ada4 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Thu, 11 Jul 2019 09:59:39 +0200
Subject: [PATCH] Allow `cast_func(x, 'float')` (previously only `cast_func(x,
 create_type('float'))`

---
 pystencils/data_types.py           |  1 +
 pystencils_tests/test_cast_cast.py | 63 ++++++++++++++++++++++++++++++
 2 files changed, 64 insertions(+)
 create mode 100644 pystencils_tests/test_cast_cast.py

diff --git a/pystencils/data_types.py b/pystencils/data_types.py
index 3f1a02c6..55788fa7 100644
--- a/pystencils/data_types.py
+++ b/pystencils/data_types.py
@@ -56,6 +56,7 @@ class cast_func(sp.Function):
         # -> thus a separate class boolean_cast_func is introduced
         if isinstance(args[0], Boolean):
             cls = boolean_cast_func
+        args = (args[0], create_type(args[1]))
         return sp.Function.__new__(cls, *args, **kwargs)
 
     @property
diff --git a/pystencils_tests/test_cast_cast.py b/pystencils_tests/test_cast_cast.py
new file mode 100644
index 00000000..a19e0803
--- /dev/null
+++ b/pystencils_tests/test_cast_cast.py
@@ -0,0 +1,63 @@
+# -*- coding: utf-8 -*-
+#
+# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
+#
+# Distributed under terms of the GPLv3 license.
+
+"""
+
+"""
+import pystencils
+from pystencils.data_types import cast_func
+import numpy as np
+
+
+def test_cast():
+    float_type = pystencils.data_types.create_type('float32')
+    x, y = pystencils.fields('x,y: [1d]')
+    assignments = pystencils.AssignmentCollection({
+        y.center(): cast_func(x.center(), float_type)}, {})
+    kernel = pystencils.create_kernel(assignments)
+    print(pystencils.show_code(kernel))
+
+    kernel.compile()
+
+
+def test_cast_cast():
+    float_type = pystencils.data_types.create_type('float32')
+    x, y = pystencils.fields('x,y: [1d]')
+    assignments = pystencils.AssignmentCollection({
+        y.center(): cast_func(cast_func(x.center(), float_type), float_type)}, {})
+    kernel = pystencils.create_kernel(assignments)
+    print(pystencils.show_code(kernel))
+
+    kernel.compile()
+
+
+def test_cast_with_string():
+    x, y = pystencils.fields('x,y: [1d]')
+    assignments = pystencils.AssignmentCollection({
+        y.center(): cast_func(x.center(), np.float32)}, {})
+    kernel = pystencils.create_kernel(assignments)
+    print(pystencils.show_code(kernel))
+
+    kernel.compile()
+
+
+def test_cast_cast_with_string():
+    x, y = pystencils.fields('x,y: [1d]')
+    assignments = pystencils.AssignmentCollection({
+        y.center(): cast_func(cast_func(x.center(), 'float32'), np.float32)}, {})
+    kernel = pystencils.create_kernel(assignments)
+    print(pystencils.show_code(kernel))
+
+
+def main():
+    test_cast()
+    test_cast_cast()
+    test_cast_with_string()
+    test_cast_cast_with_string()
+
+
+if __name__ == '__main__':
+    main()
-- 
GitLab