/*
 * Decompiled with CFR 0.152.
 */
package org.apache.paimon.flink.shuffle;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Random;
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.dag.Transformation;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.ListTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.InputSelectable;
import org.apache.flink.streaming.api.operators.InputSelection;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.StreamMap;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
import org.apache.flink.streaming.api.transformations.PartitionTransformation;
import org.apache.flink.streaming.api.transformations.StreamExchangeMode;
import org.apache.flink.streaming.api.transformations.TwoInputTransformation;
import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner;
import org.apache.flink.streaming.runtime.partitioner.CustomPartitionerWrapper;
import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.runtime.operators.TableStreamOperator;
import org.apache.flink.table.runtime.util.StreamRecordCollector;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.XORShiftRandom;
import org.apache.paimon.utils.SerializableSupplier;

public class RangeShuffle {
    public static <T> DataStream<Tuple2<T, RowData>> rangeShuffleByKey(DataStream<Tuple2<T, RowData>> inputDataStream, SerializableSupplier<Comparator<T>> keyComparator, TypeInformation<T> keyTypeInformation, int sampleSize, int rangeNum, int outParallelism) {
        Transformation input = inputDataStream.getTransformation();
        OneInputTransformation keyInput = new OneInputTransformation(input, "ABSTRACT KEY", (OneInputStreamOperator)new StreamMap((MapFunction & Serializable)a -> a.f0), keyTypeInformation, input.getParallelism());
        OneInputTransformation localSample = new OneInputTransformation((Transformation)keyInput, "LOCAL SAMPLE", new LocalSampleOperator(sampleSize), (TypeInformation)new TupleTypeInfo(new TypeInformation[]{BasicTypeInfo.DOUBLE_TYPE_INFO, keyTypeInformation}), keyInput.getParallelism());
        OneInputTransformation sampleAndHistogram = new OneInputTransformation((Transformation)localSample, "GLOBAL SAMPLE", new GlobalSampleOperator<T>(sampleSize, keyComparator, rangeNum), (TypeInformation)new ListTypeInfo(keyTypeInformation), 1);
        TwoInputTransformation preparePartition = new TwoInputTransformation((Transformation)new PartitionTransformation((Transformation)sampleAndHistogram, (StreamPartitioner)new BroadcastPartitioner(), StreamExchangeMode.BATCH), (Transformation)new PartitionTransformation(input, (StreamPartitioner)new ForwardPartitioner(), StreamExchangeMode.BATCH), "ASSIGN RANGE INDEX", new AssignRangeIndexOperator<T>(keyComparator), (TypeInformation)new TupleTypeInfo(new TypeInformation[]{BasicTypeInfo.INT_TYPE_INFO, input.getOutputType()}), input.getParallelism());
        return new DataStream(inputDataStream.getExecutionEnvironment(), (Transformation)new OneInputTransformation((Transformation)new PartitionTransformation((Transformation)preparePartition, (StreamPartitioner)new CustomPartitionerWrapper((Partitioner)new AssignRangeIndexOperator.RangePartitioner(rangeNum), new AssignRangeIndexOperator.Tuple2KeySelector()), StreamExchangeMode.BATCH), "REMOVE KEY", new RemoveRangeIndexOperator(), input.getOutputType(), outParallelism));
    }

    private static class Sampler<T> {
        private final int numSamples;
        private final Random random;
        private final PriorityQueue<Tuple2<Double, T>> queue;
        private int index = 0;
        private Tuple2<Double, T> smallest = null;

        Sampler(int numSamples, long seed) {
            Preconditions.checkArgument((numSamples >= 0 ? 1 : 0) != 0, (Object)"numSamples should be non-negative.");
            this.numSamples = numSamples;
            this.random = new XORShiftRandom(seed);
            this.queue = new PriorityQueue<Tuple2>(numSamples, Comparator.comparingDouble(o -> (Double)o.f0));
        }

        void collect(T rowData) {
            this.collect(this.random.nextDouble(), rowData);
        }

        void collect(double weight, T key) {
            if (this.index < this.numSamples) {
                this.addQueue(weight, key);
            } else if (weight > (Double)this.smallest.f0) {
                this.queue.remove();
                this.addQueue(weight, key);
            }
            ++this.index;
        }

        private void addQueue(double weight, T row) {
            this.queue.add(new Tuple2((Object)weight, row));
            this.smallest = this.queue.peek();
        }

        Iterator<Tuple2<Double, T>> sample() {
            return this.queue.iterator();
        }
    }

    private static class RemoveRangeIndexOperator<T>
    extends TableStreamOperator<Tuple2<T, RowData>>
    implements OneInputStreamOperator<Tuple2<Integer, Tuple2<T, RowData>>, Tuple2<T, RowData>> {
        private static final long serialVersionUID = 1L;
        private transient Collector<Tuple2<T, RowData>> collector;

        private RemoveRangeIndexOperator() {
        }

        public void open() throws Exception {
            super.open();
            this.collector = new StreamRecordCollector(this.output);
        }

        public void processElement(StreamRecord<Tuple2<Integer, Tuple2<T, RowData>>> streamRecord) throws Exception {
            this.collector.collect(((Tuple2)streamRecord.getValue()).f1);
        }
    }

    private static class AssignRangeIndexOperator<T>
    extends TableStreamOperator<Tuple2<Integer, Tuple2<T, RowData>>>
    implements TwoInputStreamOperator<List<T>, Tuple2<T, RowData>, Tuple2<Integer, Tuple2<T, RowData>>>,
    InputSelectable {
        private static final long serialVersionUID = 1L;
        private final SerializableSupplier<Comparator<T>> keyComparatorSupplier;
        private transient List<T> boundaries;
        private transient Collector<Tuple2<Integer, Tuple2<T, RowData>>> collector;
        private transient Comparator<T> keyComparator;

        public AssignRangeIndexOperator(SerializableSupplier<Comparator<T>> keyComparatorSupplier) {
            this.keyComparatorSupplier = keyComparatorSupplier;
        }

        public void open() throws Exception {
            super.open();
            this.keyComparator = (Comparator)this.keyComparatorSupplier.get();
            this.collector = new StreamRecordCollector(this.output);
        }

        public void processElement1(StreamRecord<List<T>> streamRecord) {
            this.boundaries = (List)streamRecord.getValue();
        }

        public void processElement2(StreamRecord<Tuple2<T, RowData>> streamRecord) {
            if (this.boundaries == null) {
                throw new RuntimeException("There should be one data from the first input.");
            }
            Tuple2 row = (Tuple2)streamRecord.getValue();
            this.collector.collect((Object)new Tuple2((Object)this.binarySearch(row.f0), (Object)row));
        }

        public InputSelection nextSelection() {
            return this.boundaries == null ? InputSelection.FIRST : InputSelection.ALL;
        }

        private int binarySearch(T key) {
            int low = 0;
            int high = this.boundaries.size() - 1;
            while (low <= high) {
                int mid = low + high >>> 1;
                int result = this.keyComparator.compare(key, this.boundaries.get(mid));
                if (result > 0) {
                    low = mid + 1;
                    continue;
                }
                if (result < 0) {
                    high = mid - 1;
                    continue;
                }
                return mid;
            }
            return low;
        }

        public static class RangePartitioner
        implements Partitioner<Integer> {
            private static final long serialVersionUID = 1L;
            private final int totalRangeNum;

            public RangePartitioner(int totalRangeNum) {
                this.totalRangeNum = totalRangeNum;
            }

            public int partition(Integer key, int numPartitions) {
                Preconditions.checkArgument((numPartitions < this.totalRangeNum ? 1 : 0) != 0, (Object)("Num of subPartitions should < totalRangeNum: " + this.totalRangeNum));
                int partition = key / (this.totalRangeNum / numPartitions);
                return Math.min(numPartitions - 1, partition);
            }
        }

        public static class Tuple2KeySelector<T>
        implements KeySelector<Tuple2<Integer, Tuple2<T, RowData>>, Integer> {
            private static final long serialVersionUID = 1L;

            public Integer getKey(Tuple2<Integer, Tuple2<T, RowData>> tuple2) throws Exception {
                return (Integer)tuple2.f0;
            }
        }
    }

    private static class GlobalSampleOperator<T>
    extends TableStreamOperator<List<T>>
    implements OneInputStreamOperator<Tuple2<Double, T>, List<T>>,
    BoundedOneInput {
        private static final long serialVersionUID = 1L;
        private final int numSample;
        private final int rangesNum;
        private final SerializableSupplier<Comparator<T>> comparatorSupplier;
        private transient Comparator<T> keyComparator;
        private transient Collector<List<T>> collector;
        private transient Sampler<T> sampler;

        public GlobalSampleOperator(int numSample, SerializableSupplier<Comparator<T>> comparatorSupplier, int rangesNum) {
            this.numSample = numSample;
            this.comparatorSupplier = comparatorSupplier;
            this.rangesNum = rangesNum;
        }

        public void open() throws Exception {
            super.open();
            this.keyComparator = (Comparator)this.comparatorSupplier.get();
            this.sampler = new Sampler(this.numSample, 0L);
            this.collector = new StreamRecordCollector(this.output);
        }

        public void processElement(StreamRecord<Tuple2<Double, T>> record) throws Exception {
            Tuple2 tuple = (Tuple2)record.getValue();
            this.sampler.collect((Double)tuple.f0, tuple.f1);
        }

        public void endInput() throws Exception {
            Iterator<Tuple2<Double, T>> sampled = this.sampler.sample();
            ArrayList<Object> sampledData = new ArrayList<Object>();
            while (sampled.hasNext()) {
                sampledData.add(sampled.next().f1);
            }
            sampledData.sort(this.keyComparator);
            int boundarySize = this.rangesNum - 1;
            Object[] boundaries = new Object[boundarySize];
            if (sampledData.size() > 0) {
                double avgRange = (double)sampledData.size() / (double)this.rangesNum;
                for (int i = 1; i < this.rangesNum; ++i) {
                    Object record = sampledData.get((int)((double)i * avgRange));
                    boundaries[i - 1] = record;
                }
            }
            this.collector.collect(Arrays.asList(boundaries));
        }
    }

    @Internal
    public static class LocalSampleOperator<T>
    extends TableStreamOperator<Tuple2<Double, T>>
    implements OneInputStreamOperator<T, Tuple2<Double, T>>,
    BoundedOneInput {
        private static final long serialVersionUID = 1L;
        private final int numSample;
        private transient Collector<Tuple2<Double, T>> collector;
        private transient Sampler<T> sampler;

        public LocalSampleOperator(int numSample) {
            this.numSample = numSample;
        }

        public void open() throws Exception {
            super.open();
            this.collector = new StreamRecordCollector(this.output);
            this.sampler = new Sampler(this.numSample, System.nanoTime());
        }

        public void processElement(StreamRecord<T> streamRecord) throws Exception {
            this.sampler.collect(streamRecord.getValue());
        }

        public void endInput() throws Exception {
            Iterator<Tuple2<Double, T>> sampled = this.sampler.sample();
            while (sampled.hasNext()) {
                this.collector.collect(sampled.next());
            }
        }
    }
}

