/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iceberg.spark.source;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.iceberg.PartitionField;
import org.apache.iceberg.PartitionScanTask;
import org.apache.iceberg.PartitionSpec;
import org.apache.iceberg.Partitioning;
import org.apache.iceberg.Scan;
import org.apache.iceberg.ScanTask;
import org.apache.iceberg.ScanTaskGroup;
import org.apache.iceberg.Schema;
import org.apache.iceberg.Table;
import org.apache.iceberg.exceptions.ValidationException;
import org.apache.iceberg.expressions.Expression;
import org.apache.iceberg.io.CloseableIterable;
import org.apache.iceberg.metrics.ScanReport;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.spark.Spark3Util;
import org.apache.iceberg.spark.SparkReadConf;
import org.apache.iceberg.spark.source.SparkScan;
import org.apache.iceberg.types.Types;
import org.apache.iceberg.util.SnapshotUtil;
import org.apache.iceberg.util.StructLikeSet;
import org.apache.iceberg.util.TableScanUtil;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.connector.expressions.Transform;
import org.apache.spark.sql.connector.read.SupportsReportPartitioning;
import org.apache.spark.sql.connector.read.partitioning.KeyGroupedPartitioning;
import org.apache.spark.sql.connector.read.partitioning.UnknownPartitioning;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

abstract class SparkPartitioningAwareScan<T extends PartitionScanTask>
extends SparkScan
implements SupportsReportPartitioning {
    private static final Logger LOG = LoggerFactory.getLogger(SparkPartitioningAwareScan.class);
    private final Scan<?, ? extends ScanTask, ? extends ScanTaskGroup<?>> scan;
    private final boolean preserveDataGrouping;
    private Set<PartitionSpec> specs = null;
    private List<T> tasks = null;
    private List<ScanTaskGroup<T>> taskGroups = null;
    private Types.StructType groupingKeyType = null;
    private Transform[] groupingKeyTransforms = null;

    SparkPartitioningAwareScan(SparkSession spark, Table table, Scan<?, ? extends ScanTask, ? extends ScanTaskGroup<?>> scan, SparkReadConf readConf, Schema expectedSchema, List<Expression> filters, Supplier<ScanReport> scanReportSupplier) {
        super(spark, table, readConf, expectedSchema, filters, scanReportSupplier);
        this.scan = scan;
        this.preserveDataGrouping = readConf.preserveDataGrouping();
        if (scan == null) {
            this.specs = Collections.emptySet();
            this.tasks = Collections.emptyList();
            this.taskGroups = Collections.emptyList();
        }
    }

    protected abstract Class<T> taskJavaClass();

    protected Scan<?, ? extends ScanTask, ? extends ScanTaskGroup<?>> scan() {
        return this.scan;
    }

    public org.apache.spark.sql.connector.read.partitioning.Partitioning outputPartitioning() {
        if (this.groupingKeyType().fields().isEmpty()) {
            LOG.info("Reporting UnknownPartitioning with {} partition(s) for table {}", (Object)this.taskGroups().size(), (Object)this.table().name());
            return new UnknownPartitioning(this.taskGroups().size());
        }
        LOG.info("Reporting KeyGroupedPartitioning by {} with {} partition(s) for table {}", new Object[]{this.groupingKeyTransforms(), this.taskGroups().size(), this.table().name()});
        return new KeyGroupedPartitioning((org.apache.spark.sql.connector.expressions.Expression[])this.groupingKeyTransforms(), this.taskGroups().size());
    }

    @Override
    protected Types.StructType groupingKeyType() {
        if (this.groupingKeyType == null) {
            this.groupingKeyType = this.preserveDataGrouping ? this.computeGroupingKeyType() : Types.StructType.of((Types.NestedField[])new Types.NestedField[0]);
        }
        return this.groupingKeyType;
    }

    private Types.StructType computeGroupingKeyType() {
        return Partitioning.groupingKeyType((Schema)this.expectedSchema(), this.specs());
    }

    private Transform[] groupingKeyTransforms() {
        if (this.groupingKeyTransforms == null) {
            Map<Integer, PartitionField> fieldsById = this.indexFieldsById(this.specs());
            List<PartitionField> groupingKeyFields = this.groupingKeyType().fields().stream().map(field -> (PartitionField)fieldsById.get(field.fieldId())).collect(Collectors.toList());
            Schema schema = SnapshotUtil.schemaFor((Table)this.table(), (String)this.branch());
            this.groupingKeyTransforms = Spark3Util.toTransforms(schema, groupingKeyFields);
        }
        return this.groupingKeyTransforms;
    }

    private Map<Integer, PartitionField> indexFieldsById(Iterable<PartitionSpec> specIterable) {
        HashMap fieldsById = Maps.newHashMap();
        for (PartitionSpec spec : specIterable) {
            for (PartitionField field : spec.fields()) {
                fieldsById.putIfAbsent(field.fieldId(), field);
            }
        }
        return fieldsById;
    }

    protected Set<PartitionSpec> specs() {
        if (this.specs == null) {
            IntStream specIds = this.tasks().stream().mapToInt(task -> task.spec().specId()).distinct();
            this.specs = specIds.mapToObj(id -> (PartitionSpec)this.table().specs().get(id)).collect(Collectors.toSet());
        }
        return this.specs;
    }

    protected synchronized List<T> tasks() {
        if (this.tasks == null) {
            try (CloseableIterable taskIterable = this.scan.planFiles();){
                ArrayList plannedTasks = Lists.newArrayList();
                for (ScanTask task : taskIterable) {
                    ValidationException.check((boolean)this.taskJavaClass().isInstance(task), (String)"Unsupported task type, expected a subtype of %s: %", (Object[])new Object[]{this.taskJavaClass().getName(), task.getClass().getName()});
                    plannedTasks.add(this.taskJavaClass().cast(task));
                }
                this.tasks = plannedTasks;
            }
            catch (IOException e) {
                throw new UncheckedIOException("Failed to close scan: " + this.scan, e);
            }
        }
        return this.tasks;
    }

    protected synchronized List<ScanTaskGroup<T>> taskGroups() {
        if (this.taskGroups == null) {
            if (this.groupingKeyType().fields().isEmpty()) {
                CloseableIterable plannedTaskGroups = TableScanUtil.planTaskGroups((CloseableIterable)CloseableIterable.withNoopClose(this.tasks()), (long)this.adjustSplitSize(this.tasks(), this.scan.targetSplitSize()), (int)this.scan.splitLookback(), (long)this.scan.splitOpenFileCost());
                this.taskGroups = Lists.newArrayList((Iterable)plannedTaskGroups);
                LOG.debug("Planned {} task group(s) without data grouping for table {}", (Object)this.taskGroups.size(), (Object)this.table().name());
            } else {
                List plannedTaskGroups = TableScanUtil.planTaskGroups(this.tasks(), (long)this.adjustSplitSize(this.tasks(), this.scan.targetSplitSize()), (int)this.scan.splitLookback(), (long)this.scan.splitOpenFileCost(), (Types.StructType)this.groupingKeyType());
                StructLikeSet plannedGroupingKeys = this.collectGroupingKeys(plannedTaskGroups);
                LOG.debug("Planned {} task group(s) with {} grouping key type and {} unique grouping key(s) for table {}", new Object[]{plannedTaskGroups.size(), this.groupingKeyType(), plannedGroupingKeys.size(), this.table().name()});
                this.taskGroups = plannedTaskGroups;
            }
        }
        return this.taskGroups;
    }

    protected void resetTasks(List<T> filteredTasks) {
        this.taskGroups = null;
        this.tasks = filteredTasks;
    }

    private StructLikeSet collectGroupingKeys(Iterable<ScanTaskGroup<T>> taskGroupIterable) {
        StructLikeSet keys = StructLikeSet.create((Types.StructType)this.groupingKeyType());
        for (ScanTaskGroup<T> taskGroup : taskGroupIterable) {
            keys.add(taskGroup.groupingKey());
        }
        return keys;
    }
}

