/*
 * Decompiled with CFR 0.152.
 */
package hivemall.factorization.fm;

import hivemall.utils.buffer.HeapBuffer;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Preconditions;
import hivemall.utils.math.MathUtils;
import java.util.Arrays;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;

class Entry {
    @Nonnull
    protected final HeapBuffer _buf;
    @Nonnegative
    protected final int _size;
    @Nonnegative
    protected final int _factors;
    protected int _key;
    @Nonnegative
    protected long _offset;

    Entry(@Nonnull HeapBuffer buf, int factors) {
        this._buf = buf;
        this._size = Entry.sizeOf(factors);
        this._factors = factors;
    }

    Entry(@Nonnull HeapBuffer buf, int key, @Nonnegative long offset) {
        this(buf, 1, key, offset);
    }

    Entry(@Nonnull HeapBuffer buf, int factors, int key, @Nonnegative long offset) {
        this(buf, factors, Entry.sizeOf(factors), key, offset);
    }

    private Entry(@Nonnull HeapBuffer buf, int factors, int size, int key, @Nonnegative long offset) {
        this._buf = buf;
        this._size = size;
        this._factors = factors;
        this._key = key;
        this._offset = offset;
    }

    final int getSize() {
        return this._size;
    }

    final int getKey() {
        return this._key;
    }

    final long getOffset() {
        return this._offset;
    }

    final void setOffset(long offset) {
        this._offset = offset;
    }

    final float getW() {
        return this._buf.getFloat(this._offset);
    }

    final void setW(float value) {
        this._buf.putFloat(this._offset, value);
    }

    final void getV(@Nonnull float[] Vf) {
        long offset = this._offset;
        int len = Vf.length;
        for (int f = 0; f < len; ++f) {
            long index = offset + (long)(4 * f);
            Vf[f] = this._buf.getFloat(index);
        }
    }

    final void setV(@Nonnull float[] Vf) {
        long offset = this._offset;
        int len = Vf.length;
        for (int f = 0; f < len; ++f) {
            long index = offset + (long)(4 * f);
            this._buf.putFloat(index, Vf[f]);
        }
    }

    final float getV(int f) {
        long index = this._offset + (long)(4 * f);
        return this._buf.getFloat(index);
    }

    final void setV(int f, float value) {
        long index = this._offset + (long)(4 * f);
        this._buf.putFloat(index, value);
    }

    double getSumOfSquaredGradients(@Nonnegative int f) {
        throw new UnsupportedOperationException();
    }

    void addGradient(@Nonnegative int f, float grad) {
        throw new UnsupportedOperationException();
    }

    final float updateZ(float gradW, float alpha) {
        float w = this.getW();
        return this.updateZ(0, w, gradW, alpha);
    }

    float updateZ(@Nonnegative int f, float W, float gradW, float alpha) {
        throw new UnsupportedOperationException();
    }

    final double updateN(float gradW) {
        return this.updateN(0, gradW);
    }

    double updateN(@Nonnegative int f, float gradW) {
        throw new UnsupportedOperationException();
    }

    boolean removable() {
        if (!Entry.isEntryW(this._key)) {
            long offset = this._offset;
            for (int f = 0; f < this._factors; ++f) {
                float Vf = this._buf.getFloat(offset + (long)(4 * f));
                if (MathUtils.closeToZero(Vf, 1.0E-9f)) continue;
                return false;
            }
        }
        return true;
    }

    void clear() {
    }

    static int sizeOf(@Nonnegative int factors) {
        Preconditions.checkArgument(factors >= 1, "Factors must be greater than 0: " + factors);
        return 4 * factors;
    }

    static boolean isEntryW(int i) {
        return i < 0;
    }

    public String toString() {
        if (Entry.isEntryW(this._key)) {
            return "W=" + this.getW();
        }
        float[] Vf = new float[this._factors];
        this.getV(Vf);
        return "V=" + Arrays.toString(Vf);
    }

    static final class FTRLEntry
    extends Entry {
        final long _z_offset;

        FTRLEntry(@Nonnull HeapBuffer buf, int key, long offset) {
            this(buf, 1, key, offset);
        }

        FTRLEntry(@Nonnull HeapBuffer buf, @Nonnegative int factors, int key, long offset) {
            super(buf, factors, FTRLEntry.sizeOf(factors), key, offset);
            this._z_offset = this._offset + (long)Entry.sizeOf(factors);
        }

        @Override
        float updateZ(int f, float W, float gradW, float alpha) {
            Preconditions.checkArgument(f >= 0);
            long zOffset = this.offsetZ(f);
            float z = this._buf.getFloat(zOffset);
            double n = this._buf.getFloat(this.offsetN(f));
            double gg = gradW * gradW;
            float sigma = (float)((Math.sqrt(n + gg) - Math.sqrt(n)) / (double)alpha);
            float newZ = z + gradW - sigma * W;
            if (!NumberUtils.isFinite(newZ)) {
                throw new IllegalStateException("Got newZ " + newZ + " where z=" + z + ", gradW=" + gradW + ", sigma=" + sigma + ", W=" + W + ", n=" + n + ", gg=" + gg + ", alpha=" + alpha);
            }
            this._buf.putFloat(zOffset, newZ);
            return newZ;
        }

        @Override
        double updateN(int f, float gradW) {
            Preconditions.checkArgument(f >= 0);
            long nOffset = this.offsetN(f);
            double n = this._buf.getFloat(nOffset);
            double newN = n + (double)(gradW * gradW);
            if (!NumberUtils.isFinite(newN)) {
                throw new IllegalStateException("Got newN " + newN + " where n=" + n + ", gradW=" + gradW);
            }
            this._buf.putFloat(nOffset, NumberUtils.castToFloat(newN));
            return newN;
        }

        private long offsetZ(@Nonnegative int f) {
            return this._z_offset + (long)(4 * f);
        }

        private long offsetN(@Nonnegative int f) {
            return this._z_offset + (long)(4 * (this._factors + f));
        }

        @Override
        void clear() {
            for (int f = 0; f < this._factors; ++f) {
                this._buf.putFloat(this.offsetZ(f), 0.0f);
                this._buf.putFloat(this.offsetN(f), 0.0f);
            }
        }

        static int sizeOf(@Nonnegative int factors) {
            return Entry.sizeOf(factors) + 8 * factors;
        }

        @Override
        public String toString() {
            float[] Z = new float[this._factors];
            float[] N = new float[this._factors];
            for (int f = 0; f < this._factors; ++f) {
                Z[f] = this._buf.getFloat(this.offsetZ(f));
                N[f] = this._buf.getFloat(this.offsetN(f));
            }
            return super.toString() + ", Z=" + Arrays.toString(Z) + ", N=" + Arrays.toString(N);
        }
    }

    static final class AdaGradEntry
    extends Entry {
        final long _gg_offset;

        AdaGradEntry(@Nonnull HeapBuffer buf, int key, @Nonnegative long offset) {
            this(buf, 1, key, offset);
        }

        AdaGradEntry(@Nonnull HeapBuffer buf, @Nonnegative int factors, int key, @Nonnegative long offset) {
            super(buf, factors, AdaGradEntry.sizeOf(factors), key, offset);
            this._gg_offset = this._offset + (long)Entry.sizeOf(factors);
        }

        @Override
        double getSumOfSquaredGradients(@Nonnegative int f) {
            Preconditions.checkArgument(f >= 0);
            long offset = this._gg_offset + (long)(8 * f);
            return this._buf.getDouble(offset);
        }

        @Override
        void addGradient(@Nonnegative int f, float grad) {
            Preconditions.checkArgument(f >= 0);
            long offset = this._gg_offset + (long)(8 * f);
            double v = this._buf.getDouble(offset);
            this._buf.putDouble(offset, v += (double)(grad * grad));
        }

        @Override
        void clear() {
            for (int f = 0; f < this._factors; ++f) {
                long offset = this._gg_offset + (long)(8 * f);
                this._buf.putDouble(offset, 0.0);
            }
        }

        static int sizeOf(@Nonnegative int factors) {
            return Entry.sizeOf(factors) + 8 * factors;
        }

        @Override
        public String toString() {
            double[] gg = new double[this._factors];
            for (int f = 0; f < this._factors; ++f) {
                gg[f] = this.getSumOfSquaredGradients(f);
            }
            return super.toString() + ", gg=" + Arrays.toString(gg);
        }
    }
}

