/*
 * Decompiled with CFR 0.152.
 */
package io.trino.orc;

import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import io.airlift.slice.Slice;
import io.trino.memory.context.AggregatedMemoryContext;
import io.trino.orc.DiskRange;
import io.trino.orc.OrcColumn;
import io.trino.orc.OrcDataSource;
import io.trino.orc.OrcDecompressor;
import io.trino.orc.OrcPredicate;
import io.trino.orc.OrcWriteValidation;
import io.trino.orc.RowGroup;
import io.trino.orc.StreamId;
import io.trino.orc.Stripe;
import io.trino.orc.checkpoint.Checkpoints;
import io.trino.orc.checkpoint.InvalidCheckpointException;
import io.trino.orc.checkpoint.StreamCheckpoint;
import io.trino.orc.metadata.ColumnEncoding;
import io.trino.orc.metadata.ColumnMetadata;
import io.trino.orc.metadata.MetadataReader;
import io.trino.orc.metadata.OrcColumnId;
import io.trino.orc.metadata.OrcType;
import io.trino.orc.metadata.PostScript;
import io.trino.orc.metadata.RowGroupIndex;
import io.trino.orc.metadata.Stream;
import io.trino.orc.metadata.StripeFooter;
import io.trino.orc.metadata.StripeInformation;
import io.trino.orc.metadata.statistics.BloomFilter;
import io.trino.orc.metadata.statistics.ColumnStatistics;
import io.trino.orc.stream.CheckpointInputStreamSource;
import io.trino.orc.stream.InputStreamSource;
import io.trino.orc.stream.InputStreamSources;
import io.trino.orc.stream.OrcChunkLoader;
import io.trino.orc.stream.OrcInputStream;
import io.trino.orc.stream.ValueInputStream;
import io.trino.orc.stream.ValueInputStreamSource;
import io.trino.orc.stream.ValueStreams;
import java.io.IOException;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;

public class StripeReader {
    private final OrcDataSource orcDataSource;
    private final ZoneId legacyFileTimeZone;
    private final Optional<OrcDecompressor> decompressor;
    private final ColumnMetadata<OrcType> types;
    private final PostScript.HiveWriterVersion hiveWriterVersion;
    private final Set<OrcColumnId> includedOrcColumnIds;
    private final OptionalInt rowsInRowGroup;
    private final OrcPredicate predicate;
    private final MetadataReader metadataReader;
    private final Optional<OrcWriteValidation> writeValidation;

    public StripeReader(OrcDataSource orcDataSource, ZoneId legacyFileTimeZone, Optional<OrcDecompressor> decompressor, ColumnMetadata<OrcType> types, Set<OrcColumn> readColumns, OptionalInt rowsInRowGroup, OrcPredicate predicate, PostScript.HiveWriterVersion hiveWriterVersion, MetadataReader metadataReader, Optional<OrcWriteValidation> writeValidation) {
        this.orcDataSource = Objects.requireNonNull(orcDataSource, "orcDataSource is null");
        this.legacyFileTimeZone = Objects.requireNonNull(legacyFileTimeZone, "legacyFileTimeZone is null");
        this.decompressor = Objects.requireNonNull(decompressor, "decompressor is null");
        this.types = Objects.requireNonNull(types, "types is null");
        this.includedOrcColumnIds = StripeReader.getIncludeColumns(Objects.requireNonNull(readColumns, "readColumns is null"));
        this.rowsInRowGroup = rowsInRowGroup;
        this.predicate = Objects.requireNonNull(predicate, "predicate is null");
        this.hiveWriterVersion = Objects.requireNonNull(hiveWriterVersion, "hiveWriterVersion is null");
        this.metadataReader = Objects.requireNonNull(metadataReader, "metadataReader is null");
        this.writeValidation = Objects.requireNonNull(writeValidation, "writeValidation is null");
    }

    public Stripe readStripe(StripeInformation stripe, AggregatedMemoryContext memoryUsage) throws IOException {
        InputStreamSources dictionaryStreamSources;
        Map<StreamId, ValueInputStream<?>> valueStreams;
        StripeFooter stripeFooter = this.readStripeFooter(stripe, memoryUsage);
        ColumnMetadata<ColumnEncoding> columnEncodings = stripeFooter.getColumnEncodings();
        if (this.writeValidation.isPresent()) {
            this.writeValidation.get().validateTimeZone(this.orcDataSource.getId(), stripeFooter.getTimeZone());
        }
        ZoneId fileTimeZone = stripeFooter.getTimeZone();
        HashMap<StreamId, Stream> streams = new HashMap<StreamId, Stream>();
        for (Stream stream : stripeFooter.getStreams()) {
            if (!this.includedOrcColumnIds.contains(stream.getColumnId()) || !StripeReader.isSupportedStreamType(stream, this.types.get(stream.getColumnId()).getOrcTypeKind())) continue;
            streams.put(new StreamId(stream), stream);
        }
        boolean invalidCheckPoint = false;
        if (this.rowsInRowGroup.isPresent() && stripe.getNumberOfRows() > this.rowsInRowGroup.getAsInt()) {
            Set<Integer> selectedRowGroups;
            Map diskRanges = StripeReader.getDiskRanges(stripeFooter.getStreams());
            diskRanges = Maps.filterKeys(diskRanges, (Predicate)Predicates.in(streams.keySet()));
            Map<StreamId, OrcChunkLoader> streamsData = this.readDiskRanges(stripe.getOffset(), diskRanges, memoryUsage);
            Map<OrcColumnId, List<BloomFilter>> bloomFilterIndexes = this.readBloomFilterIndexes(streams, streamsData);
            Map<StreamId, List<RowGroupIndex>> columnIndexes = this.readColumnIndexes(streams, streamsData, bloomFilterIndexes);
            if (this.writeValidation.isPresent()) {
                this.writeValidation.get().validateRowGroupStatistics(this.orcDataSource.getId(), stripe.getOffset(), columnIndexes);
            }
            if ((selectedRowGroups = this.selectRowGroups(stripe, columnIndexes)).isEmpty()) {
                memoryUsage.close();
                return null;
            }
            valueStreams = this.createValueStreams(streams, streamsData, columnEncodings);
            dictionaryStreamSources = this.createDictionaryStreamSources(streams, valueStreams, columnEncodings);
            try {
                List<RowGroup> rowGroups = this.createRowGroups(stripe.getNumberOfRows(), streams, valueStreams, columnIndexes, selectedRowGroups, columnEncodings);
                return new Stripe(stripe.getNumberOfRows(), fileTimeZone, columnEncodings, rowGroups, dictionaryStreamSources);
            }
            catch (InvalidCheckpointException e) {
                invalidCheckPoint = true;
            }
        }
        ImmutableMap.Builder diskRangesBuilder = ImmutableMap.builder();
        for (Map.Entry<StreamId, DiskRange> entry : StripeReader.getDiskRanges(stripeFooter.getStreams()).entrySet()) {
            StreamId streamId = entry.getKey();
            if (!streams.containsKey(streamId)) continue;
            diskRangesBuilder.put(entry);
        }
        ImmutableMap diskRanges = diskRangesBuilder.buildOrThrow();
        Map<StreamId, OrcChunkLoader> streamsData = this.readDiskRanges(stripe.getOffset(), (Map<StreamId, DiskRange>)diskRanges, memoryUsage);
        long minAverageRowBytes = 0L;
        for (Map.Entry entry : streams.entrySet()) {
            if (((StreamId)entry.getKey()).getStreamKind() != Stream.StreamKind.ROW_INDEX) continue;
            List<RowGroupIndex> rowGroupIndexes = this.metadataReader.readRowIndexes(this.hiveWriterVersion, new OrcInputStream(streamsData.get(entry.getKey())));
            Preconditions.checkState((rowGroupIndexes.size() == 1 || invalidCheckPoint ? 1 : 0) != 0, (Object)"expect a single row group or an invalid check point");
            long totalBytes = 0L;
            long totalRows = 0L;
            for (RowGroupIndex rowGroupIndex : rowGroupIndexes) {
                ColumnStatistics columnStatistics = rowGroupIndex.getColumnStatistics();
                if (!columnStatistics.hasMinAverageValueSizeInBytes()) continue;
                totalBytes += columnStatistics.getMinAverageValueSizeInBytes() * columnStatistics.getNumberOfValues();
                totalRows += columnStatistics.getNumberOfValues();
            }
            if (totalRows <= 0L) continue;
            minAverageRowBytes += totalBytes / totalRows;
        }
        valueStreams = this.createValueStreams(streams, streamsData, columnEncodings);
        dictionaryStreamSources = this.createDictionaryStreamSources(streams, valueStreams, columnEncodings);
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (Map.Entry entry : valueStreams.entrySet()) {
            builder.put((Object)((StreamId)entry.getKey()), new ValueInputStreamSource<ValueInputStream>((ValueInputStream)entry.getValue()));
        }
        RowGroup rowGroup = new RowGroup(0, 0L, stripe.getNumberOfRows(), minAverageRowBytes, new InputStreamSources((Map<StreamId, InputStreamSource<?>>)builder.buildOrThrow()));
        return new Stripe(stripe.getNumberOfRows(), fileTimeZone, columnEncodings, (List<RowGroup>)ImmutableList.of((Object)rowGroup), dictionaryStreamSources);
    }

    private static boolean isSupportedStreamType(Stream stream, OrcType.OrcTypeKind orcTypeKind) {
        if (stream.getStreamKind() == Stream.StreamKind.BLOOM_FILTER) {
            switch (orcTypeKind) {
                case STRING: 
                case VARCHAR: 
                case CHAR: {
                    return false;
                }
                case TIMESTAMP: 
                case TIMESTAMP_INSTANT: {
                    return false;
                }
            }
            return true;
        }
        if (stream.getStreamKind() == Stream.StreamKind.BLOOM_FILTER_UTF8) {
            return orcTypeKind != OrcType.OrcTypeKind.CHAR;
        }
        return true;
    }

    private Map<StreamId, OrcChunkLoader> readDiskRanges(long stripeOffset, Map<StreamId, DiskRange> diskRanges, AggregatedMemoryContext memoryUsage) throws IOException {
        ImmutableMap.Builder diskRangesBuilder = ImmutableMap.builder();
        for (Map.Entry<StreamId, DiskRange> entry : diskRanges.entrySet()) {
            DiskRange diskRange = entry.getValue();
            diskRangesBuilder.put((Object)entry.getKey(), (Object)new DiskRange(stripeOffset + diskRange.getOffset(), diskRange.getLength()));
        }
        diskRanges = diskRangesBuilder.buildOrThrow();
        Map streamsData = this.orcDataSource.readFully(diskRanges);
        ImmutableMap.Builder dataBuilder = ImmutableMap.builder();
        for (Map.Entry entry : streamsData.entrySet()) {
            dataBuilder.put((Object)((StreamId)entry.getKey()), (Object)OrcChunkLoader.create(entry.getValue(), this.decompressor, memoryUsage));
        }
        return dataBuilder.buildOrThrow();
    }

    private Map<StreamId, ValueInputStream<?>> createValueStreams(Map<StreamId, Stream> streams, Map<StreamId, OrcChunkLoader> streamsData, ColumnMetadata<ColumnEncoding> columnEncodings) {
        ImmutableMap.Builder valueStreams = ImmutableMap.builder();
        for (Map.Entry<StreamId, Stream> entry : streams.entrySet()) {
            StreamId streamId = entry.getKey();
            Stream stream = entry.getValue();
            ColumnEncoding.ColumnEncodingKind columnEncoding = columnEncodings.get(stream.getColumnId()).getColumnEncodingKind();
            if (StripeReader.isIndexStream(stream) || stream.getLength() == 0) continue;
            OrcChunkLoader chunkLoader = streamsData.get(streamId);
            OrcType.OrcTypeKind columnType = this.types.get(stream.getColumnId()).getOrcTypeKind();
            valueStreams.put((Object)streamId, ValueStreams.createValueStreams(streamId, chunkLoader, columnType, columnEncoding));
        }
        return valueStreams.buildOrThrow();
    }

    private InputStreamSources createDictionaryStreamSources(Map<StreamId, Stream> streams, Map<StreamId, ValueInputStream<?>> valueStreams, ColumnMetadata<ColumnEncoding> columnEncodings) {
        ImmutableMap.Builder dictionaryStreamBuilder = ImmutableMap.builder();
        for (Map.Entry<StreamId, Stream> entry : streams.entrySet()) {
            ValueInputStream<?> valueStream;
            OrcColumnId column;
            ColumnEncoding.ColumnEncodingKind columnEncoding;
            StreamId streamId = entry.getKey();
            Stream stream = entry.getValue();
            if (!StripeReader.isDictionary(stream, columnEncoding = columnEncodings.get(column = stream.getColumnId()).getColumnEncodingKind()) || (valueStream = valueStreams.get(streamId)) == null) continue;
            OrcType.OrcTypeKind columnType = this.types.get(stream.getColumnId()).getOrcTypeKind();
            StreamCheckpoint streamCheckpoint = Checkpoints.getDictionaryStreamCheckpoint(streamId, columnType, columnEncoding);
            CheckpointInputStreamSource streamSource = CheckpointInputStreamSource.createCheckpointStreamSource(valueStream, streamCheckpoint);
            dictionaryStreamBuilder.put((Object)streamId, streamSource);
        }
        return new InputStreamSources((Map<StreamId, InputStreamSource<?>>)dictionaryStreamBuilder.buildOrThrow());
    }

    private List<RowGroup> createRowGroups(int rowsInStripe, Map<StreamId, Stream> streams, Map<StreamId, ValueInputStream<?>> valueStreams, Map<StreamId, List<RowGroupIndex>> columnIndexes, Set<Integer> selectedRowGroups, ColumnMetadata<ColumnEncoding> encodings) throws InvalidCheckpointException {
        int rowsInRowGroup = this.rowsInRowGroup.orElseThrow(() -> new IllegalStateException("Cannot create row groups if row group info is missing"));
        ImmutableList.Builder rowGroupBuilder = ImmutableList.builder();
        for (int rowGroupId : selectedRowGroups) {
            Map<StreamId, StreamCheckpoint> checkpoints = Checkpoints.getStreamCheckpoints(this.includedOrcColumnIds, this.types, this.decompressor.isPresent(), rowGroupId, encodings, streams, columnIndexes);
            int rowOffset = rowGroupId * rowsInRowGroup;
            int rowsInGroup = Math.min(rowsInStripe - rowOffset, rowsInRowGroup);
            long minAverageRowBytes = columnIndexes.entrySet().stream().mapToLong(e -> ((RowGroupIndex)((List)e.getValue()).get(rowGroupId)).getColumnStatistics().getMinAverageValueSizeInBytes()).sum();
            rowGroupBuilder.add((Object)StripeReader.createRowGroup(rowGroupId, rowOffset, rowsInGroup, minAverageRowBytes, valueStreams, checkpoints));
        }
        return rowGroupBuilder.build();
    }

    private static RowGroup createRowGroup(int groupId, int rowOffset, int rowCount, long minAverageRowBytes, Map<StreamId, ValueInputStream<?>> valueStreams, Map<StreamId, StreamCheckpoint> checkpoints) {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (Map.Entry<StreamId, StreamCheckpoint> entry : checkpoints.entrySet()) {
            StreamId streamId = entry.getKey();
            StreamCheckpoint checkpoint = entry.getValue();
            ValueInputStream<?> valueStream = valueStreams.get(streamId);
            if (valueStream == null) continue;
            builder.put((Object)streamId, CheckpointInputStreamSource.createCheckpointStreamSource(valueStream, checkpoint));
        }
        InputStreamSources rowGroupStreams = new InputStreamSources((Map<StreamId, InputStreamSource<?>>)builder.buildOrThrow());
        return new RowGroup(groupId, rowOffset, rowCount, minAverageRowBytes, rowGroupStreams);
    }

    private StripeFooter readStripeFooter(StripeInformation stripe, AggregatedMemoryContext memoryUsage) throws IOException {
        long offset = stripe.getOffset() + stripe.getIndexLength() + stripe.getDataLength();
        int tailLength = Math.toIntExact(stripe.getFooterLength());
        Slice tailBuffer = this.orcDataSource.readFully(offset, tailLength);
        try (OrcInputStream inputStream = new OrcInputStream(OrcChunkLoader.create(this.orcDataSource.getId(), tailBuffer, this.decompressor, memoryUsage));){
            StripeFooter stripeFooter = this.metadataReader.readStripeFooter(this.types, inputStream, this.legacyFileTimeZone);
            return stripeFooter;
        }
    }

    static boolean isIndexStream(Stream stream) {
        return stream.getStreamKind() == Stream.StreamKind.ROW_INDEX || stream.getStreamKind() == Stream.StreamKind.DICTIONARY_COUNT || stream.getStreamKind() == Stream.StreamKind.BLOOM_FILTER || stream.getStreamKind() == Stream.StreamKind.BLOOM_FILTER_UTF8;
    }

    private Map<OrcColumnId, List<BloomFilter>> readBloomFilterIndexes(Map<StreamId, Stream> streams, Map<StreamId, OrcChunkLoader> streamsData) throws IOException {
        OrcInputStream inputStream;
        Stream stream;
        HashMap<OrcColumnId, List<BloomFilter>> bloomFilters = new HashMap<OrcColumnId, List<BloomFilter>>();
        for (Map.Entry<StreamId, Stream> entry : streams.entrySet()) {
            stream = entry.getValue();
            if (stream.getStreamKind() != Stream.StreamKind.BLOOM_FILTER_UTF8) continue;
            inputStream = new OrcInputStream(streamsData.get(entry.getKey()));
            bloomFilters.put(stream.getColumnId(), this.metadataReader.readBloomFilterIndexes(inputStream));
        }
        for (Map.Entry<StreamId, Stream> entry : streams.entrySet()) {
            stream = entry.getValue();
            if (stream.getStreamKind() != Stream.StreamKind.BLOOM_FILTER || bloomFilters.containsKey(stream.getColumnId())) continue;
            inputStream = new OrcInputStream(streamsData.get(entry.getKey()));
            bloomFilters.put(entry.getKey().getColumnId(), this.metadataReader.readBloomFilterIndexes(inputStream));
        }
        return ImmutableMap.copyOf(bloomFilters);
    }

    private Map<StreamId, List<RowGroupIndex>> readColumnIndexes(Map<StreamId, Stream> streams, Map<StreamId, OrcChunkLoader> streamsData, Map<OrcColumnId, List<BloomFilter>> bloomFilterIndexes) throws IOException {
        ImmutableMap.Builder columnIndexes = ImmutableMap.builder();
        for (Map.Entry<StreamId, Stream> entry : streams.entrySet()) {
            Stream stream = entry.getValue();
            if (stream.getStreamKind() != Stream.StreamKind.ROW_INDEX) continue;
            OrcInputStream inputStream = new OrcInputStream(streamsData.get(entry.getKey()));
            List<BloomFilter> bloomFilters = bloomFilterIndexes.get(entry.getKey().getColumnId());
            ImmutableList rowGroupIndexes = this.metadataReader.readRowIndexes(this.hiveWriterVersion, inputStream);
            if (bloomFilters != null && !bloomFilters.isEmpty()) {
                ImmutableList.Builder newRowGroupIndexes = ImmutableList.builder();
                for (int i = 0; i < rowGroupIndexes.size(); ++i) {
                    RowGroupIndex rowGroupIndex = rowGroupIndexes.get(i);
                    ColumnStatistics columnStatistics = rowGroupIndex.getColumnStatistics().withBloomFilter(bloomFilters.get(i));
                    newRowGroupIndexes.add((Object)new RowGroupIndex(rowGroupIndex.getPositions(), columnStatistics));
                }
                rowGroupIndexes = newRowGroupIndexes.build();
            }
            columnIndexes.put((Object)entry.getKey(), rowGroupIndexes);
        }
        return columnIndexes.buildOrThrow();
    }

    private Set<Integer> selectRowGroups(StripeInformation stripe, Map<StreamId, List<RowGroupIndex>> columnIndexes) {
        int rowsInRowGroup = this.rowsInRowGroup.orElseThrow(() -> new IllegalStateException("Cannot create row groups if row group info is missing"));
        int rowsInStripe = stripe.getNumberOfRows();
        int groupsInStripe = StripeReader.ceil(rowsInStripe, rowsInRowGroup);
        ImmutableSet.Builder selectedRowGroups = ImmutableSet.builder();
        int remainingRows = rowsInStripe;
        for (int rowGroup = 0; rowGroup < groupsInStripe; ++rowGroup) {
            ColumnMetadata<ColumnStatistics> statistics;
            int rows = Math.min(remainingRows, rowsInRowGroup);
            if (this.predicate.matches(rows, statistics = StripeReader.getRowGroupStatistics(this.types, columnIndexes, rowGroup))) {
                selectedRowGroups.add((Object)rowGroup);
            }
            remainingRows -= rows;
        }
        return selectedRowGroups.build();
    }

    private static ColumnMetadata<ColumnStatistics> getRowGroupStatistics(ColumnMetadata<OrcType> types, Map<StreamId, List<RowGroupIndex>> columnIndexes, int rowGroup) {
        Objects.requireNonNull(columnIndexes, "columnIndexes is null");
        Preconditions.checkArgument((rowGroup >= 0 ? 1 : 0) != 0, (Object)"rowGroup is negative");
        Map rowGroupIndexesByColumn = (Map)columnIndexes.entrySet().stream().collect(ImmutableMap.toImmutableMap(entry -> ((StreamId)entry.getKey()).getColumnId().getId(), Map.Entry::getValue));
        ArrayList<ColumnStatistics> statistics = new ArrayList<ColumnStatistics>(types.size());
        for (int columnIndex = 0; columnIndex < types.size(); ++columnIndex) {
            List rowGroupIndexes = (List)rowGroupIndexesByColumn.get(columnIndex);
            if (rowGroupIndexes != null) {
                statistics.add(((RowGroupIndex)rowGroupIndexes.get(rowGroup)).getColumnStatistics());
                continue;
            }
            statistics.add(null);
        }
        return new ColumnMetadata<ColumnStatistics>(statistics);
    }

    private static boolean isDictionary(Stream stream, ColumnEncoding.ColumnEncodingKind columnEncoding) {
        return stream.getStreamKind() == Stream.StreamKind.DICTIONARY_DATA || stream.getStreamKind() == Stream.StreamKind.LENGTH && (columnEncoding == ColumnEncoding.ColumnEncodingKind.DICTIONARY || columnEncoding == ColumnEncoding.ColumnEncodingKind.DICTIONARY_V2);
    }

    private static Map<StreamId, DiskRange> getDiskRanges(List<Stream> streams) {
        ImmutableMap.Builder streamDiskRanges = ImmutableMap.builder();
        long stripeOffset = 0L;
        for (Stream stream : streams) {
            int streamLength = stream.getLength();
            if (streamLength > 0) {
                streamDiskRanges.put((Object)new StreamId(stream), (Object)new DiskRange(stripeOffset, streamLength));
            }
            stripeOffset += (long)streamLength;
        }
        return streamDiskRanges.buildOrThrow();
    }

    private static Set<OrcColumnId> getIncludeColumns(Set<OrcColumn> includedColumns) {
        LinkedHashSet<OrcColumnId> result = new LinkedHashSet<OrcColumnId>();
        StripeReader.includeColumnsRecursive(result, includedColumns);
        return result;
    }

    private static void includeColumnsRecursive(Set<OrcColumnId> result, Collection<OrcColumn> readColumns) {
        for (OrcColumn column : readColumns) {
            result.add(column.getColumnId());
            StripeReader.includeColumnsRecursive(result, column.getNestedColumns());
        }
    }

    private static int ceil(int dividend, int divisor) {
        return (dividend + divisor - 1) / divisor;
    }
}

