/*
 * Decompiled with CFR 0.152.
 */
package org.apache.paimon.flink.sink.index;

import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import java.util.TreeMap;
import java.util.UUID;
import java.util.function.BiConsumer;
import org.apache.paimon.CoreOptions;
import org.apache.paimon.data.BinaryRow;
import org.apache.paimon.data.InternalRow;
import org.apache.paimon.data.serializer.RowCompactedSerializer;
import org.apache.paimon.flink.lookup.RocksDBStateFactory;
import org.apache.paimon.flink.lookup.RocksDBValueState;
import org.apache.paimon.options.Options;
import org.apache.paimon.schema.TableSchema;
import org.apache.paimon.table.AbstractFileStoreTable;
import org.apache.paimon.table.Table;
import org.apache.paimon.table.sink.PartitionKeyExtractor;
import org.apache.paimon.types.RowKind;
import org.apache.paimon.types.RowType;
import org.apache.paimon.utils.FileIOUtils;
import org.apache.paimon.utils.Filter;
import org.apache.paimon.utils.IDMapping;
import org.apache.paimon.utils.PositiveIntInt;
import org.apache.paimon.utils.PositiveIntIntSerializer;
import org.apache.paimon.utils.SerBiFunction;
import org.apache.paimon.utils.SerializableFunction;

public class GlobalIndexAssigner<T>
implements Serializable {
    private static final long serialVersionUID = 1L;
    private final AbstractFileStoreTable table;
    private final SerializableFunction<TableSchema, PartitionKeyExtractor<T>> extractorFunction;
    private final SerializableFunction<TableSchema, PartitionKeyExtractor<T>> keyPartExtractorFunction;
    private final SerBiFunction<T, BinaryRow, T> setPartition;
    private final SerBiFunction<T, RowKind, T> setRowKind;
    private transient int targetBucketRowNumber;
    private transient int assignId;
    private transient BiConsumer<T, Integer> collector;
    private transient int numAssigners;
    private transient PartitionKeyExtractor<T> extractor;
    private transient PartitionKeyExtractor<T> keyPartExtractor;
    private transient File path;
    private transient RocksDBStateFactory stateFactory;
    private transient RocksDBValueState<InternalRow, PositiveIntInt> keyIndex;
    private transient IDMapping<BinaryRow> partMapping;
    private transient BucketAssigner bucketAssigner;
    private transient ExistsAction existsAction;

    public GlobalIndexAssigner(Table table, SerializableFunction<TableSchema, PartitionKeyExtractor<T>> extractorFunction, SerializableFunction<TableSchema, PartitionKeyExtractor<T>> keyPartExtractorFunction, SerBiFunction<T, BinaryRow, T> setPartition, SerBiFunction<T, RowKind, T> setRowKind) {
        this.table = (AbstractFileStoreTable)table;
        this.extractorFunction = extractorFunction;
        this.keyPartExtractorFunction = keyPartExtractorFunction;
        this.setPartition = setPartition;
        this.setRowKind = setRowKind;
    }

    public void open(File tmpDir, int numAssigners, int assignId, BiConsumer<T, Integer> collector) throws Exception {
        this.numAssigners = numAssigners;
        this.assignId = assignId;
        this.collector = collector;
        CoreOptions coreOptions = this.table.coreOptions();
        this.targetBucketRowNumber = (int)coreOptions.dynamicBucketTargetRowNum();
        this.extractor = (PartitionKeyExtractor)this.extractorFunction.apply(this.table.schema());
        this.keyPartExtractor = (PartitionKeyExtractor)this.keyPartExtractorFunction.apply(this.table.schema());
        Options options = coreOptions.toConfiguration();
        this.path = new File(tmpDir, "lookup-" + UUID.randomUUID());
        this.stateFactory = new RocksDBStateFactory(this.path.toString(), options);
        long cacheSize = options.get(CoreOptions.LOOKUP_CACHE_MAX_MEMORY_SIZE).getBytes();
        RowType keyType = this.table.schema().logicalTrimmedPrimaryKeysType();
        this.keyIndex = this.stateFactory.valueState("keyIndex", new RowCompactedSerializer(keyType), new PositiveIntIntSerializer(), cacheSize);
        this.partMapping = new IDMapping<BinaryRow>(BinaryRow::copy);
        this.bucketAssigner = new BucketAssigner();
        this.existsAction = this.fromMergeEngine(coreOptions.mergeEngine());
    }

    public void process(T value) throws Exception {
        BinaryRow partition = this.extractor.partition(value);
        BinaryRow key = this.extractor.trimmedPrimaryKey(value);
        int partId = this.partMapping.index(partition);
        PositiveIntInt partitionBucket = this.keyIndex.get(key);
        if (partitionBucket != null) {
            int previousPartId = partitionBucket.i1();
            int previousBucket = partitionBucket.i2();
            if (previousPartId == partId) {
                this.collect(value, previousBucket);
            } else {
                switch (this.existsAction) {
                    case DELETE: {
                        BinaryRow previousPart = this.partMapping.get(previousPartId);
                        Object retract = this.setPartition.apply(value, previousPart);
                        retract = this.setRowKind.apply(retract, RowKind.DELETE);
                        this.collect(retract, previousBucket);
                        this.bucketAssigner.decrement(previousPart, previousBucket);
                        this.processNewRecord(partition, partId, key, value);
                        break;
                    }
                    case USE_OLD: {
                        BinaryRow previousPart = this.partMapping.get(previousPartId);
                        Object newValue = this.setPartition.apply(value, previousPart);
                        this.collect(newValue, previousBucket);
                        break;
                    }
                }
            }
        } else {
            this.processNewRecord(partition, partId, key, value);
        }
    }

    public void bootstrap(T value) throws IOException {
        BinaryRow partition = this.keyPartExtractor.partition(value);
        this.keyIndex.put(this.keyPartExtractor.trimmedPrimaryKey(value), new PositiveIntInt(this.partMapping.index(partition), this.assignBucket(partition)));
    }

    private void processNewRecord(BinaryRow partition, int partId, BinaryRow key, T value) throws IOException {
        int bucket = this.assignBucket(partition);
        this.keyIndex.put(key, new PositiveIntInt(partId, bucket));
        this.collect(value, bucket);
    }

    private int assignBucket(BinaryRow partition) {
        return this.bucketAssigner.assignBucket(partition, this::isAssignBucket, this.targetBucketRowNumber);
    }

    private boolean isAssignBucket(int bucket) {
        return this.computeAssignId(bucket) == this.assignId;
    }

    private int computeAssignId(int hash) {
        return Math.abs(hash % this.numAssigners);
    }

    private void collect(T value, int bucket) {
        this.collector.accept(value, bucket);
    }

    public void close() throws IOException {
        if (this.stateFactory != null) {
            this.stateFactory.close();
            this.stateFactory = null;
        }
        if (this.path != null) {
            FileIOUtils.deleteDirectoryQuietly(this.path);
        }
    }

    private ExistsAction fromMergeEngine(CoreOptions.MergeEngine mergeEngine) {
        switch (mergeEngine) {
            case DEDUPLICATE: {
                return ExistsAction.DELETE;
            }
            case PARTIAL_UPDATE: 
            case AGGREGATE: {
                return ExistsAction.USE_OLD;
            }
            case FIRST_ROW: {
                return ExistsAction.SKIP_NEW;
            }
        }
        throw new UnsupportedOperationException("Unsupported engine: " + mergeEngine);
    }

    private static enum ExistsAction {
        DELETE,
        USE_OLD,
        SKIP_NEW;

    }

    private static class BucketAssigner {
        private final Map<BinaryRow, TreeMap<Integer, Integer>> stats = new HashMap<BinaryRow, TreeMap<Integer, Integer>>();

        private BucketAssigner() {
        }

        public int assignBucket(BinaryRow part, Filter<Integer> filter, int maxCount) {
            TreeMap<Integer, Integer> bucketMap = this.bucketMap(part);
            for (Map.Entry<Integer, Integer> entry : bucketMap.entrySet()) {
                int bucket = entry.getKey();
                int count = entry.getValue();
                if (!filter.test(bucket) || count >= maxCount) continue;
                bucketMap.put(bucket, count + 1);
                return bucket;
            }
            int i = 0;
            while (true) {
                if (filter.test(i) && !bucketMap.containsKey(i)) {
                    bucketMap.put(i, 1);
                    return i;
                }
                ++i;
            }
        }

        public void decrement(BinaryRow part, int bucket) {
            this.bucketMap(part).compute(bucket, (k, v) -> v == null ? 0 : v - 1);
        }

        private TreeMap<Integer, Integer> bucketMap(BinaryRow part) {
            TreeMap<Integer, Integer> map = this.stats.get(part);
            if (map == null) {
                map = new TreeMap();
                this.stats.put(part.copy(), map);
            }
            return map;
        }
    }
}

