/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.objdetect;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import lombok.NonNull;
import org.deeplearning4j.nn.layers.objdetect.DetectedObject;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Broadcast;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;

public class YoloUtils {
    public static INDArray activate(INDArray boundingBoxPriors, INDArray input) {
        return YoloUtils.activate(boundingBoxPriors, input, LayerWorkspaceMgr.noWorkspaces());
    }

    public static INDArray activate(@NonNull INDArray boundingBoxPriors, @NonNull INDArray input, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (boundingBoxPriors == null) {
            throw new NullPointerException("boundingBoxPriors is marked @NonNull but is null");
        }
        if (input == null) {
            throw new NullPointerException("input is marked @NonNull but is null");
        }
        long mb = input.size(0);
        long h = input.size(2);
        long w = input.size(3);
        long b = boundingBoxPriors.size(0);
        long c = input.size(1) / b - 5L;
        INDArray output = layerWorkspaceMgr.create(ArrayType.ACTIVATIONS, input.dataType(), input.shape(), 'c');
        INDArray output5 = output.reshape('c', new long[]{mb, b, 5L + c, h, w});
        INDArray output4 = output;
        INDArray input4 = input.dup('c');
        INDArray input5 = input4.reshape('c', new long[]{mb, b, 5L + c, h, w});
        INDArray predictedXYCenterGrid = input5.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)2), NDArrayIndex.all(), NDArrayIndex.all()});
        Transforms.sigmoid((INDArray)predictedXYCenterGrid, (boolean)false);
        INDArray predictedWHPreExp = input5.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval((int)2, (int)4), NDArrayIndex.all(), NDArrayIndex.all()});
        INDArray predictedWH = Transforms.exp((INDArray)predictedWHPreExp, (boolean)false);
        Broadcast.mul((INDArray)predictedWH, (INDArray)boundingBoxPriors.castTo(input.dataType()), (INDArray)predictedWH, (int[])new int[]{1, 2});
        INDArray predictedConf = input5.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)4L), NDArrayIndex.all(), NDArrayIndex.all()});
        Transforms.sigmoid((INDArray)predictedConf, (boolean)false);
        output4.assign(input4);
        INDArray inputClassesPreSoftmax = input5.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval((long)5L, (long)(5L + c)), NDArrayIndex.all(), NDArrayIndex.all()});
        INDArray classPredictionsPreSoftmax2d = inputClassesPreSoftmax.permute(new int[]{0, 1, 3, 4, 2}).dup('c').reshape('c', new long[]{mb * b * h * w, c});
        Transforms.softmax((INDArray)classPredictionsPreSoftmax2d, (boolean)false);
        INDArray postSoftmax5d = classPredictionsPreSoftmax2d.reshape('c', new long[]{mb, b, h, w, c}).permute(new int[]{0, 1, 4, 2, 3});
        INDArray outputClasses = output5.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval((long)5L, (long)(5L + c)), NDArrayIndex.all(), NDArrayIndex.all()});
        outputClasses.assign(postSoftmax5d);
        return output;
    }

    public static double overlap(double x1, double x2, double x3, double x4) {
        if (x3 < x1) {
            if (x4 < x1) {
                return 0.0;
            }
            return Math.min(x2, x4) - x1;
        }
        if (x2 < x3) {
            return 0.0;
        }
        return Math.min(x2, x4) - x3;
    }

    public static double iou(DetectedObject o1, DetectedObject o2) {
        double x1min = o1.getCenterX() - o1.getWidth() / 2.0;
        double x1max = o1.getCenterX() + o1.getWidth() / 2.0;
        double y1min = o1.getCenterY() - o1.getHeight() / 2.0;
        double y1max = o1.getCenterY() + o1.getHeight() / 2.0;
        double x2min = o2.getCenterX() - o2.getWidth() / 2.0;
        double x2max = o2.getCenterX() + o2.getWidth() / 2.0;
        double y2min = o2.getCenterY() - o2.getHeight() / 2.0;
        double y2max = o2.getCenterY() + o2.getHeight() / 2.0;
        double ow = YoloUtils.overlap(x1min, x1max, x2min, x2max);
        double oh = YoloUtils.overlap(y1min, y1max, y2min, y2max);
        double intersection = ow * oh;
        double union = o1.getWidth() * o1.getHeight() + o2.getWidth() * o2.getHeight() - intersection;
        return intersection / union;
    }

    public static void nms(List<DetectedObject> objects, double iouThreshold) {
        for (int i = 0; i < objects.size(); ++i) {
            for (int j = 0; j < objects.size(); ++j) {
                DetectedObject o1 = objects.get(i);
                DetectedObject o2 = objects.get(j);
                if (o1 == null || o2 == null || o1.getPredictedClass() != o2.getPredictedClass() || !(o1.getConfidence() < o2.getConfidence()) || !(YoloUtils.iou(o1, o2) > iouThreshold)) continue;
                objects.set(i, null);
            }
        }
        Iterator<DetectedObject> it = objects.iterator();
        while (it.hasNext()) {
            if (it.next() != null) continue;
            it.remove();
        }
    }

    public static List<DetectedObject> getPredictedObjects(INDArray boundingBoxPriors, INDArray networkOutput, double confThreshold, double nmsThreshold) {
        if (networkOutput.rank() != 4) {
            throw new IllegalStateException("Invalid network output activations array: should be rank 4. Got array with shape " + Arrays.toString(networkOutput.shape()));
        }
        if (confThreshold < 0.0 || confThreshold > 1.0) {
            throw new IllegalStateException("Invalid confidence threshold: must be in range [0,1]. Got: " + confThreshold);
        }
        long mb = networkOutput.size(0);
        long h = networkOutput.size(2);
        long w = networkOutput.size(3);
        long b = boundingBoxPriors.size(0);
        long c = networkOutput.size(1) / b - 5L;
        INDArray output5 = networkOutput.dup('c').reshape(new long[]{mb, b, 5L + c, h, w});
        INDArray predictedConfidence = output5.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)4L), NDArrayIndex.all(), NDArrayIndex.all()});
        INDArray softmax = output5.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval((long)5L, (long)(5L + c)), NDArrayIndex.all(), NDArrayIndex.all()});
        ArrayList<DetectedObject> out = new ArrayList<DetectedObject>();
        int i = 0;
        while ((long)i < mb) {
            int x = 0;
            while ((long)x < w) {
                int y = 0;
                while ((long)y < h) {
                    int box = 0;
                    while ((long)box < b) {
                        double conf = predictedConfidence.getDouble(new int[]{i, box, y, x});
                        if (!(conf < confThreshold)) {
                            INDArray sm;
                            double px = output5.getDouble(new int[]{i, box, 0, y, x});
                            double py = output5.getDouble(new int[]{i, box, 1, y, x});
                            double pw = output5.getDouble(new int[]{i, box, 2, y, x});
                            double ph = output5.getDouble(new int[]{i, box, 3, y, x});
                            px += (double)x;
                            py += (double)y;
                            try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                                sm = softmax.get(new INDArrayIndex[]{NDArrayIndex.point((long)i), NDArrayIndex.point((long)box), NDArrayIndex.all(), NDArrayIndex.point((long)y), NDArrayIndex.point((long)x)}).dup();
                            }
                            out.add(new DetectedObject(i, px, py, pw, ph, sm, conf));
                        }
                        ++box;
                    }
                    ++y;
                }
                ++x;
            }
            ++i;
        }
        if (nmsThreshold > 0.0) {
            YoloUtils.nms(out, nmsThreshold);
        }
        return out;
    }
}

