/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.search;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.PriorityQueue;

public class TopDocs {
    public TotalHits totalHits;
    public ScoreDoc[] scoreDocs;
    private static final Comparator<ScoreDoc> SHARD_INDEX_TIE_BREAKER = Comparator.comparingInt(d -> d.shardIndex);
    private static final Comparator<ScoreDoc> DOC_ID_TIE_BREAKER = Comparator.comparingInt(d -> d.doc);
    private static final Comparator<ScoreDoc> DEFAULT_TIE_BREAKER = SHARD_INDEX_TIE_BREAKER.thenComparing(DOC_ID_TIE_BREAKER);

    public TopDocs(TotalHits totalHits, ScoreDoc[] scoreDocs) {
        this.totalHits = totalHits;
        this.scoreDocs = scoreDocs;
    }

    static boolean tieBreakLessThan(ShardRef first, ScoreDoc firstDoc, ShardRef second, ScoreDoc secondDoc, Comparator<ScoreDoc> tieBreaker) {
        assert (tieBreaker != null);
        int value = tieBreaker.compare(firstDoc, secondDoc);
        if (value == 0) {
            assert (first.hitIndex != second.hitIndex);
            return first.hitIndex < second.hitIndex;
        }
        return value < 0;
    }

    public static TopDocs merge(int topN, TopDocs[] shardHits) {
        return TopDocs.merge(0, topN, shardHits);
    }

    public static TopDocs merge(int start, int topN, TopDocs[] shardHits) {
        return TopDocs.mergeAux(null, start, topN, shardHits, DEFAULT_TIE_BREAKER);
    }

    public static TopDocs merge(int start, int topN, TopDocs[] shardHits, Comparator<ScoreDoc> tieBreaker) {
        return TopDocs.mergeAux(null, start, topN, shardHits, tieBreaker);
    }

    public static TopFieldDocs merge(Sort sort, int topN, TopFieldDocs[] shardHits) {
        return TopDocs.merge(sort, 0, topN, shardHits);
    }

    public static TopFieldDocs merge(Sort sort, int start, int topN, TopFieldDocs[] shardHits) {
        if (sort == null) {
            throw new IllegalArgumentException("sort must be non-null when merging field-docs");
        }
        return (TopFieldDocs)TopDocs.mergeAux(sort, start, topN, shardHits, DEFAULT_TIE_BREAKER);
    }

    public static TopFieldDocs merge(Sort sort, int start, int topN, TopFieldDocs[] shardHits, Comparator<ScoreDoc> tieBreaker) {
        if (sort == null) {
            throw new IllegalArgumentException("sort must be non-null when merging field-docs");
        }
        return (TopFieldDocs)TopDocs.mergeAux(sort, start, topN, shardHits, tieBreaker);
    }

    private static TopDocs mergeAux(Sort sort, int start, int size, TopDocs[] shardHits, Comparator<ScoreDoc> tieBreaker) {
        ScoreDoc[] hits;
        PriorityQueue queue = sort == null ? new ScoreMergeSortQueue(shardHits, tieBreaker) : new MergeSortQueue(sort, shardHits, tieBreaker);
        long totalHitCount = 0L;
        TotalHits.Relation totalHitsRelation = TotalHits.Relation.EQUAL_TO;
        int availHitCount = 0;
        for (int shardIDX = 0; shardIDX < shardHits.length; ++shardIDX) {
            TopDocs shard = shardHits[shardIDX];
            totalHitCount += shard.totalHits.value();
            if (shard.totalHits.relation() == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO) {
                totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
            }
            if (shard.scoreDocs == null || shard.scoreDocs.length <= 0) continue;
            availHitCount += shard.scoreDocs.length;
            queue.add(new ShardRef(shardIDX));
        }
        boolean unsetShardIndex = false;
        if (availHitCount <= start) {
            hits = new ScoreDoc[]{};
        } else {
            hits = new ScoreDoc[Math.min(size, availHitCount - start)];
            int requestedResultWindow = start + size;
            int numIterOnHits = Math.min(availHitCount, requestedResultWindow);
            for (int hitUpto = 0; hitUpto < numIterOnHits; ++hitUpto) {
                assert (queue.size() > 0);
                ShardRef ref = (ShardRef)queue.top();
                ScoreDoc hit = shardHits[ref.shardIndex].scoreDocs[ref.hitIndex++];
                if (hitUpto > 0 && unsetShardIndex != (hit.shardIndex == -1)) {
                    throw new IllegalArgumentException("Inconsistent order of shard indices");
                }
                unsetShardIndex |= hit.shardIndex == -1;
                if (hitUpto >= start) {
                    hits[hitUpto - start] = hit;
                }
                if (ref.hitIndex < shardHits[ref.shardIndex].scoreDocs.length) {
                    queue.updateTop();
                    continue;
                }
                queue.pop();
            }
        }
        TotalHits totalHits = new TotalHits(totalHitCount, totalHitsRelation);
        if (sort == null) {
            return new TopDocs(totalHits, hits);
        }
        return new TopFieldDocs(totalHits, hits, sort.getSort());
    }

    public static TopDocs rrf(int topN, int k, TopDocs[] hits) {
        if (topN < 1) {
            throw new IllegalArgumentException("topN must be >= 1, got " + topN);
        }
        if (k < 1) {
            throw new IllegalArgumentException("k must be >= 1, got " + k);
        }
        Boolean shardIndexSet = null;
        for (TopDocs topDocs : hits) {
            ScoreDoc[] scoreDocArray = topDocs.scoreDocs;
            int n = scoreDocArray.length;
            for (int i = 0; i < n; ++i) {
                boolean thisShardIndexSet;
                ScoreDoc scoreDoc = scoreDocArray[i];
                boolean bl = thisShardIndexSet = scoreDoc.shardIndex != -1;
                if (shardIndexSet == null) {
                    shardIndexSet = thisShardIndexSet;
                    continue;
                }
                if (shardIndexSet == thisShardIndexSet) continue;
                throw new IllegalArgumentException("All hits must either have their ScoreDoc#shardIndex set, or unset (-1), not a mix of both.");
            }
        }
        HashMap<ShardIndexAndDoc, Double> rrfScore = new HashMap<ShardIndexAndDoc, Double>();
        long totalHitCount = 0L;
        for (TopDocs topDoc : hits) {
            totalHitCount = Math.max(totalHitCount, topDoc.totalHits.value());
            for (int i = 0; i < topDoc.scoreDocs.length; ++i) {
                ScoreDoc scoreDoc = topDoc.scoreDocs[i];
                int rank = i + 1;
                double rrfScoreContribution = 1.0 / (double)Math.addExact(k, rank);
                rrfScore.compute(new ShardIndexAndDoc(scoreDoc.shardIndex, scoreDoc.doc), (key2, score) -> (score == null ? 0.0 : score) + rrfScoreContribution);
            }
        }
        ArrayList arrayList = new ArrayList(rrfScore.entrySet());
        arrayList.sort(Map.Entry.comparingByValue().reversed().thenComparing(Map.Entry.comparingByKey(Comparator.comparingInt(ShardIndexAndDoc::doc))).thenComparing(Map.Entry.comparingByKey(Comparator.comparingInt(ShardIndexAndDoc::shardIndex))));
        ScoreDoc[] rrfScoreDocs = new ScoreDoc[Math.min(topN, arrayList.size())];
        for (int i = 0; i < rrfScoreDocs.length; ++i) {
            Map.Entry entry = (Map.Entry)arrayList.get(i);
            int doc = ((ShardIndexAndDoc)entry.getKey()).doc;
            int shardIndex = ((ShardIndexAndDoc)entry.getKey()).shardIndex();
            float score2 = ((Double)entry.getValue()).floatValue();
            rrfScoreDocs[i] = new ScoreDoc(doc, score2, shardIndex);
        }
        TotalHits totalHits = new TotalHits(totalHitCount, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO);
        return new TopDocs(totalHits, rrfScoreDocs);
    }

    private static final class ShardRef {
        final int shardIndex;
        int hitIndex;

        ShardRef(int shardIndex) {
            this.shardIndex = shardIndex;
        }

        public String toString() {
            return "ShardRef(shardIndex=" + this.shardIndex + " hitIndex=" + this.hitIndex + ")";
        }
    }

    private static class ScoreMergeSortQueue
    extends PriorityQueue<ShardRef> {
        final ScoreDoc[][] shardHits;
        final Comparator<ScoreDoc> tieBreakerComparator;

        public ScoreMergeSortQueue(TopDocs[] shardHits, Comparator<ScoreDoc> tieBreakerComparator) {
            super(shardHits.length);
            this.shardHits = new ScoreDoc[shardHits.length][];
            for (int shardIDX = 0; shardIDX < shardHits.length; ++shardIDX) {
                this.shardHits[shardIDX] = shardHits[shardIDX].scoreDocs;
            }
            this.tieBreakerComparator = tieBreakerComparator;
        }

        @Override
        public boolean lessThan(ShardRef first, ShardRef second) {
            assert (first != second);
            ScoreDoc firstScoreDoc = this.shardHits[first.shardIndex][first.hitIndex];
            ScoreDoc secondScoreDoc = this.shardHits[second.shardIndex][second.hitIndex];
            if (firstScoreDoc.score < secondScoreDoc.score) {
                return false;
            }
            if (firstScoreDoc.score > secondScoreDoc.score) {
                return true;
            }
            return TopDocs.tieBreakLessThan(first, firstScoreDoc, second, secondScoreDoc, this.tieBreakerComparator);
        }
    }

    private static class MergeSortQueue
    extends PriorityQueue<ShardRef> {
        final ScoreDoc[][] shardHits;
        final FieldComparator<?>[] comparators;
        final int[] reverseMul;
        final Comparator<ScoreDoc> tieBreaker;

        public MergeSortQueue(Sort sort, TopDocs[] shardHits, Comparator<ScoreDoc> tieBreaker) {
            super(shardHits.length);
            this.shardHits = new ScoreDoc[shardHits.length][];
            this.tieBreaker = tieBreaker;
            for (int shardIDX = 0; shardIDX < shardHits.length; ++shardIDX) {
                ScoreDoc[] shard = shardHits[shardIDX].scoreDocs;
                if (shard == null) continue;
                this.shardHits[shardIDX] = shard;
                for (int hitIDX = 0; hitIDX < shard.length; ++hitIDX) {
                    ScoreDoc sd = shard[hitIDX];
                    if (!(sd instanceof FieldDoc)) {
                        throw new IllegalArgumentException("shard " + shardIDX + " was not sorted by the provided Sort (expected FieldDoc but got ScoreDoc)");
                    }
                    FieldDoc fd = (FieldDoc)sd;
                    if (fd.fields != null) continue;
                    throw new IllegalArgumentException("shard " + shardIDX + " did not set sort field values (FieldDoc.fields is null)");
                }
            }
            SortField[] sortFields = sort.getSort();
            this.comparators = new FieldComparator[sortFields.length];
            this.reverseMul = new int[sortFields.length];
            for (int compIDX = 0; compIDX < sortFields.length; ++compIDX) {
                SortField sortField = sortFields[compIDX];
                this.comparators[compIDX] = sortField.getComparator(1, Pruning.NONE);
                this.reverseMul[compIDX] = sortField.getReverse() ? -1 : 1;
            }
        }

        @Override
        public boolean lessThan(ShardRef first, ShardRef second) {
            assert (first != second);
            FieldDoc firstFD = (FieldDoc)this.shardHits[first.shardIndex][first.hitIndex];
            FieldDoc secondFD = (FieldDoc)this.shardHits[second.shardIndex][second.hitIndex];
            for (int compIDX = 0; compIDX < this.comparators.length; ++compIDX) {
                FieldComparator<?> comp = this.comparators[compIDX];
                int cmp = this.reverseMul[compIDX] * comp.compareValues(firstFD.fields[compIDX], secondFD.fields[compIDX]);
                if (cmp == 0) continue;
                return cmp < 0;
            }
            return TopDocs.tieBreakLessThan(first, firstFD, second, secondFD, this.tieBreaker);
        }
    }

    private record ShardIndexAndDoc(int shardIndex, int doc) {
    }
}

