package libsvm.wrapper;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

import libsvm.svm;
import libsvm.svm_model;

/**
 * <pre>
 * LibSvmの(主な)機能をまとめたもの
 * (svm)
 * </pre>
 *
 * @author hirainaoki
 */
public class Svm {

	/**
	 * バージョン
	 */
	public static final int LIBSVM_VERSION = svm.LIBSVM_VERSION;

	/**
	 * 学習を行い，生成したモデルを返します．
	 *
	 * @param problem 問題
	 * @param parameter 設定パラメータ
	 * @return 生成モデル
	 */
	public static SvmModel train(SvmProblem problem, SvmSettingParameter parameter) {
		return new SvmModel(svm.svm_train(problem.toUsable(), parameter.toUsable()));
	}

	/**
	 * クラスを判定します．
	 *
	 * @param model モデル
	 * @param vector 特徴ベクトル
	 * @return クラス
	 */
	public static double predict(SvmModel model, SvmFeatureVector vector) {
		return svm.svm_predict(model.toUsable(), vector.toUsable());
	}

	/**
	 * 各クラスの確率を取得します。
	 *
	 * @param model モデル
	 * @param vector 特徴ベクトル
	 * @return 各クラスIDの確率
	 */
	public static List<Entry<Integer, Double>> predictProbabilityRanking(SvmModel model, SvmFeatureVector vector) {

		Map<Integer, Double> classIdAndProbabilities = predictProbabilities(model, vector);

		// mapのvalue値でソート
		List<Entry<Integer, Double>> res = new ArrayList<Entry<Integer, Double>>(classIdAndProbabilities.entrySet());
		Collections.sort(res, new Comparator<Entry<Integer, Double>>() {

			public int compare(Entry<Integer, Double> o1, Entry<Integer, Double> o2) {
				return Double.compare(o2.getValue(), o1.getValue());
			}
		});

		return res;
	}

	/**
	 * 各クラスの確率を取得します。
	 *
	 * @param model モデル
	 * @param vector 特徴ベクトル
	 * @return 各クラスIDの確率
	 */
	public static Map<Integer, Double> predictProbabilities(SvmModel model, SvmFeatureVector vector) {

		Map<Integer, Double> res = new HashMap<Integer, Double>();
		svm_model usableModel = model.toUsable();

		double[] probabilities = new double[usableModel.nr_class];
		svm.svm_predict_probability(usableModel, vector.toUsable(), probabilities);

		for (int i = 0; i < probabilities.length; i++) {
			int classId = model.toUsable().label[i];
			res.put(classId, probabilities[i]);
		}
		return res;
	}

	/**
	 * 各クラスの票数を取得します。
	 *
	 * @param model モデル
	 * @param vector 特徴ベクトル
	 * @return 各クラスの票数
	 */
	public static List<Entry<Integer, Integer>> predictVotingRanking(SvmModel model, SvmFeatureVector vector) {

		Map<Integer, Integer> classIdAndVotes = predictVoting(model, vector);

		// mapのvalue値でソート
		List<Entry<Integer, Integer>> res = new ArrayList<Entry<Integer, Integer>>(classIdAndVotes.entrySet());
		Collections.sort(res, new Comparator<Entry<Integer, Integer>>() {

			public int compare(Entry<Integer, Integer> o1, Entry<Integer, Integer> o2) {
				return Double.compare(o2.getValue(), o1.getValue());
			}
		});

		return res;
	}

	/**
	 * 各クラスの票数を取得します。
	 *
	 * @param model モデル
	 * @param vector 特徴ベクトル
	 * @return 各クラスの票数
	 */
	public static Map<Integer, Integer> predictVoting(SvmModel model, SvmFeatureVector vector) {

		Map<Integer, Integer> res = new HashMap<Integer, Integer>();
		svm_model usableModel = model.toUsable();

		int[] votes = new int[usableModel.nr_class];
		svm.svm_predict_votes(usableModel, vector.toUsable(), votes);

		for (int i = 0; i < votes.length; i++) {
			int classId = model.toUsable().label[i];
			res.put(classId, votes[i]);
		}
		return res;
	}

	/**
	 * <pre>
	 * 評価値を取得します．
	 *
	 * ２クラス問題の場合のみ有効
	 * </pre>
	 *
	 * @param model モデル
	 * @param vector 特徴ベクトル
	 * @return 評価値
	 */
	public static double predictValue(SvmModel model, SvmFeatureVector vector) {

		// TODO 返り値をdouble[]にかえる必要ありか？そもそもいらないかもしれない。
		// TODO ラベル数を指定する必要がある。

		double[] values = new double[1];
		svm.svm_predict_values(model.toUsable(), vector.toUsable(), values);

		return values[0];
	}

	/**
	 * モデルを保存します．
	 *
	 * @param filePath ファイルパス
	 * @param model モデル
	 * @throws IOException 書き込みに失敗した場合
	 */
	public static void saveModel(String filePath, SvmModel model) throws IOException {
		svm.svm_save_model(filePath, model.toUsable());
	}

	/**
	 * モデルを読み込みます．
	 *
	 * @param filePath ファイルパス
	 * @return モデル
	 * @throws IOException 読み込みに失敗した場合
	 */
	public static SvmModel loadModel(String filePath) throws IOException {
		return new SvmModel(svm.svm_load_model(filePath));
	}
}
