#ifndef __JROTATION3D__
#define __JROTATION3D__

#include <cmath>

#include "JGeometry3D/JMatrix3D.hh"
#include "JGeometry3D/JAngle3D.hh"
#include "JGeometry3D/JVersor3D.hh"
#include "JGeometry3D/JVersor3Z.hh"
#include "JGeometry3D/JQuaternion3D.hh"
#include "JGeometry2D/JRotation2D.hh"
#include "JMath/JMath.hh"


/**
 * \author mdejong
 */

namespace JGEOMETRY3D {}
namespace JPP { using namespace JGEOMETRY3D; }

namespace JGEOMETRY3D {

  using JGEOMETRY2D::JRotation2D;
  using JMATH::JMath;
  

  /**
   * Rotation around X-axis.
   */
  class JRotation3X :
    public JRotation2D
  {
  public:
    /**
     * Default constructor (= identity matrix).
     */
    JRotation3X() :
      JRotation2D()
    {}


    /**
     * Constructor.
     *
     * \param  phi        rotation angle (anti-clock wise) [rad]
     */
    JRotation3X(const double phi) :
      JRotation2D(phi)
    {}
  };


  /**
   * Rotation around Y-axis.
   */
  class JRotation3Y :
    public JRotation2D
  {
  public:
    /**
     * Default constructor (= identity matrix).
     */
    JRotation3Y() :
      JRotation2D()
    {}


    /**
     * Constructor.
     *
     * Note that in the x-z plane the corresponding rotation is clock wise.
     *
     * \param  phi        rotation angle (anti-clock wise) [rad]
     */
    JRotation3Y(const double phi) :
      JRotation2D(-phi)
    {}
  };


  /**
   * Rotation around Z-axis.
   */
  class JRotation3Z :
    public JRotation2D
  {
  public:
    /**
     * Default constructor (= identity matrix).
     */
    JRotation3Z() :
      JRotation2D()
    {}


    /**
     * Constructor.
     *
     * \param  phi        rotation angle (anti-clock wise) [rad]
     */
    JRotation3Z(const double phi) :
      JRotation2D(phi)
    {}
  };


  /**
   * Rotation matrix
   */
  class JRotation3D :
    public JMatrix3D,
    public JMath<JRotation3D>
  {
  public:

    using JMath<JRotation3D>::mul;

    /**
     * Default constructor (= identity matrix).
     */
    JRotation3D() :
      JMatrix3D()
    {
      setIdentity();
    }


    /**
     * Constructor.
     *
     * The matrix is defined such that the rotation of a vector in the given direction ends up along the z-axis
     * and the back rotation of a vector parallel to the z-axis ends up in the given direction.
     * 
     * \param  dir        direction
     */
    JRotation3D(const JAngle3D& dir) :
      JMatrix3D()
    {
      const double ct = cos(dir.getTheta());
      const double st = sin(dir.getTheta());
      const double cp = cos(dir.getPhi());
      const double sp = sin(dir.getPhi());

      a00 =  ct*cp;  a01 =  ct*sp;  a02 = -st;
      a10 = -sp;     a11 =  cp;     a12 = 0.0;
      a20 =  st*cp;  a21 =  st*sp;  a22 = +ct;
    }


    /**
     * Constructor.
     *
     * The matrix is defined such that the rotation of a vector in the given direction ends up along the z-axis
     * and the back rotation of a vector parallel to the z-axis ends up in the given direction.
     * 
     * \param  dir        direction
     */
    JRotation3D(const JVersor3D& dir) :
      JMatrix3D()
    {
      const double ct  = dir.getDZ();
      const double st  = sqrt((1.0 + ct)*(1.0 - ct));
      const double phi = atan2(dir.getDY(), dir.getDX());
      const double cp  = cos(phi);
      const double sp  = sin(phi);

      a00 =  ct*cp;  a01 =  ct*sp;  a02 = -st;
      a10 = -sp;     a11 =  cp;     a12 = 0.0;
      a20 =  st*cp;  a21 =  st*sp;  a22 = +ct;
    }


    /**
     * Constructor.
     *
     * The matrix is defined such that the rotation of a vector in the given direction ends up along the z-axis
     * and the back rotation of a vector parallel to the z-axis ends up in the given direction.
     * 
     * \param  dir        direction
     */
    JRotation3D(const JVersor3Z& dir) :
      JMatrix3D()
    {
      const double ct  = dir.getDZ();
      const double st  = sqrt((1.0 + ct)*(1.0 - ct));
      const double phi = atan2(dir.getDY(), dir.getDX());
      const double cp  = cos(phi);
      const double sp  = sin(phi);

      a00 =  ct*cp;  a01 =  ct*sp;  a02 = -st;
      a10 = -sp;     a11 =  cp;     a12 = 0.0;
      a20 =  st*cp;  a21 =  st*sp;  a22 = +ct;
    }


    /**
     * Constructor.
     * 
     * \param  R          2D rotation matrix around X-axis
     */
    JRotation3D(const JRotation3X& R) :
      JMatrix3D()
    {
      a00 =  1.0;   a01 =  0.0;     a02 =  0.0;
      a10 =  0.0;   a11 =  R.a00;   a12 =  R.a01;
      a20 =  0.0;   a21 =  R.a10;   a22 =  R.a11;
    }


    /**
     * Constructor.
     * 
     * \param  R          2D rotation matrix around Y-axis
     */
    JRotation3D(const JRotation3Y& R) :
      JMatrix3D()
    {
      a00 =  R.a00;  a01 =  0.0;    a02 =  R.a01;
      a10 =  0.0;    a11 =  1.0;    a12 =  0.0;
      a20 =  R.a10;  a21 =  0.0;    a22 =  R.a11;
    }


    /**
     * Constructor.
     * 
     * \param  R          2D rotation matrix around Z-axis
     */
    JRotation3D(const JRotation3Z& R) :
      JMatrix3D()
    {
      a00 =  R.a00;  a01 =  R.a01;  a02 =  0.0;
      a10 =  R.a10;  a11 =  R.a11;  a12 =  0.0;
      a20 =  0.0;    a21 =  0.0;    a22 =  1.0;
    }


    /**
     * Constructor.
     * 
     * \param  Q          quaternion
     */
    JRotation3D(const JQuaternion3D& Q) :
      JMatrix3D()
    {
      const double a2 = Q.getA()*Q.getA();
      const double b2 = Q.getB()*Q.getB();
      const double c2 = Q.getC()*Q.getC();
      const double d2 = Q.getD()*Q.getD();

      const double ab = Q.getA()*Q.getB();
      const double ac = Q.getA()*Q.getC();
      const double ad = Q.getA()*Q.getD();

      const double bc = Q.getB()*Q.getC();
      const double bd = Q.getB()*Q.getD();

      const double cd = Q.getC()*Q.getD();

      a00 =  a2 + b2 - c2 - d2;  a01 =   2.0*bc - 2.0*ad;   a02 =   2.0*bd + 2.0*ac;
      a10 =   2.0*bc + 2.0*ad;   a11 =  a2 - b2 + c2 - d2;  a12 =   2.0*cd - 2.0*ab;
      a20 =   2.0*bd - 2.0*ac;   a21 =   2.0*cd + 2.0*ab;   a22 =  a2 - b2 - c2 + d2;
    }


    /**
     * Get rotation.
     *
     * \return            rotation
     */
    const JRotation3D& getRotation() const
    {
      return static_cast<const JRotation3D&>(*this);
    }


    /**
     * Type conversion operator.
     *
     * \return               quaternion
     */
    operator JQuaternion3D() const
    {
      const double q2 = 0.25 * (1.0 + a00 + a11 + a22);

      if (q2 > 0.0) {

	const double a =  sqrt(q2);
	const double w =  0.25 / a;
	const double b = (a21 - a12) * w;
	const double c = (a02 - a20) * w;
	const double d = (a10 - a01) * w;
	
	return JQuaternion3D(a,b,c,d).normalise();
      }

      return JQuaternion3D(1.0, 0.0, 0.0, 0.0);
    }

    
    /**
     * Transpose.
     */
    JRotation3D& transpose()
    {
      static_cast<JMatrix3D&>(*this).transpose();

      return *this;
    }

    
    /**
     * Matrix multiplication.
     *
     * \param  A        matrix
     * \param  B        matrix
     * \return          this matrix
     */
    JRotation3D& mul(const JRotation3D& A,
		     const JRotation3D& B)
    {
      static_cast<JMatrix3D&>(*this).mul(A, B);

      return *this;
    }      

    
    /**
     * Rotate.
     *
     * \param  __x      x value
     * \param  __y      y value
     * \param  __z      z value
     */
    void rotate(double& __x, double& __y, double& __z) const
    {
      const double x = a00 * __x  +  a01 * __y  +  a02 * __z;
      const double y = a10 * __x  +  a11 * __y  +  a12 * __z;
      const double z = a20 * __x  +  a21 * __y  +  a22 * __z;

      __x = x;
      __y = y;
      __z = z;
    }

    
    /**
     * Rotate back.
     *
     * \param  __x      x value
     * \param  __y      y value
     * \param  __z      z value
     */
    void rotate_back(double& __x, double& __y, double& __z) const
    {
      const double x = a00 * __x  +  a10 * __y  +  a20 * __z;
      const double y = a01 * __x  +  a11 * __y  +  a21 * __z;
      const double z = a02 * __x  +  a12 * __y  +  a22 * __z;

      __x = x;
      __y = y;
      __z = z;
    }
  };
}

#endif