/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.search.collector;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.PriorityQueue;
import org.opensearch.neuralsearch.query.HybridSubQueryScorer;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
import org.opensearch.neuralsearch.search.collector.HybridLeafCollector;
import org.opensearch.neuralsearch.search.collector.HybridSearchCollector;

public class HybridTopScoreDocCollector
implements HybridSearchCollector {
    @Generated
    private static final Logger log = LogManager.getLogger(HybridTopScoreDocCollector.class);
    private static final TopDocs EMPTY_TOPDOCS = new TopDocs(new TotalHits(0L, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
    private int docBase;
    private final HitsThresholdChecker hitsThresholdChecker;
    private TotalHits.Relation totalHitsRelation = TotalHits.Relation.EQUAL_TO;
    private int totalHits;
    private int[] collectedHitsPerSubQuery;
    private final int numOfHits;
    private List<PriorityQueue<ScoreDoc>> compoundScores;
    private float maxScore = 0.0f;

    public HybridTopScoreDocCollector(int numHits, HitsThresholdChecker hitsThresholdChecker) {
        this.numOfHits = numHits;
        this.hitsThresholdChecker = hitsThresholdChecker;
    }

    public LeafCollector getLeafCollector(LeafReaderContext context) {
        this.docBase = context.docBase;
        return new HybridTopScoreLeafCollector();
    }

    public ScoreMode scoreMode() {
        return this.hitsThresholdChecker.scoreMode();
    }

    public List<TopDocs> topDocs() {
        if (this.compoundScores == null) {
            return new ArrayList<TopDocs>();
        }
        ArrayList<TopDocs> topDocs = new ArrayList<TopDocs>();
        for (int i = 0; i < this.compoundScores.size(); ++i) {
            topDocs.add(this.topDocsPerQuery(0, Math.min(this.collectedHitsPerSubQuery[i], this.compoundScores.get(i).size()), this.compoundScores.get(i), this.collectedHitsPerSubQuery[i]));
        }
        return topDocs;
    }

    private TopDocs topDocsPerQuery(int start, int howMany, PriorityQueue<ScoreDoc> pq, int totalHits) {
        if (howMany < 0) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Number of hits requested must be greater than 0 but value was %d", howMany));
        }
        if (start < 0) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Expected value of starting position is between 0 and %d, got %d", howMany, start));
        }
        if (start >= howMany || howMany == 0) {
            return EMPTY_TOPDOCS;
        }
        int size = howMany - start;
        ScoreDoc[] results = new ScoreDoc[size];
        this.populateResults(results, size, pq);
        return new TopDocs(new TotalHits((long)totalHits, this.totalHitsRelation), results);
    }

    protected void populateResults(ScoreDoc[] results, int howMany, PriorityQueue<ScoreDoc> pq) {
        for (int i = howMany - 1; i >= 0 && pq.size() > 0; --i) {
            if (i >= results.length) continue;
            results[i] = (ScoreDoc)pq.pop();
        }
    }

    @Override
    @Generated
    public int getTotalHits() {
        return this.totalHits;
    }

    @Override
    @Generated
    public float getMaxScore() {
        return this.maxScore;
    }

    protected class HybridTopScoreLeafCollector
    extends HybridLeafCollector {
        float[] minScoreThresholds;

        protected HybridTopScoreLeafCollector() {
        }

        @Override
        public void setScorer(Scorable scorer) throws IOException {
            super.setScorer(scorer);
            if (Objects.isNull(this.minScoreThresholds)) {
                this.minScoreThresholds = new float[this.getCompoundQueryScorer().getNumOfSubQueries()];
                Arrays.fill(this.minScoreThresholds, Float.MIN_VALUE);
            }
        }

        public void collect(int doc) throws IOException {
            HybridSubQueryScorer compoundQueryScorer = this.getCompoundQueryScorer();
            if (Objects.isNull((Object)compoundQueryScorer)) {
                return;
            }
            this.ensureSubQueryScoreQueues();
            ++HybridTopScoreDocCollector.this.totalHits;
            float[] scores = compoundQueryScorer.getSubQueryScores();
            int docWithBase = doc + HybridTopScoreDocCollector.this.docBase;
            for (int subQueryIndex = 0; subQueryIndex < scores.length; ++subQueryIndex) {
                float score = scores[subQueryIndex];
                if (this.isNonCompetitiveScore(score, subQueryIndex)) continue;
                if (HybridTopScoreDocCollector.this.hitsThresholdChecker.isThresholdReached() && HybridTopScoreDocCollector.this.totalHitsRelation == TotalHits.Relation.EQUAL_TO) {
                    HybridTopScoreDocCollector.this.totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
                }
                int n = subQueryIndex;
                HybridTopScoreDocCollector.this.collectedHitsPerSubQuery[n] = HybridTopScoreDocCollector.this.collectedHitsPerSubQuery[n] + 1;
                PriorityQueue<ScoreDoc> pq = HybridTopScoreDocCollector.this.compoundScores.get(subQueryIndex);
                ScoreDoc currentDoc = new ScoreDoc(docWithBase, score);
                HybridTopScoreDocCollector.this.maxScore = Math.max(currentDoc.score, HybridTopScoreDocCollector.this.maxScore);
                ScoreDoc evictedScoreDoc = (ScoreDoc)pq.insertWithOverflow((Object)currentDoc);
                if (!Objects.nonNull(evictedScoreDoc)) continue;
                float newThresholdScore = evictedScoreDoc.score;
                this.minScoreThresholds[subQueryIndex] = Math.max(this.minScoreThresholds[subQueryIndex], newThresholdScore);
                compoundQueryScorer.getMinScores()[subQueryIndex] = Math.max(compoundQueryScorer.getMinScores()[subQueryIndex], newThresholdScore);
            }
        }

        private boolean isNonCompetitiveScore(float score, int subQueryIndex) {
            return score <= 0.0f && score < this.minScoreThresholds[subQueryIndex];
        }

        private void ensureSubQueryScoreQueues() {
            if (Objects.isNull(HybridTopScoreDocCollector.this.compoundScores)) {
                HybridTopScoreDocCollector.this.compoundScores = new ArrayList<PriorityQueue<ScoreDoc>>(this.compoundQueryScorer.getNumOfSubQueries());
                for (int i = 0; i < this.compoundQueryScorer.getNumOfSubQueries(); ++i) {
                    HybridTopScoreDocCollector.this.compoundScores.add((PriorityQueue<ScoreDoc>)new HitQueue(HybridTopScoreDocCollector.this.numOfHits, false));
                }
                HybridTopScoreDocCollector.this.collectedHitsPerSubQuery = new int[this.compoundQueryScorer.getNumOfSubQueries()];
            }
        }
    }
}

