From 3c53a568ca84ac6bbc7b2eb0cb9b2034ce8531fe Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Fri, 29 Nov 2019 14:42:38 +0100
Subject: [PATCH] Add csqrt, cpow to cuda_complex.hpp

---
 pystencils/include/cuda_complex.hpp | 15 +++++++++++++++
 1 file changed, 15 insertions(+)

diff --git a/pystencils/include/cuda_complex.hpp b/pystencils/include/cuda_complex.hpp
index ad555264..f3bdfee2 100644
--- a/pystencils/include/cuda_complex.hpp
+++ b/pystencils/include/cuda_complex.hpp
@@ -866,6 +866,14 @@ CUDA_CALLABLE_MEMBER complex<_Tp> sqrt(const complex<_Tp> &__x) {
   return polar(sqrt(abs(__x)), arg(__x) / _Tp(2));
 }
 
+template <class T> CUDA_CALLABLE_MEMBER complex<T> csqrt(complex<T> z) {
+  return sqrt<T>(z);
+};
+
+template<class T>
+CUDA_CALLABLE_MEMBER complex<T> csqrt(T z) {
+  return csqrt<T>({z, 0});
+};
 // exp
 
 template <class _Tp>
@@ -1224,5 +1232,12 @@ CUDA_CALLABLE_MEMBER auto operator/(const V scalar,
 
 using ComplexDouble = complex<double>;
 using ComplexFloat = complex<float>;
+
+CUDA_CALLABLE_MEMBER
+template <class T, class U> complex<T> cpow(const complex<T> &z, const U &n) {
+  return {(pow(abs(z), n) * cos(n * arg(z))),
+          (pow(abs(z), n) * sin(n * arg(z)))};
+}
+
 #endif // CUDA_COMPLEX_HPP
 }
-- 
GitLab