/*
 * Decompiled with CFR 0.152.
 */
package io.trino.spi.exchange;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.airlift.slice.SizeOf;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.spi.exchange.ExchangeId;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.openjdk.jol.info.ClassLayout;

public class ExchangeSourceOutputSelector {
    private static final long INSTANCE_SIZE = ClassLayout.parseClass(ExchangeSourceOutputSelector.class).instanceSize();
    private final int version;
    private final Map<ExchangeId, Slice> values;
    private final boolean finalSelector;

    @JsonCreator
    public ExchangeSourceOutputSelector(@JsonProperty(value="version") int version, @JsonProperty(value="values") Map<ExchangeId, Slice> values, @JsonProperty(value="finalSelector") boolean finalSelector) {
        this.version = version;
        this.values = Map.copyOf(Objects.requireNonNull(values, "values is null"));
        this.finalSelector = finalSelector;
    }

    @JsonProperty
    public int getVersion() {
        return this.version;
    }

    @JsonProperty
    public Map<ExchangeId, Slice> getValues() {
        return this.values;
    }

    @JsonProperty(value="finalSelector")
    public boolean isFinal() {
        return this.finalSelector;
    }

    public Selection getSelection(ExchangeId exchangeId, int taskPartitionId, int attemptId) {
        Objects.requireNonNull(exchangeId, "exchangeId is null");
        if (taskPartitionId < 0) {
            throw new IllegalArgumentException("unexpected taskPartitionId: " + taskPartitionId);
        }
        if (attemptId < 0 || attemptId > 127) {
            throw new IllegalArgumentException("unexpected attemptId: " + attemptId);
        }
        Slice exchangeValues = this.values.get(exchangeId);
        if (exchangeValues == null) {
            this.throwIfFinal(exchangeId, taskPartitionId);
            return Selection.UNKNOWN;
        }
        if (exchangeValues.length() <= taskPartitionId) {
            this.throwIfFinal(exchangeId, taskPartitionId);
            return Selection.UNKNOWN;
        }
        byte selectedAttempt = exchangeValues.getByte(taskPartitionId);
        if (selectedAttempt == Selection.UNKNOWN.getValue()) {
            this.throwIfFinal(exchangeId, taskPartitionId);
            return Selection.UNKNOWN;
        }
        if (selectedAttempt == Selection.EXCLUDED.getValue()) {
            return Selection.EXCLUDED;
        }
        if (selectedAttempt < 0) {
            throw new IllegalArgumentException("unexpected selectedAttempt: " + selectedAttempt);
        }
        return selectedAttempt == attemptId ? Selection.INCLUDED : Selection.EXCLUDED;
    }

    public long getRetainedSizeInBytes() {
        return INSTANCE_SIZE + SizeOf.estimatedSizeOf(this.values, ExchangeId::getRetainedSizeInBytes, Slice::getRetainedSize);
    }

    public void checkValidTransition(ExchangeSourceOutputSelector newSelector) {
        if (this.version >= newSelector.version) {
            throw new IllegalArgumentException("Invalid transition to the same or an older version");
        }
        if (this.isFinal()) {
            throw new IllegalArgumentException("Invalid transition from final selector");
        }
        HashSet<ExchangeId> exchangeIds = new HashSet<ExchangeId>();
        exchangeIds.addAll(this.values.keySet());
        exchangeIds.addAll(newSelector.values.keySet());
        for (ExchangeId exchangeId : exchangeIds) {
            int taskPartitionCount = Math.max(this.getPartitionCount(exchangeId), newSelector.getPartitionCount(exchangeId));
            for (int taskPartitionId = 0; taskPartitionId < taskPartitionCount; ++taskPartitionId) {
                byte currentValue = this.getValue(exchangeId, taskPartitionId);
                byte newValue = newSelector.getValue(exchangeId, taskPartitionId);
                if (currentValue == Selection.UNKNOWN.getValue() || currentValue == newValue) continue;
                throw new IllegalArgumentException("Invalid transition for exchange %s, taskPartitionId %s: %s -> %s".formatted(exchangeId, taskPartitionId, currentValue, newValue));
            }
        }
    }

    private int getPartitionCount(ExchangeId exchangeId) {
        Slice values = this.values.get(exchangeId);
        if (values == null) {
            return 0;
        }
        return values.length();
    }

    private byte getValue(ExchangeId exchangeId, int taskPartitionId) {
        Slice exchangeValues = this.values.get(exchangeId);
        if (exchangeValues == null) {
            return Selection.UNKNOWN.getValue();
        }
        if (exchangeValues.length() <= taskPartitionId) {
            return Selection.UNKNOWN.getValue();
        }
        return exchangeValues.getByte(taskPartitionId);
    }

    private void throwIfFinal(ExchangeId exchangeId, int taskPartitionId) {
        if (this.isFinal()) {
            throw new IllegalArgumentException("selection not found for exchangeId %s, taskPartitionId %s".formatted(exchangeId, taskPartitionId));
        }
    }

    public static Builder builder(Set<ExchangeId> sourceExchanges) {
        return new Builder(sourceExchanges);
    }

    public static enum Selection {
        INCLUDED(-1),
        EXCLUDED(-2),
        UNKNOWN(-3);

        private final byte value;

        private Selection(byte value) {
            this.value = value;
        }

        public byte getValue() {
            return this.value;
        }
    }

    public static class Builder {
        private int nextVersion;
        private final Map<ExchangeId, ValuesBuilder> exchangeValues;
        private boolean finalSelector;
        private final Map<ExchangeId, Integer> exchangeTaskPartitionCount = new HashMap<ExchangeId, Integer>();

        public Builder(Set<ExchangeId> sourceExchanges) {
            Objects.requireNonNull(sourceExchanges, "sourceExchanges is null");
            this.exchangeValues = sourceExchanges.stream().collect(Collectors.toUnmodifiableMap(Function.identity(), exchangeId -> new ValuesBuilder()));
        }

        public Builder include(ExchangeId exchangeId, int taskPartitionId, int attemptId) {
            this.getValuesBuilderForExchange(exchangeId).include(taskPartitionId, attemptId);
            return this;
        }

        public Builder exclude(ExchangeId exchangeId, int taskPartitionId) {
            this.getValuesBuilderForExchange(exchangeId).exclude(taskPartitionId);
            return this;
        }

        private ValuesBuilder getValuesBuilderForExchange(ExchangeId exchangeId) {
            ValuesBuilder result = this.exchangeValues.get(exchangeId);
            if (result == null) {
                throw new IllegalArgumentException("Unexpected exchange: " + exchangeId);
            }
            return result;
        }

        public Builder setPartitionCount(ExchangeId exchangeId, int count) {
            Integer previousCount = this.exchangeTaskPartitionCount.putIfAbsent(exchangeId, count);
            if (previousCount != null) {
                throw new IllegalStateException("Partition count for exchange is already set: " + exchangeId);
            }
            return this;
        }

        public Builder setFinal() {
            if (this.finalSelector) {
                throw new IllegalStateException("selector is already marked as final");
            }
            for (ExchangeId exchangeId : this.exchangeValues.keySet()) {
                if (this.exchangeTaskPartitionCount.containsKey(exchangeId)) continue;
                throw new IllegalStateException("partition count is missing for exchange: " + exchangeId);
            }
            this.finalSelector = true;
            return this;
        }

        public ExchangeSourceOutputSelector build() {
            return new ExchangeSourceOutputSelector(this.nextVersion++, this.exchangeValues.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> {
                ExchangeId exchangeId = (ExchangeId)entry.getKey();
                ValuesBuilder valuesBuilder = (ValuesBuilder)entry.getValue();
                if (this.finalSelector) {
                    return valuesBuilder.buildFinal(this.exchangeTaskPartitionCount.get(exchangeId));
                }
                return valuesBuilder.build();
            })), this.finalSelector);
        }
    }

    private static class ValuesBuilder {
        private Slice values = Slices.allocate((int)0);
        private int maxTaskPartitionId = -1;

        private ValuesBuilder() {
        }

        public void include(int taskPartitionId, int attemptId) {
            this.updateMaxTaskPartitionIdAndEnsureCapacity(taskPartitionId);
            if (attemptId < 0 || attemptId > 127) {
                throw new IllegalArgumentException("unexpected attemptId: " + attemptId);
            }
            byte currentValue = this.values.getByte(taskPartitionId);
            if (currentValue != Selection.UNKNOWN.getValue()) {
                throw new IllegalArgumentException("decision for partition %s is already made: %s".formatted(taskPartitionId, currentValue));
            }
            this.values.setByte(taskPartitionId, (int)((byte)attemptId));
        }

        public void exclude(int taskPartitionId) {
            this.updateMaxTaskPartitionIdAndEnsureCapacity(taskPartitionId);
            byte currentValue = this.values.getByte(taskPartitionId);
            if (currentValue != Selection.UNKNOWN.getValue()) {
                throw new IllegalArgumentException("decision for partition %s is already made: %s".formatted(taskPartitionId, currentValue));
            }
            this.values.setByte(taskPartitionId, (int)Selection.EXCLUDED.getValue());
        }

        private void updateMaxTaskPartitionIdAndEnsureCapacity(int taskPartitionId) {
            if (taskPartitionId > this.maxTaskPartitionId) {
                this.maxTaskPartitionId = taskPartitionId;
            }
            if (taskPartitionId < this.values.length()) {
                return;
            }
            byte[] newValues = new byte[(this.maxTaskPartitionId + 1) * 2];
            Arrays.fill(newValues, Selection.UNKNOWN.getValue());
            this.values.getBytes(0, newValues, 0, this.values.length());
            this.values = Slices.wrappedBuffer((byte[])newValues);
        }

        public Slice build() {
            return this.createResult(this.maxTaskPartitionId + 1);
        }

        public Slice buildFinal(int totalPartitionCount) {
            Slice result = this.createResult(totalPartitionCount);
            for (int partitionId = 0; partitionId < totalPartitionCount; ++partitionId) {
                byte selectedAttempt = result.getByte(partitionId);
                if (selectedAttempt != Selection.UNKNOWN.getValue()) continue;
                throw new IllegalStateException("Attempt is unknown for partition: " + partitionId);
            }
            return result;
        }

        private Slice createResult(int partitionCount) {
            if (this.maxTaskPartitionId >= partitionCount) {
                throw new IllegalArgumentException("expected maxTaskPartitionId to be less than or equal to " + (partitionCount - 1));
            }
            byte[] result = new byte[partitionCount];
            Arrays.fill(result, Selection.UNKNOWN.getValue());
            this.values.getBytes(0, result, 0, this.maxTaskPartitionId + 1);
            return Slices.wrappedBuffer((byte[])result);
        }
    }
}

