/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.search.startree;

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.util.DocIdSetBuilder;
import org.apache.lucene.util.FixedBitSet;
import org.opensearch.index.compositeindex.datacube.Dimension;
import org.opensearch.index.compositeindex.datacube.startree.index.StarTreeValues;
import org.opensearch.index.compositeindex.datacube.startree.node.StarTreeNode;
import org.opensearch.index.compositeindex.datacube.startree.node.StarTreeNodeType;
import org.opensearch.index.compositeindex.datacube.startree.utils.iterator.SortedNumericStarTreeValuesIterator;

public class StarTreeFilter {
    private static final Logger logger = LogManager.getLogger(StarTreeFilter.class);

    public static FixedBitSet getStarTreeResult(StarTreeValues starTreeValues, Map<String, Long> predicateEvaluators) throws IOException {
        Map<String, Long> queryMap = predicateEvaluators != null ? predicateEvaluators : Collections.emptyMap();
        StarTreeResult starTreeResult = StarTreeFilter.traverseStarTree(starTreeValues, queryMap);
        FixedBitSet bitSet = new FixedBitSet(starTreeResult.maxMatchedDoc + 1);
        SortedNumericStarTreeValuesIterator starTreeValuesIterator = new SortedNumericStarTreeValuesIterator(starTreeResult.matchedDocIds.build().iterator());
        if (starTreeResult.maxMatchedDoc == -1) {
            return bitSet;
        }
        while (starTreeValuesIterator.nextEntry() != Integer.MAX_VALUE) {
            bitSet.set(starTreeValuesIterator.entryId());
        }
        FixedBitSet tempBitSet = new FixedBitSet(starTreeResult.maxMatchedDoc + 1);
        for (String remainingPredicateColumn : starTreeResult.remainingPredicateColumns) {
            logger.debug("remainingPredicateColumn : {}, maxMatchedDoc : {} ", (Object)remainingPredicateColumn, (Object)starTreeResult.maxMatchedDoc);
            SortedNumericStarTreeValuesIterator ndv = (SortedNumericStarTreeValuesIterator)starTreeValues.getDimensionValuesIterator(remainingPredicateColumn);
            long queryValue = queryMap.get(remainingPredicateColumn);
            tempBitSet.clear(0, starTreeResult.maxMatchedDoc + 1);
            if (bitSet.length() > 0) {
                int entryId = bitSet.nextSetBit(0);
                while (entryId != Integer.MAX_VALUE) {
                    if (ndv.advance(entryId) != Integer.MAX_VALUE) {
                        int valuesCount = ndv.entryValueCount();
                        for (int i = 0; i < valuesCount; ++i) {
                            long value = ndv.nextValue();
                            if (value != queryValue) continue;
                            tempBitSet.set(entryId);
                            break;
                        }
                    }
                    entryId = entryId + 1 < bitSet.length() ? bitSet.nextSetBit(entryId + 1) : Integer.MAX_VALUE;
                }
            }
            bitSet.and(tempBitSet);
        }
        return bitSet;
    }

    private static StarTreeResult traverseStarTree(StarTreeValues starTreeValues, Map<String, Long> queryMap) throws IOException {
        StarTreeNode starTreeNode;
        DocIdSetBuilder docsWithField = new DocIdSetBuilder(starTreeValues.getStarTreeDocumentCount());
        Set<String> globalRemainingPredicateColumns = null;
        StarTreeNode starTree = starTreeValues.getRoot();
        List dimensionNames = starTreeValues.getStarTreeField().getDimensionsOrder().stream().map(Dimension::getField).collect(Collectors.toList());
        boolean foundLeafNode = starTree.isLeaf();
        assert (!foundLeafNode);
        ArrayDeque<StarTreeNode> queue = new ArrayDeque<StarTreeNode>();
        queue.add(starTree);
        int currentDimensionId = -1;
        HashSet<String> remainingPredicateColumns = new HashSet<String>(queryMap.keySet());
        int matchedDocsCountInStarTree = 0;
        int maxDocNum = -1;
        ArrayList<Integer> docIds = new ArrayList<Integer>();
        while ((starTreeNode = (StarTreeNode)queue.poll()) != null) {
            int dimensionId = starTreeNode.getDimensionId();
            if (dimensionId > currentDimensionId) {
                String dimension = (String)dimensionNames.get(dimensionId);
                remainingPredicateColumns.remove(dimension);
                if (foundLeafNode && globalRemainingPredicateColumns == null) {
                    globalRemainingPredicateColumns = new HashSet<String>(remainingPredicateColumns);
                }
                currentDimensionId = dimensionId;
            }
            if (remainingPredicateColumns.isEmpty()) {
                int docId = starTreeNode.getAggregatedDocId();
                docIds.add(docId);
                ++matchedDocsCountInStarTree;
                maxDocNum = Math.max(docId, maxDocNum);
                continue;
            }
            if (starTreeNode.isLeaf()) {
                for (long i = (long)starTreeNode.getStartDocId(); i < (long)starTreeNode.getEndDocId(); ++i) {
                    docIds.add((int)i);
                    ++matchedDocsCountInStarTree;
                    maxDocNum = Math.max((int)i, maxDocNum);
                }
                continue;
            }
            String childDimension = (String)dimensionNames.get(dimensionId + 1);
            StarTreeNode starNode = null;
            if (globalRemainingPredicateColumns == null || !globalRemainingPredicateColumns.contains(childDimension)) {
                starNode = starTreeNode.getChildStarNode();
            }
            if (remainingPredicateColumns.contains(childDimension)) {
                long queryValue = queryMap.get(childDimension);
                StarTreeNode matchingChild = starTreeNode.getChildForDimensionValue(queryValue);
                if (matchingChild == null) continue;
                queue.add(matchingChild);
                foundLeafNode |= matchingChild.isLeaf();
                continue;
            }
            if (starNode != null) {
                queue.add(starNode);
                foundLeafNode |= starNode.isLeaf();
                continue;
            }
            Iterator<? extends StarTreeNode> childrenIterator = starTreeNode.getChildrenIterator();
            while (childrenIterator.hasNext()) {
                StarTreeNode childNode = childrenIterator.next();
                if (childNode.getStarTreeNodeType() == StarTreeNodeType.STAR.getValue()) continue;
                queue.add(childNode);
                foundLeafNode |= childNode.isLeaf();
            }
        }
        DocIdSetBuilder.BulkAdder adder = docsWithField.grow(docIds.size());
        Iterator iterator = docIds.iterator();
        while (iterator.hasNext()) {
            int id = (Integer)iterator.next();
            adder.add(id);
        }
        return new StarTreeResult(docsWithField, globalRemainingPredicateColumns != null ? globalRemainingPredicateColumns : Collections.emptySet(), matchedDocsCountInStarTree, maxDocNum);
    }

    private static class StarTreeResult {
        public final DocIdSetBuilder matchedDocIds;
        public final Set<String> remainingPredicateColumns;
        public final int numOfMatchedDocs;
        public final int maxMatchedDoc;

        public StarTreeResult(DocIdSetBuilder matchedDocIds, Set<String> remainingPredicateColumns, int numOfMatchedDocs, int maxMatchedDoc) {
            this.matchedDocIds = matchedDocIds;
            this.remainingPredicateColumns = remainingPredicateColumns;
            this.numOfMatchedDocs = numOfMatchedDocs;
            this.maxMatchedDoc = maxMatchedDoc;
        }
    }
}

