/*
 * Decompiled with CFR 0.152.
 */
package deepnetts.util;

import deepnetts.util.RandomGenerator;
import java.io.Serializable;
import java.util.Arrays;
import java.util.function.Function;

public class Tensor
implements Serializable {
    private final int cols;
    private final int rows;
    private final int depth;
    private final int fourthDim;
    private final int dimensions;
    private final int[] shape = new int[4];
    private int rank;
    private int size;
    private float[] values;

    public Tensor(float ... values) {
        this.rows = 1;
        this.cols = values.length;
        this.depth = 1;
        this.fourthDim = 1;
        this.dimensions = 1;
        this.values = values;
    }

    public Tensor(float[][] vals) {
        this.rows = vals.length;
        this.cols = vals[0].length;
        this.depth = 1;
        this.fourthDim = 1;
        this.dimensions = 2;
        this.values = new float[this.rows * this.cols];
        for (int row = 0; row < this.rows; ++row) {
            for (int col = 0; col < this.cols; ++col) {
                this.set(row, col, vals[row][col]);
            }
        }
    }

    public Tensor(float[][][] vals) {
        this.depth = vals.length;
        this.rows = vals[0].length;
        this.cols = vals[0][0].length;
        this.fourthDim = 1;
        this.dimensions = 3;
        this.values = new float[this.rows * this.cols * this.depth];
        for (int z = 0; z < this.depth; ++z) {
            for (int row = 0; row < this.rows; ++row) {
                for (int col = 0; col < this.cols; ++col) {
                    this.set(row, col, z, vals[z][row][col]);
                }
            }
        }
    }

    public Tensor(float[][][][] vals) {
        this.fourthDim = vals.length;
        this.depth = vals[0].length;
        this.rows = vals[0][0].length;
        this.cols = vals[0][0][0].length;
        this.dimensions = 4;
        this.values = new float[this.rows * this.cols * this.depth * this.fourthDim];
        for (int f = 0; f < this.fourthDim; ++f) {
            for (int z = 0; z < this.depth; ++z) {
                for (int row = 0; row < this.rows; ++row) {
                    for (int col = 0; col < this.cols; ++col) {
                        this.set(row, col, z, f, vals[f][z][row][col]);
                    }
                }
            }
        }
    }

    public Tensor(int cols) {
        if (cols < 0) {
            throw new IllegalArgumentException("Number of cols cannot be negative: " + cols);
        }
        this.cols = cols;
        this.rows = 1;
        this.depth = 1;
        this.fourthDim = 1;
        this.dimensions = 1;
        this.values = new float[cols];
    }

    public Tensor(int cols, float val) {
        if (cols < 0) {
            throw new IllegalArgumentException("Number of cols cannot be negative: " + cols);
        }
        this.cols = cols;
        this.rows = 1;
        this.depth = 1;
        this.fourthDim = 1;
        this.dimensions = 1;
        this.values = new float[cols];
        for (int i = 0; i < this.values.length; ++i) {
            this.values[i] = val;
        }
    }

    public Tensor(int rows, int cols) {
        if (rows < 0) {
            throw new IllegalArgumentException("Number of rows cannot be negative: " + rows);
        }
        if (cols < 0) {
            throw new IllegalArgumentException("Number of cols cannot be negative: " + cols);
        }
        this.rows = rows;
        this.cols = cols;
        this.depth = 1;
        this.fourthDim = 1;
        this.dimensions = 2;
        this.values = new float[rows * cols];
    }

    public Tensor(int rows, int cols, float[] values) {
        if (rows < 0) {
            throw new IllegalArgumentException("Number of rows cannot be negative: " + rows);
        }
        if (cols < 0) {
            throw new IllegalArgumentException("Number of cols cannot be negative: " + cols);
        }
        if (rows * cols != values.length) {
            throw new IllegalArgumentException("Number of values does not match tensor dimensions! " + values.length);
        }
        this.rows = rows;
        this.cols = cols;
        this.depth = 1;
        this.fourthDim = 1;
        this.dimensions = 2;
        this.values = values;
    }

    public Tensor(int rows, int cols, int depth) {
        if (rows < 0) {
            throw new IllegalArgumentException("Number of rows cannot be negative: " + rows);
        }
        if (cols < 0) {
            throw new IllegalArgumentException("Number of cols cannot be negative: " + cols);
        }
        if (depth < 0) {
            throw new IllegalArgumentException("Depth cannot be negative: " + depth);
        }
        this.rows = rows;
        this.cols = cols;
        this.depth = depth;
        this.fourthDim = 1;
        this.dimensions = 3;
        this.values = new float[rows * cols * depth];
    }

    public Tensor(int rows, int cols, int depth, int fourthDim) {
        if (rows < 0) {
            throw new IllegalArgumentException("Number of rows cannot be negative: " + rows);
        }
        if (cols < 0) {
            throw new IllegalArgumentException("Number of cols cannot be negative: " + cols);
        }
        if (depth < 0) {
            throw new IllegalArgumentException("Depth cannot be negative: " + depth);
        }
        if (fourthDim < 0) {
            throw new IllegalArgumentException("fourthDim cannot be negative: " + fourthDim);
        }
        this.rows = rows;
        this.cols = cols;
        this.depth = depth;
        this.fourthDim = fourthDim;
        this.dimensions = 4;
        this.values = new float[rows * cols * depth * fourthDim];
    }

    public Tensor(int rows, int cols, int depth, int fourthDim, float[] values) {
        if (rows < 0) {
            throw new IllegalArgumentException("Number of rows cannot be negative: " + rows);
        }
        if (cols < 0) {
            throw new IllegalArgumentException("Number of cols cannot be negative: " + cols);
        }
        if (depth < 0) {
            throw new IllegalArgumentException("Depth cannot be negative: " + depth);
        }
        if (fourthDim < 0) {
            throw new IllegalArgumentException("fourthDim cannot be negative: " + fourthDim);
        }
        this.rows = rows;
        this.cols = cols;
        this.depth = depth;
        this.fourthDim = fourthDim;
        this.dimensions = 4;
        this.values = values;
    }

    public Tensor(int rows, int cols, int depth, float[] values) {
        if (rows < 0) {
            throw new IllegalArgumentException("Number of rows cannot be negative: " + rows);
        }
        if (cols < 0) {
            throw new IllegalArgumentException("Number of cols cannot be negative: " + cols);
        }
        if (depth < 0) {
            throw new IllegalArgumentException("Depth cannot be negative: " + depth);
        }
        if (rows * cols * depth != values.length) {
            throw new IllegalArgumentException("Number of values does not match tensor dimensions! " + values.length);
        }
        this.cols = cols;
        this.rows = rows;
        this.depth = depth;
        this.fourthDim = 1;
        this.dimensions = 3;
        this.values = values;
    }

    private Tensor(Tensor t) {
        this.cols = t.cols;
        this.rows = t.rows;
        this.depth = t.depth;
        this.fourthDim = t.fourthDim;
        this.dimensions = t.dimensions;
        this.values = new float[t.values.length];
        System.arraycopy(t.values, 0, this.values, 0, t.values.length);
    }

    public final float get(int idx) {
        return this.values[idx];
    }

    public final float set(int idx, float val) {
        this.values[idx] = val;
        return this.values[idx];
    }

    public final float get(int row, int col) {
        int idx = row * this.cols + col;
        return this.values[idx];
    }

    public final void set(int row, int col, float val) {
        int idx = row * this.cols + col;
        this.values[idx] = val;
    }

    public final float get(int row, int col, int z) {
        int idx = z * this.cols * this.rows + row * this.cols + col;
        return this.values[idx];
    }

    public final void set(int row, int col, int z, float val) {
        int idx = z * this.cols * this.rows + row * this.cols + col;
        this.values[idx] = val;
    }

    public final float get(int row, int col, int z, int fourth) {
        int idx = fourth * this.rows * this.cols * this.depth + z * this.rows * this.cols + row * this.cols + col;
        return this.values[idx];
    }

    public final void set(int row, int col, int z, int fourth, float val) {
        int idx = fourth * this.rows * this.cols * this.depth + z * this.rows * this.cols + row * this.cols + col;
        this.values[idx] = val;
    }

    public final float getWithStride(int[] idxs) {
        int idx = idxs[0] * this.shape[1] * this.shape[2] * this.shape[3] + idxs[1] * this.shape[2] * this.shape[3] + idxs[2] * this.shape[3] + idxs[3];
        return this.values[idx];
    }

    public final float[] getValues() {
        return this.values;
    }

    public final void setValues(float ... values) {
        this.values = values;
    }

    public final void copyFrom(float[] src) {
        System.arraycopy(src, 0, this.values, 0, this.values.length);
    }

    public final int getCols() {
        return this.cols;
    }

    public final int getRows() {
        return this.rows;
    }

    public final int getDepth() {
        return this.depth;
    }

    public final int getFourthDim() {
        return this.fourthDim;
    }

    public final int getDimensions() {
        return this.dimensions;
    }

    public final int size() {
        return this.values.length;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("[");
        for (int i = 0; i < this.values.length; ++i) {
            sb.append(this.values[i]);
            if ((i + 1) % this.cols == 0 && i < this.values.length - 1) {
                sb.append("; ");
                continue;
            }
            if (i >= this.values.length - 1) continue;
            sb.append(", ");
        }
        sb.append("]");
        return sb.toString();
    }

    public final void add(int idx, float value) {
        int n = idx;
        this.values[n] = this.values[n] + value;
    }

    public final void add(int row, int col, float value) {
        int idx;
        int n = idx = row * this.cols + col;
        this.values[n] = this.values[n] + value;
    }

    public final void add(int row, int col, int z, float value) {
        int idx;
        int n = idx = z * this.cols * this.rows + row * this.cols + col;
        this.values[n] = this.values[n] + value;
    }

    public final void add(int row, int col, int z, int fourth, float value) {
        int idx;
        int n = idx = fourth * this.cols * this.rows * this.depth + z * this.cols * this.rows + row * this.cols + col;
        this.values[n] = this.values[n] + value;
    }

    public final void add(Tensor t) {
        for (int i = 0; i < this.values.length; ++i) {
            int n = i;
            this.values[n] = this.values[n] + t.values[i];
        }
    }

    public final void sub(int row, int col, float value) {
        int idx;
        int n = idx = row * this.cols + col;
        this.values[n] = this.values[n] - value;
    }

    public final void sub(int row, int col, int z, float value) {
        int idx;
        int n = idx = z * this.rows * this.cols + row * this.cols + col;
        this.values[n] = this.values[n] - value;
    }

    public final void sub(int row, int col, int z, int fourth, float value) {
        int idx;
        int n = idx = fourth * this.rows * this.cols * this.depth + z * this.rows * this.cols + row * this.cols + col;
        this.values[n] = this.values[n] - value;
    }

    public final void sub(Tensor t) {
        for (int i = 0; i < this.values.length; ++i) {
            int n = i;
            this.values[n] = this.values[n] - t.values[i];
        }
    }

    public final void sub(float val) {
        int i = 0;
        while (i < this.values.length) {
            int n = i++;
            this.values[n] = this.values[n] - val;
        }
    }

    public static final void sub(Tensor t1, Tensor t2) {
        for (int i = 0; i < t1.values.length; ++i) {
            int n = i;
            t1.values[n] = t1.values[n] - t2.values[i];
        }
    }

    public final void div(float value) {
        int i = 0;
        while (i < this.values.length) {
            int n = i++;
            this.values[n] = this.values[n] / value;
        }
    }

    public final void div(float[] divisors) {
        for (int i = 0; i < this.values.length; ++i) {
            int n = i;
            this.values[n] = this.values[n] / divisors[i];
        }
    }

    public final void fill(float value) {
        for (int i = 0; i < this.values.length; ++i) {
            this.values[i] = value;
        }
    }

    public static final void fill(float[] array, float val) {
        for (int i = 0; i < array.length; ++i) {
            array[i] = val;
        }
    }

    public final void div(Tensor t) {
        for (int i = 0; i < this.values.length; ++i) {
            this.values[i] = this.values[i] / t.values[i];
        }
    }

    public Tensor copy() {
        Tensor newTensor = new Tensor(this.rows, this.cols, this.depth, this.fourthDim);
        System.arraycopy(this.values, 0, newTensor.values, 0, this.values.length);
        return newTensor;
    }

    public static final void copy(Tensor src, Tensor dest) {
        System.arraycopy(src.values, 0, dest.values, 0, src.values.length);
    }

    public static final void copy(float[] src, float[] dest) {
        System.arraycopy(src, 0, dest, 0, src.length);
    }

    public void apply(Function<Float, Float> f) {
        for (int i = 0; i < this.values.length; ++i) {
            this.values[i] = f.apply(Float.valueOf(this.values[i])).floatValue();
        }
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null) {
            return false;
        }
        if (this.getClass() != obj.getClass()) {
            return false;
        }
        Tensor other = (Tensor)obj;
        if (this.cols != other.cols) {
            return false;
        }
        if (this.rows != other.rows) {
            return false;
        }
        if (this.depth != other.depth) {
            return false;
        }
        if (this.fourthDim != other.fourthDim) {
            return false;
        }
        if (this.dimensions != other.dimensions) {
            return false;
        }
        return Arrays.equals(this.values, other.values);
    }

    public int hashCode() {
        int hash = 3;
        hash = 41 * hash + this.cols;
        hash = 41 * hash + this.rows;
        hash = 41 * hash + this.depth;
        hash = 41 * hash + this.fourthDim;
        hash = 41 * hash + this.dimensions;
        hash = 41 * hash + Arrays.hashCode(this.values);
        return hash;
    }

    public boolean equals(Tensor t2, float delta) {
        float[] arr2 = t2.getValues();
        for (int i = 0; i < this.values.length; ++i) {
            if (!(Math.abs(this.values[i] - arr2[i]) > delta)) continue;
            return false;
        }
        return true;
    }

    public static String valuesAsString(Tensor[] tensors) {
        StringBuilder sb = new StringBuilder();
        for (Tensor t : tensors) {
            sb.append(t.toString());
        }
        return sb.toString();
    }

    public void setValuesFromString(String values) {
        String[] strArr = values.split(",");
        for (int i = 0; i < strArr.length; ++i) {
            this.values[i] = Float.parseFloat(strArr[i]);
        }
    }

    public static Tensor create(int rows, int cols, float[] values) {
        return new Tensor(rows, cols, values);
    }

    public static Tensor create(int rows, int cols, int depth, float[] values) {
        return new Tensor(rows, cols, depth, values);
    }

    public static Tensor create(int rows, int cols, int depth, int fourthDim, float[] values) {
        return new Tensor(rows, cols, depth, fourthDim, values);
    }

    public float sumAbs() {
        float sum = 0.0f;
        for (int i = 0; i < this.values.length; ++i) {
            sum += Math.abs(this.values[i]);
        }
        return sum;
    }

    public float sumSqr() {
        float sum = 0.0f;
        for (int i = 0; i < this.values.length; ++i) {
            sum += this.values[i] * this.values[i];
        }
        return sum;
    }

    public void randomize() {
        for (int r = 0; r < this.rows; ++r) {
            for (int c = 0; c < this.cols; ++c) {
                this.values[r * this.cols + c] = RandomGenerator.getDefault().nextFloat();
            }
        }
    }

    public void multiplyElementWise(Tensor tensor2) {
        for (int i = 0; i < this.values.length; ++i) {
            int n = i;
            this.values[n] = this.values[n] * tensor2.values[i];
        }
    }

    public void multiply(float m) {
        int i = 0;
        while (i < this.values.length) {
            int n = i++;
            this.values[n] = this.values[n] * m;
        }
    }

    public void sqrt() {
        for (int i = 0; i < this.values.length; ++i) {
            this.values[i] = (float)Math.sqrt(this.values[i]);
        }
    }
}

