/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor.collapse;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.grouping.CollapseTopFieldDocs;
import org.apache.lucene.util.BytesRef;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;
import org.opensearch.neuralsearch.processor.collapse.CollapseDTO;
import org.opensearch.neuralsearch.search.query.HybridQueryFieldDocComparator;

public class CollapseDataCollector<T> {
    @Generated
    private static final Logger log = LogManager.getLogger(CollapseDataCollector.class);
    private final Map<T, FieldDoc> collapseValueToTopDocMap = new HashMap<T, FieldDoc>();
    private final Map<T, Integer> collapseValueToShardMap = new HashMap<T, Integer>();
    private final HybridQueryFieldDocComparator collapseComparator;
    private final Class<T> expectedType;
    private String collapseField;

    public CollapseDataCollector(CollapseDTO collapseDTO) {
        this.collapseComparator = new HybridQueryFieldDocComparator(((CollapseTopFieldDocs)collapseDTO.getCollapseQueryTopDocs().get((int)collapseDTO.getIndexOfFirstNonEmpty()).getTopDocs().getFirst()).fields, Comparator.comparing(scoreDoc -> Float.valueOf(scoreDoc.score)));
        this.expectedType = this.determineExpectedType(collapseDTO);
    }

    private Class<T> determineExpectedType(CollapseDTO collapseDTO) {
        Object firstCollapseValue = ((CollapseTopFieldDocs)collapseDTO.getCollapseQueryTopDocs().get((int)collapseDTO.getIndexOfFirstNonEmpty()).getTopDocs().getFirst()).collapseValues[0];
        if (firstCollapseValue instanceof BytesRef) {
            return BytesRef.class;
        }
        if (firstCollapseValue instanceof Long) {
            return Long.class;
        }
        return null;
    }

    public void collectCollapseData(CollapseDTO collapseDTO) {
        for (int shardIndex = 0; shardIndex < collapseDTO.getCollapseQuerySearchResults().size(); ++shardIndex) {
            CompoundTopDocs updatedCollapseTopDocs = collapseDTO.getCollapseQueryTopDocs().get(shardIndex);
            List<ScoreDoc> updatedCollapseDocs = updatedCollapseTopDocs.getScoreDocs();
            if (updatedCollapseDocs.isEmpty()) continue;
            if (!(updatedCollapseTopDocs.getTopDocs().getFirst() instanceof CollapseTopFieldDocs)) {
                throw new IllegalArgumentException(String.format(Locale.ROOT, "Expected CollapseTopFieldDocs but got: %s", updatedCollapseTopDocs.getTopDocs().getFirst().getClass().getSimpleName()));
            }
            this.collapseField = ((CollapseTopFieldDocs)updatedCollapseTopDocs.getTopDocs().getFirst()).field;
            for (int scoreDocIndex = 0; scoreDocIndex < updatedCollapseDocs.size(); ++scoreDocIndex) {
                ScoreDoc scoreDoc = updatedCollapseDocs.get(scoreDocIndex);
                try {
                    this.processCollapseDoc(scoreDoc, shardIndex);
                    continue;
                }
                catch (ClassCastException | IllegalArgumentException e) {
                    log.error(String.format(Locale.ROOT, "Error processing collapse doc in shard %d: %s", shardIndex, e.getMessage()));
                    throw e;
                }
            }
        }
    }

    private void processCollapseDoc(ScoreDoc scoreDoc, int shardIndex) {
        if (!(scoreDoc instanceof FieldDoc)) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Expected FieldDoc but got: %s", scoreDoc.getClass().getSimpleName()));
        }
        FieldDoc fieldDoc = (FieldDoc)scoreDoc;
        if (fieldDoc.fields == null || fieldDoc.fields.length == 0) {
            log.info("Field doc 'fields' attribute does not contain any values");
            return;
        }
        Object collapseValueObj = fieldDoc.fields[fieldDoc.fields.length - 1];
        if (collapseValueObj != null && this.expectedType != null && !this.expectedType.isInstance(collapseValueObj)) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Expected collapse value of type %s but got: %s", this.expectedType.getSimpleName(), collapseValueObj.getClass().getSimpleName()));
        }
        Object collapseValue = collapseValueObj;
        FieldDoc currentBestFieldDoc = this.collapseValueToTopDocMap.get(collapseValue);
        if (currentBestFieldDoc == null || this.collapseComparator.compare(fieldDoc, currentBestFieldDoc) < 0) {
            Object key = collapseValue instanceof BytesRef ? BytesRef.deepCopyOf((BytesRef)((BytesRef)collapseValue)) : collapseValue;
            this.collapseValueToTopDocMap.put(key, fieldDoc);
            this.collapseValueToShardMap.put(key, shardIndex);
        }
    }

    public List<Map.Entry<T, FieldDoc>> getSortedCollapseEntries() {
        ArrayList<Map.Entry<T, FieldDoc>> collapseEntryList = new ArrayList<Map.Entry<T, FieldDoc>>(this.collapseValueToTopDocMap.entrySet());
        collapseEntryList.sort(Map.Entry.comparingByValue(this.collapseComparator));
        return collapseEntryList;
    }

    public Integer getCollapseShardIndex(T key) {
        if (key != null && this.expectedType != null && !this.expectedType.isInstance(key)) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Expected key of type %s but got: %s", this.expectedType.getSimpleName(), key.getClass().getSimpleName()));
        }
        return this.collapseValueToShardMap.get(key);
    }

    @Generated
    public String getCollapseField() {
        return this.collapseField;
    }
}

