/*
 * Decompiled with CFR 0.152.
 */
package io.trino.plugin.kafka;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import io.trino.plugin.kafka.KafkaAdminFactory;
import io.trino.plugin.kafka.KafkaColumnHandle;
import io.trino.plugin.kafka.KafkaConsumerFactory;
import io.trino.plugin.kafka.KafkaErrorCode;
import io.trino.plugin.kafka.KafkaFilteringResult;
import io.trino.plugin.kafka.KafkaInternalFieldManager;
import io.trino.plugin.kafka.KafkaSessionProperties;
import io.trino.plugin.kafka.KafkaTableHandle;
import io.trino.plugin.kafka.Range;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.Ranges;
import io.trino.spi.predicate.SortedRangeSet;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.predicate.ValueSet;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.TimestampType;
import io.trino.spi.type.Type;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import javax.inject.Inject;
import org.apache.kafka.clients.admin.Admin;
import org.apache.kafka.clients.admin.Config;
import org.apache.kafka.clients.admin.DescribeConfigsResult;
import org.apache.kafka.clients.consumer.KafkaConsumer;
import org.apache.kafka.clients.consumer.OffsetAndTimestamp;
import org.apache.kafka.common.PartitionInfo;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.config.ConfigResource;

public class KafkaFilterManager {
    private static final long INVALID_KAFKA_RANGE_INDEX = -1L;
    private static final String TOPIC_CONFIG_TIMESTAMP_KEY = "message.timestamp.type";
    private static final String TOPIC_CONFIG_TIMESTAMP_VALUE_LOG_APPEND_TIME = "LogAppendTime";
    private final KafkaConsumerFactory consumerFactory;
    private final KafkaAdminFactory adminFactory;
    private final KafkaInternalFieldManager kafkaInternalFieldManager;

    @Inject
    public KafkaFilterManager(KafkaConsumerFactory consumerFactory, KafkaAdminFactory adminFactory, KafkaInternalFieldManager kafkaInternalFieldManager) {
        this.consumerFactory = Objects.requireNonNull(consumerFactory, "consumerFactory is null");
        this.adminFactory = Objects.requireNonNull(adminFactory, "adminFactory is null");
        this.kafkaInternalFieldManager = Objects.requireNonNull(kafkaInternalFieldManager, "kafkaInternalFieldManager is null");
    }

    public KafkaFilteringResult getKafkaFilterResult(ConnectorSession session, KafkaTableHandle kafkaTableHandle, List<PartitionInfo> partitionInfos, Map<TopicPartition, Long> partitionBeginOffsets, Map<TopicPartition, Long> partitionEndOffsets) {
        Objects.requireNonNull(session, "session is null");
        Objects.requireNonNull(kafkaTableHandle, "kafkaTableHandle is null");
        Objects.requireNonNull(partitionInfos, "partitionInfos is null");
        Objects.requireNonNull(partitionBeginOffsets, "partitionBeginOffsets is null");
        Objects.requireNonNull(partitionEndOffsets, "partitionEndOffsets is null");
        TupleDomain<ColumnHandle> constraint = kafkaTableHandle.getConstraint();
        Verify.verify((!constraint.isNone() ? 1 : 0) != 0, (String)"constraint is none", (Object[])new Object[0]);
        if (!constraint.isAll()) {
            Set partitionIds = (Set)partitionInfos.stream().map(partitionInfo -> partitionInfo.partition()).collect(ImmutableSet.toImmutableSet());
            Map domains = (Map)((Map)constraint.getDomains().orElseThrow()).entrySet().stream().collect(ImmutableMap.toImmutableMap(entry -> ((KafkaColumnHandle)entry.getKey()).getName(), Map.Entry::getValue));
            Optional offsetRanged = this.getDomain(KafkaInternalFieldManager.InternalFieldId.PARTITION_OFFSET_FIELD, domains).flatMap(KafkaFilterManager::filterRangeByDomain);
            Set partitionIdsFiltered = this.getDomain(KafkaInternalFieldManager.InternalFieldId.PARTITION_ID_FIELD, domains).map(domain -> KafkaFilterManager.filterValuesByDomain(domain, partitionIds)).orElse(partitionIds);
            Optional offsetTimestampRanged = this.getDomain(KafkaInternalFieldManager.InternalFieldId.OFFSET_TIMESTAMP_FIELD, domains).flatMap(KafkaFilterManager::filterRangeByDomain);
            if (offsetRanged.isPresent()) {
                Range range = (Range)offsetRanged.get();
                partitionBeginOffsets = KafkaFilterManager.overridePartitionBeginOffsets(partitionBeginOffsets, partition -> range.getBegin() != -1L ? Optional.of(range.getBegin()) : Optional.empty());
                partitionEndOffsets = KafkaFilterManager.overridePartitionEndOffsets(partitionEndOffsets, partition -> range.getEnd() != -1L ? Optional.of(range.getEnd()) : Optional.empty());
            }
            if (offsetTimestampRanged.isPresent()) {
                try (KafkaConsumer<byte[], byte[]> kafkaConsumer = this.consumerFactory.create(session);){
                    if (((Range)offsetTimestampRanged.get()).getBegin() > -1L) {
                        partitionBeginOffsets = KafkaFilterManager.overridePartitionBeginOffsets(partitionBeginOffsets, partition -> KafkaFilterManager.findOffsetsForTimestampGreaterOrEqual(kafkaConsumer, partition, ((Range)offsetTimestampRanged.get()).getBegin()));
                    }
                    if (this.isTimestampUpperBoundPushdownEnabled(session, kafkaTableHandle.getTopicName()) && ((Range)offsetTimestampRanged.get()).getEnd() > -1L) {
                        partitionEndOffsets = KafkaFilterManager.overridePartitionEndOffsets(partitionEndOffsets, partition -> KafkaFilterManager.findOffsetsForTimestampGreaterOrEqual(kafkaConsumer, partition, ((Range)offsetTimestampRanged.get()).getEnd()));
                    }
                }
            }
            List partitionFilteredInfos = (List)partitionInfos.stream().filter(partitionInfo -> partitionIdsFiltered.contains(partitionInfo.partition())).collect(ImmutableList.toImmutableList());
            return new KafkaFilteringResult(partitionFilteredInfos, partitionBeginOffsets, partitionEndOffsets);
        }
        return new KafkaFilteringResult(partitionInfos, partitionBeginOffsets, partitionEndOffsets);
    }

    private Optional<Domain> getDomain(KafkaInternalFieldManager.InternalFieldId internalFieldId, Map<String, Domain> columnNameToDomain) {
        String columnName = this.kafkaInternalFieldManager.getFieldById(internalFieldId).getColumnName();
        return Optional.ofNullable(columnNameToDomain.get(columnName));
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private boolean isTimestampUpperBoundPushdownEnabled(ConnectorSession session, String topic) {
        try (Admin adminClient = this.adminFactory.create(session);){
            ConfigResource topicResource = new ConfigResource(ConfigResource.Type.TOPIC, topic);
            DescribeConfigsResult describeResult = adminClient.describeConfigs(Collections.singleton(topicResource));
            Map configMap = (Map)describeResult.all().get();
            if (configMap == null) return KafkaSessionProperties.isTimestampUpperBoundPushdownEnabled(session);
            Config config = (Config)configMap.get(topicResource);
            String timestampType = config.get(TOPIC_CONFIG_TIMESTAMP_KEY).value();
            if (!TOPIC_CONFIG_TIMESTAMP_VALUE_LOG_APPEND_TIME.equals(timestampType)) return KafkaSessionProperties.isTimestampUpperBoundPushdownEnabled(session);
            boolean bl = true;
            return bl;
        }
        catch (Exception e) {
            throw new TrinoException((ErrorCodeSupplier)KafkaErrorCode.KAFKA_SPLIT_ERROR, String.format("Failed to get configuration for topic '%s'", topic), (Throwable)e);
        }
    }

    private static Optional<Long> findOffsetsForTimestampGreaterOrEqual(KafkaConsumer<byte[], byte[]> kafkaConsumer, TopicPartition topicPartition, long timestamp) {
        long transferTimestamp = Math.floorDiv(timestamp, 1000);
        Map topicPartitionOffsets = kafkaConsumer.offsetsForTimes((Map)ImmutableMap.of((Object)topicPartition, (Object)transferTimestamp));
        return Optional.ofNullable((OffsetAndTimestamp)Iterables.getOnlyElement(topicPartitionOffsets.values(), null)).map(OffsetAndTimestamp::offset);
    }

    private static Map<TopicPartition, Long> overridePartitionBeginOffsets(Map<TopicPartition, Long> partitionBeginOffsets, Function<TopicPartition, Optional<Long>> overrideFunction) {
        ImmutableMap.Builder partitionFilteredBeginOffsetsBuilder = ImmutableMap.builder();
        partitionBeginOffsets.forEach((partition, partitionIndex) -> {
            Optional newOffset = (Optional)overrideFunction.apply((TopicPartition)partition);
            partitionFilteredBeginOffsetsBuilder.put(partition, (Object)newOffset.map(index -> Long.max(partitionIndex, index)).orElse((Long)partitionIndex));
        });
        return partitionFilteredBeginOffsetsBuilder.buildOrThrow();
    }

    private static Map<TopicPartition, Long> overridePartitionEndOffsets(Map<TopicPartition, Long> partitionEndOffsets, Function<TopicPartition, Optional<Long>> overrideFunction) {
        ImmutableMap.Builder partitionFilteredEndOffsetsBuilder = ImmutableMap.builder();
        partitionEndOffsets.forEach((partition, partitionIndex) -> {
            Optional newOffset = (Optional)overrideFunction.apply((TopicPartition)partition);
            partitionFilteredEndOffsetsBuilder.put(partition, (Object)newOffset.map(index -> Long.min(partitionIndex, index)).orElse((Long)partitionIndex));
        });
        return partitionFilteredEndOffsetsBuilder.buildOrThrow();
    }

    @VisibleForTesting
    public static Optional<Range> filterRangeByDomain(Domain domain) {
        Long low = -1L;
        Long high = -1L;
        if (domain.isSingleValue()) {
            low = (long)((Long)domain.getSingleValue());
            high = (long)((Long)domain.getSingleValue());
        } else {
            ValueSet valueSet = domain.getValues();
            if (valueSet instanceof SortedRangeSet) {
                Ranges ranges = ((SortedRangeSet)valueSet).getRanges();
                List rangeList = ranges.getOrderedRanges();
                if (rangeList.stream().allMatch(io.trino.spi.predicate.Range::isSingleValue)) {
                    List values = (List)rangeList.stream().map(range -> (Long)range.getSingleValue()).collect(ImmutableList.toImmutableList());
                    low = (Long)Collections.min(values);
                    high = (Long)Collections.max(values);
                } else {
                    io.trino.spi.predicate.Range span = ranges.getSpan();
                    low = KafkaFilterManager.getLowIncludedValue(span).orElse(low);
                    high = KafkaFilterManager.getHighIncludedValue(span).orElse(high);
                }
            }
        }
        if (high != -1L) {
            high = high + 1L;
        }
        return Optional.of(new Range(low, high));
    }

    @VisibleForTesting
    public static Set<Long> filterValuesByDomain(Domain domain, Set<Long> sourceValues) {
        Objects.requireNonNull(sourceValues, "sourceValues is none");
        if (domain.isSingleValue()) {
            long singleValue = (Long)domain.getSingleValue();
            return (Set)sourceValues.stream().filter(sourceValue -> sourceValue == singleValue).collect(ImmutableSet.toImmutableSet());
        }
        ValueSet valueSet = domain.getValues();
        if (valueSet instanceof SortedRangeSet) {
            Ranges ranges = ((SortedRangeSet)valueSet).getRanges();
            List rangeList = ranges.getOrderedRanges();
            if (rangeList.stream().allMatch(io.trino.spi.predicate.Range::isSingleValue)) {
                return (Set)rangeList.stream().map(range -> (Long)range.getSingleValue()).filter(sourceValues::contains).collect(ImmutableSet.toImmutableSet());
            }
            io.trino.spi.predicate.Range span = ranges.getSpan();
            long low = KafkaFilterManager.getLowIncludedValue(span).orElse(0L);
            long high = KafkaFilterManager.getHighIncludedValue(span).orElse(Long.MAX_VALUE);
            return (Set)sourceValues.stream().filter(item -> item >= low && item <= high).collect(ImmutableSet.toImmutableSet());
        }
        return sourceValues;
    }

    private static Optional<Long> getLowIncludedValue(io.trino.spi.predicate.Range range) {
        long step = KafkaFilterManager.nativeRepresentationGranularity(range.getType());
        return range.getLowValue().map(Long.class::cast).map(value -> range.isLowInclusive() ? value : value + step);
    }

    private static Optional<Long> getHighIncludedValue(io.trino.spi.predicate.Range range) {
        long step = KafkaFilterManager.nativeRepresentationGranularity(range.getType());
        return range.getHighValue().map(Long.class::cast).map(value -> range.isHighInclusive() ? value : value - step);
    }

    private static long nativeRepresentationGranularity(Type type) {
        if (type == BigintType.BIGINT) {
            return 1L;
        }
        if (type instanceof TimestampType && ((TimestampType)type).getPrecision() == 3) {
            return 1000L;
        }
        throw new IllegalArgumentException("Unsupported type: " + type);
    }
}

