diff --git a/pystencils/include/cuda_complex.hpp b/pystencils/include/cuda_complex.hpp index ad555264a87881d8eaee6b2476c482039d606f71..f3bdfee2abf1e1bde14a23dccaf888b4524c3d08 100644 --- a/pystencils/include/cuda_complex.hpp +++ b/pystencils/include/cuda_complex.hpp @@ -866,6 +866,14 @@ CUDA_CALLABLE_MEMBER complex<_Tp> sqrt(const complex<_Tp> &__x) { return polar(sqrt(abs(__x)), arg(__x) / _Tp(2)); } +template <class T> CUDA_CALLABLE_MEMBER complex<T> csqrt(complex<T> z) { + return sqrt<T>(z); +}; + +template<class T> +CUDA_CALLABLE_MEMBER complex<T> csqrt(T z) { + return csqrt<T>({z, 0}); +}; // exp template <class _Tp> @@ -1224,5 +1232,12 @@ CUDA_CALLABLE_MEMBER auto operator/(const V scalar, using ComplexDouble = complex<double>; using ComplexFloat = complex<float>; + +CUDA_CALLABLE_MEMBER +template <class T, class U> complex<T> cpow(const complex<T> &z, const U &n) { + return {(pow(abs(z), n) * cos(n * arg(z))), + (pow(abs(z), n) * sin(n * arg(z)))}; +} + #endif // CUDA_COMPLEX_HPP }