/*
 * Decompiled with CFR 0.152.
 */
package ai.sklearn4j.core.libraries.numpy;

import ai.sklearn4j.core.libraries.numpy.INumpyArrayElementOperation;
import ai.sklearn4j.core.libraries.numpy.INumpyReduceAxisFunction;
import ai.sklearn4j.core.libraries.numpy.NumpyArray;
import ai.sklearn4j.core.libraries.numpy.NumpyArrayFactory;
import ai.sklearn4j.core.libraries.numpy.NumpyArrayOperationWithAxisReduction;
import ai.sklearn4j.core.libraries.numpy.NumpyOperationException;

public final class Numpy {
    public static <Type> NumpyArray<Long> argmax(NumpyArray<Type> array, int axis) {
        NumpyArrayOperationWithAxisReduction operation = new NumpyArrayOperationWithAxisReduction<Type, Long>(){

            @Override
            public NumpyArray<Long> createInstanceResultNumpyArray(int[] shape) {
                return NumpyArrayFactory.arrayOfInt64WithShape(shape);
            }

            @Override
            public Object reduceAxisValues(Object[] valuesInAxis) {
                long result = 0L;
                double max = (Double)valuesInAxis[0];
                for (int i = 1; i < valuesInAxis.length; ++i) {
                    double m = (Double)valuesInAxis[i];
                    if (!(m > max)) continue;
                    max = m;
                    result = i;
                }
                return result;
            }
        };
        return operation.apply(array, axis);
    }

    public static NumpyArray<Double> pow(NumpyArray array, double power) {
        NumpyArray<Double> result = NumpyArrayFactory.arrayOfDoubleWithShape(array.getShape());
        array.applyToEachElementAnsSaveToTarget(result, value -> Math.pow((Double)value, power));
        return result;
    }

    public static NumpyArray sum(NumpyArray array, int axis) {
        INumpyReduceAxisFunction function = null;
        if (!array.isFloatingPoint() && array.numberOfBytes() == 1) {
            function = values -> {
                byte result = 0;
                for (Object value : values) {
                    result = (byte)(result + (Byte)value);
                }
                return result;
            };
        } else if (!array.isFloatingPoint() && array.numberOfBytes() == 2) {
            function = values -> {
                short result = 0;
                for (Object value : values) {
                    result = (short)(result + (Short)value);
                }
                return result;
            };
        } else if (!array.isFloatingPoint() && array.numberOfBytes() == 4) {
            function = values -> {
                int result = 0;
                for (Object value : values) {
                    result += ((Integer)value).intValue();
                }
                return result;
            };
        } else if (!array.isFloatingPoint() && array.numberOfBytes() == 8) {
            function = values -> {
                long result = 0L;
                for (Object value : values) {
                    result += ((Long)value).longValue();
                }
                return result;
            };
        } else if (array.isFloatingPoint() && array.numberOfBytes() == 4) {
            function = values -> {
                float result = 0.0f;
                for (Object value : values) {
                    result += ((Float)value).floatValue();
                }
                return Float.valueOf(result);
            };
        } else if (array.isFloatingPoint() && array.numberOfBytes() == 8) {
            function = values -> {
                double result = 0.0;
                for (Object value : values) {
                    result += ((Double)value).doubleValue();
                }
                return result;
            };
        }
        final INumpyReduceAxisFunction finalFunction = function;
        NumpyArrayOperationWithAxisReduction<Double, Double> operation = new NumpyArrayOperationWithAxisReduction<Double, Double>(){

            @Override
            public Object reduceAxisValues(Object[] valuesInAxis) {
                return finalFunction.reduceAxisValues(valuesInAxis);
            }
        };
        return operation.apply(array, axis);
    }

    public static NumpyArray<Double> exp(NumpyArray array) {
        NumpyArray result = NumpyArrayFactory.createArrayOfShapeAndTypeInfo(true, 8, array.getShape());
        array.applyToEachElementAnsSaveToTarget(result, value -> Math.exp((Double)value));
        return result;
    }

    public static NumpyArray<Double> log(NumpyArray array) {
        NumpyArray result = NumpyArrayFactory.createArrayOfShapeAndTypeInfo(true, 8, array.getShape());
        array.applyToEachElementAnsSaveToTarget(result, value -> Math.log((Double)value));
        return result;
    }

    private static boolean shouldSwapForAdd(NumpyArray a1, NumpyArray a2) {
        boolean result = false;
        int[] s1 = a1.getShape();
        int[] s2 = a2.getShape();
        if (Numpy.getEffectiveShapeWithRemovingEndingDimensions(s1) < Numpy.getEffectiveShapeWithRemovingEndingDimensions(s2)) {
            result = true;
        }
        return result;
    }

    private static int getEffectiveShapeWithRemovingEndingDimensions(int[] shape) {
        int lastOnes = 0;
        for (int i = 0; i < shape.length && shape[shape.length - i - 1] == 1; ++i) {
            ++lastOnes;
        }
        return shape.length - lastOnes;
    }

    private static void validateDimensionsForAdd(int[] shape1, int[] shape2) {
        int effective2;
        int effective1 = Numpy.getEffectiveShapeWithRemovingEndingDimensions(shape1);
        if (effective1 != (effective2 = Numpy.getEffectiveShapeWithRemovingEndingDimensions(shape2)) && Math.abs(effective2 - effective1) != 1) {
            throw new NumpyOperationException("The effective shape of the two numpy array has different number of dimensions.");
        }
        if (shape1.length > shape2.length && Numpy.isShapeEndingLike(shape1, shape2) || shape2.length > shape1.length && Numpy.isShapeEndingLike(shape2, shape1)) {
            return;
        }
        for (int i = 0; i < effective1; ++i) {
            if (shape1[i] == shape2[i] || shape1[i] == 1 || shape2[i] == 1) continue;
            throw new NumpyOperationException(String.format("Dimension %d of the two numpy arrays doesn't match.", i + 1));
        }
    }

    private static boolean isShapeEndingLike(int[] shape1, int[] shape2) {
        boolean result = true;
        for (int i = 0; i < shape2.length; ++i) {
            if (shape2[shape2.length - i - 1] == shape1[shape1.length - i - 1]) continue;
            result = false;
            break;
        }
        return result;
    }

    public static NumpyArray subtract(NumpyArray a1, NumpyArray a2) {
        INumpyArrayElementOperation<Object> negate = null;
        if (a2.isFloatingPoint()) {
            if (a2.numberOfBytes() == 8) {
                negate = value -> -((Double)value).doubleValue();
            } else if (a2.numberOfBytes() == 4) {
                negate = value -> Float.valueOf(-((Float)value).floatValue());
            }
        } else if (a2.numberOfBytes() == 1) {
            negate = value -> (int)(-((Byte)value).byteValue());
        } else if (a2.numberOfBytes() == 2) {
            negate = value -> (int)(-((Short)value).shortValue());
        } else if (a2.numberOfBytes() == 4) {
            negate = value -> -((Integer)value).intValue();
        } else if (a2.numberOfBytes() == 8) {
            negate = value -> -((Long)value).longValue();
        }
        NumpyArray negA2 = NumpyArrayFactory.createArrayOfShapeAndTypeInfo(a2);
        INumpyArrayElementOperation<Object> finalNegate = negate;
        a2.applyToEachElementAnsSaveToTarget(negA2, value -> finalNegate.apply(value));
        return Numpy.add(a1, negA2);
    }

    public static NumpyArray add(NumpyArray a1, NumpyArray a2) {
        Numpy.validateDimensionsForAdd(a1.getShape(), a2.getShape());
        if (Numpy.shouldSwapForAdd(a1, a2)) {
            return Numpy.add(a2, a1);
        }
        boolean isFloatingPoint = a1.isFloatingPoint() || a2.isFloatingPoint();
        int size = Math.max(a1.numberOfBytes(), a2.numberOfBytes());
        if (!a1.isFloatingPoint()) {
            size = a2.numberOfBytes();
        } else if (!a2.isFloatingPoint()) {
            size = a1.numberOfBytes();
        }
        NumpyArray result = NumpyArrayFactory.createArrayOfShapeAndTypeInfo(isFloatingPoint, size, a1.getShape());
        Numpy.addInPlace(result, a1, a2, (byte)1);
        return result;
    }

    public static NumpyArray<Double> add(NumpyArray array, double value) {
        NumpyArray result = NumpyArrayFactory.createArrayOfShapeAndTypeInfo(true, 8, array.getShape());
        Numpy.addInPlace(result, array, value);
        return result;
    }

    public static NumpyArray<Double> subtract(NumpyArray array, double value) {
        return Numpy.add(array, -value);
    }

    public static NumpyArray<Float> add(NumpyArray array, float value) {
        NumpyArray result = NumpyArrayFactory.createArrayOfShapeAndTypeInfo(true, 4, array.getShape());
        Numpy.addInPlace(result, array, value);
        return result;
    }

    public static NumpyArray<Float> subtract(NumpyArray array, float value) {
        return Numpy.add(array, -value);
    }

    public static NumpyArray add(NumpyArray array, byte value) {
        NumpyArray result = NumpyArrayFactory.createArrayOfShapeAndTypeInfo(false, 1, array.getShape());
        Numpy.addInPlace(result, array, value);
        return result;
    }

    public static NumpyArray subtract(NumpyArray array, byte value) {
        return Numpy.add(array, -value);
    }

    public static NumpyArray add(NumpyArray array, short value) {
        NumpyArray result = NumpyArrayFactory.createArrayOfShapeAndTypeInfo(false, 2, array.getShape());
        Numpy.addInPlace(result, array, value);
        return result;
    }

    public static NumpyArray subtract(NumpyArray array, short value) {
        return Numpy.add(array, -value);
    }

    public static NumpyArray add(NumpyArray array, int value) {
        NumpyArray result = NumpyArrayFactory.createArrayOfShapeAndTypeInfo(false, 4, array.getShape());
        Numpy.addInPlace(result, array, value);
        return result;
    }

    public static NumpyArray subtract(NumpyArray array, int value) {
        return Numpy.add(array, -value);
    }

    public static NumpyArray add(NumpyArray array, long value) {
        NumpyArray result = NumpyArrayFactory.createArrayOfShapeAndTypeInfo(false, 8, array.getShape());
        Numpy.addInPlace(result, array, value);
        return result;
    }

    public static NumpyArray subtract(NumpyArray array, long value) {
        return Numpy.add(array, -value);
    }

    private static void addInPlace(NumpyArray target, NumpyArray a1, NumpyArray a2, byte sign) {
        block37: {
            block34: {
                Object singleValue;
                block40: {
                    block39: {
                        block38: {
                            block35: {
                                block36: {
                                    if (!a2.isSingleValueArray()) break block34;
                                    singleValue = a2.getSingleValue();
                                    if (!target.isFloatingPoint()) break block35;
                                    if (target.numberOfBytes() != 8) break block36;
                                    Numpy.addInPlace(target, a1, (Double)singleValue * (double)sign);
                                    break block37;
                                }
                                if (target.numberOfBytes() != 4) break block37;
                                Numpy.addInPlace(target, a1, (Double)singleValue * (double)sign);
                                break block37;
                            }
                            if (target.numberOfBytes() != 1) break block38;
                            Numpy.addInPlace(target, a1, (Byte)singleValue * sign);
                            break block37;
                        }
                        if (target.numberOfBytes() != 2) break block39;
                        Numpy.addInPlace(target, a1, (Short)singleValue * sign);
                        break block37;
                    }
                    if (target.numberOfBytes() != 4) break block40;
                    Numpy.addInPlace(target, a1, (Integer)singleValue * sign);
                    break block37;
                }
                if (target.numberOfBytes() != 8) break block37;
                Numpy.addInPlace(target, a1, (Long)singleValue * (long)sign);
                break block37;
            }
            if (a1.numberOfDimensions() > 1 && a2.numberOfDimensions() == 1) {
                int[] leftNoneCommonShape = new int[a1.numberOfDimensions() - 1];
                int[] index = new int[a1.numberOfDimensions()];
                for (int i = 0; i < leftNoneCommonShape.length; ++i) {
                    leftNoneCommonShape[i] = a1.getShape()[i];
                }
                int[] counter = new int[leftNoneCommonShape.length + 1];
                int rightShape = a2.getShape()[0];
                do {
                    int i;
                    NumpyArray.addCounter(counter, leftNoneCommonShape);
                    for (i = 0; i < leftNoneCommonShape.length; ++i) {
                        index[i] = counter[i];
                    }
                    for (i = 0; i < rightShape; ++i) {
                        index[index.length - 1] = i;
                        Number value = null;
                        if (target.isFloatingPoint()) {
                            if (target.numberOfBytes() == 8) {
                                value = (Double)a1.get(index) + (double)sign * (Double)a2.get(i);
                            } else if (target.numberOfBytes() == 4) {
                                value = Float.valueOf(((Float)a1.get(index)).floatValue() + (float)sign * ((Float)a2.get(i)).floatValue());
                            }
                        } else if (target.numberOfBytes() == 1) {
                            value = (Byte)a1.get(index) + sign * (Byte)a2.get(i);
                        } else if (target.numberOfBytes() == 2) {
                            value = (Short)a1.get(index) + sign * (Short)a2.get(i);
                        } else if (target.numberOfBytes() == 4) {
                            value = (Integer)a1.get(index) + sign * (Integer)a2.get(i);
                        } else if (target.numberOfBytes() == 8) {
                            value = (Long)a1.get(index) + (long)sign * (Long)a2.get(i);
                        }
                        target.set(value, index);
                    }
                } while (counter[counter.length - 1] == 0);
            } else if (a1.numberOfDimensions() == 1 && a2.numberOfDimensions() == 1) {
                int firstDim = target.getShape()[0];
                int i = 0;
                while (i < firstDim) {
                    Number value = null;
                    if (target.isFloatingPoint()) {
                        if (target.numberOfBytes() == 8) {
                            value = (Double)a1.get(i) + (double)sign * (Double)a2.get(i);
                        } else if (target.numberOfBytes() == 4) {
                            value = Float.valueOf(((Float)a1.get(i)).floatValue() + (float)sign * ((Float)a2.get(i)).floatValue());
                        }
                    } else if (target.numberOfBytes() == 1) {
                        value = (Byte)a1.get(i) + sign * (Byte)a2.get(i);
                    } else if (target.numberOfBytes() == 2) {
                        value = (Short)a1.get(i) + sign * (Short)a2.get(i);
                    } else if (target.numberOfBytes() == 4) {
                        value = (Integer)a1.get(i) + sign * (Integer)a2.get(i);
                    } else if (target.numberOfBytes() == 8) {
                        value = (Long)a1.get(i) + (long)sign * (Long)a2.get(i);
                    }
                    target.set(value, i++);
                }
            } else {
                int firstDim = target.getShape()[0];
                int i = 0;
                while (i < firstDim) {
                    Numpy.addInPlace(target.wrapInnerSubsetArray(i), a1.wrapInnerSubsetArray(i), a2.wrapInnerSubsetArray(i++), sign);
                }
            }
        }
    }

    private static void addInPlace(NumpyArray target, NumpyArray array, double value) {
        array.applyToEachElementAnsSaveToTarget(target, element -> value + (Double)element);
    }

    private static void addInPlace(NumpyArray target, NumpyArray array, float value) {
        array.applyToEachElementAnsSaveToTarget(target, element -> Float.valueOf(value + ((Float)element).floatValue()));
    }

    private static void addInPlace(NumpyArray target, NumpyArray array, long value) {
        array.applyToEachElementAnsSaveToTarget(target, element -> value + (Long)element);
    }

    private static void addInPlace(NumpyArray target, NumpyArray array, int value) {
        array.applyToEachElementAnsSaveToTarget(target, element -> value + (Integer)element);
    }

    private static void addInPlace(NumpyArray target, NumpyArray array, short value) {
        array.applyToEachElementAnsSaveToTarget(target, element -> value + (Short)element);
    }

    private static void addInPlace(NumpyArray target, NumpyArray array, byte value) {
        array.applyToEachElementAnsSaveToTarget(target, element -> value + (Byte)element);
    }

    public static NumpyArray<Double> multiply(NumpyArray<Double> array, double factor) {
        NumpyArray result = NumpyArrayFactory.createArrayOfShapeAndTypeInfo(array);
        array.applyToEachElementAnsSaveToTarget(result, value -> value * factor);
        return result;
    }

    public static NumpyArray<Float> multiply(NumpyArray<Float> array, float factor) {
        NumpyArray result = NumpyArrayFactory.createArrayOfShapeAndTypeInfo(array);
        array.applyToEachElementAnsSaveToTarget(result, value -> Float.valueOf(value.floatValue() * factor));
        return result;
    }

    public static NumpyArray<Double> divide(NumpyArray<Double> array, double factor) {
        return Numpy.multiply(array, 1.0 / factor);
    }

    public static NumpyArray<Float> divide(NumpyArray<Float> array, float factor) {
        return Numpy.multiply(array, 1.0f / factor);
    }

    public static NumpyArray<Double> atLeast2D(double value) {
        NumpyArray result = NumpyArrayFactory.createArrayOfShapeAndTypeInfo(true, 8, new int[]{1, 1});
        result.set(value, 0, 0);
        return result;
    }

    public static NumpyArray<Float> atLeast2D(float value) {
        NumpyArray result = NumpyArrayFactory.createArrayOfShapeAndTypeInfo(true, 4, new int[]{1, 1});
        result.set(Float.valueOf(value), 0, 0);
        return result;
    }

    public static NumpyArray<Long> atLeast2D(long value) {
        NumpyArray result = NumpyArrayFactory.createArrayOfShapeAndTypeInfo(false, 8, new int[]{1, 1});
        result.set(value, 0, 0);
        return result;
    }

    public static NumpyArray<Integer> atLeast2D(int value) {
        NumpyArray result = NumpyArrayFactory.createArrayOfShapeAndTypeInfo(false, 4, new int[]{1, 1});
        result.set(value, 0, 0);
        return result;
    }

    public static NumpyArray<Short> atLeast2D(short value) {
        NumpyArray result = NumpyArrayFactory.createArrayOfShapeAndTypeInfo(false, 2, new int[]{1, 1});
        result.set(value, 0, 0);
        return result;
    }

    public static NumpyArray<Byte> atLeast2D(byte value) {
        NumpyArray result = NumpyArrayFactory.createArrayOfShapeAndTypeInfo(false, 1, new int[]{1, 1});
        result.set(value, 0, 0);
        return result;
    }

    public static <Type> NumpyArray<Type> atLeast2D(NumpyArray<Type> array) {
        NumpyArray result = null;
        if (array.numberOfDimensions() == 1) {
            result = NumpyArrayFactory.createArrayOfShapeAndTypeInfo(array.isFloatingPoint(), array.numberOfBytes(), new int[]{1, array.getShape()[0]});
            int i = 0;
            while (i < array.getShape()[0]) {
                result.set(array.get(i), 0, i++);
            }
        } else if (array.numberOfDimensions() > 1) {
            result = array;
        } else {
            throw new NumpyOperationException("The input for atLeast2D is invalid");
        }
        return result;
    }

    public static NumpyArray<Double> arrayMax(NumpyArray<Double> array, int axis) {
        NumpyArrayOperationWithAxisReduction<Double, Double> operation = new NumpyArrayOperationWithAxisReduction<Double, Double>(){

            @Override
            public NumpyArray<Double> createInstanceResultNumpyArray(int[] shape) {
                return NumpyArrayFactory.arrayOfDoubleWithShape(shape);
            }

            @Override
            public Object reduceAxisValues(Object[] valuesInAxis) {
                double max = (Double)valuesInAxis[0];
                for (int i = 1; i < valuesInAxis.length; ++i) {
                    double m = (Double)valuesInAxis[i];
                    if (!(m > max)) continue;
                    max = m;
                }
                return max;
            }
        };
        return operation.apply(array, axis);
    }
}

