/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.quantization.quantizer;

import java.io.IOException;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.models.requests.TrainingRequest;
import org.opensearch.knn.quantization.quantizer.BitPacker;
import org.opensearch.knn.quantization.quantizer.Quantizer;
import org.opensearch.knn.quantization.quantizer.QuantizerHelper;
import org.opensearch.knn.quantization.sampler.Sampler;
import org.opensearch.knn.quantization.sampler.SamplerType;
import org.opensearch.knn.quantization.sampler.SamplingFactory;
import oshi.util.tuples.Pair;

public class MultiBitScalarQuantizer
implements Quantizer<float[], byte[]> {
    private final int bitsPerCoordinate;
    private final int samplingSize;
    private final Sampler sampler;
    private static final boolean IS_TRAINING_REQUIRED = true;
    private static final int DEFAULT_SAMPLE_SIZE = 25000;

    public MultiBitScalarQuantizer(int bitsPerCoordinate) {
        this(bitsPerCoordinate, 25000, SamplingFactory.getSampler(SamplerType.RESERVOIR));
    }

    public MultiBitScalarQuantizer(int bitsPerCoordinate, int samplingSize, Sampler sampler) {
        if (bitsPerCoordinate < 2) {
            throw new IllegalArgumentException("bitsPerCoordinate must be greater than or equal to 2 for multibit quantizer.");
        }
        this.bitsPerCoordinate = bitsPerCoordinate;
        this.samplingSize = samplingSize;
        this.sampler = sampler;
    }

    @Override
    public QuantizationState train(TrainingRequest<float[]> trainingRequest) throws IOException {
        int[] sampledIndices = this.sampler.sample(trainingRequest.getTotalNumberOfVectors(), this.samplingSize);
        Pair<float[], float[]> meanAndStdDev = QuantizerHelper.calculateMeanAndStdDev(trainingRequest, sampledIndices);
        float[][] thresholds = this.calculateThresholds((float[])meanAndStdDev.getA(), (float[])meanAndStdDev.getB());
        ScalarQuantizationParams params = this.bitsPerCoordinate == 2 ? new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT) : new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
        return new MultiBitScalarQuantizationState(params, thresholds);
    }

    @Override
    public void quantize(float[] vector, QuantizationState state, QuantizationOutput<byte[]> output) {
        if (vector == null) {
            throw new IllegalArgumentException("Vector to quantize must not be null.");
        }
        this.validateState(state);
        int vectorLength = vector.length;
        MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState)state;
        float[][] thresholds = multiBitState.getThresholds();
        if (thresholds == null || thresholds[0].length != vector.length) {
            throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector.");
        }
        output.prepareQuantizedVector(vectorLength);
        BitPacker.quantizeAndPackBits(vector, thresholds, this.bitsPerCoordinate, output.getQuantizedVector());
    }

    private float[][] calculateThresholds(float[] meanArray, float[] stdDevArray) {
        int dimension = meanArray.length;
        float[][] thresholds = new float[this.bitsPerCoordinate][dimension];
        float coef = this.bitsPerCoordinate + 1;
        for (int i = 0; i < this.bitsPerCoordinate; ++i) {
            float iCoef = -1.0f + (float)(2 * (i + 1)) / coef;
            for (int j = 0; j < dimension; ++j) {
                thresholds[i][j] = meanArray[j] + iCoef * stdDevArray[j];
            }
        }
        return thresholds;
    }

    private void validateState(QuantizationState state) {
        if (!(state instanceof MultiBitScalarQuantizationState)) {
            throw new IllegalArgumentException("Quantization state must be of type MultiBitScalarQuantizationState.");
        }
    }

    public int getBitsPerCoordinate() {
        return this.bitsPerCoordinate;
    }
}

