/*
 * Decompiled with CFR 0.152.
 */
package org.javimmutable.collections.array;

import java.util.ArrayList;
import java.util.function.IntFunction;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.javimmutable.collections.Func1;
import org.javimmutable.collections.Holder;
import org.javimmutable.collections.Holders;
import org.javimmutable.collections.IndexedProc1;
import org.javimmutable.collections.IndexedProc1Throws;
import org.javimmutable.collections.IntFunc2;
import org.javimmutable.collections.JImmutableMap;
import org.javimmutable.collections.MapEntry;
import org.javimmutable.collections.Proc1;
import org.javimmutable.collections.Proc1Throws;
import org.javimmutable.collections.array.ArrayAssignMapper;
import org.javimmutable.collections.array.ArrayContainsMapper;
import org.javimmutable.collections.array.ArrayDeleteMapper;
import org.javimmutable.collections.array.ArrayGetMapper;
import org.javimmutable.collections.array.ArrayIterationMapper;
import org.javimmutable.collections.array.ArraySizeMapper;
import org.javimmutable.collections.array.ArrayUpdateMapper;
import org.javimmutable.collections.common.ArrayHelper;
import org.javimmutable.collections.common.BitmaskMath;
import org.javimmutable.collections.common.IntArrayMappedTrieMath;
import org.javimmutable.collections.indexed.IndexedList;
import org.javimmutable.collections.iterators.GenericIterator;

public class TrieArrayNode<T> {
    static final int LEAF_SHIFT_COUNT = 0;
    static final int ROOT_SHIFT_COUNT = IntArrayMappedTrieMath.maxShiftsForBitCount(30);
    private static final int SIGN_BIT = Integer.MIN_VALUE;
    private static final Object[] EMPTY_VALUES = new Object[0];
    private static final TrieArrayNode[] EMPTY_NODES = new TrieArrayNode[0];
    private static final TrieArrayNode EMPTY = new TrieArrayNode<Object>(ROOT_SHIFT_COUNT, 0, 0L, EMPTY_VALUES, 0L, EMPTY_NODES, 0);
    private final int shiftCount;
    private final int baseIndex;
    private final long valuesBitmask;
    private final T[] values;
    private final long nodesBitmask;
    private final TrieArrayNode<T>[] nodes;
    private final int size;

    TrieArrayNode(int shiftCount, int baseIndex, long valuesBitmask, T[] values, long nodesBitmask, @Nonnull TrieArrayNode<T>[] nodes, int size) {
        assert (BitmaskMath.bitCount(valuesBitmask) == values.length);
        assert (BitmaskMath.bitCount(nodesBitmask) == nodes.length);
        this.shiftCount = shiftCount;
        this.baseIndex = baseIndex;
        this.valuesBitmask = valuesBitmask;
        this.values = values;
        this.nodesBitmask = nodesBitmask;
        this.nodes = nodes;
        this.size = size;
        assert (TrieArrayNode.checkChildShifts(shiftCount, nodes));
    }

    @Nonnull
    public static <T> TrieArrayNode<T> empty() {
        return EMPTY;
    }

    @Nonnull
    private static <T> TrieArrayNode<T> forValue(int shiftCount, int index, T value) {
        assert (shiftCount == TrieArrayNode.findShiftForIndex(index));
        int baseIndex = IntArrayMappedTrieMath.baseIndexAtShift(shiftCount, index);
        long valueBitmask = BitmaskMath.bitFromIndex(IntArrayMappedTrieMath.indexAtShift(shiftCount, index));
        T[] values = ArrayHelper.newArray(value);
        long nodeBitmask = 0L;
        TrieArrayNode<T>[] nodes = TrieArrayNode.emptyNodes();
        return new TrieArrayNode<T>(shiftCount, baseIndex, valueBitmask, values, 0L, nodes, 1);
    }

    @Nonnull
    private static <T> TrieArrayNode<T> forNode(int shiftCount, int nodeBaseIndex, @Nonnull TrieArrayNode<T> node) {
        int baseIndex = IntArrayMappedTrieMath.baseIndexAtShift(shiftCount, nodeBaseIndex);
        long valueBitmask = 0L;
        T[] values = TrieArrayNode.emptyValues();
        long nodeBitmask = BitmaskMath.bitFromIndex(IntArrayMappedTrieMath.indexAtShift(shiftCount, nodeBaseIndex));
        TrieArrayNode<T>[] nodes = TrieArrayNode.allocateNodes(1);
        nodes[0] = node;
        return new TrieArrayNode<T>(shiftCount, baseIndex, 0L, values, nodeBitmask, nodes, node.size());
    }

    public boolean isEmpty() {
        return this.size == 0;
    }

    public T getValueOr(int index, T defaultValue) {
        index = TrieArrayNode.flip(index);
        int shiftCountForValue = TrieArrayNode.findShiftForIndex(index);
        return this.getValueOrImpl(shiftCountForValue, index, defaultValue);
    }

    @Nonnull
    public Holder<T> find(int index) {
        index = TrieArrayNode.flip(index);
        int shiftCountForValue = TrieArrayNode.findShiftForIndex(index);
        return this.findImpl(shiftCountForValue, index);
    }

    @Nonnull
    public TrieArrayNode<T> assign(int index, T value) {
        index = TrieArrayNode.flip(index);
        int shiftCountForValue = TrieArrayNode.findShiftForIndex(index);
        return this.assignImpl(ROOT_SHIFT_COUNT, shiftCountForValue, index, value);
    }

    @Nonnull
    public TrieArrayNode<T> delete(int index) {
        index = TrieArrayNode.flip(index);
        int shiftCountForValue = TrieArrayNode.findShiftForIndex(index);
        return this.deleteImpl(shiftCountForValue, index);
    }

    public <K> boolean mappedContains(@Nonnull ArrayContainsMapper<K, T> mapper, @Nonnull K key) {
        int index = TrieArrayNode.flip(key.hashCode());
        int shiftCountForValue = TrieArrayNode.findShiftForIndex(index);
        Object node = this.getValueOrImpl(shiftCountForValue, index, null);
        return node != null && mapper.mappedContains(node, key);
    }

    public <K, V> V mappedGetValueOr(@Nonnull ArrayGetMapper<K, V, T> mapper, @Nonnull K key, V defaultValue) {
        int index = TrieArrayNode.flip(key.hashCode());
        int shiftCountForValue = TrieArrayNode.findShiftForIndex(index);
        Object node = this.getValueOrImpl(shiftCountForValue, index, null);
        return node != null ? mapper.mappedGetValueOr(node, key, defaultValue) : defaultValue;
    }

    @Nonnull
    public <K, V> Holder<V> mappedFind(@Nonnull ArrayGetMapper<K, V, T> mapper, @Nonnull K key) {
        int index = TrieArrayNode.flip(key.hashCode());
        int shiftCountForValue = TrieArrayNode.findShiftForIndex(index);
        Object node = this.getValueOrImpl(shiftCountForValue, index, null);
        return node != null ? mapper.mappedFind(node, key) : Holders.of();
    }

    @Nonnull
    public <K, V> TrieArrayNode<T> mappedAssign(@Nonnull ArrayAssignMapper<K, V, T> mapper, @Nonnull K key, V value) {
        int index = TrieArrayNode.flip(key.hashCode());
        int shiftCountForValue = TrieArrayNode.findShiftForIndex(index);
        return this.mappedAssignImpl(ROOT_SHIFT_COUNT, shiftCountForValue, index, mapper, key, value);
    }

    @Nonnull
    public <K, V> TrieArrayNode<T> mappedUpdate(@Nonnull ArrayUpdateMapper<K, V, T> mapper, @Nonnull K key, @Nonnull Func1<Holder<V>, V> generator) {
        int index = TrieArrayNode.flip(key.hashCode());
        int shiftCountForValue = TrieArrayNode.findShiftForIndex(index);
        return this.mappedUpdateImpl(ROOT_SHIFT_COUNT, shiftCountForValue, index, mapper, key, generator);
    }

    @Nonnull
    public <K> TrieArrayNode<T> mappedDelete(@Nonnull ArrayDeleteMapper<K, T> mapper, @Nonnull K key) {
        int index = TrieArrayNode.flip(key.hashCode());
        int shiftCountForValue = TrieArrayNode.findShiftForIndex(index);
        return this.mappedDeleteImpl(shiftCountForValue, index, mapper, key);
    }

    public void forEach(@Nonnull Proc1<T> proc) {
        long combinedBitmask = BitmaskMath.addBit(this.valuesBitmask, this.nodesBitmask);
        while (combinedBitmask != 0L) {
            long bit = BitmaskMath.leastBit(combinedBitmask);
            if (BitmaskMath.bitIsPresent(this.valuesBitmask, bit)) {
                int arrayIndex = BitmaskMath.arrayIndexForBit(this.valuesBitmask, bit);
                proc.apply(this.values[arrayIndex]);
            }
            if (BitmaskMath.bitIsPresent(this.nodesBitmask, bit)) {
                int nodeIndex = BitmaskMath.arrayIndexForBit(this.nodesBitmask, bit);
                this.nodes[nodeIndex].forEach(proc);
            }
            combinedBitmask = BitmaskMath.removeBit(combinedBitmask, bit);
        }
    }

    public <E extends Exception> void forEachThrows(@Nonnull Proc1Throws<T, E> proc) throws E {
        long combinedBitmask = BitmaskMath.addBit(this.valuesBitmask, this.nodesBitmask);
        while (combinedBitmask != 0L) {
            long bit = BitmaskMath.leastBit(combinedBitmask);
            if (BitmaskMath.bitIsPresent(this.valuesBitmask, bit)) {
                int arrayIndex = BitmaskMath.arrayIndexForBit(this.valuesBitmask, bit);
                proc.apply(this.values[arrayIndex]);
            }
            if (BitmaskMath.bitIsPresent(this.nodesBitmask, bit)) {
                int nodeIndex = BitmaskMath.arrayIndexForBit(this.nodesBitmask, bit);
                this.nodes[nodeIndex].forEachThrows(proc);
            }
            combinedBitmask = BitmaskMath.removeBit(combinedBitmask, bit);
        }
    }

    public void forEach(@Nonnull IndexedProc1<T> proc) {
        long combinedBitmask = BitmaskMath.addBit(this.valuesBitmask, this.nodesBitmask);
        while (combinedBitmask != 0L) {
            long bit = BitmaskMath.leastBit(combinedBitmask);
            if (BitmaskMath.bitIsPresent(this.valuesBitmask, bit)) {
                int valueIndex = BitmaskMath.indexForBit(bit);
                int arrayIndex = BitmaskMath.arrayIndexForBit(this.valuesBitmask, bit);
                int entryIndex = this.baseIndex + IntArrayMappedTrieMath.shift(this.shiftCount, valueIndex);
                proc.apply(TrieArrayNode.flip(entryIndex), this.values[arrayIndex]);
            }
            if (BitmaskMath.bitIsPresent(this.nodesBitmask, bit)) {
                int nodeIndex = BitmaskMath.arrayIndexForBit(this.nodesBitmask, bit);
                this.nodes[nodeIndex].forEach(proc);
            }
            combinedBitmask = BitmaskMath.removeBit(combinedBitmask, bit);
        }
    }

    public <E extends Exception> void forEachThrows(@Nonnull IndexedProc1Throws<T, E> proc) throws E {
        long combinedBitmask = BitmaskMath.addBit(this.valuesBitmask, this.nodesBitmask);
        while (combinedBitmask != 0L) {
            long bit = BitmaskMath.leastBit(combinedBitmask);
            if (BitmaskMath.bitIsPresent(this.valuesBitmask, bit)) {
                int valueIndex = BitmaskMath.indexForBit(bit);
                int arrayIndex = BitmaskMath.arrayIndexForBit(this.valuesBitmask, bit);
                int entryIndex = this.baseIndex + IntArrayMappedTrieMath.shift(this.shiftCount, valueIndex);
                proc.apply(TrieArrayNode.flip(entryIndex), this.values[arrayIndex]);
            }
            if (BitmaskMath.bitIsPresent(this.nodesBitmask, bit)) {
                int nodeIndex = BitmaskMath.arrayIndexForBit(this.nodesBitmask, bit);
                this.nodes[nodeIndex].forEachThrows(proc);
            }
            combinedBitmask = BitmaskMath.removeBit(combinedBitmask, bit);
        }
    }

    private T getValueOrImpl(int shiftCountForValue, int index, T defaultValue) {
        int shiftCount = this.shiftCount;
        if (shiftCountForValue > shiftCount) {
            return defaultValue;
        }
        if (IntArrayMappedTrieMath.baseIndexAtShift(shiftCount, index) != this.baseIndex) {
            return defaultValue;
        }
        int myIndex = IntArrayMappedTrieMath.indexAtShift(shiftCount, index);
        long bit = BitmaskMath.bitFromIndex(myIndex);
        if (shiftCountForValue == shiftCount) {
            long bitmask = this.valuesBitmask;
            if (BitmaskMath.bitIsPresent(bitmask, bit)) {
                int arrayIndex = BitmaskMath.arrayIndexForBit(bitmask, bit);
                return this.values[arrayIndex];
            }
        } else {
            long bitmask = this.nodesBitmask;
            if (BitmaskMath.bitIsPresent(bitmask, bit)) {
                int arrayIndex = BitmaskMath.arrayIndexForBit(bitmask, bit);
                return super.getValueOrImpl(shiftCountForValue, index, defaultValue);
            }
        }
        return defaultValue;
    }

    @Nonnull
    private Holder<T> findImpl(int shiftCountForValue, int index) {
        int shiftCount = this.shiftCount;
        if (shiftCountForValue > shiftCount) {
            return Holders.of();
        }
        if (IntArrayMappedTrieMath.baseIndexAtShift(shiftCount, index) != this.baseIndex) {
            return Holders.of();
        }
        int myIndex = IntArrayMappedTrieMath.indexAtShift(shiftCount, index);
        long bit = BitmaskMath.bitFromIndex(myIndex);
        if (shiftCountForValue == shiftCount) {
            long bitmask = this.valuesBitmask;
            if (BitmaskMath.bitIsPresent(bitmask, bit)) {
                int arrayIndex = BitmaskMath.arrayIndexForBit(bitmask, bit);
                return Holders.of(this.values[arrayIndex]);
            }
        } else {
            long bitmask = this.nodesBitmask;
            if (BitmaskMath.bitIsPresent(bitmask, bit)) {
                int arrayIndex = BitmaskMath.arrayIndexForBit(bitmask, bit);
                return super.findImpl(shiftCountForValue, index);
            }
        }
        return Holders.of();
    }

    @Nonnull
    private TrieArrayNode<T> assignImpl(int shiftCount, int shiftCountForValue, int index, T value) {
        int thisShiftCount = this.shiftCount;
        int baseIndex = this.baseIndex;
        assert (IntArrayMappedTrieMath.baseIndexAtShift(shiftCount, index) == IntArrayMappedTrieMath.baseIndexAtShift(shiftCount, baseIndex));
        assert (shiftCount >= thisShiftCount);
        assert (shiftCount >= shiftCountForValue);
        if (shiftCount != thisShiftCount) {
            int ancestorShiftCount = TrieArrayNode.findCommonAncestorShift(baseIndex + IntArrayMappedTrieMath.shift(thisShiftCount, 1), index);
            assert (ancestorShiftCount <= shiftCount);
            if (ancestorShiftCount > thisShiftCount) {
                TrieArrayNode<T> ancestor = TrieArrayNode.forNode(ancestorShiftCount, baseIndex, this);
                return super.assignImpl(ancestorShiftCount, shiftCountForValue, index, value);
            }
            shiftCount = thisShiftCount;
        }
        assert (IntArrayMappedTrieMath.baseIndexAtShift(shiftCount, index) == baseIndex);
        int myIndex = IntArrayMappedTrieMath.indexAtShift(shiftCount, index);
        long bit = BitmaskMath.bitFromIndex(myIndex);
        long valuesBitmask = this.valuesBitmask;
        long nodesBitmask = this.nodesBitmask;
        if (shiftCount == shiftCountForValue) {
            T[] values = this.values;
            long newBitmask = BitmaskMath.addBit(valuesBitmask, bit);
            int arrayIndex = BitmaskMath.arrayIndexForBit(valuesBitmask, bit);
            if (BitmaskMath.bitIsPresent(valuesBitmask, bit)) {
                T[] newValues = ArrayHelper.assign(values, arrayIndex, value);
                return new TrieArrayNode<T>(shiftCount, baseIndex, newBitmask, newValues, nodesBitmask, this.nodes, this.size);
            }
            T[] newValues = ArrayHelper.insert(TrieArrayNode::allocateValues, values, arrayIndex, value);
            return new TrieArrayNode(shiftCount, baseIndex, newBitmask, newValues, nodesBitmask, this.nodes, this.size + 1);
        }
        int arrayIndex = BitmaskMath.arrayIndexForBit(nodesBitmask, bit);
        if (BitmaskMath.bitIsPresent(nodesBitmask, bit)) {
            TrieArrayNode<T> node = this.nodes[arrayIndex];
            TrieArrayNode<T> newNode = super.assignImpl(shiftCount - 1, shiftCountForValue, index, value);
            TrieArrayNode<T>[] newNodes = ArrayHelper.assign(this.nodes, arrayIndex, newNode);
            int newSize = this.size - node.size() + newNode.size();
            return new TrieArrayNode<T>(shiftCount, baseIndex, valuesBitmask, this.values, nodesBitmask, newNodes, newSize);
        }
        long newBitmask = BitmaskMath.addBit(nodesBitmask, bit);
        TrieArrayNode<T> newNode = TrieArrayNode.forValue(shiftCountForValue, index, value);
        if (valuesBitmask == 0L && nodesBitmask == 0L) {
            return newNode;
        }
        TrieArrayNode<T>[] newNodes = ArrayHelper.insert(TrieArrayNode::allocateNodes, this.nodes, arrayIndex, newNode);
        return new TrieArrayNode<T>(shiftCount, baseIndex, valuesBitmask, this.values, newBitmask, newNodes, this.size + 1);
    }

    @Nonnull
    private <K, V> TrieArrayNode<T> mappedAssignImpl(int shiftCount, int shiftCountForValue, int index, @Nonnull ArrayAssignMapper<K, V, T> mapper, @Nonnull K key, V value) {
        int thisShiftCount = this.shiftCount;
        int baseIndex = this.baseIndex;
        assert (IntArrayMappedTrieMath.baseIndexAtShift(shiftCount, index) == IntArrayMappedTrieMath.baseIndexAtShift(shiftCount, baseIndex));
        assert (shiftCount >= thisShiftCount);
        assert (shiftCount >= shiftCountForValue);
        if (shiftCount != thisShiftCount) {
            int ancestorShiftCount = TrieArrayNode.findCommonAncestorShift(baseIndex + IntArrayMappedTrieMath.shift(thisShiftCount, 1), index);
            assert (ancestorShiftCount <= shiftCount);
            if (ancestorShiftCount > thisShiftCount) {
                TrieArrayNode<T> ancestor = TrieArrayNode.forNode(ancestorShiftCount, baseIndex, this);
                return super.mappedAssignImpl(ancestorShiftCount, shiftCountForValue, index, mapper, key, value);
            }
            shiftCount = thisShiftCount;
        }
        assert (IntArrayMappedTrieMath.baseIndexAtShift(shiftCount, index) == baseIndex);
        int myIndex = IntArrayMappedTrieMath.indexAtShift(shiftCount, index);
        long bit = BitmaskMath.bitFromIndex(myIndex);
        long valuesBitmask = this.valuesBitmask;
        long nodesBitmask = this.nodesBitmask;
        T[] values = this.values;
        TrieArrayNode<T>[] nodes = this.nodes;
        if (shiftCount == shiftCountForValue) {
            long newBitmask = BitmaskMath.addBit(valuesBitmask, bit);
            int arrayIndex = BitmaskMath.arrayIndexForBit(valuesBitmask, bit);
            if (BitmaskMath.bitIsPresent(valuesBitmask, bit)) {
                T oldValue = values[arrayIndex];
                T newValue = mapper.mappedAssign(oldValue, key, value);
                if (newValue == oldValue) {
                    return this;
                }
                int newSize = this.size - mapper.mappedSize(oldValue) + mapper.mappedSize(newValue);
                assert (newSize == this.size || newSize == this.size + 1);
                T[] newValues = ArrayHelper.assign(values, arrayIndex, newValue);
                assert (newSize == TrieArrayNode.computeSize(mapper, nodes, newValues));
                return new TrieArrayNode<T>(shiftCount, baseIndex, newBitmask, newValues, nodesBitmask, nodes, newSize);
            }
            T newValue = mapper.mappedAssign(key, value);
            assert (mapper.mappedSize(newValue) == 1);
            T[] newValues = ArrayHelper.insert(TrieArrayNode::allocateValues, values, arrayIndex, newValue);
            assert (this.size + 1 == TrieArrayNode.computeSize(mapper, nodes, newValues));
            return new TrieArrayNode(shiftCount, baseIndex, newBitmask, newValues, nodesBitmask, nodes, this.size + 1);
        }
        int arrayIndex = BitmaskMath.arrayIndexForBit(nodesBitmask, bit);
        if (BitmaskMath.bitIsPresent(nodesBitmask, bit)) {
            TrieArrayNode<T> node = nodes[arrayIndex];
            TrieArrayNode<T> newNode = super.mappedAssignImpl(shiftCount - 1, shiftCountForValue, index, mapper, key, value);
            if (newNode == node) {
                return this;
            }
            TrieArrayNode<T>[] newNodes = ArrayHelper.assign(nodes, arrayIndex, newNode);
            int newSize = this.size - node.size() + newNode.size();
            assert (newSize == TrieArrayNode.computeSize(mapper, newNodes, values));
            return new TrieArrayNode<T>(shiftCount, baseIndex, valuesBitmask, values, nodesBitmask, newNodes, newSize);
        }
        long newBitmask = BitmaskMath.addBit(nodesBitmask, bit);
        TrieArrayNode<T> newNode = TrieArrayNode.forValue(shiftCountForValue, index, mapper.mappedAssign(key, value));
        if (valuesBitmask == 0L && nodesBitmask == 0L) {
            return newNode;
        }
        TrieArrayNode<T>[] newNodes = ArrayHelper.insert(TrieArrayNode::allocateNodes, nodes, arrayIndex, newNode);
        assert (this.size + 1 == TrieArrayNode.computeSize(mapper, newNodes, values));
        return new TrieArrayNode<T>(shiftCount, baseIndex, valuesBitmask, values, newBitmask, newNodes, this.size + 1);
    }

    @Nonnull
    private <K, V> TrieArrayNode<T> mappedUpdateImpl(int shiftCount, int shiftCountForValue, int index, @Nonnull ArrayUpdateMapper<K, V, T> mapper, @Nonnull K key, @Nonnull Func1<Holder<V>, V> generator) {
        int thisShiftCount = this.shiftCount;
        int baseIndex = this.baseIndex;
        assert (IntArrayMappedTrieMath.baseIndexAtShift(shiftCount, index) == IntArrayMappedTrieMath.baseIndexAtShift(shiftCount, baseIndex));
        assert (shiftCount >= thisShiftCount);
        assert (shiftCount >= shiftCountForValue);
        if (shiftCount != thisShiftCount) {
            int ancestorShiftCount = TrieArrayNode.findCommonAncestorShift(baseIndex + IntArrayMappedTrieMath.shift(thisShiftCount, 1), index);
            assert (ancestorShiftCount <= shiftCount);
            if (ancestorShiftCount > thisShiftCount) {
                TrieArrayNode<T> ancestor = TrieArrayNode.forNode(ancestorShiftCount, baseIndex, this);
                return super.mappedUpdateImpl(ancestorShiftCount, shiftCountForValue, index, mapper, key, generator);
            }
            shiftCount = thisShiftCount;
        }
        assert (IntArrayMappedTrieMath.baseIndexAtShift(shiftCount, index) == baseIndex);
        int myIndex = IntArrayMappedTrieMath.indexAtShift(shiftCount, index);
        long bit = BitmaskMath.bitFromIndex(myIndex);
        long valuesBitmask = this.valuesBitmask;
        long nodesBitmask = this.nodesBitmask;
        T[] values = this.values;
        TrieArrayNode<T>[] nodes = this.nodes;
        if (shiftCount == shiftCountForValue) {
            long newBitmask = BitmaskMath.addBit(valuesBitmask, bit);
            int arrayIndex = BitmaskMath.arrayIndexForBit(valuesBitmask, bit);
            if (BitmaskMath.bitIsPresent(valuesBitmask, bit)) {
                T oldValue = values[arrayIndex];
                T newValue = mapper.mappedUpdate(oldValue, key, generator);
                if (newValue == oldValue) {
                    return this;
                }
                int newSize = this.size - mapper.mappedSize(oldValue) + mapper.mappedSize(newValue);
                assert (newSize == this.size || newSize == this.size + 1);
                T[] newValues = ArrayHelper.assign(values, arrayIndex, newValue);
                assert (newSize == TrieArrayNode.computeSize(mapper, nodes, newValues));
                return new TrieArrayNode<T>(shiftCount, baseIndex, newBitmask, newValues, nodesBitmask, nodes, newSize);
            }
            Object newValue = mapper.mappedAssign(key, generator.apply(Holders.of()));
            assert (mapper.mappedSize(newValue) == 1);
            T[] newValues = ArrayHelper.insert(TrieArrayNode::allocateValues, values, arrayIndex, newValue);
            assert (this.size + 1 == TrieArrayNode.computeSize(mapper, nodes, newValues));
            return new TrieArrayNode(shiftCount, baseIndex, newBitmask, newValues, nodesBitmask, nodes, this.size + 1);
        }
        int arrayIndex = BitmaskMath.arrayIndexForBit(nodesBitmask, bit);
        if (BitmaskMath.bitIsPresent(nodesBitmask, bit)) {
            TrieArrayNode<T> node = nodes[arrayIndex];
            TrieArrayNode<T> newNode = super.mappedUpdateImpl(shiftCount - 1, shiftCountForValue, index, mapper, key, generator);
            if (newNode == node) {
                return this;
            }
            TrieArrayNode<T>[] newNodes = ArrayHelper.assign(nodes, arrayIndex, newNode);
            int newSize = this.size - node.size() + newNode.size();
            assert (newSize == TrieArrayNode.computeSize(mapper, newNodes, values));
            return new TrieArrayNode<T>(shiftCount, baseIndex, valuesBitmask, values, nodesBitmask, newNodes, newSize);
        }
        long newBitmask = BitmaskMath.addBit(nodesBitmask, bit);
        V value = generator.apply(Holders.of());
        TrieArrayNode newNode = TrieArrayNode.forValue(shiftCountForValue, index, mapper.mappedAssign(key, value));
        if (valuesBitmask == 0L && nodesBitmask == 0L) {
            return newNode;
        }
        TrieArrayNode<T>[] newNodes = ArrayHelper.insert(TrieArrayNode::allocateNodes, nodes, arrayIndex, newNode);
        assert (this.size + 1 == TrieArrayNode.computeSize(mapper, newNodes, values));
        return new TrieArrayNode<T>(shiftCount, baseIndex, valuesBitmask, values, newBitmask, newNodes, this.size + 1);
    }

    @Nonnull
    private TrieArrayNode<T> deleteImpl(int shiftCountForValue, int index) {
        int arrayIndex;
        TrieArrayNode<T> node;
        TrieArrayNode<T> newNode;
        int shiftCount = this.shiftCount;
        if (shiftCountForValue > shiftCount) {
            return this;
        }
        if (IntArrayMappedTrieMath.baseIndexAtShift(shiftCount, index) != this.baseIndex) {
            return this;
        }
        int myIndex = IntArrayMappedTrieMath.indexAtShift(shiftCount, index);
        long bit = BitmaskMath.bitFromIndex(myIndex);
        long valuesBitmask = this.valuesBitmask;
        long nodesBitmask = this.nodesBitmask;
        T[] values = this.values;
        TrieArrayNode<T>[] nodes = this.nodes;
        if (shiftCountForValue == shiftCount) {
            if (BitmaskMath.bitIsPresent(valuesBitmask, bit)) {
                if (this.size == 1) {
                    return TrieArrayNode.empty();
                }
                long newBitmask = BitmaskMath.removeBit(valuesBitmask, bit);
                int arrayIndex2 = BitmaskMath.arrayIndexForBit(valuesBitmask, bit);
                T[] newValues = ArrayHelper.delete(TrieArrayNode::allocateValues, values, arrayIndex2);
                return new TrieArrayNode(shiftCount, this.baseIndex, newBitmask, newValues, nodesBitmask, nodes, this.size - 1);
            }
        } else if (BitmaskMath.bitIsPresent(nodesBitmask, bit) && (newNode = super.deleteImpl(shiftCountForValue, index)) != node) {
            int newSize = this.size - node.size() + newNode.size();
            if (newSize == 0) {
                return TrieArrayNode.empty();
            }
            if (newNode.isEmpty()) {
                long newBitmask = BitmaskMath.removeBit(nodesBitmask, bit);
                if (valuesBitmask == 0L && BitmaskMath.bitCount(newBitmask) == 1) {
                    return nodes[BitmaskMath.arrayIndexForBit(nodesBitmask, newBitmask)];
                }
                TrieArrayNode<T>[] newNodes = ArrayHelper.delete(TrieArrayNode::allocateNodes, nodes, arrayIndex);
                return new TrieArrayNode<T>(shiftCount, this.baseIndex, valuesBitmask, values, newBitmask, newNodes, newSize);
            }
            TrieArrayNode<T>[] newNodes = ArrayHelper.assign(nodes, arrayIndex, newNode);
            return new TrieArrayNode<T>(shiftCount, this.baseIndex, valuesBitmask, values, nodesBitmask, newNodes, newSize);
        }
        return this;
    }

    @Nonnull
    private <K> TrieArrayNode<T> mappedDeleteImpl(int shiftCountForValue, int index, @Nonnull ArrayDeleteMapper<K, T> mapper, @Nonnull K key) {
        int shiftCount = this.shiftCount;
        if (shiftCountForValue > shiftCount) {
            return this;
        }
        if (IntArrayMappedTrieMath.baseIndexAtShift(shiftCount, index) != this.baseIndex) {
            return this;
        }
        int myIndex = IntArrayMappedTrieMath.indexAtShift(shiftCount, index);
        long bit = BitmaskMath.bitFromIndex(myIndex);
        long valuesBitmask = this.valuesBitmask;
        T[] values = this.values;
        if (shiftCountForValue == shiftCount) {
            int arrayIndex;
            T mapping;
            T newMapping;
            if (BitmaskMath.bitIsPresent(valuesBitmask, bit) && (newMapping = mapper.mappedDelete(mapping = values[arrayIndex = BitmaskMath.arrayIndexForBit(valuesBitmask, bit)], key)) != mapping) {
                T[] newValues;
                long newBitmask;
                if (newMapping == null) {
                    if (this.size == 1) {
                        return TrieArrayNode.empty();
                    }
                    newBitmask = BitmaskMath.removeBit(valuesBitmask, bit);
                    newValues = ArrayHelper.delete(TrieArrayNode::allocateValues, values, arrayIndex);
                } else {
                    newBitmask = valuesBitmask;
                    newValues = ArrayHelper.assign(values, arrayIndex, newMapping);
                }
                assert (this.size - 1 == TrieArrayNode.computeSize(mapper, this.nodes, newValues));
                return new TrieArrayNode<T>(shiftCount, this.baseIndex, newBitmask, newValues, this.nodesBitmask, this.nodes, this.size - 1);
            }
        } else {
            int arrayIndex;
            TrieArrayNode<T>[] nodes;
            TrieArrayNode<T> node;
            TrieArrayNode<T> newNode;
            long bitmask = this.nodesBitmask;
            if (BitmaskMath.bitIsPresent(bitmask, bit) && (newNode = super.mappedDeleteImpl(shiftCountForValue, index, mapper, key)) != node) {
                int newSize = this.size - node.size() + newNode.size();
                if (newSize == 0) {
                    return TrieArrayNode.empty();
                }
                if (newNode.isEmpty()) {
                    long newBitmask = BitmaskMath.removeBit(bitmask, bit);
                    if (valuesBitmask == 0L && BitmaskMath.bitCount(newBitmask) == 1) {
                        return nodes[BitmaskMath.arrayIndexForBit(bitmask, newBitmask)];
                    }
                    TrieArrayNode<T>[] newNodes = ArrayHelper.delete(TrieArrayNode::allocateNodes, nodes, arrayIndex);
                    assert (newSize == TrieArrayNode.computeSize(mapper, newNodes, values));
                    return new TrieArrayNode<T>(shiftCount, this.baseIndex, valuesBitmask, values, newBitmask, newNodes, newSize);
                }
                TrieArrayNode<T>[] newNodes = ArrayHelper.assign(nodes, arrayIndex, newNode);
                assert (newSize == TrieArrayNode.computeSize(mapper, newNodes, values));
                return new TrieArrayNode<T>(shiftCount, this.baseIndex, valuesBitmask, values, bitmask, newNodes, newSize);
            }
        }
        return this;
    }

    public void checkInvariants(@Nullable ArraySizeMapper<T> mapper) {
        int computedSize;
        if (BitmaskMath.bitCount(this.valuesBitmask) != this.values.length) {
            throw new IllegalStateException(String.format("invalid bitmask for values array: bitmask=%s length=%d", Long.toBinaryString(this.valuesBitmask), this.values.length));
        }
        if (BitmaskMath.bitCount(this.nodesBitmask) != this.nodes.length) {
            throw new IllegalStateException(String.format("invalid bitmask for nodes array: bitmask=%s length=%d", Long.toBinaryString(this.nodesBitmask), this.nodes.length));
        }
        if (!TrieArrayNode.checkChildShifts(this.shiftCount, this.nodes)) {
            throw new IllegalStateException("one or more nodes invalid for this branch");
        }
        int n = computedSize = mapper != null ? TrieArrayNode.computeSize(mapper, this.nodes, this.values) : TrieArrayNode.computeSize(this.nodes) + this.values.length;
        if (computedSize != this.size) {
            throw new IllegalStateException(String.format("size mismatch: size=%d computed=%d", this.size, computedSize));
        }
    }

    @Nonnull
    public GenericIterator.Iterable<Integer> keys() {
        return this.iterable((valueIndex, arrayIndex) -> this.computeUserIndexForValue(valueIndex), nodeIndex -> this.nodes[nodeIndex].keys());
    }

    @Nonnull
    public GenericIterator.Iterable<T> values() {
        return this.iterable((valueIndex, arrayIndex) -> this.values[arrayIndex], nodeIndex -> this.nodes[nodeIndex].values());
    }

    @Nonnull
    public GenericIterator.Iterable<JImmutableMap.Entry<Integer, T>> entries() {
        return this.iterable((valueIndex, arrayIndex) -> MapEntry.entry(this.computeUserIndexForValue(valueIndex), this.values[arrayIndex]), nodeIndex -> this.nodes[nodeIndex].entries());
    }

    @Nonnull
    public <K> GenericIterator.Iterable<K> mappedKeys(@Nonnull ArrayIterationMapper<K, ?, T> mapper) {
        return this.mappedIterable(mapper::mappedKeys);
    }

    @Nonnull
    public <V> GenericIterator.Iterable<V> mappedValues(@Nonnull ArrayIterationMapper<?, V, T> mapper) {
        return this.mappedIterable(mapper::mappedValues);
    }

    @Nonnull
    public <K, V> GenericIterator.Iterable<JImmutableMap.Entry<K, V>> mappedEntries(@Nonnull ArrayIterationMapper<K, V, T> mapper) {
        return this.mappedIterable(mapper::mappedEntries);
    }

    public int size() {
        return this.size;
    }

    private int computeUserIndexForValue(Integer valueIndex) {
        return TrieArrayNode.flip(this.baseIndex + IntArrayMappedTrieMath.shift(this.shiftCount, valueIndex));
    }

    private <V> GenericIterator.Iterable<V> iterable(final @Nonnull IntFunc2<V> valueFunction, final @Nonnull IntFunction<GenericIterator.Iterable<V>> nodeFunction) {
        return new GenericIterator.Iterable<V>(){

            @Override
            @Nullable
            public GenericIterator.State<V> iterateOverRange(@Nullable GenericIterator.State<V> parent, int offset, int limit) {
                ArrayList iterables = new ArrayList(TrieArrayNode.this.values.length + TrieArrayNode.this.nodes.length);
                long combinedBitmask = BitmaskMath.addBit(TrieArrayNode.this.valuesBitmask, TrieArrayNode.this.nodesBitmask);
                while (combinedBitmask != 0L) {
                    long bit = BitmaskMath.leastBit(combinedBitmask);
                    if (BitmaskMath.bitIsPresent(TrieArrayNode.this.valuesBitmask, bit)) {
                        int valueIndex = BitmaskMath.indexForBit(bit);
                        int arrayIndex = BitmaskMath.arrayIndexForBit(TrieArrayNode.this.valuesBitmask, bit);
                        int entryIndex = TrieArrayNode.this.baseIndex + IntArrayMappedTrieMath.shift(TrieArrayNode.this.shiftCount, valueIndex);
                        iterables.add(GenericIterator.singleValueIterable(valueFunction.apply(valueIndex, arrayIndex)));
                    }
                    if (BitmaskMath.bitIsPresent(TrieArrayNode.this.nodesBitmask, bit)) {
                        int nodeIndex = BitmaskMath.arrayIndexForBit(TrieArrayNode.this.nodesBitmask, bit);
                        iterables.add(nodeFunction.apply(nodeIndex));
                    }
                    combinedBitmask = BitmaskMath.removeBit(combinedBitmask, bit);
                }
                assert (iterables.size() == TrieArrayNode.this.values.length + TrieArrayNode.this.nodes.length);
                return GenericIterator.multiIterableState(parent, IndexedList.retained(iterables), offset, limit);
            }

            @Override
            public int iterableSize() {
                return TrieArrayNode.this.size;
            }
        };
    }

    @Nonnull
    private <V> GenericIterator.Iterable<V> mappedIterable(final @Nonnull Func1<T, GenericIterator.Iterable<V>> valueFunction) {
        return new GenericIterator.Iterable<V>(){

            @Override
            @Nullable
            public GenericIterator.State<V> iterateOverRange(@Nullable GenericIterator.State<V> parent, int offset, int limit) {
                ArrayList<Object> iterables = new ArrayList<Object>(TrieArrayNode.this.values.length + TrieArrayNode.this.nodes.length);
                long combinedBitmask = BitmaskMath.addBit(TrieArrayNode.this.valuesBitmask, TrieArrayNode.this.nodesBitmask);
                while (combinedBitmask != 0L) {
                    long bit = BitmaskMath.leastBit(combinedBitmask);
                    if (BitmaskMath.bitIsPresent(TrieArrayNode.this.valuesBitmask, bit)) {
                        int arrayIndex = BitmaskMath.arrayIndexForBit(TrieArrayNode.this.valuesBitmask, bit);
                        iterables.add(valueFunction.apply(TrieArrayNode.this.values[arrayIndex]));
                    }
                    if (BitmaskMath.bitIsPresent(TrieArrayNode.this.nodesBitmask, bit)) {
                        int nodeIndex = BitmaskMath.arrayIndexForBit(TrieArrayNode.this.nodesBitmask, bit);
                        iterables.add(TrieArrayNode.this.nodes[nodeIndex].mappedIterable(valueFunction));
                    }
                    combinedBitmask = BitmaskMath.removeBit(combinedBitmask, bit);
                }
                assert (iterables.size() == TrieArrayNode.this.values.length + TrieArrayNode.this.nodes.length);
                return GenericIterator.multiIterableState(parent, IndexedList.retained(iterables), offset, limit);
            }

            @Override
            public int iterableSize() {
                return TrieArrayNode.this.size;
            }
        };
    }

    @Nonnull
    static <T> T[] allocateValues(int size) {
        return size == 0 ? TrieArrayNode.emptyValues() : new Object[size];
    }

    @Nonnull
    static <T> TrieArrayNode<T>[] allocateNodes(int size) {
        return size == 0 ? TrieArrayNode.emptyNodes() : new TrieArrayNode[size];
    }

    @Nonnull
    private static <T> T[] emptyValues() {
        return EMPTY_VALUES;
    }

    @Nonnull
    private static <T> TrieArrayNode<T>[] emptyNodes() {
        return EMPTY_NODES;
    }

    static int findShiftForIndex(int index) {
        return IntArrayMappedTrieMath.findMinimumShiftForZeroBelowHashCode(index);
    }

    static int findCommonAncestorShift(int index1, int index2) {
        int shift1 = TrieArrayNode.findShiftForIndex(index1);
        int shift2 = TrieArrayNode.findShiftForIndex(index2);
        int shiftCount = Math.max(shift1, shift2);
        while (IntArrayMappedTrieMath.baseIndexAtShift(shiftCount, index1) != IntArrayMappedTrieMath.baseIndexAtShift(shiftCount, index2)) {
            ++shiftCount;
        }
        assert (shiftCount <= ROOT_SHIFT_COUNT);
        return shiftCount;
    }

    static int flip(int index) {
        return index ^ Integer.MIN_VALUE;
    }

    private static <T> boolean checkChildShifts(int shiftCount, @Nonnull TrieArrayNode<T>[] nodes) {
        for (TrieArrayNode<T> node : nodes) {
            if (shiftCount > node.shiftCount && !node.isEmpty()) continue;
            return false;
        }
        return true;
    }

    private static <K, T> int computeSize(@Nonnull ArraySizeMapper<T> mapper, @Nonnull TrieArrayNode<T>[] children, @Nonnull T[] values) {
        int total = 0;
        for (TrieArrayNode<T> child : children) {
            total += child.size();
        }
        for (TrieArrayNode<T> value : values) {
            total += mapper.mappedSize(value);
        }
        return total;
    }

    private static <T> int computeSize(@Nonnull TrieArrayNode<T>[] children) {
        int total = 0;
        for (TrieArrayNode<T> child : children) {
            total += child.size();
        }
        return total;
    }
}

