/*
 * Decompiled with CFR 0.152.
 */
package io.trino.jdbc.$internal.airlift.compress.v3.zstd;

import io.trino.jdbc.$internal.airlift.compress.v3.zstd.BitOutputStream;
import io.trino.jdbc.$internal.airlift.compress.v3.zstd.FiniteStateEntropy;
import io.trino.jdbc.$internal.airlift.compress.v3.zstd.FseCompressionTable;
import io.trino.jdbc.$internal.airlift.compress.v3.zstd.Histogram;
import io.trino.jdbc.$internal.airlift.compress.v3.zstd.HuffmanCompressionTableWorkspace;
import io.trino.jdbc.$internal.airlift.compress.v3.zstd.HuffmanTableWriterWorkspace;
import io.trino.jdbc.$internal.airlift.compress.v3.zstd.NodeTable;
import io.trino.jdbc.$internal.airlift.compress.v3.zstd.UnsafeUtil;
import io.trino.jdbc.$internal.airlift.compress.v3.zstd.Util;
import java.util.Arrays;

final class HuffmanCompressionTable {
    private final short[] values;
    private final byte[] numberOfBits;
    private int maxSymbol;
    private int maxNumberOfBits;

    public HuffmanCompressionTable(int capacity) {
        this.values = new short[capacity];
        this.numberOfBits = new byte[capacity];
    }

    public static int optimalNumberOfBits(int maxNumberOfBits, int inputSize, int maxSymbol) {
        if (inputSize <= 1) {
            throw new IllegalArgumentException();
        }
        int result2 = maxNumberOfBits;
        result2 = Math.min(result2, Util.highestBit(inputSize - 1) - 1);
        result2 = Math.max(result2, Util.minTableLog(inputSize, maxSymbol));
        result2 = Math.max(result2, 5);
        result2 = Math.min(result2, 12);
        return result2;
    }

    public void initialize(int[] counts, int maxSymbol, int maxNumberOfBits, HuffmanCompressionTableWorkspace workspace) {
        Util.checkArgument(maxSymbol <= 255, "Max symbol value too large");
        workspace.reset();
        NodeTable nodeTable = workspace.nodeTable;
        nodeTable.reset();
        int lastNonZero = this.buildTree(counts, maxSymbol, nodeTable);
        maxNumberOfBits = HuffmanCompressionTable.setMaxHeight(nodeTable, lastNonZero, maxNumberOfBits, workspace);
        Util.checkArgument(maxNumberOfBits <= 12, "Max number of bits larger than max table size");
        int symbolCount = maxSymbol + 1;
        for (int node = 0; node < symbolCount; ++node) {
            int symbol = nodeTable.symbols[node];
            this.numberOfBits[symbol] = nodeTable.numberOfBits[node];
        }
        short[] entriesPerRank = workspace.entriesPerRank;
        short[] valuesPerRank = workspace.valuesPerRank;
        for (int n = 0; n <= lastNonZero; ++n) {
            byte by = nodeTable.numberOfBits[n];
            entriesPerRank[by] = (short)(entriesPerRank[by] + 1);
        }
        short startingValue = 0;
        for (int rank = maxNumberOfBits; rank > 0; --rank) {
            valuesPerRank[rank] = startingValue;
            startingValue = (short)(startingValue + entriesPerRank[rank]);
            startingValue = (short)(startingValue >>> 1);
        }
        for (int n = 0; n <= maxSymbol; ++n) {
            byte by = this.numberOfBits[n];
            valuesPerRank[by] = (short)(valuesPerRank[by] + 1);
        }
        this.maxSymbol = maxSymbol;
        this.maxNumberOfBits = maxNumberOfBits;
    }

    private int buildTree(int[] counts, int maxSymbol, NodeTable nodeTable) {
        short parent;
        int n;
        int nonLeafStart;
        int current = 0;
        int symbol = 0;
        while (symbol <= maxSymbol) {
            int count = counts[symbol];
            for (int position = current; position > 1 && count > nodeTable.count[position - 1]; --position) {
                nodeTable.copyNode(position - 1, position);
            }
            nodeTable.count[position] = count;
            nodeTable.symbols[position] = symbol++;
            current = (short)(current + 1);
        }
        int lastNonZero = maxSymbol;
        while (nodeTable.count[lastNonZero] == 0) {
            --lastNonZero;
        }
        current = nonLeafStart = 256;
        int currentLeaf = lastNonZero;
        int currentNonLeaf = current;
        nodeTable.count[current] = nodeTable.count[currentLeaf] + nodeTable.count[currentLeaf - 1];
        nodeTable.parents[currentLeaf] = current;
        nodeTable.parents[currentLeaf - 1] = current;
        current = (short)(current + 1);
        currentLeaf -= 2;
        int root = 256 + lastNonZero - 1;
        for (n = current; n <= root; ++n) {
            nodeTable.count[n] = 0x40000000;
        }
        while (current <= root) {
            int child1 = currentLeaf >= 0 && nodeTable.count[currentLeaf] < nodeTable.count[currentNonLeaf] ? currentLeaf-- : currentNonLeaf++;
            int child2 = currentLeaf >= 0 && nodeTable.count[currentLeaf] < nodeTable.count[currentNonLeaf] ? currentLeaf-- : currentNonLeaf++;
            nodeTable.count[current] = nodeTable.count[child1] + nodeTable.count[child2];
            nodeTable.parents[child1] = current;
            nodeTable.parents[child2] = current;
            current = (short)(current + 1);
        }
        nodeTable.numberOfBits[root] = 0;
        for (n = root - 1; n >= nonLeafStart; --n) {
            parent = nodeTable.parents[n];
            nodeTable.numberOfBits[n] = (byte)(nodeTable.numberOfBits[parent] + 1);
        }
        for (n = 0; n <= lastNonZero; ++n) {
            parent = nodeTable.parents[n];
            nodeTable.numberOfBits[n] = (byte)(nodeTable.numberOfBits[parent] + 1);
        }
        return lastNonZero;
    }

    public void encodeSymbol(BitOutputStream output, int symbol) {
        output.addBitsFast(this.values[symbol], this.numberOfBits[symbol]);
    }

    public int write(Object outputBase, long outputAddress, int outputSize, HuffmanTableWriterWorkspace workspace) {
        byte[] weights = workspace.weights;
        long output = outputAddress;
        int maxNumberOfBits = this.maxNumberOfBits;
        int maxSymbol = this.maxSymbol;
        for (int symbol = 0; symbol < maxSymbol; ++symbol) {
            byte bits = this.numberOfBits[symbol];
            weights[symbol] = bits == 0 ? (byte)0 : (byte)(maxNumberOfBits + 1 - bits);
        }
        int size = HuffmanCompressionTable.compressWeights(outputBase, output + 1L, outputSize - 1, weights, maxSymbol, workspace);
        if (maxSymbol > 127 && size > 127) {
            throw new AssertionError();
        }
        if (size != 0 && size != 1 && size < maxSymbol / 2) {
            UnsafeUtil.UNSAFE.putByte(outputBase, output, (byte)size);
            return size + 1;
        }
        int entryCount = maxSymbol;
        size = (entryCount + 1) / 2;
        Util.checkArgument(size + 1 <= outputSize, "Output size too small");
        UnsafeUtil.UNSAFE.putByte(outputBase, output, (byte)(127 + entryCount));
        ++output;
        weights[maxSymbol] = 0;
        for (int i = 0; i < entryCount; i += 2) {
            UnsafeUtil.UNSAFE.putByte(outputBase, output, (byte)((weights[i] << 4) + weights[i + 1]));
            ++output;
        }
        return (int)(output - outputAddress);
    }

    public boolean isValid(int[] counts, int maxSymbol) {
        if (maxSymbol > this.maxSymbol) {
            return false;
        }
        for (int symbol = 0; symbol <= maxSymbol; ++symbol) {
            if (counts[symbol] == 0 || this.numberOfBits[symbol] != 0) continue;
            return false;
        }
        return true;
    }

    public int estimateCompressedSize(int[] counts, int maxSymbol) {
        int numberOfBits = 0;
        for (int symbol = 0; symbol <= Math.min(maxSymbol, this.maxSymbol); ++symbol) {
            numberOfBits += this.numberOfBits[symbol] * counts[symbol];
        }
        return numberOfBits >>> 3;
    }

    private static int setMaxHeight(NodeTable nodeTable, int lastNonZero, int maxNumberOfBits, HuffmanCompressionTableWorkspace workspace) {
        byte largestBits = nodeTable.numberOfBits[lastNonZero];
        if (largestBits <= maxNumberOfBits) {
            return largestBits;
        }
        int totalCost = 0;
        int baseCost = 1 << largestBits - maxNumberOfBits;
        int n = lastNonZero;
        while (nodeTable.numberOfBits[n] > maxNumberOfBits) {
            totalCost += baseCost - (1 << largestBits - nodeTable.numberOfBits[n]);
            nodeTable.numberOfBits[n] = (byte)maxNumberOfBits;
            --n;
        }
        while (nodeTable.numberOfBits[n] == maxNumberOfBits) {
            --n;
        }
        totalCost >>>= largestBits - maxNumberOfBits;
        int noSymbol = -252645136;
        int[] rankLast = workspace.rankLast;
        Arrays.fill(rankLast, noSymbol);
        int currentNbBits = maxNumberOfBits;
        for (int pos = n; pos >= 0; --pos) {
            if (nodeTable.numberOfBits[pos] >= currentNbBits) continue;
            currentNbBits = nodeTable.numberOfBits[pos];
            rankLast[maxNumberOfBits - currentNbBits] = pos;
        }
        while (totalCost > 0) {
            int numberOfBitsToDecrease;
            for (numberOfBitsToDecrease = Util.highestBit(totalCost) + 1; numberOfBitsToDecrease > 1; --numberOfBitsToDecrease) {
                int lowTotal;
                int highTotal;
                int highPosition = rankLast[numberOfBitsToDecrease];
                int lowPosition = rankLast[numberOfBitsToDecrease - 1];
                if (highPosition != noSymbol && (lowPosition == noSymbol || (highTotal = nodeTable.count[highPosition]) <= (lowTotal = 2 * nodeTable.count[lowPosition]))) break;
            }
            while (numberOfBitsToDecrease <= 12 && rankLast[numberOfBitsToDecrease] == noSymbol) {
                ++numberOfBitsToDecrease;
            }
            totalCost -= 1 << numberOfBitsToDecrease - 1;
            if (rankLast[numberOfBitsToDecrease - 1] == noSymbol) {
                rankLast[numberOfBitsToDecrease - 1] = rankLast[numberOfBitsToDecrease];
            }
            int n2 = rankLast[numberOfBitsToDecrease];
            nodeTable.numberOfBits[n2] = (byte)(nodeTable.numberOfBits[n2] + 1);
            if (rankLast[numberOfBitsToDecrease] == 0) {
                rankLast[numberOfBitsToDecrease] = noSymbol;
                continue;
            }
            int n3 = numberOfBitsToDecrease;
            rankLast[n3] = rankLast[n3] - 1;
            if (nodeTable.numberOfBits[rankLast[numberOfBitsToDecrease]] == maxNumberOfBits - numberOfBitsToDecrease) continue;
            rankLast[numberOfBitsToDecrease] = noSymbol;
        }
        while (totalCost < 0) {
            if (rankLast[1] == noSymbol) {
                while (nodeTable.numberOfBits[n] == maxNumberOfBits) {
                    --n;
                }
                int n4 = n + 1;
                nodeTable.numberOfBits[n4] = (byte)(nodeTable.numberOfBits[n4] - 1);
                rankLast[1] = n + 1;
                ++totalCost;
                continue;
            }
            int n5 = rankLast[1] + 1;
            nodeTable.numberOfBits[n5] = (byte)(nodeTable.numberOfBits[n5] - 1);
            rankLast[1] = rankLast[1] + 1;
            ++totalCost;
        }
        return maxNumberOfBits;
    }

    private static int compressWeights(Object outputBase, long outputAddress, int outputSize, byte[] weights, int weightsLength, HuffmanTableWriterWorkspace workspace) {
        if (weightsLength <= 1) {
            return 0;
        }
        int[] counts = workspace.counts;
        Histogram.count(weights, weightsLength, counts);
        int maxSymbol = Histogram.findMaxSymbol(counts, 12);
        int maxCount = Histogram.findLargestCount(counts, maxSymbol);
        if (maxCount == weightsLength) {
            return 1;
        }
        if (maxCount == 1) {
            return 0;
        }
        short[] normalizedCounts = workspace.normalizedCounts;
        int tableLog = FiniteStateEntropy.optimalTableLog(6, weightsLength, maxSymbol);
        FiniteStateEntropy.normalizeCounts(normalizedCounts, tableLog, counts, weightsLength, maxSymbol);
        long output = outputAddress;
        long outputLimit = outputAddress + (long)outputSize;
        int headerSize = FiniteStateEntropy.writeNormalizedCounts(outputBase, output, outputSize, normalizedCounts, maxSymbol, tableLog);
        FseCompressionTable compressionTable = workspace.fseTable;
        compressionTable.initialize(normalizedCounts, maxSymbol, tableLog);
        int compressedSize = FiniteStateEntropy.compress(outputBase, output += (long)headerSize, (int)(outputLimit - output), weights, weightsLength, compressionTable);
        if (compressedSize == 0) {
            return 0;
        }
        return (int)((output += (long)compressedSize) - outputAddress);
    }
}

