/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.sparse.algorithm;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.PriorityQueue;
import java.util.stream.Collectors;
import org.apache.commons.collections4.CollectionUtils;
import org.opensearch.neuralsearch.sparse.accessor.SparseVectorReader;
import org.opensearch.neuralsearch.sparse.common.DocWeightIterator;
import org.opensearch.neuralsearch.sparse.common.IteratorWrapper;
import org.opensearch.neuralsearch.sparse.data.DocWeight;
import org.opensearch.neuralsearch.sparse.data.DocumentCluster;
import org.opensearch.neuralsearch.sparse.data.SparseVector;
import org.opensearch.neuralsearch.sparse.quantization.ByteQuantizationUtil;

public class PostingsProcessingUtils {
    public static List<DocWeight> getTopK(List<DocWeight> postings, int K) {
        if (CollectionUtils.isEmpty(postings) || K == 0) {
            return Collections.emptyList();
        }
        if (K >= postings.size()) {
            return postings;
        }
        PriorityQueue<DocWeight> pq = new PriorityQueue<DocWeight>(K, (o1, o2) -> ByteQuantizationUtil.compareUnsignedByte(o1.getWeight(), o2.getWeight()));
        for (DocWeight docWeight : postings) {
            pq.add(docWeight);
            if (pq.size() <= K) continue;
            pq.poll();
        }
        return new ArrayList<DocWeight>(pq);
    }

    public static void summarize(DocumentCluster cluster, SparseVectorReader reader, float summaryPruneRatio) throws IOException {
        HashMap<Integer, Integer> summary = new HashMap<Integer, Integer>();
        DocWeightIterator iterator = cluster.getDisi();
        while (iterator.nextDoc() != Integer.MAX_VALUE) {
            int docId = iterator.docID();
            SparseVector vector = reader.read(docId);
            if (vector == null) continue;
            IteratorWrapper<SparseVector.Item> vectorIterator = vector.iterator();
            while (vectorIterator.hasNext()) {
                SparseVector.Item item = vectorIterator.next();
                if (!summary.containsKey(item.getToken())) {
                    summary.put(item.getToken(), item.getIntWeight());
                    continue;
                }
                summary.put(item.getToken(), Math.max((Integer)summary.get(item.getToken()), item.getIntWeight()));
            }
        }
        List<SparseVector.Item> items = summary.entrySet().stream().map(entry -> new SparseVector.Item((Integer)entry.getKey(), (byte)((Integer)entry.getValue()).intValue())).sorted((o1, o2) -> ByteQuantizationUtil.compareUnsignedByte(o2.getWeight(), o1.getWeight())).collect(Collectors.toList());
        double totalWeight = items.stream().mapToDouble(SparseVector.Item::getIntWeight).sum();
        int weightThreshold = (int)Math.floor(totalWeight * (double)summaryPruneRatio);
        int weightSum = 0;
        int idx = 0;
        for (SparseVector.Item item : items) {
            ++idx;
            if ((weightSum += item.getIntWeight()) <= weightThreshold) continue;
            break;
        }
        items = items.subList(0, idx);
        cluster.setSummary(new SparseVector(items));
    }
}

