Skip to content
Snippets Groups Projects
Commit 3f5fe292 authored by Martin Bauer's avatar Martin Bauer
Browse files

Merge branch 'support-complex-numbers' into 'master'

Support complex numbers

See merge request pycodegen/pystencils!72
parents a834955b eda2f772
No related branches found
No related tags found
No related merge requests found
...@@ -5,7 +5,7 @@ from typing import Any, List, Optional, Sequence, Set, Union ...@@ -5,7 +5,7 @@ from typing import Any, List, Optional, Sequence, Set, Union
import sympy as sp import sympy as sp
from pystencils.data_types import TypedSymbol, cast_func, create_type from pystencils.data_types import TypedImaginaryUnit, TypedSymbol, cast_func, create_type
from pystencils.field import Field from pystencils.field import Field
from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
from pystencils.sympyextensions import fast_subs from pystencils.sympyextensions import fast_subs
...@@ -569,6 +569,7 @@ class SympyAssignment(Node): ...@@ -569,6 +569,7 @@ class SympyAssignment(Node):
if isinstance(symbol, Field.Access): if isinstance(symbol, Field.Access):
for i in range(len(symbol.offsets)): for i in range(len(symbol.offsets)):
loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i)) loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i))
result = {r for r in result if not isinstance(r, TypedImaginaryUnit)}
result.update(loop_counters) result.update(loop_counters)
result.update(self._lhs_symbol.atoms(sp.Symbol)) result.update(self._lhs_symbol.atoms(sp.Symbol))
return result return result
......
...@@ -80,8 +80,8 @@ def get_global_declarations(ast): ...@@ -80,8 +80,8 @@ def get_global_declarations(ast):
global_declarations = [] global_declarations = []
def visit_node(sub_ast): def visit_node(sub_ast):
nonlocal global_declarations
if hasattr(sub_ast, "required_global_declarations"): if hasattr(sub_ast, "required_global_declarations"):
nonlocal global_declarations
global_declarations += sub_ast.required_global_declarations global_declarations += sub_ast.required_global_declarations
if hasattr(sub_ast, "args"): if hasattr(sub_ast, "args"):
...@@ -103,7 +103,7 @@ def get_headers(ast_node: Node) -> Set[str]: ...@@ -103,7 +103,7 @@ def get_headers(ast_node: Node) -> Set[str]:
if hasattr(ast_node, 'headers'): if hasattr(ast_node, 'headers'):
headers.update(ast_node.headers) headers.update(ast_node.headers)
for a in ast_node.args: for a in ast_node.args:
if isinstance(a, Node): if isinstance(a, (sp.Expr, Node)):
headers.update(get_headers(a)) headers.update(get_headers(a))
for g in get_global_declarations(ast_node): for g in get_global_declarations(ast_node):
...@@ -234,7 +234,8 @@ class CBackend: ...@@ -234,7 +234,8 @@ class CBackend:
else: else:
prefix = '' prefix = ''
data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " " data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " "
return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs), return "%s%s = %s;" % (data_type,
self.sympy_printer.doprint(node.lhs),
self.sympy_printer.doprint(node.rhs)) self.sympy_printer.doprint(node.rhs))
else: else:
lhs_type = get_type_of_expression(node.lhs) lhs_type = get_type_of_expression(node.lhs)
...@@ -443,6 +444,27 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -443,6 +444,27 @@ class CustomSympyPrinter(CCodePrinter):
_print_Max = C89CodePrinter._print_Max _print_Max = C89CodePrinter._print_Max
_print_Min = C89CodePrinter._print_Min _print_Min = C89CodePrinter._print_Min
def _print_re(self, expr):
return f"real({self._print(expr.args[0])})"
def _print_im(self, expr):
return f"imag({self._print(expr.args[0])})"
def _print_ImaginaryUnit(self, expr):
return "complex<double>{0,1}"
def _print_TypedImaginaryUnit(self, expr):
if expr.dtype.numpy_dtype == np.complex64:
return "complex<float>{0,1}"
elif expr.dtype.numpy_dtype == np.complex128:
return "complex<double>{0,1}"
else:
raise NotImplementedError(
"only complex64 and complex128 supported")
def _print_Complex(self, expr):
return self._typed_number(expr, np.complex64)
# noinspection PyPep8Naming # noinspection PyPep8Naming
class VectorizedCustomSympyPrinter(CustomSympyPrinter): class VectorizedCustomSympyPrinter(CustomSympyPrinter):
......
...@@ -255,6 +255,8 @@ type_mapping = { ...@@ -255,6 +255,8 @@ type_mapping = {
np.uint16: ('PyLong_AsUnsignedLong', 'uint16_t'), np.uint16: ('PyLong_AsUnsignedLong', 'uint16_t'),
np.uint32: ('PyLong_AsUnsignedLong', 'uint32_t'), np.uint32: ('PyLong_AsUnsignedLong', 'uint32_t'),
np.uint64: ('PyLong_AsUnsignedLong', 'uint64_t'), np.uint64: ('PyLong_AsUnsignedLong', 'uint64_t'),
np.complex64: (('PyComplex_RealAsDouble', 'PyComplex_ImagAsDouble'), 'ComplexFloat'),
np.complex128: (('PyComplex_RealAsDouble', 'PyComplex_ImagAsDouble'), 'ComplexDouble'),
} }
...@@ -265,6 +267,13 @@ if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument ' ...@@ -265,6 +267,13 @@ if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '
if( PyErr_Occurred() ) {{ return NULL; }} if( PyErr_Occurred() ) {{ return NULL; }}
""" """
template_extract_complex = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
{target_type} {name}{{ {extract_function_real}( obj_{name} ), {extract_function_imag}( obj_{name} ) }};
if( PyErr_Occurred() ) {{ return NULL; }}
"""
template_extract_array = """ template_extract_array = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}"); PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }}; if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
...@@ -358,7 +367,8 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True): ...@@ -358,7 +367,8 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
np_dtype = field.dtype.numpy_dtype np_dtype = field.dtype.numpy_dtype
item_size = np_dtype.itemsize item_size = np_dtype.itemsize
if np_dtype.isbuiltin and FieldType.is_generic(field): if (np_dtype.isbuiltin and FieldType.is_generic(field)
and not np.issubdtype(field.dtype.numpy_dtype, np.complexfloating)):
dtype_cond = "buffer_{name}.format[0] == '{format}'".format(name=field.name, dtype_cond = "buffer_{name}.format[0] == '{format}'".format(name=field.name,
format=field.dtype.numpy_dtype.char) format=field.dtype.numpy_dtype.char)
pre_call_code += template_check_array.format(cond=dtype_cond, what="data type", name=field.name, pre_call_code += template_check_array.format(cond=dtype_cond, what="data type", name=field.name,
...@@ -395,8 +405,16 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True): ...@@ -395,8 +405,16 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
parameters.append("buffer_{name}.shape[{i}]".format(i=param.symbol.coordinate, name=param.field_name)) parameters.append("buffer_{name}.shape[{i}]".format(i=param.symbol.coordinate, name=param.field_name))
else: else:
extract_function, target_type = type_mapping[param.symbol.dtype.numpy_dtype.type] extract_function, target_type = type_mapping[param.symbol.dtype.numpy_dtype.type]
pre_call_code += template_extract_scalar.format(extract_function=extract_function, target_type=target_type, if np.issubdtype(param.symbol.dtype.numpy_dtype, np.complexfloating):
name=param.symbol.name) pre_call_code += template_extract_complex.format(extract_function_real=extract_function[0],
extract_function_imag=extract_function[1],
target_type=target_type,
name=param.symbol.name)
else:
pre_call_code += template_extract_scalar.format(extract_function=extract_function,
target_type=target_type,
name=param.symbol.name)
parameters.append(param.symbol.name) parameters.append(param.symbol.name)
pre_call_code += equal_size_check(variable_sized_normal_fields) pre_call_code += equal_size_check(variable_sized_normal_fields)
......
...@@ -4,14 +4,14 @@ from functools import partial ...@@ -4,14 +4,14 @@ from functools import partial
from typing import Tuple from typing import Tuple
import numpy as np import numpy as np
import sympy as sp
import sympy.codegen.ast
from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean
import pystencils import pystencils
import sympy as sp
import sympy.codegen.ast
from pystencils.cache import memorycache, memorycache_if_hashable from pystencils.cache import memorycache, memorycache_if_hashable
from pystencils.utils import all_equal from pystencils.utils import all_equal
from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean
try: try:
import llvmlite.ir as ir import llvmlite.ir as ir
...@@ -250,6 +250,22 @@ class TypedSymbol(sp.Symbol): ...@@ -250,6 +250,22 @@ class TypedSymbol(sp.Symbol):
def reversed(self): def reversed(self):
return self return self
@property
def headers(self):
headers = []
try:
if np.issubdtype(self.dtype.numpy_dtype, np.complexfloating):
headers.append('"cuda_complex.hpp"')
except Exception:
pass
try:
if np.issubdtype(self.dtype.base_type.numpy_dtype, np.complexfloating):
headers.append('"cuda_complex.hpp"')
except Exception:
pass
return headers
def create_type(specification): def create_type(specification):
"""Creates a subclass of Type according to a string or an object of subclass Type. """Creates a subclass of Type according to a string or an object of subclass Type.
...@@ -420,16 +436,29 @@ def peel_off_type(dtype, type_to_peel_off): ...@@ -420,16 +436,29 @@ def peel_off_type(dtype, type_to_peel_off):
return dtype return dtype
def collate_types(types, forbid_collation_to_float=False): def collate_types(types,
forbid_collation_to_complex=False,
forbid_collation_to_float=False,
default_float_type='float64',
default_int_type='int64'):
""" """
Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
Uses the collation rules from numpy. Uses the collation rules from numpy.
""" """
if forbid_collation_to_complex:
types = [
t for t in types
if not np.issubdtype(t.numpy_dtype, np.complexfloating)
]
if not types:
return create_type(default_float_type)
if forbid_collation_to_float: if forbid_collation_to_float:
types = [t for t in types if not (hasattr(t, 'is_float') and t.is_float())] types = [
t for t in types if not np.issubdtype(t.numpy_dtype, np.floating)
]
if not types: if not types:
return create_type('int32') return create_type(default_int_type)
# Pointer arithmetic case i.e. pointer + integer is allowed # Pointer arithmetic case i.e. pointer + integer is allowed
if any(type(t) is PointerType for t in types): if any(type(t) is PointerType for t in types):
...@@ -484,6 +513,8 @@ def get_type_of_expression(expr, ...@@ -484,6 +513,8 @@ def get_type_of_expression(expr,
expr = sp.sympify(expr) expr = sp.sympify(expr)
if isinstance(expr, sp.Integer): if isinstance(expr, sp.Integer):
return create_type(default_int_type) return create_type(default_int_type)
elif expr.is_real is False:
return create_type((np.zeros((1,), default_float_type) * 1j).dtype)
elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float): elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
return create_type(default_float_type) return create_type(default_float_type)
elif isinstance(expr, ResolvedFieldAccess): elif isinstance(expr, ResolvedFieldAccess):
...@@ -510,7 +541,7 @@ def get_type_of_expression(expr, ...@@ -510,7 +541,7 @@ def get_type_of_expression(expr,
elif isinstance(expr, sp.Indexed): elif isinstance(expr, sp.Indexed):
typed_symbol = expr.base.label typed_symbol = expr.base.label
return typed_symbol.dtype.base_type return typed_symbol.dtype.base_type
elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction): elif isinstance(expr, (sp.boolalg.Boolean, sp.boolalg.BooleanFunction)):
# if any arg is of vector type return a vector boolean, else return a normal scalar boolean # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
result = create_type("bool") result = create_type("bool")
vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)] vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
...@@ -523,7 +554,12 @@ def get_type_of_expression(expr, ...@@ -523,7 +554,12 @@ def get_type_of_expression(expr,
expr: sp.Expr expr: sp.Expr
if expr.args: if expr.args:
types = tuple(get_type(a) for a in expr.args) types = tuple(get_type(a) for a in expr.args)
return collate_types(types) return collate_types(
types,
forbid_collation_to_complex=expr.is_real is True,
forbid_collation_to_float=expr.is_integer is True,
default_float_type=default_float_type,
default_int_type=default_int_type)
else: else:
if expr.is_integer: if expr.is_integer:
return create_type(default_int_type) return create_type(default_int_type)
...@@ -550,6 +586,10 @@ class BasicType(Type): ...@@ -550,6 +586,10 @@ class BasicType(Type):
return 'double' return 'double'
elif name == 'float32': elif name == 'float32':
return 'float' return 'float'
elif name == 'complex64':
return 'ComplexFloat'
elif name == 'complex128':
return 'ComplexDouble'
elif name.startswith('int'): elif name.startswith('int'):
width = int(name[len("int"):]) width = int(name[len("int"):])
return "int%d_t" % (width,) return "int%d_t" % (width,)
...@@ -761,3 +801,23 @@ class StructType: ...@@ -761,3 +801,23 @@ class StructType:
def __hash__(self): def __hash__(self):
return hash((self.numpy_dtype, self.const)) return hash((self.numpy_dtype, self.const))
class TypedImaginaryUnit(TypedSymbol):
def __new__(cls, *args, **kwds):
obj = TypedImaginaryUnit.__xnew_cached_(cls, *args, **kwds)
return obj
def __new_stage2__(cls, dtype, *args, **kwargs):
obj = super(TypedImaginaryUnit, cls).__xnew__(cls,
"_i",
dtype,
imaginary=True,
*args,
**kwargs)
return obj
headers = ['"cuda_complex.hpp"']
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
// An implementation of C++ std::complex for use on CUDA devices.
// Written by John C. Travers <jtravs@gmail.com> (2012)
//
// Missing:
// - long double support (not supported on CUDA)
// - some integral pow functions (due to lack of C++11 support on CUDA)
//
// Heavily derived from the LLVM libcpp project (svn revision 147853).
// Based on libcxx/include/complex.
// The git history contains the complete change history from the original.
// The modifications are licensed as per the original LLVM license below.
//
// -*- C++ -*-
//===--------------------------- complex ----------------------------------===//
//
// The LLVM Compiler Infrastructure
//
// This file is dual licensed under the MIT and the University of Illinois Open
// Source Licenses. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
extern "C++" {
#ifndef CUDA_COMPLEX_HPP
#define CUDA_COMPLEX_HPP
#ifdef __CUDACC__
#define CUDA_CALLABLE_MEMBER __host__ __device__
#else
#define CUDA_CALLABLE_MEMBER
#endif
/*
complex synopsis
template<class T>
class complex
{
public:
typedef T value_type;
complex(const T& re = T(), const T& im = T());
complex(const complex&);
template<class X> complex(const complex<X>&);
T real() const;
T imag() const;
void real(T);
void imag(T);
complex<T>& operator= (const T&);
complex<T>& operator+=(const T&);
complex<T>& operator-=(const T&);
complex<T>& operator*=(const T&);
complex<T>& operator/=(const T&);
complex& operator=(const complex&);
template<class X> complex<T>& operator= (const complex<X>&);
template<class X> complex<T>& operator+=(const complex<X>&);
template<class X> complex<T>& operator-=(const complex<X>&);
template<class X> complex<T>& operator*=(const complex<X>&);
template<class X> complex<T>& operator/=(const complex<X>&);
};
template<>
class complex<float>
{
public:
typedef float value_type;
constexpr complex(float re = 0.0f, float im = 0.0f);
explicit constexpr complex(const complex<double>&);
constexpr float real() const;
void real(float);
constexpr float imag() const;
void imag(float);
complex<float>& operator= (float);
complex<float>& operator+=(float);
complex<float>& operator-=(float);
complex<float>& operator*=(float);
complex<float>& operator/=(float);
complex<float>& operator=(const complex<float>&);
template<class X> complex<float>& operator= (const complex<X>&);
template<class X> complex<float>& operator+=(const complex<X>&);
template<class X> complex<float>& operator-=(const complex<X>&);
template<class X> complex<float>& operator*=(const complex<X>&);
template<class X> complex<float>& operator/=(const complex<X>&);
};
template<>
class complex<double>
{
public:
typedef double value_type;
constexpr complex(double re = 0.0, double im = 0.0);
constexpr complex(const complex<float>&);
constexpr double real() const;
void real(double);
constexpr double imag() const;
void imag(double);
complex<double>& operator= (double);
complex<double>& operator+=(double);
complex<double>& operator-=(double);
complex<double>& operator*=(double);
complex<double>& operator/=(double);
complex<double>& operator=(const complex<double>&);
template<class X> complex<double>& operator= (const complex<X>&);
template<class X> complex<double>& operator+=(const complex<X>&);
template<class X> complex<double>& operator-=(const complex<X>&);
template<class X> complex<double>& operator*=(const complex<X>&);
template<class X> complex<double>& operator/=(const complex<X>&);
};
// 26.3.6 operators:
template<class T> complex<T> operator+(const complex<T>&, const complex<T>&);
template<class T> complex<T> operator+(const complex<T>&, const T&);
template<class T> complex<T> operator+(const T&, const complex<T>&);
template<class T> complex<T> operator-(const complex<T>&, const complex<T>&);
template<class T> complex<T> operator-(const complex<T>&, const T&);
template<class T> complex<T> operator-(const T&, const complex<T>&);
template<class T> complex<T> operator*(const complex<T>&, const complex<T>&);
template<class T> complex<T> operator*(const complex<T>&, const T&);
template<class T> complex<T> operator*(const T&, const complex<T>&);
template<class T> complex<T> operator/(const complex<T>&, const complex<T>&);
template<class T> complex<T> operator/(const complex<T>&, const T&);
template<class T> complex<T> operator/(const T&, const complex<T>&);
template<class T> complex<T> operator+(const complex<T>&);
template<class T> complex<T> operator-(const complex<T>&);
template<class T> bool operator==(const complex<T>&, const complex<T>&);
template<class T> bool operator==(const complex<T>&, const T&);
template<class T> bool operator==(const T&, const complex<T>&);
template<class T> bool operator!=(const complex<T>&, const complex<T>&);
template<class T> bool operator!=(const complex<T>&, const T&);
template<class T> bool operator!=(const T&, const complex<T>&);
template<class T, class charT, class traits>
basic_istream<charT, traits>&
operator>>(basic_istream<charT, traits>&, complex<T>&);
template<class T, class charT, class traits>
basic_ostream<charT, traits>&
operator<<(basic_ostream<charT, traits>&, const complex<T>&);
// 26.3.7 values:
template<class T> T real(const complex<T>&);
double real(double);
template<Integral T> double real(T);
float real(float);
template<class T> T imag(const complex<T>&);
double imag(double);
template<Integral T> double imag(T);
float imag(float);
template<class T> T abs(const complex<T>&);
template<class T> T arg(const complex<T>&);
double arg(double);
template<Integral T> double arg(T);
float arg(float);
template<class T> T norm(const complex<T>&);
double norm(double);
template<Integral T> double norm(T);
float norm(float);
template<class T> complex<T> conj(const complex<T>&);
complex<double> conj(double);
template<Integral T> complex<double> conj(T);
complex<float> conj(float);
template<class T> complex<T> proj(const complex<T>&);
complex<double> proj(double);
template<Integral T> complex<double> proj(T);
complex<float> proj(float);
template<class T> complex<T> polar(const T&, const T& = 0);
// 26.3.8 transcendentals:
template<class T> complex<T> acos(const complex<T>&);
template<class T> complex<T> asin(const complex<T>&);
template<class T> complex<T> atan(const complex<T>&);
template<class T> complex<T> acosh(const complex<T>&);
template<class T> complex<T> asinh(const complex<T>&);
template<class T> complex<T> atanh(const complex<T>&);
template<class T> complex<T> cos (const complex<T>&);
template<class T> complex<T> cosh (const complex<T>&);
template<class T> complex<T> exp (const complex<T>&);
template<class T> complex<T> log (const complex<T>&);
template<class T> complex<T> log10(const complex<T>&);
template<class T> complex<T> pow(const complex<T>&, const T&);
template<class T> complex<T> pow(const complex<T>&, const complex<T>&);
template<class T> complex<T> pow(const T&, const complex<T>&);
template<class T> complex<T> sin (const complex<T>&);
template<class T> complex<T> sinh (const complex<T>&);
template<class T> complex<T> sqrt (const complex<T>&);
template<class T> complex<T> tan (const complex<T>&);
template<class T> complex<T> tanh (const complex<T>&);
template<class T, class charT, class traits>
basic_istream<charT, traits>&
operator>>(basic_istream<charT, traits>& is, complex<T>& x);
template<class T, class charT, class traits>
basic_ostream<charT, traits>&
operator<<(basic_ostream<charT, traits>& o, const complex<T>& x);
*/
#include <math.h>
#include <sstream>
template <class _Tp> class complex;
template <class _Tp>
complex<_Tp> operator*(const complex<_Tp> &__z, const complex<_Tp> &__w);
template <class _Tp>
complex<_Tp> operator/(const complex<_Tp> &__x, const complex<_Tp> &__y);
template <class _Tp> class complex {
public:
typedef _Tp value_type;
private:
value_type __re_;
value_type __im_;
public:
CUDA_CALLABLE_MEMBER
complex(const value_type &__re = value_type(),
const value_type &__im = value_type())
: __re_(__re), __im_(__im) {}
template <class _Xp>
CUDA_CALLABLE_MEMBER complex(const complex<_Xp> &__c)
: __re_(__c.real()), __im_(__c.imag()) {}
CUDA_CALLABLE_MEMBER value_type real() const { return __re_; }
CUDA_CALLABLE_MEMBER value_type imag() const { return __im_; }
CUDA_CALLABLE_MEMBER void real(value_type __re) { __re_ = __re; }
CUDA_CALLABLE_MEMBER void imag(value_type __im) { __im_ = __im; }
CUDA_CALLABLE_MEMBER complex &operator=(const value_type &__re) {
__re_ = __re;
__im_ = value_type();
return *this;
}
CUDA_CALLABLE_MEMBER complex &operator+=(const value_type &__re) {
__re_ += __re;
return *this;
}
CUDA_CALLABLE_MEMBER complex &operator-=(const value_type &__re) {
__re_ -= __re;
return *this;
}
CUDA_CALLABLE_MEMBER complex &operator*=(const value_type &__re) {
__re_ *= __re;
__im_ *= __re;
return *this;
}
CUDA_CALLABLE_MEMBER complex &operator/=(const value_type &__re) {
__re_ /= __re;
__im_ /= __re;
return *this;
}
template <class _Xp>
CUDA_CALLABLE_MEMBER complex &operator=(const complex<_Xp> &__c) {
__re_ = __c.real();
__im_ = __c.imag();
return *this;
}
template <class _Xp>
CUDA_CALLABLE_MEMBER complex &operator+=(const complex<_Xp> &__c) {
__re_ += __c.real();
__im_ += __c.imag();
return *this;
}
template <class _Xp>
CUDA_CALLABLE_MEMBER complex &operator-=(const complex<_Xp> &__c) {
__re_ -= __c.real();
__im_ -= __c.imag();
return *this;
}
template <class _Xp>
CUDA_CALLABLE_MEMBER complex &operator*=(const complex<_Xp> &__c) {
*this = *this * __c;
return *this;
}
template <class _Xp>
CUDA_CALLABLE_MEMBER complex &operator/=(const complex<_Xp> &__c) {
*this = *this / __c;
return *this;
}
};
template <> class complex<double>;
template <> class complex<float> {
float __re_;
float __im_;
public:
typedef float value_type;
/*constexpr*/ CUDA_CALLABLE_MEMBER complex(float __re = 0.0f,
float __im = 0.0f)
: __re_(__re), __im_(__im) {}
explicit /*constexpr*/ complex(const complex<double> &__c);
/*constexpr*/ CUDA_CALLABLE_MEMBER float real() const { return __re_; }
/*constexpr*/ CUDA_CALLABLE_MEMBER float imag() const { return __im_; }
CUDA_CALLABLE_MEMBER void real(value_type __re) { __re_ = __re; }
CUDA_CALLABLE_MEMBER void imag(value_type __im) { __im_ = __im; }
CUDA_CALLABLE_MEMBER complex &operator=(float __re) {
__re_ = __re;
__im_ = value_type();
return *this;
}
CUDA_CALLABLE_MEMBER complex &operator+=(float __re) {
__re_ += __re;
return *this;
}
CUDA_CALLABLE_MEMBER complex &operator-=(float __re) {
__re_ -= __re;
return *this;
}
CUDA_CALLABLE_MEMBER complex &operator*=(float __re) {
__re_ *= __re;
__im_ *= __re;
return *this;
}
CUDA_CALLABLE_MEMBER complex &operator/=(float __re) {
__re_ /= __re;
__im_ /= __re;
return *this;
}
template <class _Xp>
CUDA_CALLABLE_MEMBER complex &operator=(const complex<_Xp> &__c) {
__re_ = __c.real();
__im_ = __c.imag();
return *this;
}
template <class _Xp>
CUDA_CALLABLE_MEMBER complex &operator+=(const complex<_Xp> &__c) {
__re_ += __c.real();
__im_ += __c.imag();
return *this;
}
template <class _Xp>
CUDA_CALLABLE_MEMBER complex &operator-=(const complex<_Xp> &__c) {
__re_ -= __c.real();
__im_ -= __c.imag();
return *this;
}
template <class _Xp>
CUDA_CALLABLE_MEMBER complex &operator*=(const complex<_Xp> &__c) {
*this = *this * __c;
return *this;
}
template <class _Xp>
CUDA_CALLABLE_MEMBER complex &operator/=(const complex<_Xp> &__c) {
*this = *this / __c;
return *this;
}
};
template <> class complex<double> {
double __re_;
double __im_;
public:
typedef double value_type;
/*constexpr*/ CUDA_CALLABLE_MEMBER complex(double __re = 0.0,
double __im = 0.0)
: __re_(__re), __im_(__im) {}
/*constexpr*/ complex(const complex<float> &__c);
/*constexpr*/ CUDA_CALLABLE_MEMBER double real() const { return __re_; }
/*constexpr*/ CUDA_CALLABLE_MEMBER double imag() const { return __im_; }
CUDA_CALLABLE_MEMBER void real(value_type __re) { __re_ = __re; }
CUDA_CALLABLE_MEMBER void imag(value_type __im) { __im_ = __im; }
CUDA_CALLABLE_MEMBER complex &operator=(double __re) {
__re_ = __re;
__im_ = value_type();
return *this;
}
CUDA_CALLABLE_MEMBER complex &operator+=(double __re) {
__re_ += __re;
return *this;
}
CUDA_CALLABLE_MEMBER complex &operator-=(double __re) {
__re_ -= __re;
return *this;
}
CUDA_CALLABLE_MEMBER complex &operator*=(double __re) {
__re_ *= __re;
__im_ *= __re;
return *this;
}
CUDA_CALLABLE_MEMBER complex &operator/=(double __re) {
__re_ /= __re;
__im_ /= __re;
return *this;
}
template <class _Xp>
CUDA_CALLABLE_MEMBER complex &operator=(const complex<_Xp> &__c) {
__re_ = __c.real();
__im_ = __c.imag();
return *this;
}
template <class _Xp>
CUDA_CALLABLE_MEMBER complex &operator+=(const complex<_Xp> &__c) {
__re_ += __c.real();
__im_ += __c.imag();
return *this;
}
template <class _Xp>
CUDA_CALLABLE_MEMBER complex &operator-=(const complex<_Xp> &__c) {
__re_ -= __c.real();
__im_ -= __c.imag();
return *this;
}
template <class _Xp>
CUDA_CALLABLE_MEMBER complex &operator*=(const complex<_Xp> &__c) {
*this = *this * __c;
return *this;
}
template <class _Xp>
CUDA_CALLABLE_MEMBER complex &operator/=(const complex<_Xp> &__c) {
*this = *this / __c;
return *this;
}
};
// constexpr
inline CUDA_CALLABLE_MEMBER complex<float>::complex(const complex<double> &__c)
: __re_(__c.real()), __im_(__c.imag()) {}
// constexpr
inline CUDA_CALLABLE_MEMBER complex<double>::complex(const complex<float> &__c)
: __re_(__c.real()), __im_(__c.imag()) {}
// 26.3.6 operators:
template <class _Tp>
inline CUDA_CALLABLE_MEMBER complex<_Tp> operator+(const complex<_Tp> &__x,
const complex<_Tp> &__y) {
complex<_Tp> __t(__x);
__t += __y;
return __t;
}
template <class _Tp>
inline CUDA_CALLABLE_MEMBER complex<_Tp> operator+(const complex<_Tp> &__x,
const _Tp &__y) {
complex<_Tp> __t(__x);
__t += __y;
return __t;
}
template <class _Tp>
inline CUDA_CALLABLE_MEMBER complex<_Tp> operator+(const _Tp &__x,
const complex<_Tp> &__y) {
complex<_Tp> __t(__y);
__t += __x;
return __t;
}
template <class _Tp>
inline CUDA_CALLABLE_MEMBER complex<_Tp> operator-(const complex<_Tp> &__x,
const complex<_Tp> &__y) {
complex<_Tp> __t(__x);
__t -= __y;
return __t;
}
template <class _Tp>
inline CUDA_CALLABLE_MEMBER complex<_Tp> operator-(const complex<_Tp> &__x,
const _Tp &__y) {
complex<_Tp> __t(__x);
__t -= __y;
return __t;
}
template <class _Tp>
inline CUDA_CALLABLE_MEMBER complex<_Tp> operator-(const _Tp &__x,
const complex<_Tp> &__y) {
complex<_Tp> __t(-__y);
__t += __x;
return __t;
}
template <class _Tp>
CUDA_CALLABLE_MEMBER complex<_Tp> operator*(const complex<_Tp> &__z,
const complex<_Tp> &__w) {
_Tp __a = __z.real();
_Tp __b = __z.imag();
_Tp __c = __w.real();
_Tp __d = __w.imag();
_Tp __ac = __a * __c;
_Tp __bd = __b * __d;
_Tp __ad = __a * __d;
_Tp __bc = __b * __c;
_Tp __x = __ac - __bd;
_Tp __y = __ad + __bc;
if (isnan(__x) && isnan(__y)) {
bool __recalc = false;
if (isinf(__a) || isinf(__b)) {
__a = copysign(isinf(__a) ? _Tp(1) : _Tp(0), __a);
__b = copysign(isinf(__b) ? _Tp(1) : _Tp(0), __b);
if (isnan(__c))
__c = copysign(_Tp(0), __c);
if (isnan(__d))
__d = copysign(_Tp(0), __d);
__recalc = true;
}
if (isinf(__c) || isinf(__d)) {
__c = copysign(isinf(__c) ? _Tp(1) : _Tp(0), __c);
__d = copysign(isinf(__d) ? _Tp(1) : _Tp(0), __d);
if (isnan(__a))
__a = copysign(_Tp(0), __a);
if (isnan(__b))
__b = copysign(_Tp(0), __b);
__recalc = true;
}
if (!__recalc &&
(isinf(__ac) || isinf(__bd) || isinf(__ad) || isinf(__bc))) {
if (isnan(__a))
__a = copysign(_Tp(0), __a);
if (isnan(__b))
__b = copysign(_Tp(0), __b);
if (isnan(__c))
__c = copysign(_Tp(0), __c);
if (isnan(__d))
__d = copysign(_Tp(0), __d);
__recalc = true;
}
if (__recalc) {
__x = _Tp(INFINITY) * (__a * __c - __b * __d);
__y = _Tp(INFINITY) * (__a * __d + __b * __c);
}
}
return complex<_Tp>(__x, __y);
}
template <class _Tp>
inline CUDA_CALLABLE_MEMBER complex<_Tp> operator*(const complex<_Tp> &__x,
const _Tp &__y) {
complex<_Tp> __t(__x);
__t *= __y;
return __t;
}
template <class _Tp>
inline CUDA_CALLABLE_MEMBER complex<_Tp> operator*(const _Tp &__x,
const complex<_Tp> &__y) {
complex<_Tp> __t(__y);
__t *= __x;
return __t;
}
template <class _Tp>
CUDA_CALLABLE_MEMBER complex<_Tp> operator/(const complex<_Tp> &__z,
const complex<_Tp> &__w) {
int __ilogbw = 0;
_Tp __a = __z.real();
_Tp __b = __z.imag();
_Tp __c = __w.real();
_Tp __d = __w.imag();
_Tp __logbw = logb(fmax(fabs(__c), fabs(__d)));
if (isfinite(__logbw)) {
__ilogbw = static_cast<int>(__logbw);
__c = scalbn(__c, -__ilogbw);
__d = scalbn(__d, -__ilogbw);
}
_Tp __denom = __c * __c + __d * __d;
_Tp __x = scalbn((__a * __c + __b * __d) / __denom, -__ilogbw);
_Tp __y = scalbn((__b * __c - __a * __d) / __denom, -__ilogbw);
if (isnan(__x) && isnan(__y)) {
if ((__denom == _Tp(0)) && (!isnan(__a) || !isnan(__b))) {
__x = copysign(_Tp(INFINITY), __c) * __a;
__y = copysign(_Tp(INFINITY), __c) * __b;
} else if ((isinf(__a) || isinf(__b)) && isfinite(__c) && isfinite(__d)) {
__a = copysign(isinf(__a) ? _Tp(1) : _Tp(0), __a);
__b = copysign(isinf(__b) ? _Tp(1) : _Tp(0), __b);
__x = _Tp(INFINITY) * (__a * __c + __b * __d);
__y = _Tp(INFINITY) * (__b * __c - __a * __d);
} else if (isinf(__logbw) && __logbw > _Tp(0) && isfinite(__a) &&
isfinite(__b)) {
__c = copysign(isinf(__c) ? _Tp(1) : _Tp(0), __c);
__d = copysign(isinf(__d) ? _Tp(1) : _Tp(0), __d);
__x = _Tp(0) * (__a * __c + __b * __d);
__y = _Tp(0) * (__b * __c - __a * __d);
}
}
return complex<_Tp>(__x, __y);
}
template <>
CUDA_CALLABLE_MEMBER complex<float> operator/(const complex<float> &__z,
const complex<float> &__w) {
int __ilogbw = 0;
float __a = __z.real();
float __b = __z.imag();
float __c = __w.real();
float __d = __w.imag();
float __logbw = logbf(fmaxf(fabsf(__c), fabsf(__d)));
if (isfinite(__logbw)) {
__ilogbw = static_cast<int>(__logbw);
__c = scalbnf(__c, -__ilogbw);
__d = scalbnf(__d, -__ilogbw);
}
float __denom = __c * __c + __d * __d;
float __x = scalbnf((__a * __c + __b * __d) / __denom, -__ilogbw);
float __y = scalbnf((__b * __c - __a * __d) / __denom, -__ilogbw);
if (isnan(__x) && isnan(__y)) {
if ((__denom == float(0)) && (!isnan(__a) || !isnan(__b))) {
#pragma warning(suppress : 4756) // Ignore INFINITY related warning
__x = copysignf(INFINITY, __c) * __a;
#pragma warning(suppress : 4756) // Ignore INFINITY related warning
__y = copysignf(INFINITY, __c) * __b;
} else if ((isinf(__a) || isinf(__b)) && isfinite(__c) && isfinite(__d)) {
__a = copysignf(isinf(__a) ? float(1) : float(0), __a);
__b = copysignf(isinf(__b) ? float(1) : float(0), __b);
#pragma warning(suppress : 4756) // Ignore INFINITY related warning
__x = INFINITY * (__a * __c + __b * __d);
#pragma warning(suppress : 4756) // Ignore INFINITY related warning
__y = INFINITY * (__b * __c - __a * __d);
} else if (isinf(__logbw) && __logbw > float(0) && isfinite(__a) &&
isfinite(__b)) {
__c = copysignf(isinf(__c) ? float(1) : float(0), __c);
__d = copysignf(isinf(__d) ? float(1) : float(0), __d);
__x = float(0) * (__a * __c + __b * __d);
__y = float(0) * (__b * __c - __a * __d);
}
}
return complex<float>(__x, __y);
}
template <class _Tp>
inline CUDA_CALLABLE_MEMBER complex<_Tp> operator/(const complex<_Tp> &__x,
const _Tp &__y) {
return complex<_Tp>(__x.real() / __y, __x.imag() / __y);
}
template <class _Tp>
inline CUDA_CALLABLE_MEMBER complex<_Tp> operator/(const _Tp &__x,
const complex<_Tp> &__y) {
complex<_Tp> __t(__x);
__t /= __y;
return __t;
}
template <class _Tp>
inline CUDA_CALLABLE_MEMBER complex<_Tp> operator+(const complex<_Tp> &__x) {
return __x;
}
template <class _Tp>
inline CUDA_CALLABLE_MEMBER complex<_Tp> operator-(const complex<_Tp> &__x) {
return complex<_Tp>(-__x.real(), -__x.imag());
}
template <class _Tp>
inline CUDA_CALLABLE_MEMBER bool operator==(const complex<_Tp> &__x,
const complex<_Tp> &__y) {
return __x.real() == __y.real() && __x.imag() == __y.imag();
}
template <class _Tp>
inline CUDA_CALLABLE_MEMBER bool operator==(const complex<_Tp> &__x,
const _Tp &__y) {
return __x.real() == __y && __x.imag() == 0;
}
template <class _Tp>
inline CUDA_CALLABLE_MEMBER bool operator==(const _Tp &__x,
const complex<_Tp> &__y) {
return __x == __y.real() && 0 == __y.imag();
}
template <class _Tp>
inline CUDA_CALLABLE_MEMBER bool operator!=(const complex<_Tp> &__x,
const complex<_Tp> &__y) {
return !(__x == __y);
}
template <class _Tp>
inline CUDA_CALLABLE_MEMBER bool operator!=(const complex<_Tp> &__x,
const _Tp &__y) {
return !(__x == __y);
}
template <class _Tp>
inline CUDA_CALLABLE_MEMBER bool operator!=(const _Tp &__x,
const complex<_Tp> &__y) {
return !(__x == __y);
}
// 26.3.7 values:
// real
template <class _Tp>
inline CUDA_CALLABLE_MEMBER _Tp real(const complex<_Tp> &__c) {
return __c.real();
}
inline CUDA_CALLABLE_MEMBER double real(double __re) { return __re; }
inline CUDA_CALLABLE_MEMBER float real(float __re) { return __re; }
// imag
template <class _Tp>
inline CUDA_CALLABLE_MEMBER _Tp imag(const complex<_Tp> &__c) {
return __c.imag();
}
inline CUDA_CALLABLE_MEMBER double imag(double __re) { return 0; }
inline CUDA_CALLABLE_MEMBER float imag(float __re) { return 0; }
// abs
template <class _Tp>
inline CUDA_CALLABLE_MEMBER _Tp abs(const complex<_Tp> &__c) {
return hypot(__c.real(), __c.imag());
}
// arg
template <class _Tp>
inline CUDA_CALLABLE_MEMBER _Tp arg(const complex<_Tp> &__c) {
return atan2(__c.imag(), __c.real());
}
inline CUDA_CALLABLE_MEMBER double arg(double __re) { return atan2(0., __re); }
inline CUDA_CALLABLE_MEMBER float arg(float __re) { return atan2f(0.F, __re); }
// norm
template <class _Tp>
inline CUDA_CALLABLE_MEMBER _Tp norm(const complex<_Tp> &__c) {
if (isinf(__c.real()))
return fabs(__c.real());
if (isinf(__c.imag()))
return fabs(__c.imag());
return __c.real() * __c.real() + __c.imag() * __c.imag();
}
inline CUDA_CALLABLE_MEMBER double norm(double __re) { return __re * __re; }
inline CUDA_CALLABLE_MEMBER float norm(float __re) { return __re * __re; }
// conj
template <class _Tp>
inline CUDA_CALLABLE_MEMBER complex<_Tp> conj(const complex<_Tp> &__c) {
return complex<_Tp>(__c.real(), -__c.imag());
}
inline CUDA_CALLABLE_MEMBER complex<double> conj(double __re) {
return complex<double>(__re);
}
inline CUDA_CALLABLE_MEMBER complex<float> conj(float __re) {
return complex<float>(__re);
}
// proj
template <class _Tp>
inline CUDA_CALLABLE_MEMBER complex<_Tp> proj(const complex<_Tp> &__c) {
complex<_Tp> __r = __c;
if (isinf(__c.real()) || isinf(__c.imag()))
__r = complex<_Tp>(INFINITY, copysign(_Tp(0), __c.imag()));
return __r;
}
inline CUDA_CALLABLE_MEMBER complex<double> proj(double __re) {
if (isinf(__re))
__re = fabs(__re);
return complex<double>(__re);
}
inline CUDA_CALLABLE_MEMBER complex<float> proj(float __re) {
if (isinf(__re))
__re = fabs(__re);
return complex<float>(__re);
}
// polar
template <class _Tp>
CUDA_CALLABLE_MEMBER complex<_Tp> polar(const _Tp &__rho,
const _Tp &__theta = _Tp(0)) {
if (isnan(__rho) || signbit(__rho))
return complex<_Tp>(_Tp(NAN), _Tp(NAN));
if (isnan(__theta)) {
if (isinf(__rho))
return complex<_Tp>(__rho, __theta);
return complex<_Tp>(__theta, __theta);
}
if (isinf(__theta)) {
if (isinf(__rho))
return complex<_Tp>(__rho, _Tp(NAN));
return complex<_Tp>(_Tp(NAN), _Tp(NAN));
}
_Tp __x = __rho * cos(__theta);
if (isnan(__x))
__x = 0;
_Tp __y = __rho * sin(__theta);
if (isnan(__y))
__y = 0;
return complex<_Tp>(__x, __y);
}
// log
template <class _Tp>
inline CUDA_CALLABLE_MEMBER complex<_Tp> log(const complex<_Tp> &__x) {
return complex<_Tp>(log(abs(__x)), arg(__x));
}
// log10
template <class _Tp>
inline CUDA_CALLABLE_MEMBER complex<_Tp> log10(const complex<_Tp> &__x) {
return log(__x) / log(_Tp(10));
}
// sqrt
template <class _Tp>
CUDA_CALLABLE_MEMBER complex<_Tp> sqrt(const complex<_Tp> &__x) {
if (isinf(__x.imag()))
return complex<_Tp>(_Tp(INFINITY), __x.imag());
if (isinf(__x.real())) {
if (__x.real() > _Tp(0))
return complex<_Tp>(__x.real(), isnan(__x.imag())
? __x.imag()
: copysign(_Tp(0), __x.imag()));
return complex<_Tp>(isnan(__x.imag()) ? __x.imag() : _Tp(0),
copysign(__x.real(), __x.imag()));
}
return polar(sqrt(abs(__x)), arg(__x) / _Tp(2));
}
// exp
template <class _Tp>
CUDA_CALLABLE_MEMBER complex<_Tp> exp(const complex<_Tp> &__x) {
_Tp __i = __x.imag();
if (isinf(__x.real())) {
if (__x.real() < _Tp(0)) {
if (!isfinite(__i))
__i = _Tp(1);
} else if (__i == 0 || !isfinite(__i)) {
if (isinf(__i))
__i = _Tp(NAN);
return complex<_Tp>(__x.real(), __i);
}
} else if (isnan(__x.real()) && __x.imag() == 0)
return __x;
_Tp __e = exp(__x.real());
return complex<_Tp>(__e * cos(__i), __e * sin(__i));
}
// pow
template <class _Tp>
inline CUDA_CALLABLE_MEMBER complex<_Tp> pow(const complex<_Tp> &__x,
const complex<_Tp> &__y) {
return exp(__y * log(__x));
}
template <class _Tp>
inline CUDA_CALLABLE_MEMBER complex<_Tp> pow(const complex<_Tp> &__x,
const _Tp &__y) {
return pow(__x, complex<_Tp>(__y));
}
template <class _Tp>
inline CUDA_CALLABLE_MEMBER complex<_Tp> pow(const _Tp &__x,
const complex<_Tp> &__y) {
return pow(complex<_Tp>(__x), __y);
}
// asinh
template <class _Tp>
CUDA_CALLABLE_MEMBER complex<_Tp> asinh(const complex<_Tp> &__x) {
const _Tp __pi(atan2(+0., -0.));
if (isinf(__x.real())) {
if (isnan(__x.imag()))
return __x;
if (isinf(__x.imag()))
return complex<_Tp>(__x.real(), copysign(__pi * _Tp(0.25), __x.imag()));
return complex<_Tp>(__x.real(), copysign(_Tp(0), __x.imag()));
}
if (isnan(__x.real())) {
if (isinf(__x.imag()))
return complex<_Tp>(__x.imag(), __x.real());
if (__x.imag() == 0)
return __x;
return complex<_Tp>(__x.real(), __x.real());
}
if (isinf(__x.imag()))
return complex<_Tp>(copysign(__x.imag(), __x.real()),
copysign(__pi / _Tp(2), __x.imag()));
complex<_Tp> __z = log(__x + sqrt(pow(__x, _Tp(2)) + _Tp(1)));
return complex<_Tp>(copysign(__z.real(), __x.real()),
copysign(__z.imag(), __x.imag()));
}
// acosh
template <class _Tp>
CUDA_CALLABLE_MEMBER complex<_Tp> acosh(const complex<_Tp> &__x) {
const _Tp __pi(atan2(+0., -0.));
if (isinf(__x.real())) {
if (isnan(__x.imag()))
return complex<_Tp>(fabs(__x.real()), __x.imag());
if (isinf(__x.imag()))
if (__x.real() > 0)
return complex<_Tp>(__x.real(), copysign(__pi * _Tp(0.25), __x.imag()));
else
return complex<_Tp>(-__x.real(),
copysign(__pi * _Tp(0.75), __x.imag()));
if (__x.real() < 0)
return complex<_Tp>(-__x.real(), copysign(__pi, __x.imag()));
return complex<_Tp>(__x.real(), copysign(_Tp(0), __x.imag()));
}
if (isnan(__x.real())) {
if (isinf(__x.imag()))
return complex<_Tp>(fabs(__x.imag()), __x.real());
return complex<_Tp>(__x.real(), __x.real());
}
if (isinf(__x.imag()))
return complex<_Tp>(fabs(__x.imag()), copysign(__pi / _Tp(2), __x.imag()));
complex<_Tp> __z = log(__x + sqrt(pow(__x, _Tp(2)) - _Tp(1)));
return complex<_Tp>(copysign(__z.real(), _Tp(0)),
copysign(__z.imag(), __x.imag()));
}
// atanh
template <class _Tp>
CUDA_CALLABLE_MEMBER complex<_Tp> atanh(const complex<_Tp> &__x) {
const _Tp __pi(atan2(+0., -0.));
if (isinf(__x.imag())) {
return complex<_Tp>(copysign(_Tp(0), __x.real()),
copysign(__pi / _Tp(2), __x.imag()));
}
if (isnan(__x.imag())) {
if (isinf(__x.real()) || __x.real() == 0)
return complex<_Tp>(copysign(_Tp(0), __x.real()), __x.imag());
return complex<_Tp>(__x.imag(), __x.imag());
}
if (isnan(__x.real())) {
return complex<_Tp>(__x.real(), __x.real());
}
if (isinf(__x.real())) {
return complex<_Tp>(copysign(_Tp(0), __x.real()),
copysign(__pi / _Tp(2), __x.imag()));
}
if (fabs(__x.real()) == _Tp(1) && __x.imag() == _Tp(0)) {
return complex<_Tp>(copysign(_Tp(INFINITY), __x.real()),
copysign(_Tp(0), __x.imag()));
}
complex<_Tp> __z = log((_Tp(1) + __x) / (_Tp(1) - __x)) / _Tp(2);
return complex<_Tp>(copysign(__z.real(), __x.real()),
copysign(__z.imag(), __x.imag()));
}
// sinh
template <class _Tp>
CUDA_CALLABLE_MEMBER complex<_Tp> sinh(const complex<_Tp> &__x) {
if (isinf(__x.real()) && !isfinite(__x.imag()))
return complex<_Tp>(__x.real(), _Tp(NAN));
if (__x.real() == 0 && !isfinite(__x.imag()))
return complex<_Tp>(__x.real(), _Tp(NAN));
if (__x.imag() == 0 && !isfinite(__x.real()))
return __x;
return complex<_Tp>(sinh(__x.real()) * cos(__x.imag()),
cosh(__x.real()) * sin(__x.imag()));
}
// cosh
template <class _Tp>
CUDA_CALLABLE_MEMBER complex<_Tp> cosh(const complex<_Tp> &__x) {
if (isinf(__x.real()) && !isfinite(__x.imag()))
return complex<_Tp>(fabs(__x.real()), _Tp(NAN));
if (__x.real() == 0 && !isfinite(__x.imag()))
return complex<_Tp>(_Tp(NAN), __x.real());
if (__x.real() == 0 && __x.imag() == 0)
return complex<_Tp>(_Tp(1), __x.imag());
if (__x.imag() == 0 && !isfinite(__x.real()))
return complex<_Tp>(fabs(__x.real()), __x.imag());
return complex<_Tp>(cosh(__x.real()) * cos(__x.imag()),
sinh(__x.real()) * sin(__x.imag()));
}
// tanh
template <class _Tp>
CUDA_CALLABLE_MEMBER complex<_Tp> tanh(const complex<_Tp> &__x) {
if (isinf(__x.real())) {
if (!isfinite(__x.imag()))
return complex<_Tp>(_Tp(1), _Tp(0));
return complex<_Tp>(_Tp(1), copysign(_Tp(0), sin(_Tp(2) * __x.imag())));
}
if (isnan(__x.real()) && __x.imag() == 0)
return __x;
_Tp __2r(_Tp(2) * __x.real());
_Tp __2i(_Tp(2) * __x.imag());
_Tp __d(cosh(__2r) + cos(__2i));
return complex<_Tp>(sinh(__2r) / __d, sin(__2i) / __d);
}
// asin
template <class _Tp>
CUDA_CALLABLE_MEMBER complex<_Tp> asin(const complex<_Tp> &__x) {
complex<_Tp> __z = asinh(complex<_Tp>(-__x.imag(), __x.real()));
return complex<_Tp>(__z.imag(), -__z.real());
}
// acos
template <class _Tp>
CUDA_CALLABLE_MEMBER complex<_Tp> acos(const complex<_Tp> &__x) {
const _Tp __pi(atan2(+0., -0.));
if (isinf(__x.real())) {
if (isnan(__x.imag()))
return complex<_Tp>(__x.imag(), __x.real());
if (isinf(__x.imag())) {
if (__x.real() < _Tp(0))
return complex<_Tp>(_Tp(0.75) * __pi, -__x.imag());
return complex<_Tp>(_Tp(0.25) * __pi, -__x.imag());
}
if (__x.real() < _Tp(0))
return complex<_Tp>(__pi, signbit(__x.imag()) ? -__x.real() : __x.real());
return complex<_Tp>(_Tp(0), signbit(__x.imag()) ? __x.real() : -__x.real());
}
if (isnan(__x.real())) {
if (isinf(__x.imag()))
return complex<_Tp>(__x.real(), -__x.imag());
return complex<_Tp>(__x.real(), __x.real());
}
if (isinf(__x.imag()))
return complex<_Tp>(__pi / _Tp(2), -__x.imag());
if (__x.real() == 0)
return complex<_Tp>(__pi / _Tp(2), -__x.imag());
complex<_Tp> __z = log(__x + sqrt(pow(__x, _Tp(2)) - _Tp(1)));
if (signbit(__x.imag()))
return complex<_Tp>(fabs(__z.imag()), fabs(__z.real()));
return complex<_Tp>(fabs(__z.imag()), -fabs(__z.real()));
}
// atan
template <class _Tp>
CUDA_CALLABLE_MEMBER complex<_Tp> atan(const complex<_Tp> &__x) {
complex<_Tp> __z = atanh(complex<_Tp>(-__x.imag(), __x.real()));
return complex<_Tp>(__z.imag(), -__z.real());
}
// sin
template <class _Tp>
CUDA_CALLABLE_MEMBER complex<_Tp> sin(const complex<_Tp> &__x) {
complex<_Tp> __z = sinh(complex<_Tp>(-__x.imag(), __x.real()));
return complex<_Tp>(__z.imag(), -__z.real());
}
// cos
template <class _Tp>
inline CUDA_CALLABLE_MEMBER complex<_Tp> cos(const complex<_Tp> &__x) {
return cosh(complex<_Tp>(-__x.imag(), __x.real()));
}
// tan
template <class _Tp>
CUDA_CALLABLE_MEMBER complex<_Tp> tan(const complex<_Tp> &__x) {
complex<_Tp> __z = tanh(complex<_Tp>(-__x.imag(), __x.real()));
return complex<_Tp>(__z.imag(), -__z.real());
}
template <class _Tp, class _CharT, class _Traits>
std::basic_istream<_CharT, _Traits> &
operator>>(std::basic_istream<_CharT, _Traits> &__is, complex<_Tp> &__x) {
if (__is.good()) {
ws(__is);
if (__is.peek() == _CharT('(')) {
__is.get();
_Tp __r;
__is >> __r;
if (!__is.fail()) {
ws(__is);
_CharT __c = __is.peek();
if (__c == _CharT(',')) {
__is.get();
_Tp __i;
__is >> __i;
if (!__is.fail()) {
ws(__is);
__c = __is.peek();
if (__c == _CharT(')')) {
__is.get();
__x = complex<_Tp>(__r, __i);
} else
__is.setstate(std::ios_base::failbit);
} else
__is.setstate(std::ios_base::failbit);
} else if (__c == _CharT(')')) {
__is.get();
__x = complex<_Tp>(__r, _Tp(0));
} else
__is.setstate(std::ios_base::failbit);
} else
__is.setstate(std::ios_base::failbit);
} else {
_Tp __r;
__is >> __r;
if (!__is.fail())
__x = complex<_Tp>(__r, _Tp(0));
else
__is.setstate(std::ios_base::failbit);
}
} else
__is.setstate(std::ios_base::failbit);
return __is;
}
template <class _Tp, class _CharT, class _Traits>
std::basic_ostream<_CharT, _Traits> &
operator<<(std::basic_ostream<_CharT, _Traits> &__os, const complex<_Tp> &__x) {
std::basic_ostringstream<_CharT, _Traits> __s;
__s.flags(__os.flags());
__s.imbue(__os.getloc());
__s.precision(__os.precision());
__s << '(' << __x.real() << ',' << __x.imag() << ')';
return __os << __s.str();
}
//} // close namespace cuda_complex
template <class U, class V>
CUDA_CALLABLE_MEMBER auto operator*(const complex<U> &complexNumber,
const V &scalar) -> complex<U> {
return complex<U>{real(complexNumber) * scalar, imag(complexNumber) * scalar};
}
template <class U, class V>
CUDA_CALLABLE_MEMBER auto operator*(const V &scalar,
const complex<U> &complexNumber)
-> complex<U> {
return complex<U>{real(complexNumber) * scalar, imag(complexNumber) * scalar};
}
template <class U, class V>
CUDA_CALLABLE_MEMBER auto operator+(const complex<U> &complexNumber,
const V &scalar) -> complex<U> {
return complex<U>{real(complexNumber) + scalar, imag(complexNumber)};
}
template <class U, class V>
CUDA_CALLABLE_MEMBER auto operator+(const V &scalar,
const complex<U> &complexNumber)
-> complex<U> {
return complex<U>{real(complexNumber) + scalar, imag(complexNumber)};
}
template <class U, class V>
CUDA_CALLABLE_MEMBER auto operator-(const complex<U> &complexNumber,
const V &scalar) -> complex<U> {
return complex<U>{real(complexNumber) - scalar, imag(complexNumber)};
}
template <class U, class V>
CUDA_CALLABLE_MEMBER auto operator-(const V &scalar,
const complex<U> &complexNumber)
-> complex<U> {
return complex<U>{scalar - real(complexNumber), imag(complexNumber)};
}
template <class U, class V>
CUDA_CALLABLE_MEMBER auto operator/(const complex<U> &complexNumber,
const V scalar) -> complex<U> {
return complex<U>{real(complexNumber) / scalar, imag(complexNumber) / scalar};
}
template <class U, class V>
CUDA_CALLABLE_MEMBER auto operator/(const V scalar,
const complex<U> &complexNumber)
-> complex<U> {
return complex<U>{scalar, 0} / complexNumber;
}
using ComplexDouble = complex<double>;
using ComplexFloat = complex<float>;
#endif // CUDA_COMPLEX_HPP
}
...@@ -7,14 +7,15 @@ from types import MappingProxyType ...@@ -7,14 +7,15 @@ from types import MappingProxyType
import numpy as np import numpy as np
import sympy as sp import sympy as sp
from sympy.core.numbers import ImaginaryUnit
from sympy.logic.boolalg import Boolean from sympy.logic.boolalg import Boolean
import pystencils.astnodes as ast import pystencils.astnodes as ast
import pystencils.integer_functions import pystencils.integer_functions
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.data_types import ( from pystencils.data_types import (
PointerType, StructType, TypedSymbol, cast_func, collate_types, create_type, get_base_type, PointerType, StructType, TypedImaginaryUnit, TypedSymbol, cast_func, collate_types, create_type,
get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func) get_base_type, get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func)
from pystencils.field import AbstractField, Field, FieldType from pystencils.field import AbstractField, Field, FieldType
from pystencils.kernelparameters import FieldPointerSymbol from pystencils.kernelparameters import FieldPointerSymbol
from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.simp.assignment_collection import AssignmentCollection
...@@ -830,6 +831,8 @@ class KernelConstraintsCheck: ...@@ -830,6 +831,8 @@ class KernelConstraintsCheck:
if new_args: if new_args:
rhs.offsets = new_args rhs.offsets = new_args
return rhs return rhs
elif isinstance(rhs, ImaginaryUnit):
return TypedImaginaryUnit(create_type(self._type_for_symbol['_complex_type']))
elif isinstance(rhs, TypedSymbol): elif isinstance(rhs, TypedSymbol):
return rhs return rhs
elif isinstance(rhs, sp.Symbol): elif isinstance(rhs, sp.Symbol):
...@@ -929,7 +932,7 @@ def add_types(eqs, type_for_symbol, check_independence_condition): ...@@ -929,7 +932,7 @@ def add_types(eqs, type_for_symbol, check_independence_condition):
``fields_read, fields_written, typed_equations`` set of read fields, set of written fields, ``fields_read, fields_written, typed_equations`` set of read fields, set of written fields,
list of equations where symbols have been replaced by typed symbols list of equations where symbols have been replaced by typed symbols
""" """
if isinstance(type_for_symbol, str) or not hasattr(type_for_symbol, '__getitem__'): if isinstance(type_for_symbol, (str, type)) or not hasattr(type_for_symbol, '__getitem__'):
type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol) type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol)
check = KernelConstraintsCheck(type_for_symbol, check_independence_condition) check = KernelConstraintsCheck(type_for_symbol, check_independence_condition)
...@@ -1093,6 +1096,10 @@ def typing_from_sympy_inspection(eqs, default_type="double", default_int_type='i ...@@ -1093,6 +1096,10 @@ def typing_from_sympy_inspection(eqs, default_type="double", default_int_type='i
dictionary, mapping symbol name to type dictionary, mapping symbol name to type
""" """
result = defaultdict(lambda: default_type) result = defaultdict(lambda: default_type)
if hasattr(default_type, 'numpy_dtype'):
result['_complex_type'] = (np.zeros((1,), default_type.numpy_dtype) * 1j).dtype
else:
result['_complex_type'] = (np.zeros((1,), default_type) * 1j).dtype
for eq in eqs: for eq in eqs:
if isinstance(eq, ast.Conditional): if isinstance(eq, ast.Conditional):
result.update(typing_from_sympy_inspection(eq.true_block.args)) result.update(typing_from_sympy_inspection(eq.true_block.args))
......
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import itertools
import numpy as np
import pytest
import sympy
from sympy.functions import im, re
import pystencils
from pystencils import AssignmentCollection
from pystencils.data_types import TypedSymbol, create_type
X, Y = pystencils.fields('x, y: complex64[2d]')
A, B = pystencils.fields('a, b: float32[2d]')
S1, S2, T = sympy.symbols('S1, S2, T')
TEST_ASSIGNMENTS = [
AssignmentCollection({X[0, 0]: 1j}),
AssignmentCollection({
S1: re(Y.center),
S2: im(Y.center),
X[0, 0]: 2j * S1 + S2
}),
AssignmentCollection({
A.center: re(Y.center),
B.center: im(Y.center),
}),
AssignmentCollection({
Y.center: re(Y.center) + X.center + 2j,
}),
AssignmentCollection({
T: 2 + 4j,
Y.center: X.center / T,
})
]
SCALAR_DTYPES = ['float32', 'float64']
@pytest.mark.parametrize("assignment, scalar_dtypes",
itertools.product(TEST_ASSIGNMENTS, (np.float32,)))
@pytest.mark.parametrize('target', ('cpu', 'gpu'))
def test_complex_numbers(assignment, scalar_dtypes, target):
ast = pystencils.create_kernel(assignment,
target=target,
data_type=scalar_dtypes)
code = str(pystencils.show_code(ast))
print(code)
assert "Not supported" not in code
kernel = ast.compile()
assert kernel is not None
X, Y = pystencils.fields('x, y: complex128[2d]')
A, B = pystencils.fields('a, b: float64[2d]')
S1, S2 = sympy.symbols('S1, S2')
T128 = TypedSymbol('ts', create_type('complex128'))
TEST_ASSIGNMENTS = [
AssignmentCollection({X[0, 0]: 1j}),
AssignmentCollection({
S1: re(Y.center),
S2: im(Y.center),
X[0, 0]: 2j * S1 + S2
}),
AssignmentCollection({
A.center: re(Y.center),
B.center: im(Y.center),
}),
AssignmentCollection({
Y.center: re(Y.center) + X.center + 2j,
}),
AssignmentCollection({
T128: 2 + 4j,
Y.center: X.center / T128,
})
]
SCALAR_DTYPES = ['float64']
@pytest.mark.parametrize("assignment", TEST_ASSIGNMENTS)
@pytest.mark.parametrize('target', ('cpu', 'gpu'))
def test_complex_numbers_64(assignment, target):
ast = pystencils.create_kernel(assignment,
target=target,
data_type='double')
code = str(pystencils.show_code(ast))
print(code)
assert "Not supported" not in code
kernel = ast.compile()
assert kernel is not None
@pytest.mark.parametrize('dtype', (np.float32, np.float64))
@pytest.mark.parametrize('target', ('cpu', 'gpu'))
@pytest.mark.parametrize('with_complex_argument', ('with_complex_argument', False))
def test_complex_execution(dtype, target, with_complex_argument):
complex_dtype = f'complex{64 if dtype ==np.float32 else 128}'
x, y = pystencils.fields(f'x, y: {complex_dtype}[2d]')
x_arr = np.zeros((20, 30), complex_dtype)
y_arr = np.zeros((20, 30), complex_dtype)
if with_complex_argument:
a = pystencils.TypedSymbol('a', create_type(complex_dtype))
else:
a = (2j+1)
assignments = AssignmentCollection({
y.center: x.center + a
})
if target == 'gpu':
from pycuda.gpuarray import zeros
x_arr = zeros((20, 30), complex_dtype)
y_arr = zeros((20, 30), complex_dtype)
kernel = pystencils.create_kernel(assignments, target=target, data_type=dtype).compile()
if with_complex_argument:
kernel(x=x_arr, y=y_arr, a=2j+1)
else:
kernel(x=x_arr, y=y_arr)
if target == 'gpu':
y_arr = y_arr.get()
assert np.allclose(y_arr, 2j+1)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment