/*
 * Decompiled with CFR 0.152.
 */
package io.trino.parquet;

import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.airlift.slice.SizeOf;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.airlift.slice.XxHash64;
import io.trino.parquet.ColumnStatisticsValidation;
import io.trino.parquet.ParquetCorruptionException;
import io.trino.parquet.ParquetDataSourceId;
import io.trino.parquet.ParquetValidationUtils;
import io.trino.parquet.ValidationHash;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.type.Type;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import org.apache.parquet.column.ColumnDescriptor;
import org.apache.parquet.column.statistics.Statistics;
import org.apache.parquet.format.ColumnChunk;
import org.apache.parquet.format.ColumnMetaData;
import org.apache.parquet.format.Encoding;
import org.apache.parquet.format.RowGroup;
import org.apache.parquet.format.converter.ParquetMetadataConverter;
import org.apache.parquet.hadoop.metadata.BlockMetaData;
import org.apache.parquet.hadoop.metadata.ColumnChunkMetaData;
import org.apache.parquet.hadoop.metadata.ColumnPath;
import org.apache.parquet.internal.hadoop.metadata.IndexReference;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.PrimitiveType;

public class ParquetWriteValidation {
    private static final ParquetMetadataConverter METADATA_CONVERTER = new ParquetMetadataConverter();
    private final String createdBy;
    private final Optional<String> timeZoneId;
    private final List<ColumnDescriptor> columns;
    private final List<RowGroup> rowGroups;
    private final WriteChecksum checksum;
    private final List<Type> types;
    private final List<String> columnNames;

    private ParquetWriteValidation(String createdBy, Optional<String> timeZoneId, List<ColumnDescriptor> columns, List<RowGroup> rowGroups, WriteChecksum checksum, List<Type> types, List<String> columnNames) {
        this.createdBy = Objects.requireNonNull(createdBy, "createdBy is null");
        Preconditions.checkArgument((!createdBy.isEmpty() ? 1 : 0) != 0, (Object)"createdBy is empty");
        this.timeZoneId = Objects.requireNonNull(timeZoneId, "timeZoneId is null");
        this.columns = Objects.requireNonNull(columns, "columnPaths is null");
        this.rowGroups = Objects.requireNonNull(rowGroups, "rowGroups is null");
        this.checksum = Objects.requireNonNull(checksum, "checksum is null");
        this.types = Objects.requireNonNull(types, "types is null");
        this.columnNames = Objects.requireNonNull(columnNames, "columnNames is null");
    }

    public String getCreatedBy() {
        return this.createdBy;
    }

    public List<Type> getTypes() {
        return this.types;
    }

    public List<String> getColumnNames() {
        return this.columnNames;
    }

    public void validateTimeZone(ParquetDataSourceId dataSourceId, Optional<String> actualTimeZoneId) throws ParquetCorruptionException {
        ParquetValidationUtils.validateParquet(this.timeZoneId.equals(actualTimeZoneId), dataSourceId, "Found unexpected time zone %s, expected %s", actualTimeZoneId, this.timeZoneId);
    }

    public void validateColumns(ParquetDataSourceId dataSourceId, MessageType schema) throws ParquetCorruptionException {
        List actualColumns = schema.getColumns();
        ParquetValidationUtils.validateParquet(actualColumns.size() == this.columns.size(), dataSourceId, "Found columns %s, expected %s", actualColumns, this.columns);
        for (int columnIndex = 0; columnIndex < this.columns.size(); ++columnIndex) {
            ParquetWriteValidation.validateColumnDescriptorsSame((ColumnDescriptor)actualColumns.get(columnIndex), this.columns.get(columnIndex), dataSourceId);
        }
    }

    public void validateBlocksMetadata(ParquetDataSourceId dataSourceId, List<BlockMetaData> blocksMetaData) throws ParquetCorruptionException {
        ParquetValidationUtils.validateParquet(blocksMetaData.size() == this.rowGroups.size(), dataSourceId, "Number of row groups %d did not match %d", blocksMetaData.size(), this.rowGroups.size());
        for (int rowGroupIndex = 0; rowGroupIndex < blocksMetaData.size(); ++rowGroupIndex) {
            BlockMetaData block = blocksMetaData.get(rowGroupIndex);
            RowGroup rowGroup = this.rowGroups.get(rowGroupIndex);
            ParquetValidationUtils.validateParquet(block.getRowCount() == rowGroup.getNum_rows(), dataSourceId, "Number of rows %d in row group %d did not match %d", block.getRowCount(), rowGroupIndex, rowGroup.getNum_rows());
            List columnChunkMetaData = block.getColumns();
            ParquetValidationUtils.validateParquet(columnChunkMetaData.size() == rowGroup.getColumnsSize(), dataSourceId, "Number of columns %d in row group %d did not match %d", columnChunkMetaData.size(), rowGroupIndex, rowGroup.getColumnsSize());
            for (int columnIndex = 0; columnIndex < columnChunkMetaData.size(); ++columnIndex) {
                ColumnChunkMetaData actualColumnMetadata = (ColumnChunkMetaData)columnChunkMetaData.get(columnIndex);
                ColumnChunk columnChunk = (ColumnChunk)rowGroup.getColumns().get(columnIndex);
                ColumnMetaData expectedColumnMetadata = columnChunk.getMeta_data();
                ParquetWriteValidation.verifyColumnMetadataMatch(actualColumnMetadata.getCodec().getParquetCompressionCodec().equals((Object)expectedColumnMetadata.getCodec()), "Compression codec", actualColumnMetadata.getCodec(), actualColumnMetadata.getPath(), rowGroupIndex, dataSourceId, expectedColumnMetadata.getCodec());
                ParquetWriteValidation.verifyColumnMetadataMatch(actualColumnMetadata.getPrimitiveType().getPrimitiveTypeName().equals((Object)METADATA_CONVERTER.getPrimitive(expectedColumnMetadata.getType())), "Type", actualColumnMetadata.getPrimitiveType().getPrimitiveTypeName(), actualColumnMetadata.getPath(), rowGroupIndex, dataSourceId, expectedColumnMetadata.getType());
                ParquetWriteValidation.verifyColumnMetadataMatch(ParquetWriteValidation.areEncodingsSame(actualColumnMetadata.getEncodings(), expectedColumnMetadata.getEncodings()), "Encodings", actualColumnMetadata.getEncodings(), actualColumnMetadata.getPath(), rowGroupIndex, dataSourceId, expectedColumnMetadata.getEncodings());
                ParquetWriteValidation.verifyColumnMetadataMatch(ParquetWriteValidation.areStatisticsSame(actualColumnMetadata.getStatistics(), expectedColumnMetadata.getStatistics()), "Statistics", actualColumnMetadata.getStatistics(), actualColumnMetadata.getPath(), rowGroupIndex, dataSourceId, expectedColumnMetadata.getStatistics());
                ParquetWriteValidation.verifyColumnMetadataMatch(actualColumnMetadata.getFirstDataPageOffset() == expectedColumnMetadata.getData_page_offset(), "Data page offset", actualColumnMetadata.getFirstDataPageOffset(), actualColumnMetadata.getPath(), rowGroupIndex, dataSourceId, expectedColumnMetadata.getData_page_offset());
                ParquetWriteValidation.verifyColumnMetadataMatch(actualColumnMetadata.getDictionaryPageOffset() == expectedColumnMetadata.getDictionary_page_offset(), "Dictionary page offset", actualColumnMetadata.getDictionaryPageOffset(), actualColumnMetadata.getPath(), rowGroupIndex, dataSourceId, expectedColumnMetadata.getDictionary_page_offset());
                ParquetWriteValidation.verifyColumnMetadataMatch(actualColumnMetadata.getValueCount() == expectedColumnMetadata.getNum_values(), "Value count", actualColumnMetadata.getValueCount(), actualColumnMetadata.getPath(), rowGroupIndex, dataSourceId, expectedColumnMetadata.getNum_values());
                ParquetWriteValidation.verifyColumnMetadataMatch(actualColumnMetadata.getTotalUncompressedSize() == expectedColumnMetadata.getTotal_uncompressed_size(), "Total uncompressed size", actualColumnMetadata.getTotalUncompressedSize(), actualColumnMetadata.getPath(), rowGroupIndex, dataSourceId, expectedColumnMetadata.getTotal_uncompressed_size());
                ParquetWriteValidation.verifyColumnMetadataMatch(actualColumnMetadata.getTotalSize() == expectedColumnMetadata.getTotal_compressed_size(), "Total size", actualColumnMetadata.getTotalSize(), actualColumnMetadata.getPath(), rowGroupIndex, dataSourceId, expectedColumnMetadata.getTotal_compressed_size());
                IndexReferenceValidation expectedColumnIndexReference = new IndexReferenceValidation(columnChunk.getColumn_index_offset(), columnChunk.getColumn_index_length());
                IndexReference actualColumnIndexReference = actualColumnMetadata.getColumnIndexReference();
                ParquetWriteValidation.verifyColumnMetadataMatch(actualColumnIndexReference == null || IndexReferenceValidation.fromIndexReference(actualColumnMetadata.getColumnIndexReference()).equals(expectedColumnIndexReference), "Column index reference", actualColumnIndexReference, actualColumnMetadata.getPath(), rowGroupIndex, dataSourceId, expectedColumnIndexReference);
                IndexReferenceValidation expectedOffsetIndexReference = new IndexReferenceValidation(columnChunk.getOffset_index_offset(), columnChunk.getOffset_index_length());
                IndexReference actualOffsetIndexReference = actualColumnMetadata.getOffsetIndexReference();
                ParquetWriteValidation.verifyColumnMetadataMatch(actualOffsetIndexReference == null || IndexReferenceValidation.fromIndexReference(actualOffsetIndexReference).equals(expectedOffsetIndexReference), "Offset index reference", actualOffsetIndexReference, actualColumnMetadata.getPath(), rowGroupIndex, dataSourceId, expectedOffsetIndexReference);
            }
        }
    }

    public void validateChecksum(ParquetDataSourceId dataSourceId, WriteChecksum actualChecksum) throws ParquetCorruptionException {
        ParquetValidationUtils.validateParquet(this.checksum.totalRowCount() == actualChecksum.totalRowCount(), dataSourceId, "Write validation failed: Expected row count %d, found %d", this.checksum.totalRowCount(), actualChecksum.totalRowCount());
        List<Long> columnHashes = actualChecksum.columnHashes();
        for (int columnIndex = 0; columnIndex < columnHashes.size(); ++columnIndex) {
            long expectedHash = this.checksum.columnHashes().get(columnIndex);
            ParquetValidationUtils.validateParquet(expectedHash == columnHashes.get(columnIndex), dataSourceId, "Invalid checksum for column %s: Expected hash %d, found %d", columnIndex, expectedHash, columnHashes.get(columnIndex));
        }
    }

    public void validateRowGroupStatistics(ParquetDataSourceId dataSourceId, BlockMetaData blockMetaData, List<ColumnStatisticsValidation.ColumnStatistics> actualColumnStatistics) throws ParquetCorruptionException {
        List columnChunks = blockMetaData.getColumns();
        Preconditions.checkArgument((columnChunks.size() == actualColumnStatistics.size() ? 1 : 0) != 0, (String)"Column chunk metadata count %s did not match column fields count %s", (int)columnChunks.size(), (int)actualColumnStatistics.size());
        for (int columnIndex = 0; columnIndex < columnChunks.size(); ++columnIndex) {
            ColumnChunkMetaData columnMetaData = (ColumnChunkMetaData)columnChunks.get(columnIndex);
            ColumnStatisticsValidation.ColumnStatistics columnStatistics = actualColumnStatistics.get(columnIndex);
            long expectedValuesCount = columnMetaData.getValueCount();
            ParquetValidationUtils.validateParquet(expectedValuesCount == columnStatistics.valuesCount(), dataSourceId, "Invalid values count for column %s: Expected %d, found %d", columnIndex, expectedValuesCount, columnStatistics.valuesCount());
            Statistics parquetStatistics = columnMetaData.getStatistics();
            if (!parquetStatistics.isNumNullsSet()) continue;
            long expectedNullsCount = parquetStatistics.getNumNulls();
            ParquetValidationUtils.validateParquet(expectedNullsCount == columnStatistics.nonLeafValuesCount(), dataSourceId, "Invalid nulls count for column %s: Expected %d, found %d", columnIndex, expectedNullsCount, columnStatistics.nonLeafValuesCount());
        }
    }

    private static <T, U> void verifyColumnMetadataMatch(boolean condition, String name, T actual, ColumnPath path, int rowGroup, ParquetDataSourceId dataSourceId, U expected) throws ParquetCorruptionException {
        if (!condition) {
            throw new ParquetCorruptionException(dataSourceId, "%s [%s] for column %s in row group %d did not match [%s]", name, actual, path, rowGroup, expected);
        }
    }

    private static boolean areEncodingsSame(Set<org.apache.parquet.column.Encoding> actual, List<Encoding> expected) {
        return actual.equals(expected.stream().map(arg_0 -> ((ParquetMetadataConverter)METADATA_CONVERTER).getEncoding(arg_0)).collect(ImmutableSet.toImmutableSet()));
    }

    private static boolean areStatisticsSame(Statistics actual, org.apache.parquet.format.Statistics expected) {
        Statistics.Builder expectedStatsBuilder = Statistics.getBuilderForReading((PrimitiveType)actual.type());
        if (expected.isSetNull_count()) {
            expectedStatsBuilder.withNumNulls(expected.getNull_count());
        }
        if (expected.isSetMin_value()) {
            expectedStatsBuilder.withMin(expected.getMin_value());
        }
        if (expected.isSetMax_value()) {
            expectedStatsBuilder.withMax(expected.getMax_value());
        }
        return actual.equals((Object)expectedStatsBuilder.build());
    }

    private static void validateColumnDescriptorsSame(ColumnDescriptor actual, ColumnDescriptor expected, ParquetDataSourceId dataSourceId) throws ParquetCorruptionException {
        ParquetValidationUtils.validateParquet(Arrays.equals(actual.getPath(), Arrays.stream(expected.getPath()).map(field -> field.toLowerCase(Locale.ENGLISH)).toArray()), dataSourceId, "Column path %s did not match expected column path %s", actual.getPath(), expected.getPath());
        ParquetValidationUtils.validateParquet(actual.getMaxDefinitionLevel() == expected.getMaxDefinitionLevel(), dataSourceId, "Column %s max definition level %d did not match expected max definition level %d", actual.getPath(), actual.getMaxDefinitionLevel(), expected.getMaxDefinitionLevel());
        ParquetValidationUtils.validateParquet(actual.getMaxRepetitionLevel() == expected.getMaxRepetitionLevel(), dataSourceId, "Column %s max repetition level %d did not match expected max repetition level %d", actual.getPath(), actual.getMaxRepetitionLevel(), expected.getMaxRepetitionLevel());
        PrimitiveType actualPrimitiveType = actual.getPrimitiveType();
        PrimitiveType expectedPrimitiveType = expected.getPrimitiveType();
        ParquetValidationUtils.validateParquet(actualPrimitiveType.getPrimitiveTypeName().equals((Object)expectedPrimitiveType.getPrimitiveTypeName()) && actualPrimitiveType.getTypeLength() == expectedPrimitiveType.getTypeLength() && actualPrimitiveType.getRepetition().equals((Object)expectedPrimitiveType.getRepetition()) && actualPrimitiveType.getName().equals(expectedPrimitiveType.getName().toLowerCase(Locale.ENGLISH)) && Objects.equals(actualPrimitiveType.getLogicalTypeAnnotation(), expectedPrimitiveType.getLogicalTypeAnnotation()), dataSourceId, "Column %s primitive type %s did not match expected primitive type %s", actual.getPath(), actualPrimitiveType, expectedPrimitiveType);
    }

    private static long estimatedSizeOfStringArray(String[] path) {
        long size = SizeOf.sizeOf((Object[])path);
        for (String field : path) {
            size += SizeOf.estimatedSizeOf((String)field);
        }
        return size;
    }

    public record WriteChecksum(long totalRowCount, List<Long> columnHashes) {
        public WriteChecksum(long totalRowCount, List<Long> columnHashes) {
            this.totalRowCount = totalRowCount;
            this.columnHashes = ImmutableList.copyOf((Collection)Objects.requireNonNull(columnHashes, "columnHashes is null"));
        }
    }

    static class IndexReferenceValidation {
        private final long offset;
        private final int length;

        private IndexReferenceValidation(long offset, int length) {
            this.offset = offset;
            this.length = length;
        }

        static IndexReferenceValidation fromIndexReference(IndexReference indexReference) {
            return new IndexReferenceValidation(indexReference.getOffset(), indexReference.getLength());
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            IndexReferenceValidation that = (IndexReferenceValidation)o;
            return this.offset == that.offset && this.length == that.length;
        }

        public int hashCode() {
            return Objects.hash(this.offset, this.length);
        }

        public String toString() {
            return MoreObjects.toStringHelper((Object)this).add("offset", this.offset).add("length", this.length).toString();
        }
    }

    public static class ParquetWriteValidationBuilder {
        private static final int INSTANCE_SIZE = SizeOf.instanceSize(ParquetWriteValidationBuilder.class);
        private static final int COLUMN_DESCRIPTOR_INSTANCE_SIZE = SizeOf.instanceSize(ColumnDescriptor.class);
        private static final int PRIMITIVE_TYPE_INSTANCE_SIZE = SizeOf.instanceSize(PrimitiveType.class);
        private final List<Type> types;
        private final List<String> columnNames;
        private final WriteChecksumBuilder checksum;
        private String createdBy;
        private Optional<String> timeZoneId = Optional.empty();
        private List<ColumnDescriptor> columns;
        private List<RowGroup> rowGroups;
        private long retainedSize = INSTANCE_SIZE;

        public ParquetWriteValidationBuilder(List<Type> types, List<String> columnNames) {
            this.types = ImmutableList.copyOf((Collection)Objects.requireNonNull(types, "types is null"));
            this.columnNames = ImmutableList.copyOf((Collection)Objects.requireNonNull(columnNames, "columnNames is null"));
            Preconditions.checkArgument((types.size() == columnNames.size() ? 1 : 0) != 0, (String)"Types count %s did not match column names count %s", (int)types.size(), (int)columnNames.size());
            this.checksum = new WriteChecksumBuilder(types);
            this.retainedSize += SizeOf.estimatedSizeOf(types, type -> 0L) + SizeOf.estimatedSizeOf(columnNames, SizeOf::estimatedSizeOf);
        }

        public long getRetainedSize() {
            return this.retainedSize;
        }

        public void setCreatedBy(String createdBy) {
            this.createdBy = createdBy;
            this.retainedSize += SizeOf.estimatedSizeOf((String)createdBy);
        }

        public void setTimeZone(Optional<String> timeZoneId) {
            this.timeZoneId = timeZoneId;
            timeZoneId.ifPresent(id -> this.retainedSize += SizeOf.estimatedSizeOf((String)id));
        }

        public void setColumns(List<ColumnDescriptor> columns) {
            this.columns = ImmutableList.copyOf((Collection)Objects.requireNonNull(columns, "columns is null"));
            this.retainedSize += SizeOf.estimatedSizeOf(columns, descriptor -> (long)(COLUMN_DESCRIPTOR_INSTANCE_SIZE + 8) + ParquetWriteValidation.estimatedSizeOfStringArray(descriptor.getPath()) + (long)PRIMITIVE_TYPE_INSTANCE_SIZE + 12L);
        }

        public void setRowGroups(List<RowGroup> rowGroups) {
            this.rowGroups = ImmutableList.copyOf((Collection)Objects.requireNonNull(rowGroups, "rowGroups is null"));
        }

        public void addPage(Page page) {
            this.checksum.addPage(page);
        }

        public ParquetWriteValidation build() {
            return new ParquetWriteValidation(this.createdBy, this.timeZoneId, this.columns, this.rowGroups, this.checksum.build(), this.types, this.columnNames);
        }
    }

    public static class StatisticsValidation {
        private final List<Type> types;
        private List<ColumnStatisticsValidation> columnStatisticsValidations;

        private StatisticsValidation(List<Type> types) {
            this.types = Objects.requireNonNull(types, "types is null");
            this.columnStatisticsValidations = (List)types.stream().map(ColumnStatisticsValidation::new).collect(ImmutableList.toImmutableList());
        }

        public static StatisticsValidation createStatisticsValidationBuilder(List<Type> readTypes) {
            return new StatisticsValidation(readTypes);
        }

        public void addPage(Page page) {
            Objects.requireNonNull(page, "page is null");
            Preconditions.checkArgument((page.getChannelCount() == this.columnStatisticsValidations.size() ? 1 : 0) != 0, (String)"Invalid page: page channels count %s did not match columns count %s", (int)page.getChannelCount(), (int)this.columnStatisticsValidations.size());
            for (int channel = 0; channel < this.columnStatisticsValidations.size(); ++channel) {
                ColumnStatisticsValidation columnStatisticsValidation = this.columnStatisticsValidations.get(channel);
                columnStatisticsValidation.addBlock(page.getBlock(channel));
            }
        }

        public void reset() {
            this.columnStatisticsValidations = (List)this.types.stream().map(ColumnStatisticsValidation::new).collect(ImmutableList.toImmutableList());
        }

        public List<ColumnStatisticsValidation.ColumnStatistics> build() {
            return (List)this.columnStatisticsValidations.stream().flatMap(validation -> validation.build().stream()).collect(ImmutableList.toImmutableList());
        }
    }

    public static class WriteChecksumBuilder {
        private final List<ValidationHash> validationHashes;
        private final List<XxHash64> columnHashes;
        private final byte[] longBuffer = new byte[8];
        private final Slice longSlice = Slices.wrappedBuffer((byte[])this.longBuffer);
        private long totalRowCount;

        private WriteChecksumBuilder(List<Type> types) {
            this.validationHashes = (List)Objects.requireNonNull(types, "types is null").stream().map(ValidationHash::createValidationHash).collect(ImmutableList.toImmutableList());
            ImmutableList.Builder columnHashes = ImmutableList.builder();
            for (Type ignored : types) {
                columnHashes.add((Object)new XxHash64());
            }
            this.columnHashes = columnHashes.build();
        }

        public static WriteChecksumBuilder createWriteChecksumBuilder(List<Type> readTypes) {
            return new WriteChecksumBuilder(readTypes);
        }

        public void addPage(Page page) {
            Objects.requireNonNull(page, "page is null");
            Preconditions.checkArgument((page.getChannelCount() == this.columnHashes.size() ? 1 : 0) != 0, (String)"Invalid page: page channels count %s did not match columns count %s", (int)page.getChannelCount(), (int)this.columnHashes.size());
            for (int channel = 0; channel < this.columnHashes.size(); ++channel) {
                ValidationHash validationHash = this.validationHashes.get(channel);
                Block block = page.getBlock(channel);
                XxHash64 xxHash64 = this.columnHashes.get(channel);
                for (int position = 0; position < block.getPositionCount(); ++position) {
                    long hash = validationHash.hash(block, position);
                    this.longSlice.setLong(0, hash);
                    xxHash64.update(this.longBuffer);
                }
            }
            this.totalRowCount += (long)page.getPositionCount();
        }

        public WriteChecksum build() {
            return new WriteChecksum(this.totalRowCount, (List)this.columnHashes.stream().map(XxHash64::hash).collect(ImmutableList.toImmutableList()));
        }
    }
}

