diff --git a/pystencils/data_types.py b/pystencils/data_types.py index 3f1a02c6833756985308282c49148fe0ecaa0153..55788fa711a928464526054d18651651b7098e3b 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -56,6 +56,7 @@ class cast_func(sp.Function): # -> thus a separate class boolean_cast_func is introduced if isinstance(args[0], Boolean): cls = boolean_cast_func + args = (args[0], create_type(args[1])) return sp.Function.__new__(cls, *args, **kwargs) @property diff --git a/pystencils_tests/test_cast_cast.py b/pystencils_tests/test_cast_cast.py new file mode 100644 index 0000000000000000000000000000000000000000..a19e08030dc9243479e172948309d7ed6b92655b --- /dev/null +++ b/pystencils_tests/test_cast_cast.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- +# +# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de> +# +# Distributed under terms of the GPLv3 license. + +""" + +""" +import pystencils +from pystencils.data_types import cast_func +import numpy as np + + +def test_cast(): + float_type = pystencils.data_types.create_type('float32') + x, y = pystencils.fields('x,y: [1d]') + assignments = pystencils.AssignmentCollection({ + y.center(): cast_func(x.center(), float_type)}, {}) + kernel = pystencils.create_kernel(assignments) + print(pystencils.show_code(kernel)) + + kernel.compile() + + +def test_cast_cast(): + float_type = pystencils.data_types.create_type('float32') + x, y = pystencils.fields('x,y: [1d]') + assignments = pystencils.AssignmentCollection({ + y.center(): cast_func(cast_func(x.center(), float_type), float_type)}, {}) + kernel = pystencils.create_kernel(assignments) + print(pystencils.show_code(kernel)) + + kernel.compile() + + +def test_cast_with_string(): + x, y = pystencils.fields('x,y: [1d]') + assignments = pystencils.AssignmentCollection({ + y.center(): cast_func(x.center(), np.float32)}, {}) + kernel = pystencils.create_kernel(assignments) + print(pystencils.show_code(kernel)) + + kernel.compile() + + +def test_cast_cast_with_string(): + x, y = pystencils.fields('x,y: [1d]') + assignments = pystencils.AssignmentCollection({ + y.center(): cast_func(cast_func(x.center(), 'float32'), np.float32)}, {}) + kernel = pystencils.create_kernel(assignments) + print(pystencils.show_code(kernel)) + + +def main(): + test_cast() + test_cast_cast() + test_cast_with_string() + test_cast_cast_with_string() + + +if __name__ == '__main__': + main()