Skip to content
Snippets Groups Projects
Commit 0460532f authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add call code for complex scalars as arguments

parent 5a5a878c
Branches
Tags
No related merge requests found
Pipeline #18831 passed
...@@ -255,6 +255,8 @@ type_mapping = { ...@@ -255,6 +255,8 @@ type_mapping = {
np.uint16: ('PyLong_AsUnsignedLong', 'uint16_t'), np.uint16: ('PyLong_AsUnsignedLong', 'uint16_t'),
np.uint32: ('PyLong_AsUnsignedLong', 'uint32_t'), np.uint32: ('PyLong_AsUnsignedLong', 'uint32_t'),
np.uint64: ('PyLong_AsUnsignedLong', 'uint64_t'), np.uint64: ('PyLong_AsUnsignedLong', 'uint64_t'),
np.complex64: (('PyComplex_RealAsDouble', 'PyComplex_ImagAsDouble'), 'ComplexFloat'),
np.complex128: (('PyComplex_RealAsDouble', 'PyComplex_ImagAsDouble'), 'ComplexDouble'),
} }
...@@ -265,6 +267,13 @@ if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument ' ...@@ -265,6 +267,13 @@ if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '
if( PyErr_Occurred() ) {{ return NULL; }} if( PyErr_Occurred() ) {{ return NULL; }}
""" """
template_extract_complex = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
{target_type} {name}{{ {extract_function_real}( obj_{name} ), {extract_function_imag}( obj_{name} ) }};
if( PyErr_Occurred() ) {{ return NULL; }}
"""
template_extract_array = """ template_extract_array = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}"); PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }}; if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
...@@ -396,8 +405,16 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True): ...@@ -396,8 +405,16 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
parameters.append("buffer_{name}.shape[{i}]".format(i=param.symbol.coordinate, name=param.field_name)) parameters.append("buffer_{name}.shape[{i}]".format(i=param.symbol.coordinate, name=param.field_name))
else: else:
extract_function, target_type = type_mapping[param.symbol.dtype.numpy_dtype.type] extract_function, target_type = type_mapping[param.symbol.dtype.numpy_dtype.type]
pre_call_code += template_extract_scalar.format(extract_function=extract_function, target_type=target_type, if np.issubdtype(param.symbol.dtype.numpy_dtype, np.complexfloating):
name=param.symbol.name) pre_call_code += template_extract_complex.format(extract_function_real=extract_function[0],
extract_function_imag=extract_function[1],
target_type=target_type,
name=param.symbol.name)
else:
pre_call_code += template_extract_scalar.format(extract_function=extract_function,
target_type=target_type,
name=param.symbol.name)
parameters.append(param.symbol.name) parameters.append(param.symbol.name)
pre_call_code += equal_size_check(variable_sized_normal_fields) pre_call_code += equal_size_check(variable_sized_normal_fields)
......
...@@ -106,7 +106,8 @@ def test_complex_numbers_64(assignment, target): ...@@ -106,7 +106,8 @@ def test_complex_numbers_64(assignment, target):
@pytest.mark.parametrize('dtype', (np.float32, np.float64)) @pytest.mark.parametrize('dtype', (np.float32, np.float64))
@pytest.mark.parametrize('target', ('cpu', 'gpu')) @pytest.mark.parametrize('target', ('cpu', 'gpu'))
def test_complex_execution(dtype, target): @pytest.mark.parametrize('with_complex_argument', ('with_complex_argument', False))
def test_complex_execution(dtype, target, with_complex_argument):
complex_dtype = f'complex{64 if dtype ==np.float32 else 128}' complex_dtype = f'complex{64 if dtype ==np.float32 else 128}'
x, y = pystencils.fields(f'x, y: {complex_dtype}[2d]') x, y = pystencils.fields(f'x, y: {complex_dtype}[2d]')
...@@ -114,8 +115,13 @@ def test_complex_execution(dtype, target): ...@@ -114,8 +115,13 @@ def test_complex_execution(dtype, target):
x_arr = np.zeros((20, 30), complex_dtype) x_arr = np.zeros((20, 30), complex_dtype)
y_arr = np.zeros((20, 30), complex_dtype) y_arr = np.zeros((20, 30), complex_dtype)
if with_complex_argument:
a = pystencils.TypedSymbol('a', create_type(complex_dtype))
else:
a = (2j+1)
assignments = AssignmentCollection({ assignments = AssignmentCollection({
y.center: x.center * (2j+1) y.center: x.center + a
}) })
if target == 'gpu': if target == 'gpu':
...@@ -125,4 +131,12 @@ def test_complex_execution(dtype, target): ...@@ -125,4 +131,12 @@ def test_complex_execution(dtype, target):
kernel = pystencils.create_kernel(assignments, target=target, data_type=dtype).compile() kernel = pystencils.create_kernel(assignments, target=target, data_type=dtype).compile()
kernel(x=x_arr, y=y_arr) if with_complex_argument:
kernel(x=x_arr, y=y_arr, a=2j+1)
else:
kernel(x=x_arr, y=y_arr)
if target == 'gpu':
y_arr = y_arr.get()
assert np.allclose(y_arr, 2j+1)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment