Skip to content
Snippets Groups Projects

fixed create_kernel parameter data_type="float" to procucde single precision

Merged Christoph Alt requested to merge ob28imeq/pystencils:fix_single_precision into master
Compare and
2 files
+ 51
0
Preferences
Compare changes
Files
2
@@ -960,6 +960,8 @@ def add_types(eqs, type_for_symbol, check_independence_condition, check_double_w
if isinstance(type_for_symbol, (str, type)) or not hasattr(type_for_symbol, '__getitem__'):
type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol)
type_for_symbol = adjust_c_single_precision_type(type_for_symbol)
check = KernelConstraintsCheck(type_for_symbol, check_independence_condition,
check_double_write_condition=check_double_write_condition)
@@ -1397,3 +1399,16 @@ def implement_interpolations(ast_node: ast.Node,
ast_node.subs(substitutions)
return ast_node
def adjust_c_single_precision_type(type_for_symbol):
"""Replaces every occurrence of 'float' with 'single' to enforce the numpy single precision type."""
def single_factory():
return "single"
for symbol in type_for_symbol:
if type_for_symbol[symbol] == "float":
type_for_symbol[symbol] = single_factory()
if hasattr(type_for_symbol, "default_factory") and type_for_symbol.default_factory() == "float":
type_for_symbol.default_factory = single_factory
return type_for_symbol