/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor;

import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.env.Environment;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.InferenceProcessor;
import org.opensearch.neuralsearch.util.TokenWeightUtil;
import org.opensearch.neuralsearch.util.prune.PruneType;
import org.opensearch.neuralsearch.util.prune.PruneUtils;

public final class SparseEncodingProcessor
extends InferenceProcessor {
    @Generated
    private static final Logger log = LogManager.getLogger(SparseEncodingProcessor.class);
    public static final String TYPE = "sparse_encoding";
    public static final String LIST_TYPE_NESTED_MAP_KEY = "sparse_encoding";
    private final PruneType pruneType;
    private final float pruneRatio;

    public SparseEncodingProcessor(String tag, String description, int batchSize, String modelId, Map<String, Object> fieldMap, PruneType pruneType, float pruneRatio, MLCommonsClientAccessor clientAccessor, Environment environment, ClusterService clusterService) {
        super(tag, description, batchSize, "sparse_encoding", "sparse_encoding", modelId, fieldMap, clientAccessor, environment, clusterService);
        this.pruneType = pruneType;
        this.pruneRatio = pruneRatio;
    }

    @Override
    public void doExecute(IngestDocument ingestDocument, Map<String, Object> ProcessMap, List<String> inferenceList, BiConsumer<IngestDocument, Exception> handler) {
        this.mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> {
            List sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps).stream().map(vector -> PruneUtils.pruneSparseVector(this.pruneType, this.pruneRatio, vector)).collect(Collectors.toList());
            this.setVectorFieldsToDocument(ingestDocument, ProcessMap, sparseVectors);
            handler.accept(ingestDocument, null);
        }, e -> handler.accept((IngestDocument)null, (Exception)e)));
    }

    @Override
    public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException) {
        this.mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> {
            List sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps).stream().map(vector -> PruneUtils.pruneSparseVector(this.pruneType, this.pruneRatio, vector)).collect(Collectors.toList());
            handler.accept(sparseVectors);
        }, onException));
    }

    @Generated
    public PruneType getPruneType() {
        return this.pruneType;
    }

    @Generated
    public float getPruneRatio() {
        return this.pruneRatio;
    }
}

