﻿#pragma once

#include "tuple4.h"
#include "vector3.h"
#include "matrix4.h"

#include "math_util.h"
#include "math_exception.h"


namespace lm
{

template<class T>
class quat4 : public tuple4<T>
{
public:
	quat4(const quat4<T>& _q)                                 : tuple4<T>( _q.x , _q.y , _q.z , _q.w ) {}
	quat4(const T* _q)                                        : tuple4<T>( _q ) {}
	quat4(const T& _x, const T& _y, const T& _z, const T& _w) : tuple4<T>( _x , _y , _z , _w ) {}
	quat4(const vector3<T>& _v , const T& _w )                : tuple4<T>( _v.x , _v.y , _v.z , _w ) {}
	quat4(const vector3<T>& _v )                              : tuple4<T>( _v.x , _v.y , _v.z , T(1) ) {}
	// 値指定無しの場合は単位クォータニオンで初期化する
	quat4(void)                                               : tuple4<T>( T(0) , T(0) , T(0) , T(1) ) {}


	//! 要素型の変換
	template<typename U>
	operator quat4<U>(void) const
	{
		return quat4<U>( static_cast<U>(x) , static_cast<U>(y) , static_cast<U>(z) , static_cast<U>(w) );
	}


	//! ノムルを取得する
	T norm(void) const;
	//! ノムルの二乗を取得する
	T square_norm(void) const;

	// 虚数成分をvector3形式に変換して取得する
	vector3<T> get_vector(void) const;
	void get_vector( vector3<T>& o_vector ) const;

	// 単位クォータニオン
	void identity(void);
	static const quat4<T>& get_identity(void);

	//! 共役数化
	void conjugate(void);
	//! 共役数を取得
	quat4<T> get_conjugate(void) const;

	//! 逆クォータニオン化
	void inverse(void);
	//! 逆クォータニオンを取得
	quat4<T> get_inverse(void) const;

	//! 正規化
	void normalize(void);
	//! 正規化した値を取得
	quat4<T> get_normalize(void) const;

	quat4<T>& mult(const quat4<T>& _q);
	quat4<T>& operator*=(const quat4<T>& _q);
	//friend quat4<T> operator*(const quat4<T>& _a, const quat4<T>& _b);


	// 現在のクォータニオンが表す位置を、指定した軸、角度で回転
	void rotate_position( const vector3<T>& i_axis , T angle );
	// 現在のクォータニオンが表す位置を、指定した回転クォータニオンで回転
	void rotate_position( const quat4<T>& i_quat );
	// マイナス回転
	void rotate_position_rev( const vector3<T>& i_axis , T angle );
	void rotate_position_rev( const quat4<T>& i_quat );

	// 回転クォータニオンをベクトルに適用する
	void vec_rotate( vector3<T>& o_vec ) const;
	// マイナス回転
	void vec_rotate_rev( vector3<T>& o_vec ) const;


	// 回転クォータニオンの取得
	static quat4<T> get_rotate( T axis_x , T axis_y , T axis_z , T angle );
	static quat4<T> get_rotate( const vector3<T>& i_axis , T angle );

	// 回転クォータニオン化
	void set_rotate( T axis_x , T axis_y , T axis_z , T angle );
	void set_rotate( const vector3<T>& i_axis , T angle );

	// 回転クォータニオンを積算
	void mult_rotate( T axis_x , T axis_y , T axis_z , T angle );
	void mult_rotate( const vector3<T>& i_axis , T angle );
	void mult_rotate_x( T angle );
	void mult_rotate_y( T angle );
	void mult_rotate_z( T angle );

	// 回転クォータニオンを逆から積算
	void mult_rotate_rev( T axis_x , T axis_y , T axis_z , T angle );
	void mult_rotate_rev( const vector3<T>& i_axis , T angle );
	void mult_rotate_x_rev( T angle );
	void mult_rotate_y_rev( T angle );
	void mult_rotate_z_rev( T angle );
	
	// 基底ベクトルから回転クォータニオンをセットする
	void set_rotate_from_coord( const vector3<T>& i_ex , const vector3<T>& i_ey , const vector3<T>& i_ez );


	// 回転クォータニオンが表す現在の姿勢の基底ベクトルを求める
	vector3<T> get_ex(void) const;
	void       get_ex( vector3<T>& o_ex ) const;
	void       get_ex( T* o_ex ) const;
	vector3<T> get_ey(void) const;
	void       get_ey( vector3<T>& o_ey ) const;
	void       get_ey( T* o_ey ) const;
	vector3<T> get_ez(void) const;
	void       get_ez( vector3<T>& o_ez ) const;
	void       get_ez( T* o_ez ) const;
	void get_bases( vector3<T>& o_ex , vector3<T>& o_ey , vector3<T>& o_ez ) const;

	// 回転クォータニオンを直交変換行列に変換する
	matrix4<T> get_ortho_matrix(void) const;
	void get_ortho_matrix( matrix4<T>& ortho_matrix ) const;

	// 回転クォータニオンの成分を求める
	void get_rotate_params( vector3<T>& o_axis , T& angular_velocity ) const;

	// 補間
	void interpolate( const lm::quat4<T>& q , T alpha );
	static quat4<T> interpolate( const lm::quat4<T>& q0 , const lm::quat4<T>& q1 , T alpha );

	// 2つのクォータニオンが示す座標間の角度差が最小になるように
	// 回転方向を補正する.
	static void minimize_rotation(const lm::quat4<T>& quat_src, lm::quat4<T>& quat_dst);
};

typedef quat4<float > quat4f;
typedef quat4<double> quat4d;


// global method

template<typename T> inline
quat4<T> operator*( const quat4<T>& _a , const quat4<T>& _b )
{
	quat4<T> ws = _a;
	ws *= _b;
	return ws;
}


// implement

template<typename T> inline
T quat4<T>::norm(void) const
{
	return sqrt( square_norm() );
}

template<typename T> inline
T quat4<T>::square_norm(void) const
{
	return ( x * x + y * y + z * z + w * w );
}


template<typename T> inline
vector3<T> quat4<T>::get_vector(void) const
{
	return vector3<T>( x , y , z );
}

template<typename T> inline
void quat4<T>::get_vector( vector3<T>& o_vector ) const
{
	o_vector.set( x , y , z );
}


template<typename T> inline
void quat4<T>::identity(void)
{
	this->set( T(0) , T(0) , T(0) , T(1) );
}

template<typename T> inline
const quat4<T>& quat4<T>::get_identity(void)
{
	static quat4<T> q( T(0) , T(0) , T(0) , T(1) );
	return q;
}


template<typename T> inline
void quat4<T>::conjugate(void)
{
	this->set( -x , -y , -z , w );
}

template<typename T> inline
quat4<T> quat4<T>::get_conjugate(void) const
{
	quat4<T> q = (*this);
	q.conjugate();
	return q;
}


template<typename T> inline
void quat4<T>::inverse(void)
{
	T sn = square_norm();
	this->set( -x / sn , -y / sn , -z / sn , w / sn );
}

template<typename T> inline
quat4<T> quat4<T>::get_inverse(void) const
{
	quat4<T> q = (*this);
	q.inverse();
	return q;
}


template<typename T> inline
void quat4<T>::normalize(void)
{
	T n = norm();
	this->set( x / n , y / n , z / n , w / n );
}

template<typename T> inline
quat4<T> quat4<T>::get_normalize(void) const
{
	quat4<T> q = (*this);
	q.normalize();
	return q;
}


template<typename T> inline
quat4<T>& quat4<T>::mult(const quat4<T>& _q)
{
	quat4<T> wq = *this;
	w = ( wq.w * _q.w ) - ( wq.x * _q.x + wq.y * _q.y + wq.z * _q.z );
	x = ( wq.w * _q.x ) + ( wq.x * _q.w ) - ( wq.y * _q.z - wq.z * _q.y );
	y = ( wq.w * _q.y ) + ( wq.y * _q.w ) - ( wq.z * _q.x - wq.x * _q.z );
	z = ( wq.w * _q.z ) + ( wq.z * _q.w ) - ( wq.x * _q.y - wq.y * _q.x );
	return *this;
}

template<typename T> inline
quat4<T>& quat4<T>::operator*=(const quat4<T>& _q)
{
	mult( _q );
	return *this;
}


template<typename T> inline
void quat4<T>::rotate_position( const vector3<T>& i_axis , T angle )
{
	quat4<T> rq;
	rq.set_rotate( i_axis , angle );
	rotate_position( rq );
}

template<typename T> inline
void quat4<T>::rotate_position( const quat4<T>& i_quat )
{
	quat4<T> wq = i_quat;
	wq.conjugate();
	wq *= (*this);
	wq *= i_quat;

	*this = wq;
}

template<typename T> inline
void quat4<T>::rotate_position_rev( const vector3<T>& i_axis , T angle )
{
	quat4<T> rq;
	rq.set_rotate( i_axis , angle );
	rotate_position_rev( rq );
}

template<typename T> inline
void quat4<T>::rotate_position_rev( const quat4<T>& i_quat )
{
	quat4<T> c_quat  = i_quat;
	c_quat.conjugate();
	quat4<T> this_tmp = *this;

	*this  = i_quat;
	*this *= this_tmp;
	*this *= c_quat;
}

template<typename T> inline
void quat4<T>::vec_rotate( vector3<T>& o_vec ) const
{
	quat4<T> wq( o_vec.x , o_vec.y , o_vec.z , T(1.0) );
	wq.rotate_position( *this );
	o_vec.set( wq.x , wq.y , wq.z );
}

template<typename T> inline
void quat4<T>::vec_rotate_rev( vector3<T>& o_vec ) const
{
	quat4<T> wq( o_vec.x , o_vec.y , o_vec.z , T(1.0) );
	wq.rotate_position_rev( *this );
	o_vec.set( wq.x , wq.y , wq.z );
}

template<typename T> inline
quat4<T> quat4<T>::get_rotate( T axis_x , T axis_y , T axis_z , T angle )
{
	T c = cos( angle * 0.5f ) ;
	T s = sin( angle * 0.5f ) ;
	return quat4<T>( axis_x * s , axis_y * s , axis_z * s , c ) ;
}

template<typename T> inline
quat4<T> quat4<T>::get_rotate( const vector3<T>& i_axis , T angle )
{
	T c = cos( angle * 0.5f ) ;
	T s = sin( angle * 0.5f ) ;
	return quat4<T>( i_axis.x * s , i_axis.y * s , i_axis.z * s , c ) ;
}

template<typename T> inline
void quat4<T>::set_rotate( T axis_x , T axis_y , T axis_z , T angle )
{
	T c = cos( angle * 0.5f ) ;
	T s = sin( angle * 0.5f ) ;
	this->set( axis_x * s , axis_y * s , axis_z * s , c ) ;
}

template<typename T> inline
void quat4<T>::set_rotate( const vector3<T>& i_axis , T angle )
{
	set_rotate( i_axis.x , i_axis.y , i_axis.z , angle );
}

template<typename T> inline
void quat4<T>::mult_rotate( T axis_x , T axis_y , T axis_z , T angle )
{
	(*this) *= get_rotate( axis_x , axis_y , axis_z , angle );
}

template<typename T> inline
void quat4<T>::mult_rotate( const vector3<T>& i_axis , T angle )
{
	mult_rotate( i_axis.x , i_axis.y , i_axis.z , angle );
}

template<typename T> inline
void quat4<T>::mult_rotate_x( T angle )
{
	mult_rotate( T(1) , T(0) , T(0) , angle );
}

template<typename T> inline
void quat4<T>::mult_rotate_y( T angle )
{
	mult_rotate( T(0) , T(1) , T(0) , angle );
}

template<typename T> inline
void quat4<T>::mult_rotate_z( T angle )
{
	mult_rotate( T(0) , T(0) , T(1) , angle );
}

template<typename T> inline
void quat4<T>::mult_rotate_rev( T axis_x , T axis_y , T axis_z , T angle )
{
	quat4<T> q;
	q.set_rotate( axis_x , axis_y , axis_z , angle );
	q *= (*this);
	(*this) = q;
}

template<typename T> inline
void quat4<T>::mult_rotate_rev( const vector3<T>& i_axis , T angle )
{
	mult_rotate_rev( i_axis.x , i_axis.y , i_axis.z , angle );
}

template<typename T> inline
void quat4<T>::mult_rotate_x_rev( T angle )
{
	mult_rotate_rev( T(1) , T(0) , T(0) , angle );
}

template<typename T> inline
void quat4<T>::mult_rotate_y_rev( T angle )
{
	mult_rotate_rev( T(0) , T(1) , T(0) , angle );
}

template<typename T> inline
void quat4<T>::mult_rotate_z_rev( T angle )
{
	mult_rotate_rev( T(0) , T(0) , T(1) , angle );
}

template<typename T> inline
void quat4<T>::set_rotate_from_coord( const vector3<T>& i_ex , const vector3<T>& i_ey , const vector3<T>& i_ez )
{
	// http://marupeke296.com/DXG_No58_RotQuaternionTrans.html

	// 最大成分を検索
	T elem[ 4 ]; // 0:x, 1:y, 2:z, 3:w
	elem[ 0 ] = i_ex.x - i_ey.y - i_ez.z + T(1.0);
	elem[ 1 ] = -i_ex.x + i_ey.y - i_ez.z + T(1.0);
	elem[ 2 ] = -i_ex.x - i_ey.y + i_ez.z + T(1.0);
	elem[ 3 ] = i_ex.x + i_ey.y + i_ez.z + T(1.0);

	int biggestIndex = 0;
	for ( int i = 1; i < 4; i++ )
	{
		if ( elem[i] > elem[biggestIndex] )
			biggestIndex = i;
	}

	// 最大要素の値を算出
	T v = sqrt( elem[biggestIndex] ) * T(0.5);
	at(biggestIndex) = v;
	T mult = T(0.25) / v;

	switch ( biggestIndex )
	{
	case 0: // x
		y = (i_ex.y + i_ey.x) * mult;
		z = (i_ez.x + i_ex.z) * mult;
		w = (i_ey.z - i_ez.y) * mult;
		break;
	case 1: // y
		x = (i_ex.y + i_ey.x) * mult;
		z = (i_ey.z + i_ez.y) * mult;
		w = (i_ez.x - i_ex.z) * mult;
		break;
	case 2: // z
		x = (i_ez.x + i_ex.z) * mult;
		y = (i_ey.z + i_ez.y) * mult;
		w = (i_ex.y - i_ey.x) * mult;
		break;
	case 3: // w
		x = (i_ey.z - i_ez.y) * mult;
		y = (i_ez.x - i_ex.z) * mult;
		z = (i_ex.y - i_ey.x) * mult;
		break;
	}

	this->normalize();
}


template<typename T> inline
vector3<T> quat4<T>::get_ex(void) const
{
	vector3<T> v;
	get_ex(v);
	return v;
}

template<typename T> inline
void quat4<T>::get_ex( vector3<T>& o_ex ) const
{
	o_ex.x = T(1) - T(2) * ( y * y + z * z );
	o_ex.y = T(2) * ( x * y + z * w );
	o_ex.z = T(2) * ( z * x - w * y );
	o_ex.normalize();
}

template<typename T> inline
void quat4<T>::get_ex( T* o_ex ) const
{
	vector3<T> v;
	get_ex(v);
	o_ex[0] = v.x;
	o_ex[1] = v.y;
	o_ex[2] = v.z;
}

template<typename T> inline
vector3<T> quat4<T>::get_ey(void) const
{
	vector3<T> v;
	get_ey(v);
	return v;
}

template<typename T> inline
void quat4<T>::get_ey( vector3<T>& o_ey ) const
{
	o_ey.x = T(2) * ( x * y - z * w );
	o_ey.y = T(1) - T(2) * ( z * z + x * x );
	o_ey.z = T(2) * ( y * z + x * w );
	o_ey.normalize();
}

template<typename T> inline
void quat4<T>::get_ey( T* o_ey ) const
{
	vector3<T> v;
	get_ey(v);
	o_ey[0] = v.x;
	o_ey[1] = v.y;
	o_ey[2] = v.z;
}

template<typename T> inline
vector3<T> quat4<T>::get_ez(void) const
{
	vector3<T> v;
	get_ez(v);
	return v;
}

template<typename T> inline
void quat4<T>::get_ez( vector3<T>& o_ez ) const
{
	o_ez.x = T(2) * ( z * x + w * y );
	o_ez.y = T(2) * ( y * z - w * x );
	o_ez.z = T(1) - T(2) * ( x * x + y * y );
	o_ez.normalize();
}

template<typename T> inline
void quat4<T>::get_ez( T* o_ez ) const
{
	vector3<T> v;
	get_ez(v);
	o_ez[0] = v.x;
	o_ez[1] = v.y;
	o_ez[2] = v.z;
}

template<typename T> inline
void quat4<T>::get_bases( vector3<T>& o_ex , vector3<T>& o_ey , vector3<T>& o_ez ) const 
{
	get_ex( o_ex );
	get_ey( o_ey );
	get_ez( o_ez );
}

template<typename T> inline
matrix4<T> quat4<T>::get_ortho_matrix(void) const
{
	matrix4<T> m;
	get_ortho_matrix( m );
	return m;
}

template<typename T> inline
void quat4<T>::get_ortho_matrix( matrix4<T>& ortho_matrix ) const
{
	ortho_matrix.fill( 0.0f );
	get_ex( &ortho_matrix.m00 );
	get_ey( &ortho_matrix.m10 );
	get_ez( &ortho_matrix.m20 );
	ortho_matrix.m33 = 1.0f;
}

template<typename T> inline
void quat4<T>::get_rotate_params( vector3<T>& o_axis , T& angular_velocity ) const
{
	if (w >= T(1))
		angular_velocity = 0.0f;
	else if (w <= -T(1))
		angular_velocity = T(M_PI) * T(2);
	else
		angular_velocity = acos(w) * T(2);

	o_axis.set( x , y , z );
	o_axis.normalize();
}


//! 線形補間
//! @param[in] q     - 補間する対となるクォータニオン
//! @param[in] alpha - 補間位置. 0でthisと一致, 1でqと一致する結果を返す
template<typename T> inline
void quat4<T>::interpolate( const lm::quat4<T>& q , T alpha )
{
	(*this) = interpolate( (*this) , q , alpha );
}

//! 線形補間
//! @param[in] q0    - 始点となるクォータニオン
//! @param[in] q1    - 終点となるクォータニオン
//! @param[in] alpha - 補間位置. 0でq0と一致, 1でq1と一致する結果を返す
template<typename T> inline
quat4<T> quat4<T>::interpolate( const lm::quat4<T>& q0 , const lm::quat4<T>& q1 , T alpha )
{
	lm::quat4<T> tq0 = q0;
	tq0.normalize();

	lm::quat4<T> tq1 = q1;

	float n1 = tq1.norm();
	if( n1 == 0.0f )
		return tq0;

	tq1.x /= n1;
	tq1.y /= n1;
	tq1.z /= n1;
	tq1.w /= n1;

	// 正規化クォータニオンの内積
	float t = tq0.x * tq1.x + tq0.y * tq1.y + tq0.z * tq1.z + tq0.w * tq1.w;

	// 正規化後の内積が1 => q0 == q1
	if(t >= 1.0f)
		return tq0;

	float u = acos(t);

	float sin_t = sin(u);

	float s0 = sin((T(1.0) - alpha) * u) / sin_t;
	float s1 = sin(alpha * u) / sin_t;

	// set values
	tq0.x = s0 * tq0.x + s1 * tq1.x;
	tq0.y = s0 * tq0.y + s1 * tq1.y;
	tq0.z = s0 * tq0.z + s1 * tq1.z;
	tq0.w = s0 * tq0.w + s1 * tq1.w;

	return tq0;
}


// 2つのクォータニオンが示す座標間の角度差が最小になるように
// 回転方向を補正する.
template<typename T> inline
void quat4<T>::minimize_rotation(const quat4<T>& quat_src, quat4<T>& quat_dst)
{
	// 回転量を求める
	quat4<T> qrot = quat_dst;
	qrot *= quat_src.get_conjugate();
	qrot.normalize();

	vector3<T> axis;
	T rot;
	qrot.get_rotate_params(axis, rot);

	// 回転量を180度以下にする
	if(rot <= T(M_PI))
		return;

	rot = T(2) * T(M_PI) - rot;
	axis *= T(-1);
	qrot.set_rotate(axis, rot);

	quat_dst = qrot;
	quat_dst *= quat_src;
}


}
