/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.orc;

import com.facebook.presto.orc.metadata.ColumnEncoding;
import com.facebook.presto.orc.metadata.DwrfSequenceEncoding;
import com.facebook.presto.orc.metadata.OrcType;
import com.facebook.presto.orc.metadata.Stream;
import com.facebook.presto.orc.proto.DwrfProto;
import com.facebook.presto.orc.stream.StreamDataOutput;
import com.google.common.base.Preconditions;
import it.unimi.dsi.fastutil.ints.Int2IntMap;
import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
import it.unimi.dsi.fastutil.ints.Int2LongMap;
import it.unimi.dsi.fastutil.ints.Int2LongOpenHashMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntIterator;
import it.unimi.dsi.fastutil.ints.IntList;
import it.unimi.dsi.fastutil.objects.Object2LongMap;
import it.unimi.dsi.fastutil.objects.Object2LongOpenHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.SortedMap;

public class StreamSizeHelper {
    private final List<OrcType> orcTypes;
    private final boolean collectKeyStats;
    private final Set<Integer> flatMapNodes;
    private final Int2IntMap flatMapNodeTrees;
    private final long[] nodeSizes;
    private final Int2ObjectMap<Object2LongMap<DwrfProto.KeyInfo>> keySizes = new Int2ObjectOpenHashMap();

    public StreamSizeHelper(List<OrcType> orcTypes, Set<Integer> flatMapNodes, boolean mapStatisticsEnabled) {
        this.orcTypes = Objects.requireNonNull(orcTypes, "orcTypes is null");
        this.flatMapNodes = Objects.requireNonNull(flatMapNodes, "flattenedNodes is null");
        this.collectKeyStats = mapStatisticsEnabled && !flatMapNodes.isEmpty();
        this.nodeSizes = new long[orcTypes.size()];
        this.flatMapNodeTrees = this.buildFlattenedNodeTrees();
    }

    private Int2IntMap buildFlattenedNodeTrees() {
        Int2IntOpenHashMap flattenedNodeTrees = new Int2IntOpenHashMap();
        if (!this.collectKeyStats) {
            return flattenedNodeTrees;
        }
        for (Integer mapNode : this.flatMapNodes) {
            OrcType mapType = this.orcTypes.get(mapNode);
            Preconditions.checkArgument((mapType.getOrcTypeKind() == OrcType.OrcTypeKind.MAP ? 1 : 0) != 0, (String)"flat map node %s must be a map, but was %s", (Object)mapNode, (Object)((Object)mapType.getOrcTypeKind()));
            Preconditions.checkArgument((mapType.getFieldCount() == 2 ? 1 : 0) != 0, (String)"flat map node %s must have exactly 2 sub-fields but had %s", (Object)mapNode, (int)mapType.getFieldCount());
            int mapValueNode = mapType.getFieldTypeIndex(1);
            IntList deepValueNodes = StreamSizeHelper.collectDeepTreeNodes(this.orcTypes, mapValueNode);
            deepValueNodes.intStream().forEach(arg_0 -> StreamSizeHelper.lambda$buildFlattenedNodeTrees$0((Int2IntMap)flattenedNodeTrees, mapNode, arg_0));
        }
        return flattenedNodeTrees;
    }

    public void collectStreamSizes(Iterable<StreamDataOutput> streamDataOutputs, Map<Integer, ColumnEncoding> columnEncodings) {
        for (StreamDataOutput streamDataOutput : streamDataOutputs) {
            int node;
            Objects.requireNonNull(streamDataOutput, "streamDataOutput is null");
            Stream stream = streamDataOutput.getStream();
            int n = node = stream.getColumn();
            this.nodeSizes[n] = this.nodeSizes[n] + streamDataOutput.size();
        }
        if (this.collectKeyStats) {
            Int2LongMap sequenceToSize;
            Int2ObjectOpenHashMap flatMapNodeSizes = new Int2ObjectOpenHashMap();
            for (StreamDataOutput streamDataOutput : streamDataOutputs) {
                Stream stream = streamDataOutput.getStream();
                int node = stream.getColumn();
                int flatMapNode = this.flatMapNodeTrees.getOrDefault(node, -1);
                if (flatMapNode == -1) continue;
                sequenceToSize = (Int2LongMap)flatMapNodeSizes.computeIfAbsent(flatMapNode, Int2LongOpenHashMap::new);
                sequenceToSize.mergeLong(stream.getSequence(), (long)stream.getLength(), Long::sum);
            }
            IntIterator intIterator = flatMapNodeSizes.keySet().iterator();
            while (intIterator.hasNext()) {
                int flatMapNode = (Integer)intIterator.next();
                int flatMapValueNode = this.orcTypes.get(flatMapNode).getFieldTypeIndex(1);
                ColumnEncoding columnEncoding = columnEncodings.get(flatMapValueNode);
                Preconditions.checkArgument((columnEncoding != null ? 1 : 0) != 0, (String)"columnEncoding for flat map node %s is null", (int)flatMapNode);
                Preconditions.checkArgument((boolean)columnEncoding.getAdditionalSequenceEncodings().isPresent(), (String)"columnEncoding for flat map node %s does not have keys", (int)flatMapNode);
                SortedMap<Integer, DwrfSequenceEncoding> sequenceToKey = columnEncoding.getAdditionalSequenceEncodings().get();
                sequenceToSize = (Int2LongMap)flatMapNodeSizes.get(flatMapNode);
                Object2LongMap keyToSize = (Object2LongMap)this.keySizes.computeIfAbsent(flatMapNode, ignore -> new Object2LongOpenHashMap());
                for (Map.Entry<Integer, DwrfSequenceEncoding> entry : sequenceToKey.entrySet()) {
                    int sequence = entry.getKey();
                    DwrfProto.KeyInfo key = entry.getValue().getKey();
                    long size = sequenceToSize.getOrDefault(sequence, 0L);
                    keyToSize.mergeLong((Object)key, size, Long::sum);
                }
            }
        }
    }

    public Int2ObjectMap<Object2LongMap<DwrfProto.KeyInfo>> getMapKeySizes() {
        return this.keySizes;
    }

    public Int2LongMap getNodeSizes() {
        Int2LongOpenHashMap result = new Int2LongOpenHashMap(this.nodeSizes.length);
        this.rollupNodeSizes((Int2LongMap)result, 0);
        return result;
    }

    private long rollupNodeSizes(Int2LongMap result, int node) {
        long size = this.nodeSizes[node];
        List<Integer> subFieldIndexes = this.orcTypes.get(node).getFieldTypeIndexes();
        for (Integer subNode : subFieldIndexes) {
            size += this.rollupNodeSizes(result, subNode);
        }
        result.put(node, size);
        return size;
    }

    private static IntList collectDeepTreeNodes(List<OrcType> orcTypes, int startNode) {
        IntArrayList result = new IntArrayList();
        result.add(startNode);
        for (int i = 0; i < result.size(); ++i) {
            int node = result.getInt(i);
            OrcType orcType = orcTypes.get(node);
            for (int j = 0; j < orcType.getFieldCount(); ++j) {
                result.add(orcType.getFieldTypeIndex(j));
            }
        }
        return result;
    }

    private static /* synthetic */ void lambda$buildFlattenedNodeTrees$0(Int2IntMap flattenedNodeTrees, Integer mapNode, int valueNode) {
        flattenedNodeTrees.put(valueNode, mapNode.intValue());
    }
}

