/*
 * Decompiled with CFR 0.152.
 */
package org.apache.paimon.flink.shuffle;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Random;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.functions.OpenContext;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.dag.Transformation;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.typeutils.ListTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.InputSelectable;
import org.apache.flink.streaming.api.operators.InputSelection;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.StreamMap;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
import org.apache.flink.streaming.api.transformations.PartitionTransformation;
import org.apache.flink.streaming.api.transformations.StreamExchangeMode;
import org.apache.flink.streaming.api.transformations.TwoInputTransformation;
import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner;
import org.apache.flink.streaming.runtime.partitioner.CustomPartitionerWrapper;
import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.runtime.operators.TableStreamOperator;
import org.apache.flink.table.runtime.util.StreamRecordCollector;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.XORShiftRandom;
import org.apache.paimon.annotation.VisibleForTesting;
import org.apache.paimon.data.DataGetters;
import org.apache.paimon.flink.FlinkRowWrapper;
import org.apache.paimon.types.InternalRowToSizeVisitor;
import org.apache.paimon.types.RowType;
import org.apache.paimon.utils.Pair;
import org.apache.paimon.utils.SerializableSupplier;

public class RangeShuffle {
    public static <T> DataStream<Tuple2<T, RowData>> rangeShuffleByKey(DataStream<Tuple2<T, RowData>> inputDataStream, SerializableSupplier<Comparator<T>> keyComparator, TypeInformation<T> keyTypeInformation, int localSampleSize, int globalSampleSize, int rangeNum, int outParallelism, RowType valueRowType, boolean isSortBySize) {
        Transformation input = inputDataStream.getTransformation();
        OneInputTransformation keyInput = new OneInputTransformation(input, "ABSTRACT KEY AND SIZE", (OneInputStreamOperator)new StreamMap(new KeyAndSizeExtractor(valueRowType, isSortBySize)), (TypeInformation)new TupleTypeInfo(new TypeInformation[]{keyTypeInformation, BasicTypeInfo.INT_TYPE_INFO}), input.getParallelism(), input.isParallelismConfigured());
        OneInputTransformation localSample = new OneInputTransformation((Transformation)keyInput, "LOCAL SAMPLE", new LocalSampleOperator(localSampleSize), (TypeInformation)new TupleTypeInfo(new TypeInformation[]{BasicTypeInfo.DOUBLE_TYPE_INFO, keyTypeInformation, BasicTypeInfo.INT_TYPE_INFO}), keyInput.getParallelism(), keyInput.isParallelismConfigured());
        OneInputTransformation sampleAndHistogram = new OneInputTransformation((Transformation)localSample, "GLOBAL SAMPLE", new GlobalSampleOperator<T>(globalSampleSize, keyComparator, rangeNum), (TypeInformation)new ListTypeInfo(keyTypeInformation), 1, true);
        TwoInputTransformation preparePartition = new TwoInputTransformation((Transformation)new PartitionTransformation((Transformation)sampleAndHistogram, (StreamPartitioner)new BroadcastPartitioner(), StreamExchangeMode.BATCH), (Transformation)new PartitionTransformation(input, (StreamPartitioner)new ForwardPartitioner(), StreamExchangeMode.BATCH), "ASSIGN RANGE INDEX", new AssignRangeIndexOperator<T>(keyComparator), (TypeInformation)new TupleTypeInfo(new TypeInformation[]{BasicTypeInfo.INT_TYPE_INFO, input.getOutputType()}), input.getParallelism(), input.isParallelismConfigured());
        return new DataStream(inputDataStream.getExecutionEnvironment(), (Transformation)new OneInputTransformation((Transformation)new PartitionTransformation((Transformation)preparePartition, (StreamPartitioner)new CustomPartitionerWrapper((Partitioner)new AssignRangeIndexOperator.RangePartitioner(rangeNum), new AssignRangeIndexOperator.Tuple2KeySelector()), StreamExchangeMode.BATCH), "REMOVE RANGE INDEX", new RemoveRangeIndexOperator(), input.getOutputType(), outParallelism, true));
    }

    @VisibleForTesting
    static <T> T[] allocateRangeBaseSize(List<Tuple2<T, Integer>> sampledData, int rangesNum) {
        int sampeNum = sampledData.size();
        int boundarySize = rangesNum - 1;
        Object[] boundaries = new Object[boundarySize];
        if (!sampledData.isEmpty()) {
            long restSize = sampledData.stream().mapToLong(t -> ((Integer)t.f1).intValue()).sum();
            double stepRange = (double)restSize / (double)rangesNum;
            int currentWeight = 0;
            int index = 0;
            for (int i = 0; i < boundarySize; ++i) {
                while ((double)currentWeight < stepRange && index < sampeNum) {
                    boundaries[i] = sampledData.get((int)Math.min((int)index, (int)(sampeNum - 1))).f0;
                    int sampleWeight = (Integer)sampledData.get((int)index++).f1;
                    currentWeight += sampleWeight;
                    restSize -= (long)sampleWeight;
                }
                currentWeight = 0;
                stepRange = (double)restSize / (double)(rangesNum - i - 1);
            }
        }
        for (int i = 0; i < boundarySize; ++i) {
            if (boundaries[i] != null) continue;
            boundaries[i] = sampledData.get((int)(sampeNum - 1)).f0;
        }
        return boundaries;
    }

    private static class RandomList {
        private static final Random RANDOM = new Random();
        private final List<Integer> list = new ArrayList<Integer>();

        private RandomList() {
        }

        public void add(int i) {
            this.list.add(i);
        }

        public int get() {
            return this.list.get(RANDOM.nextInt(this.list.size()));
        }
    }

    private static class Sampler<T> {
        private final int numSamples;
        private final Random random;
        private final PriorityQueue<Tuple2<Double, T>> queue;
        private int index = 0;
        private Tuple2<Double, T> smallest = null;

        Sampler(int numSamples, long seed) {
            Preconditions.checkArgument((numSamples >= 0 ? 1 : 0) != 0, (Object)"numSamples should be non-negative.");
            this.numSamples = numSamples;
            this.random = new XORShiftRandom(seed);
            this.queue = new PriorityQueue<Tuple2>(numSamples, Comparator.comparingDouble(o -> (Double)o.f0));
        }

        void collect(T rowData) {
            this.collect(this.random.nextDouble(), rowData);
        }

        void collect(double weight, T key) {
            if (this.index < this.numSamples) {
                this.addQueue(weight, key);
            } else if (weight > (Double)this.smallest.f0) {
                this.queue.remove();
                this.addQueue(weight, key);
            }
            ++this.index;
        }

        private void addQueue(double weight, T row) {
            this.queue.add(new Tuple2((Object)weight, row));
            this.smallest = this.queue.peek();
        }

        Iterator<Tuple2<Double, T>> sample() {
            return this.queue.iterator();
        }
    }

    private static class RemoveRangeIndexOperator<T>
    extends TableStreamOperator<Tuple2<T, RowData>>
    implements OneInputStreamOperator<Tuple2<Integer, Tuple2<T, RowData>>, Tuple2<T, RowData>> {
        private static final long serialVersionUID = 1L;
        private transient Collector<Tuple2<T, RowData>> collector;

        private RemoveRangeIndexOperator() {
        }

        public void open() throws Exception {
            super.open();
            this.collector = new StreamRecordCollector(this.output);
        }

        public void processElement(StreamRecord<Tuple2<Integer, Tuple2<T, RowData>>> streamRecord) throws Exception {
            this.collector.collect(((Tuple2)streamRecord.getValue()).f1);
        }
    }

    private static class AssignRangeIndexOperator<T>
    extends TableStreamOperator<Tuple2<Integer, Tuple2<T, RowData>>>
    implements TwoInputStreamOperator<List<T>, Tuple2<T, RowData>, Tuple2<Integer, Tuple2<T, RowData>>>,
    InputSelectable {
        private static final long serialVersionUID = 1L;
        private final SerializableSupplier<Comparator<T>> keyComparatorSupplier;
        private transient List<Pair<T, RandomList>> keyIndex;
        private transient Collector<Tuple2<Integer, Tuple2<T, RowData>>> collector;
        private transient Comparator<T> keyComparator;

        public AssignRangeIndexOperator(SerializableSupplier<Comparator<T>> keyComparatorSupplier) {
            this.keyComparatorSupplier = keyComparatorSupplier;
        }

        public void open() throws Exception {
            super.open();
            this.keyComparator = (Comparator)this.keyComparatorSupplier.get();
            this.collector = new StreamRecordCollector(this.output);
        }

        public void processElement1(StreamRecord<List<T>> streamRecord) {
            this.keyIndex = new ArrayList<Pair<T, RandomList>>();
            Object last = null;
            int index = 0;
            for (Object t : (List)streamRecord.getValue()) {
                if (last != null && this.keyComparator.compare(last, t) == 0) {
                    this.keyIndex.get(this.keyIndex.size() - 1).getRight().add(index++);
                    continue;
                }
                Pair pair = Pair.of(t, new RandomList());
                pair.getRight().add(index++);
                this.keyIndex.add(pair);
                last = t;
            }
        }

        public void processElement2(StreamRecord<Tuple2<T, RowData>> streamRecord) {
            if (this.keyIndex == null) {
                throw new RuntimeException("There should be one data from the first input.");
            }
            if (this.keyIndex.isEmpty()) {
                this.collector.collect((Object)new Tuple2((Object)0, streamRecord.getValue()));
            } else {
                Tuple2 row = (Tuple2)streamRecord.getValue();
                this.collector.collect((Object)new Tuple2((Object)this.binarySearch(row.f0), (Object)row));
            }
        }

        public InputSelection nextSelection() {
            return this.keyIndex == null ? InputSelection.FIRST : InputSelection.ALL;
        }

        private int binarySearch(T key) {
            int lastIndex = this.keyIndex.size() - 1;
            int low = 0;
            int high = lastIndex;
            while (low <= high) {
                int mid = low + high >>> 1;
                Pair<T, RandomList> indexPair = this.keyIndex.get(mid);
                int result = this.keyComparator.compare(key, indexPair.getLeft());
                if (result > 0) {
                    low = mid + 1;
                    continue;
                }
                if (result < 0) {
                    high = mid - 1;
                    continue;
                }
                return indexPair.getRight().get();
            }
            return low > lastIndex ? this.keyIndex.get(lastIndex).getRight().get() + 1 : this.keyIndex.get(low).getRight().get();
        }

        public static class RangePartitioner
        implements Partitioner<Integer> {
            private static final long serialVersionUID = 1L;
            private final int totalRangeNum;

            public RangePartitioner(int totalRangeNum) {
                this.totalRangeNum = totalRangeNum;
            }

            public int partition(Integer key, int numPartitions) {
                Preconditions.checkArgument((numPartitions <= this.totalRangeNum ? 1 : 0) != 0, (Object)("Num of subPartitions should <= totalRangeNum: " + this.totalRangeNum));
                int partition = key / (this.totalRangeNum / numPartitions);
                return Math.min(numPartitions - 1, partition);
            }
        }

        public static class Tuple2KeySelector<T>
        implements KeySelector<Tuple2<Integer, Tuple2<T, RowData>>, Integer> {
            private static final long serialVersionUID = 1L;

            public Integer getKey(Tuple2<Integer, Tuple2<T, RowData>> tuple2) throws Exception {
                return (Integer)tuple2.f0;
            }
        }
    }

    private static class GlobalSampleOperator<T>
    extends TableStreamOperator<List<T>>
    implements OneInputStreamOperator<Tuple3<Double, T, Integer>, List<T>>,
    BoundedOneInput {
        private static final long serialVersionUID = 1L;
        private final int numSample;
        private final int rangesNum;
        private final SerializableSupplier<Comparator<T>> comparatorSupplier;
        private transient Comparator<T> keyComparator;
        private transient Collector<List<T>> collector;
        private transient Sampler<Tuple2<T, Integer>> sampler;

        public GlobalSampleOperator(int numSample, SerializableSupplier<Comparator<T>> comparatorSupplier, int rangesNum) {
            this.numSample = numSample;
            this.comparatorSupplier = comparatorSupplier;
            this.rangesNum = rangesNum;
        }

        public void open() throws Exception {
            super.open();
            this.keyComparator = (Comparator)this.comparatorSupplier.get();
            this.sampler = new Sampler(this.numSample, 0L);
            this.collector = new StreamRecordCollector(this.output);
        }

        public void processElement(StreamRecord<Tuple3<Double, T, Integer>> record) throws Exception {
            Tuple3 tuple = (Tuple3)record.getValue();
            this.sampler.collect((Double)tuple.f0, new Tuple2(tuple.f1, tuple.f2));
        }

        public void endInput() {
            Iterator<Tuple2<Double, Tuple2<T, Integer>>> sampled = this.sampler.sample();
            ArrayList sampledData = new ArrayList();
            while (sampled.hasNext()) {
                sampledData.add(sampled.next().f1);
            }
            sampledData.sort((o1, o2) -> this.keyComparator.compare(o1.f0, o2.f0));
            List<Object> range = sampledData.isEmpty() ? new ArrayList() : Arrays.asList(RangeShuffle.allocateRangeBaseSize(sampledData, this.rangesNum));
            this.collector.collect(range);
        }
    }

    @Internal
    public static class LocalSampleOperator<T>
    extends TableStreamOperator<Tuple3<Double, T, Integer>>
    implements OneInputStreamOperator<Tuple2<T, Integer>, Tuple3<Double, T, Integer>>,
    BoundedOneInput {
        private static final long serialVersionUID = 1L;
        private final int numSample;
        private transient Collector<Tuple3<Double, T, Integer>> collector;
        private transient Sampler<Tuple2<T, Integer>> sampler;

        public LocalSampleOperator(int numSample) {
            this.numSample = numSample;
        }

        public void open() throws Exception {
            super.open();
            this.collector = new StreamRecordCollector(this.output);
            this.sampler = new Sampler(this.numSample, System.nanoTime());
        }

        public void processElement(StreamRecord<Tuple2<T, Integer>> streamRecord) throws Exception {
            this.sampler.collect((Tuple2<Object, Integer>)streamRecord.getValue());
        }

        public void endInput() {
            Iterator<Tuple2<Double, Tuple2<T, Integer>>> sampled = this.sampler.sample();
            while (sampled.hasNext()) {
                Tuple2<Double, Tuple2<T, Integer>> next = sampled.next();
                this.collector.collect((Object)new Tuple3(next.f0, ((Tuple2)next.f1).f0, ((Tuple2)next.f1).f1));
            }
        }
    }

    public static class KeyAndSizeExtractor<T>
    extends RichMapFunction<Tuple2<T, RowData>, Tuple2<T, Integer>> {
        private final RowType rowType;
        private final boolean isSortBySize;
        private transient List<BiFunction<DataGetters, Integer, Integer>> fieldSizeCalculator;

        public KeyAndSizeExtractor(RowType rowType, boolean isSortBySize) {
            this.rowType = rowType;
            this.isSortBySize = isSortBySize;
        }

        public void open(OpenContext openContext) throws Exception {
            this.open(new Configuration());
        }

        public void open(Configuration parameters) throws Exception {
            InternalRowToSizeVisitor internalRowToSizeVisitor = new InternalRowToSizeVisitor();
            this.fieldSizeCalculator = this.rowType.getFieldTypes().stream().map((? super T dataType) -> dataType.accept(internalRowToSizeVisitor)).collect(Collectors.toList());
        }

        public Tuple2<T, Integer> map(Tuple2<T, RowData> keyAndRowData) throws Exception {
            if (this.isSortBySize) {
                int size = 0;
                for (int i = 0; i < this.fieldSizeCalculator.size(); ++i) {
                    size += this.fieldSizeCalculator.get(i).apply(new FlinkRowWrapper((RowData)keyAndRowData.f1), i).intValue();
                }
                return new Tuple2(keyAndRowData.f0, (Object)size);
            }
            return new Tuple2(keyAndRowData.f0, (Object)1);
        }
    }
}

