/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iceberg.spark.source;

import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.net.URI;
import java.sql.Timestamp;
import java.time.Instant;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.RawLocalFileSystem;
import org.apache.iceberg.PartitionSpec;
import org.apache.iceberg.Schema;
import org.apache.iceberg.Table;
import org.apache.iceberg.expressions.Literal;
import org.apache.iceberg.hadoop.HadoopTables;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.relocated.com.google.common.collect.Sets;
import org.apache.iceberg.spark.SparkSchemaUtil;
import org.apache.iceberg.spark.source.LogMessage;
import org.apache.iceberg.transforms.Transform;
import org.apache.iceberg.transforms.Transforms;
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.Types;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
import org.apache.spark.sql.catalyst.util.DateTimeUtils;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.unsafe.types.UTF8String;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(value=Parameterized.class)
public class TestPartitionPruning {
    private static final Configuration CONF = new Configuration();
    private static final HadoopTables TABLES = new HadoopTables(CONF);
    private final String format;
    private final boolean vectorized;
    private static SparkSession spark = null;
    private static JavaSparkContext sparkContext = null;
    private static Transform<Object, Integer> bucketTransform = Transforms.bucket((Type)Types.IntegerType.get(), (int)3);
    private static Transform<Object, Object> truncateTransform = Transforms.truncate((Type)Types.StringType.get(), (int)5);
    private static Transform<Object, Integer> hourTransform = Transforms.hour((Type)Types.TimestampType.withoutZone());
    private static final Schema LOG_SCHEMA = new Schema(new Types.NestedField[]{Types.NestedField.optional((int)1, (String)"id", (Type)Types.IntegerType.get()), Types.NestedField.optional((int)2, (String)"date", (Type)Types.StringType.get()), Types.NestedField.optional((int)3, (String)"level", (Type)Types.StringType.get()), Types.NestedField.optional((int)4, (String)"message", (Type)Types.StringType.get()), Types.NestedField.optional((int)5, (String)"timestamp", (Type)Types.TimestampType.withZone())});
    private static final List<LogMessage> LOGS = ImmutableList.of((Object)LogMessage.debug("2020-02-02", "debug event 1", TestPartitionPruning.getInstant("2020-02-02T00:00:00")), (Object)LogMessage.info("2020-02-02", "info event 1", TestPartitionPruning.getInstant("2020-02-02T01:00:00")), (Object)LogMessage.debug("2020-02-02", "debug event 2", TestPartitionPruning.getInstant("2020-02-02T02:00:00")), (Object)LogMessage.info("2020-02-03", "info event 2", TestPartitionPruning.getInstant("2020-02-03T00:00:00")), (Object)LogMessage.debug("2020-02-03", "debug event 3", TestPartitionPruning.getInstant("2020-02-03T01:00:00")), (Object)LogMessage.info("2020-02-03", "info event 3", TestPartitionPruning.getInstant("2020-02-03T02:00:00")), (Object)LogMessage.error("2020-02-03", "error event 1", TestPartitionPruning.getInstant("2020-02-03T03:00:00")), (Object)LogMessage.debug("2020-02-04", "debug event 4", TestPartitionPruning.getInstant("2020-02-04T01:00:00")), (Object)LogMessage.warn("2020-02-04", "warn event 1", TestPartitionPruning.getInstant("2020-02-04T02:00:00")), (Object)LogMessage.debug("2020-02-04", "debug event 5", TestPartitionPruning.getInstant("2020-02-04T03:00:00")));
    @Rule
    public TemporaryFolder temp = new TemporaryFolder();
    private PartitionSpec spec = PartitionSpec.builderFor((Schema)LOG_SCHEMA).identity("date").identity("level").bucket("id", 3).truncate("message", 5).hour("timestamp").build();

    @Parameterized.Parameters(name="format = {0}, vectorized = {1}")
    public static Object[][] parameters() {
        return new Object[][]{{"parquet", false}, {"parquet", true}, {"avro", false}, {"orc", false}, {"orc", true}};
    }

    public TestPartitionPruning(String format, boolean vectorized) {
        this.format = format;
        this.vectorized = vectorized;
    }

    @BeforeClass
    public static void startSpark() {
        spark = SparkSession.builder().master("local[2]").getOrCreate();
        sparkContext = JavaSparkContext.fromSparkContext((SparkContext)spark.sparkContext());
        String optionKey = String.format("fs.%s.impl", CountOpenLocalFileSystem.scheme);
        CONF.set(optionKey, CountOpenLocalFileSystem.class.getName());
        spark.conf().set(optionKey, CountOpenLocalFileSystem.class.getName());
        spark.conf().set("spark.sql.session.timeZone", "UTC");
        spark.udf().register("bucket3", (UDF1 & Serializable)num -> (Integer)bucketTransform.apply(num), DataTypes.IntegerType);
        spark.udf().register("truncate5", (UDF1 & Serializable)str -> truncateTransform.apply(str), DataTypes.StringType);
        spark.udf().register("hour", (UDF1 & Serializable)ts -> (Integer)hourTransform.apply((Object)DateTimeUtils.fromJavaTimestamp((Timestamp)ts)), DataTypes.IntegerType);
    }

    @AfterClass
    public static void stopSpark() {
        SparkSession currentSpark = spark;
        spark = null;
        currentSpark.stop();
    }

    private static Instant getInstant(String timestampWithoutZone) {
        Long epochMicros = (Long)Literal.of((CharSequence)timestampWithoutZone).to((Type)Types.TimestampType.withoutZone()).value();
        return Instant.ofEpochMilli(TimeUnit.MICROSECONDS.toMillis(epochMicros));
    }

    @Test
    public void testPartitionPruningIdentityString() {
        String filterCond = "date >= '2020-02-03' AND level = 'DEBUG'";
        Predicate<Row> partCondition = r -> {
            String date = r.getString(0);
            String level = r.getString(1);
            return date.compareTo("2020-02-03") >= 0 && level.equals("DEBUG");
        };
        this.runTest(filterCond, partCondition);
    }

    @Test
    public void testPartitionPruningBucketingInteger() {
        int[] ids = new int[]{LOGS.get(3).getId(), LOGS.get(7).getId()};
        String condForIds = Arrays.stream(ids).mapToObj(String::valueOf).collect(Collectors.joining(",", "(", ")"));
        String filterCond = "id in " + condForIds;
        Predicate<Row> partCondition = r -> {
            int bucketId = r.getInt(2);
            Set buckets = Arrays.stream(ids).map(arg_0 -> bucketTransform.apply(arg_0)).boxed().collect(Collectors.toSet());
            return buckets.contains(bucketId);
        };
        this.runTest(filterCond, partCondition);
    }

    @Test
    public void testPartitionPruningTruncatedString() {
        String filterCond = "message like 'info event%'";
        Predicate<Row> partCondition = r -> {
            String truncatedMessage = r.getString(3);
            return truncatedMessage.equals("info ");
        };
        this.runTest(filterCond, partCondition);
    }

    @Test
    public void testPartitionPruningTruncatedStringComparingValueShorterThanPartitionValue() {
        String filterCond = "message like 'inf%'";
        Predicate<Row> partCondition = r -> {
            String truncatedMessage = r.getString(3);
            return truncatedMessage.startsWith("inf");
        };
        this.runTest(filterCond, partCondition);
    }

    @Test
    public void testPartitionPruningHourlyPartition() {
        String filterCond = spark.version().startsWith("2") ? "timestamp >= to_timestamp('2020-02-03T01:00:00')" : "timestamp >= '2020-02-03T01:00:00'";
        Predicate<Row> partCondition = r -> {
            Instant instant;
            Integer hourValueToFilter;
            int hourValue = r.getInt(4);
            return hourValue >= (hourValueToFilter = (Integer)hourTransform.apply((Object)TimeUnit.MILLISECONDS.toMicros((instant = TestPartitionPruning.getInstant("2020-02-03T01:00:00")).toEpochMilli())));
        };
        this.runTest(filterCond, partCondition);
    }

    private void runTest(String filterCond, Predicate<Row> partCondition) {
        File originTableLocation = this.createTempDir();
        Assert.assertTrue((String)"Temp folder should exist", (boolean)originTableLocation.exists());
        Table table = this.createTable(originTableLocation);
        Dataset<Row> logs = this.createTestDataset();
        this.saveTestDatasetToTable(logs, table);
        List expected = logs.select("id", new String[]{"date", "level", "message", "timestamp"}).filter(filterCond).orderBy("id", new String[0]).collectAsList();
        Assert.assertFalse((String)"Expected rows should be not empty", (boolean)expected.isEmpty());
        CountOpenLocalFileSystem.resetRecordsInPathPrefix(originTableLocation.getAbsolutePath());
        List actual = spark.read().format("iceberg").option("vectorization-enabled", String.valueOf(this.vectorized)).load(table.location()).select("id", new String[]{"date", "level", "message", "timestamp"}).filter(filterCond).orderBy("id", new String[0]).collectAsList();
        Assert.assertFalse((String)"Actual rows should not be empty", (boolean)actual.isEmpty());
        Assert.assertEquals((String)"Rows should match", (Object)expected, (Object)actual);
        this.assertAccessOnDataFiles(originTableLocation, table, partCondition);
    }

    private File createTempDir() {
        try {
            return this.temp.newFolder();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private Table createTable(File originTableLocation) {
        String trackedTableLocation = CountOpenLocalFileSystem.convertPath(originTableLocation);
        ImmutableMap properties = ImmutableMap.of((Object)"write.format.default", (Object)this.format);
        return TABLES.create(LOG_SCHEMA, this.spec, (Map)properties, trackedTableLocation);
    }

    private Dataset<Row> createTestDataset() {
        List rows = LOGS.stream().map(logMessage -> {
            Object[] underlying = new Object[]{logMessage.getId(), UTF8String.fromString((String)logMessage.getDate()), UTF8String.fromString((String)logMessage.getLevel()), UTF8String.fromString((String)logMessage.getMessage()), TimeUnit.MILLISECONDS.toMicros(logMessage.getTimestamp().toEpochMilli())};
            return new GenericInternalRow(underlying);
        }).collect(Collectors.toList());
        JavaRDD rdd = sparkContext.parallelize(rows);
        Dataset df = spark.internalCreateDataFrame(JavaRDD.toRDD((JavaRDD)rdd), SparkSchemaUtil.convert((Schema)LOG_SCHEMA), false);
        return df.selectExpr(new String[]{"id", "date", "level", "message", "timestamp"}).selectExpr(new String[]{"id", "date", "level", "message", "timestamp", "bucket3(id) AS bucket_id", "truncate5(message) AS truncated_message", "hour(timestamp) AS ts_hour"});
    }

    private void saveTestDatasetToTable(Dataset<Row> logs, Table table) {
        logs.orderBy("date", new String[]{"level", "bucket_id", "truncated_message", "ts_hour"}).select("id", new String[]{"date", "level", "message", "timestamp"}).write().format("iceberg").mode("append").save(table.location());
    }

    private void assertAccessOnDataFiles(File originTableLocation, Table table, Predicate<Row> partCondition) {
        Set readFilesInQuery = CountOpenLocalFileSystem.pathToNumOpenCalled.keySet().stream().filter(path -> path.startsWith(originTableLocation.getAbsolutePath())).collect(Collectors.toSet());
        List files = spark.read().format("iceberg").load(table.location() + "#files").collectAsList();
        Set<String> filesToRead = this.extractFilePathsMatchingConditionOnPartition(files, partCondition);
        Set<String> filesToNotRead = this.extractFilePathsNotIn(files, filesToRead);
        Assert.assertTrue((boolean)Sets.intersection(filesToRead, filesToNotRead).isEmpty());
        Assert.assertFalse((String)"The query should prune some data files.", (boolean)filesToNotRead.isEmpty());
        Assert.assertFalse((String)("Some of data files in partition range should be read. Read files in query: " + readFilesInQuery + " / data files in partition range: " + filesToRead), (boolean)Sets.intersection(filesToRead, readFilesInQuery).isEmpty());
        Assert.assertTrue((String)("Data files outside of partition range should not be read. Read files in query: " + readFilesInQuery + " / data files outside of partition range: " + filesToNotRead), (boolean)Sets.intersection(filesToNotRead, readFilesInQuery).isEmpty());
    }

    private Set<String> extractFilePathsMatchingConditionOnPartition(List<Row> files, Predicate<Row> condition) {
        return files.stream().filter(r -> {
            Row partition = r.getStruct(4);
            return condition.test(partition);
        }).map(r -> CountOpenLocalFileSystem.stripScheme(r.getString(1))).collect(Collectors.toSet());
    }

    private Set<String> extractFilePathsNotIn(List<Row> files, Set<String> filePaths) {
        Set allFilePaths = files.stream().map(r -> CountOpenLocalFileSystem.stripScheme(r.getString(1))).collect(Collectors.toSet());
        return Sets.newHashSet((Iterable)Sets.symmetricDifference(allFilePaths, filePaths));
    }

    public static class CountOpenLocalFileSystem
    extends RawLocalFileSystem {
        public static String scheme = String.format("TestIdentityPartitionData%dfs", new Random().nextInt());
        public static Map<String, Long> pathToNumOpenCalled = Maps.newConcurrentMap();

        public static String convertPath(String absPath) {
            return scheme + "://" + absPath;
        }

        public static String convertPath(File file) {
            return CountOpenLocalFileSystem.convertPath(file.getAbsolutePath());
        }

        public static String stripScheme(String pathWithScheme) {
            if (!pathWithScheme.startsWith(scheme + ":")) {
                throw new IllegalArgumentException("Received unexpected path: " + pathWithScheme);
            }
            int idxToCut = scheme.length() + 1;
            while (pathWithScheme.charAt(idxToCut) == '/') {
                ++idxToCut;
            }
            return pathWithScheme.substring(--idxToCut);
        }

        public static void resetRecordsInPathPrefix(String pathPrefix) {
            pathToNumOpenCalled.keySet().stream().filter(p -> p.startsWith(pathPrefix)).forEach(key -> pathToNumOpenCalled.remove(key));
        }

        public URI getUri() {
            return URI.create(scheme + ":///");
        }

        public String getScheme() {
            return scheme;
        }

        public FSDataInputStream open(Path f, int bufferSize) throws IOException {
            String path = f.toUri().getPath();
            pathToNumOpenCalled.compute(path, (ignored, v) -> {
                if (v == null) {
                    return 1L;
                }
                return v + 1L;
            });
            return super.open(f, bufferSize);
        }
    }
}

