Skip to content
Snippets Groups Projects

Support complex numbers

Merged Stephan Seitz requested to merge seitz/pystencils:support-complex-numbers into master
2 files
+ 36
5
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 19
2
@@ -255,6 +255,8 @@ type_mapping = {
@@ -255,6 +255,8 @@ type_mapping = {
np.uint16: ('PyLong_AsUnsignedLong', 'uint16_t'),
np.uint16: ('PyLong_AsUnsignedLong', 'uint16_t'),
np.uint32: ('PyLong_AsUnsignedLong', 'uint32_t'),
np.uint32: ('PyLong_AsUnsignedLong', 'uint32_t'),
np.uint64: ('PyLong_AsUnsignedLong', 'uint64_t'),
np.uint64: ('PyLong_AsUnsignedLong', 'uint64_t'),
 
np.complex64: (('PyComplex_RealAsDouble', 'PyComplex_ImagAsDouble'), 'ComplexFloat'),
 
np.complex128: (('PyComplex_RealAsDouble', 'PyComplex_ImagAsDouble'), 'ComplexDouble'),
}
}
@@ -265,6 +267,13 @@ if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '
@@ -265,6 +267,13 @@ if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '
if( PyErr_Occurred() ) {{ return NULL; }}
if( PyErr_Occurred() ) {{ return NULL; }}
"""
"""
 
template_extract_complex = """
 
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
 
if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
 
{target_type} {name}{{ {extract_function_real}( obj_{name} ), {extract_function_imag}( obj_{name} ) }};
 
if( PyErr_Occurred() ) {{ return NULL; }}
 
"""
 
template_extract_array = """
template_extract_array = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
@@ -396,8 +405,16 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
@@ -396,8 +405,16 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
parameters.append("buffer_{name}.shape[{i}]".format(i=param.symbol.coordinate, name=param.field_name))
parameters.append("buffer_{name}.shape[{i}]".format(i=param.symbol.coordinate, name=param.field_name))
else:
else:
extract_function, target_type = type_mapping[param.symbol.dtype.numpy_dtype.type]
extract_function, target_type = type_mapping[param.symbol.dtype.numpy_dtype.type]
pre_call_code += template_extract_scalar.format(extract_function=extract_function, target_type=target_type,
if np.issubdtype(param.symbol.dtype.numpy_dtype, np.complexfloating):
name=param.symbol.name)
pre_call_code += template_extract_complex.format(extract_function_real=extract_function[0],
 
extract_function_imag=extract_function[1],
 
target_type=target_type,
 
name=param.symbol.name)
 
else:
 
pre_call_code += template_extract_scalar.format(extract_function=extract_function,
 
target_type=target_type,
 
name=param.symbol.name)
 
parameters.append(param.symbol.name)
parameters.append(param.symbol.name)
pre_call_code += equal_size_check(variable_sized_normal_fields)
pre_call_code += equal_size_check(variable_sized_normal_fields)
Loading