#include "GmmSet.h"

#include "NormalDistribution.h"
#include "Matrix.h"
#include "MathSet.h"
#include "GmmData.h"

#include <stdio.h>
#include <math.h>
#include <float.h>
#include <QtGlobal>

using namespace stand::math;
using namespace stand::math::matrix;

const double GmmSet::DefaultEpsilon = 1.0e-4;

GmmSet::GmmSet(int numMix, int dim)
{
    _gaussians = NULL;
    _pi = NULL;
    _numMix = DefaultMixtureNumber;
    setNumMix(numMix, dim);
}

GmmSet::~GmmSet()
{
    _destroy();
}

void GmmSet::_create(int numMix, int dim)
{
    _destroy();
    _pi = new double[numMix];
    _gaussians = new NormalDistribution[numMix];
    for(int i = 0; i < numMix; i++)
    {
        _gaussians[i].set(dim);
    }
    _numMix = numMix;
    _dim = dim;
}

void GmmSet::setNumMix(int numMix, int dim)
{
    if(numMix <= 0 || dim <= 0)
    {
        return;
    }
    if(numMix == _numMix && _dim == dim)
    {
        return;
    }
    _create(numMix, dim);
}

void GmmSet::_destroy()
{
    delete[] _pi;
    delete[] _gaussians;
    _pi = NULL;
    _gaussians = NULL;
    _numMix = DefaultMixtureNumber;
}

double GmmSet::f(const double *x, int dim)
{
    if(!x || dim != _dim)
    {
        qDebug("GmmSet::f(%d, %d); // invalid args.", x, dim);
        return -1.0;
    }
    if(!_pi || !_gaussians || _numMix <= 0)
    {
        qDebug("GmmSet::f(%d, %d); // GMM not set.", x, dim);
        return -1.0;
    }

    double *p = new double[_numMix];
    for(int i = 0; i < _numMix; i++)
    {
        p[i] = log(_pi[i] + 1.0e-40) + _gaussians[i].f(x, dim);
    }
    double val = logSumExp(p, _numMix);

    delete[] p;
    return val;
}

void GmmSet::estimate(GmmData *data, int length)
{
    if(!data || length <= 0)
    {
        qDebug("GmmSet::estimate(%d, %d); // invalid args.", data, length);
        return;
    }

    int l = 0;
    bool dimCheck = true;
    int dim = data[0].dim();
    for(int i = 0; i < length; i++)
    {
        l += data[i].length();
        dimCheck &= (dim == data[i].dim() && !data[i].empty());
    }
    if(l <= 0)
    {
        qDebug("GmmSet::estimate(); // data is empty.");
        return;
    }
    if(!dimCheck)
    {
        qDebug("GmmSet::estimate(); // data dimension is invalid.");
        return;
    }

    double **x = new double*[l];
    for(int i = 0, c = 0; i < length; i++)
    {
        double **val = data[i].data();
        for(int j = 0; j < data[i].length(); j++)
        {
            x[c] = val[j];
            c++;
        }
    }

    estimate(x, l, dim);

    delete[] x;
}

void GmmSet::estimate(double **x, int length, int dim)
{
    if(!x || length <= 0 || dim <= 0)
    {
        qDebug("GmmSet::estimate(%d, %d, %d); // invalid args.", x, length, dim);
        return;
    }
    _create(_numMix, dim);

    double **mus = new double*[_numMix];
    double ***sigmas = new double**[_numMix];

    // set initial values.
    _setInitialValues(x, length, mus, sigmas);

    // 推定ここから．
    double prev = 0.0;
    int c = 0;
    double **gamma, *N = new double[_numMix];
    CreateMatarix(&gamma, _numMix, length);
    double epsilon = DBL_MAX;
    while(1)
    {
        c++;
        double pres = 0.0;
        for(int n = 0; n < length; n++)
        {
            pres += this->f(x[n], dim);
        }
        qDebug("log probability = %f", pres);
        // 終了条件を超えたら．
        if(epsilon == DBL_MAX)
        {
            epsilon = fabs(pres * DefaultEpsilon);
        }
        if(fabs(pres - prev) <= epsilon || c >= DefaultMaxRepeat)
        {
            break;
        }
        prev = pres;
        // 負担率を求める．
        for(int n = 0; n < length; n++)
        {
            double *buf = new double[_numMix];
            for(int k = 0; k < _numMix; k++)
            {
                buf[k] = log(_pi[k] + 1.0e-40) + _gaussians[k].f(x[n], dim);
            }
            double sum = logSumExp(buf, _numMix);
            for(int k = 0; k < _numMix; k++)
            {
                gamma[k][n] = buf[k] - sum;
            }
            delete[] buf;
        }
        for(int k = 0; k < _numMix; k++)
        {
            N[k] = logSumExp(gamma[k], length);
        }
        // 更新を行う．
        for(int k = 0; k < _numMix; k++)
        {
            // 平均の更新
            double *buf = new double[length];
            for(int j = 0; j < dim; j++)
            {
                mus[k][j] = 0;
                for(int n = 0; n < length; n++)
                {
                    mus[k][j] += exp(gamma[k][n]) * x[n][j];
                }
                mus[k][j] /= exp(N[k]);
            }
            // 寄与率の更新
            _pi[k] = exp(N[k]) / length;

            // 分散共分散行列の更新
            for(int i = 0; i < dim; i++)
            {
                for(int j = 0; j < dim; j++)
                {
                    sigmas[k][i][j] = 0.0;
                    for(int n = 0; n < length; n++)
                    {
                        sigmas[k][i][j] += exp(gamma[k][n]) * (x[n][i] - mus[k][i]) * (x[n][j] - mus[k][j]);
                    }
                    sigmas[k][i][j] /= exp(N[k]);
                }
            }
            _gaussians[k].update();
            delete[] buf;
        }
    }

    DestroyMatrix(&gamma);
    delete[] N;
    delete[] mus;
    delete[] sigmas;
}

void GmmSet::_setInitialValues(double **x, int length, double **mus, double ***sigmas)
{
    for(int i = 0; i < _numMix; i++)
    {
        mus[i] = _gaussians[i].average();
        sigmas[i] = _gaussians[i].convariance();
    }

    // initialize average
    int step = length / _numMix;
    for(int i = 0; i < _numMix; i++)
    {
        average(mus[i], x + i * step, step, _dim);
    }

    // initialize covariance matrix
    for(int i = 0; i < _numMix; i++)
    {
        Covariance(sigmas[i], x + i * step, mus[i], _dim, step);
    }

    for(int i = 0; i < _numMix; i++)
    {
        _pi[i] = 1.0 / _numMix;
    }

    // update all normal distributions
    for(int i = 0; i < _numMix; i++)
    {
        _gaussians[i].update();
    }
}

double GmmSet::logSumExp(double *array, int length)
{
    double maxVal = -DBL_MAX;
    for(int i = 0; i < length; i++)
    {
        if(array[i] > maxVal)
        {
            maxVal = array[i];
        }
    }
    double val = 0.0;
    for(int i = 0; i < length; i++)
    {
        val += exp(array[i] - maxVal);
    }
    val = maxVal + log(val);
    return val;
}

bool GmmSet::write(const char *path)
{
    if(!path || !_gaussians)
    {
        return false;
    }
    FILE *fp = fopen(path, "w");
    if(!fp)
    {
        qDebug("GmmSet::write(%s); // cannot open file", path);
        return false;
    }
    for(int i = 0; i < _numMix; i++)
    {
        fprintf(fp, "Normal Distribution id.%d\n", i);
        fprintf(fp, "pi = %f\n", _pi[i]);
        _gaussians[i].write(fp);
        fprintf(fp, "---\n");
    }
    fclose(fp);
    return true;
}
