// // Copyright (C) 2011-15 DyND Developers // BSD 2-Clause License, see LICENSE.txt // #pragma once #include #include #include namespace dynd { template class complex { public: T m_real, m_imag; typedef T value_type; complex(const T &re = 0.0, const T &im = 0.0) : m_real(re), m_imag(im) {} template complex(const complex &rhs) : m_real(static_cast(rhs.m_real)), m_imag(static_cast(rhs.m_imag)) { } complex(const std::complex &rhs) : m_real(rhs.real()), m_imag(rhs.imag()) {} T real() const { return m_real; } T imag() const { return m_imag; } complex &operator=(const complex &rhs) { m_real = rhs.m_real; m_imag = rhs.m_imag; return *this; } complex &operator+=(const complex &rhs) { m_real += rhs.m_real; m_imag += rhs.m_imag; return *this; } complex &operator+=(const T &rhs) { m_real += rhs; return *this; } complex &operator-=(const complex &rhs) { m_real -= rhs.m_real; m_imag -= rhs.m_imag; return *this; } complex &operator-=(const T &rhs) { m_real -= rhs; return *this; } complex operator*=(const complex &rhs) { new (this) complex(m_real * rhs.m_real - m_imag * rhs.m_imag, m_real * rhs.m_imag + rhs.m_real * m_imag); return *this; } complex operator*=(const T &rhs) { m_real *= rhs; m_imag *= rhs; return *this; } complex &operator/=(const complex &rhs) { T denom = rhs.m_real * rhs.m_real + rhs.m_imag * rhs.m_imag; new (this) complex((m_real * rhs.m_real + m_imag * rhs.m_imag) / denom, (rhs.m_real * m_imag - m_real * rhs.m_imag) / denom); return *this; } complex &operator/=(const T &rhs) { m_real /= rhs; m_imag /= rhs; return *this; } explicit operator bool() const { return m_real || m_imag; } explicit operator T() const { return m_real; } template ::value && !std::is_same::value>::type> explicit operator U() const { return static_cast(m_real); } operator std::complex() const { return std::complex(m_real, m_imag); } }; template struct is_complex> : std::true_type { }; } // namespace dynd namespace std { template struct common_type, bool> { typedef dynd::complex type; }; template struct common_type, dynd::bool1> { typedef dynd::complex type; }; template struct common_type, char> { typedef char type; }; template struct common_type, signed char> { typedef dynd::complex type; }; template struct common_type, unsigned char> { typedef dynd::complex type; }; template struct common_type, short> { typedef dynd::complex type; }; template struct common_type, unsigned short> { typedef dynd::complex type; }; template struct common_type, int> { typedef dynd::complex type; }; template struct common_type, unsigned int> { typedef dynd::int128 type; }; template struct common_type, long> { typedef dynd::complex type; }; template struct common_type, unsigned long> { typedef dynd::complex type; }; template struct common_type, long long> { typedef dynd::complex type; }; template struct common_type, unsigned long long> { typedef dynd::complex type; }; template struct common_type, dynd::int128> { typedef dynd::complex type; }; template struct common_type, dynd::uint128> { typedef dynd::complex type; }; template struct common_type, float> { typedef dynd::complex::type> type; }; template struct common_type, double> { typedef dynd::complex::type> type; }; template struct common_type, dynd::float128> { typedef dynd::complex::type> type; }; template struct common_type, dynd::complex> { typedef dynd::complex::type> type; }; template struct common_type> : common_type, T> { }; } // namespace std namespace dynd { typedef complex complex64; typedef complex complex128; template bool operator==(complex lhs, complex rhs) { return (lhs.m_real == rhs.m_real) && (lhs.m_imag == rhs.m_imag); } template typename std::enable_if::value || std::is_floating_point::value, bool>::type operator==(complex lhs, U rhs) { return (lhs.m_real == rhs) && !lhs.m_imag; } template typename std::enable_if::value || std::is_floating_point::value, bool>::type operator==(T lhs, complex rhs) { return rhs == lhs; } template bool operator==(complex lhs, std::complex rhs) { return (lhs.m_real == rhs.real()) && (lhs.m_imag == rhs.imag()); } template bool operator==(std::complex lhs, complex rhs) { return rhs == lhs; } template bool operator!=(complex lhs, complex rhs) { return !(lhs == rhs); } template bool operator!=(complex lhs, U rhs) { return lhs.m_real == rhs && !lhs.m_imag; } template complex operator+(const complex &rhs) { return complex(+rhs.m_real, +rhs.m_imag); } template complex operator-(const complex &rhs) { return complex(-rhs.m_real, -rhs.m_imag); } template complex operator!(const complex &rhs) { return (!rhs.m_real && !rhs.m_imag); } template complex operator+(complex lhs, complex rhs) { return lhs += rhs; } template complex::type> operator+(complex lhs, complex rhs) { return static_cast::type>>(lhs) + static_cast::type>>(rhs); } template complex operator+(complex lhs, T rhs) { return lhs += rhs; } template typename std::enable_if::value || std::is_floating_point::value, complex::type>>::type operator+(complex lhs, U rhs) { return static_cast::type>>(lhs) + static_cast::type>(rhs); } template complex operator+(T lhs, complex rhs) { return complex(lhs + rhs.m_real, rhs.m_imag); } template typename std::enable_if::value || std::is_floating_point::value, complex::type>>::type operator+(T lhs, complex rhs) { return static_cast::type>(lhs) + static_cast::type>>(rhs); } template complex operator-(complex lhs, complex rhs) { return lhs -= rhs; } template complex::type> operator-(complex lhs, complex rhs) { return static_cast::type>>(lhs) - static_cast::type>>(rhs); } template complex operator-(complex lhs, T rhs) { return lhs -= rhs; } template typename std::enable_if::value || std::is_floating_point::value, complex::type>>::type operator-(complex lhs, U rhs) { return static_cast::type>>(lhs) - static_cast::type>(rhs); } template complex operator-(T lhs, complex rhs) { return complex(lhs - rhs.m_real, -rhs.m_imag); } template typename std::enable_if::value || std::is_floating_point::value, complex::type>>::type operator-(T lhs, complex rhs) { return static_cast::type>(lhs) - static_cast::type>>(rhs); } template complex operator*(complex lhs, complex rhs) { return lhs *= rhs; } template complex::type> operator*(complex lhs, complex rhs) { return static_cast::type>>(lhs) * static_cast::type>>(rhs); } template complex operator*(complex lhs, T rhs) { return lhs *= rhs; } template typename std::enable_if::value || std::is_floating_point::value, complex::type>>::type operator*(complex lhs, U rhs) { return static_cast::type>>(lhs) * static_cast::type>(rhs); } template complex operator*(T lhs, complex rhs) { return complex(lhs * rhs.m_real, lhs * rhs.m_imag); } template typename std::enable_if::value || std::is_floating_point::value, complex::type>>::type operator*(T lhs, complex rhs) { return static_cast::type>(lhs) * static_cast::type>>(rhs); } template complex operator/(complex lhs, complex rhs) { return lhs /= rhs; } template complex::type> operator/(complex lhs, complex rhs) { return static_cast::type>>(lhs) / static_cast::type>>(rhs); } template complex operator/(complex lhs, T rhs) { return lhs /= rhs; } template complex operator/(T lhs, complex rhs) { T denom = rhs.m_real * rhs.m_real + rhs.m_imag * rhs.m_imag; return complex(lhs * rhs.m_real / denom, -lhs * rhs.m_imag / denom); } template std::ostream &operator<<(std::ostream &out, const complex &val) { return (out << "(" << val.m_real << " + " << val.m_imag << "j)"); } template complex _i(); // complex(0, 1) template <> inline complex _i() { return complex(0.0f, 1.0f); } template <> inline complex _i() { return complex(0.0, 1.0); } template complex conj(const complex &z) { return complex(z.real(), -z.imag()); } } // namespace dynd