/*
 * Decompiled with CFR 0.152.
 */
package ai.konduit.serving.data.image.step.segmentation.index;

import ai.konduit.serving.annotation.runner.CanRun;
import ai.konduit.serving.data.image.convert.ImageToNDArray;
import ai.konduit.serving.data.image.step.segmentation.index.DrawSegmentationStep;
import ai.konduit.serving.data.image.util.ColorUtil;
import ai.konduit.serving.pipeline.api.context.Context;
import ai.konduit.serving.pipeline.api.data.BoundingBox;
import ai.konduit.serving.pipeline.api.data.Data;
import ai.konduit.serving.pipeline.api.data.Image;
import ai.konduit.serving.pipeline.api.data.NDArray;
import ai.konduit.serving.pipeline.api.data.NDArrayType;
import ai.konduit.serving.pipeline.api.step.PipelineStep;
import ai.konduit.serving.pipeline.api.step.PipelineStepRunner;
import ai.konduit.serving.pipeline.impl.data.ndarray.SerializedNDArray;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import lombok.NonNull;
import org.bytedeco.javacpp.indexer.UByteIndexer;
import org.bytedeco.javacpp.indexer.UByteRawIndexer;
import org.bytedeco.opencv.global.opencv_imgproc;
import org.bytedeco.opencv.opencv_core.Mat;
import org.bytedeco.opencv.opencv_core.Rect;
import org.bytedeco.opencv.opencv_core.Scalar;
import org.bytedeco.opencv.opencv_core.Size;
import org.nd4j.common.base.Preconditions;
import org.opencv.core.CvType;

@CanRun(value={DrawSegmentationStep.class})
public class DrawSegmentationRunner
implements PipelineStepRunner {
    protected final DrawSegmentationStep step;
    protected int[] colorsB;
    protected int[] colorsG;
    protected int[] colorsR;

    public DrawSegmentationRunner(@NonNull DrawSegmentationStep step) {
        if (step == null) {
            throw new NullPointerException("step is marked non-null but is null");
        }
        this.step = step;
    }

    public void close() {
    }

    public PipelineStep getPipelineStep() {
        return this.step;
    }

    public Data exec(Context ctx, Data data) {
        String outputName;
        int backgroundClass;
        boolean drawingOnImage;
        Mat drawOn;
        NDArray segmentArr;
        long[] shape;
        if (this.colorsB == null) {
            List classColors = this.step.classColors();
            this.initColors(classColors, 32);
        }
        Preconditions.checkState(((shape = (segmentArr = data.getNDArray(this.step.segmentArray())).shape()).length == 3 && shape[0] == 1L ? 1 : 0) != 0, (String)"Expected segment indices array with shape [1, height, width], got array with shape %s", (Object)shape);
        String imgName = this.step.image();
        Mat backgroundMask = null;
        boolean resizeRequired = false;
        if (imgName == null) {
            drawOn = new Mat((int)shape[1], (int)shape[2], CvType.CV_8UC3);
            drawingOnImage = false;
        } else {
            Image i = data.getImage(imgName);
            int iH = i.height();
            int iW = i.width();
            double arImg = (double)iW / (double)iH;
            double arSegment = (double)shape[2] / (double)shape[1];
            if ((long)iH != shape[1] && (long)iW != shape[2]) {
                resizeRequired = true;
                if (arImg != arSegment) {
                    Preconditions.checkState((this.step.imageToNDArrayConfig() != null ? 1 : 0) != 0, (String)"Image and segment indices array dimensions do not match in terms of aspect ratio, and no ImageToNDArrayConfig was provided. Expected segment indices array with shape [1, height, width] - got array with shape %s and image with h=%s, w=%s", (Object)shape, (Object)iH, (Object)iW);
                }
                drawOn = new Mat((int)shape[1], (int)shape[2], CvType.CV_8UC3);
                drawingOnImage = false;
            } else {
                drawOn = new Mat();
                ((Mat)i.getAs(Mat.class)).clone().convertTo(drawOn, CvType.CV_8UC3);
                drawingOnImage = true;
            }
        }
        SerializedNDArray nd = (SerializedNDArray)segmentArr.getAs(SerializedNDArray.class);
        long[] maskShape = nd.getShape();
        int h = (int)maskShape[1];
        int w = (int)maskShape[2];
        UByteIndexer idx = (UByteIndexer)drawOn.createIndexer();
        UByteRawIndexer uByteRawIndexer = (UByteRawIndexer)idx;
        IntGetter ig = null;
        if (nd.getType() == NDArrayType.INT32) {
            IntBuffer ib = nd.getBuffer().asIntBuffer();
            ig = ib::get;
        } else if (nd.getType() == NDArrayType.INT64) {
            nd.getBuffer().rewind();
            LongBuffer lb = nd.getBuffer().asLongBuffer();
            ig = () -> (int)lb.get();
        } else {
            throw new RuntimeException();
        }
        boolean skipBackgroundClass = this.step.backgroundClass() != null;
        int n = backgroundClass = skipBackgroundClass ? this.step.backgroundClass() : -1;
        if (skipBackgroundClass && !drawingOnImage) {
            backgroundMask = new Mat(drawOn.rows(), drawOn.cols(), CvType.CV_8UC1);
        }
        if (drawingOnImage) {
            double opacity;
            if (this.step.opacity() == null) {
                opacity = 0.5;
            } else {
                opacity = this.step.opacity();
                Preconditions.checkState((opacity >= 0.0 && opacity <= 1.0 ? 1 : 0) != 0, (String)"Opacity value (if set) must be between 0.0 and 1.0, got %s", (double)opacity);
            }
            double o2 = 1.0 - opacity;
            for (int y = 0; y < h; ++y) {
                for (int x = 0; x < w; ++x) {
                    int r;
                    int g;
                    int b;
                    int classIdx = ig.get();
                    if (classIdx >= this.colorsB.length) {
                        this.initColors(this.step.classColors(), this.colorsB.length + 32);
                    }
                    long idxB = 3 * w * y + 3 * x;
                    if (skipBackgroundClass && classIdx == backgroundClass) {
                        b = uByteRawIndexer.getRaw(idxB);
                        g = uByteRawIndexer.getRaw(idxB + 1L);
                        r = uByteRawIndexer.getRaw(idxB + 2L);
                    } else {
                        b = (int)(opacity * (double)this.colorsB[classIdx] + o2 * (double)uByteRawIndexer.getRaw(idxB));
                        g = (int)(opacity * (double)this.colorsG[classIdx] + o2 * (double)uByteRawIndexer.getRaw(idxB + 1L));
                        r = (int)(opacity * (double)this.colorsR[classIdx] + o2 * (double)uByteRawIndexer.getRaw(idxB + 2L));
                    }
                    uByteRawIndexer.putRaw(idxB, b);
                    uByteRawIndexer.putRaw(idxB + 1L, g);
                    uByteRawIndexer.putRaw(idxB + 2L, r);
                }
            }
        } else {
            UByteIndexer bMaskIdx = backgroundMask == null ? null : (UByteIndexer)backgroundMask.createIndexer();
            UByteRawIndexer uByteRawIndexer2 = (UByteRawIndexer)bMaskIdx;
            for (int y = 0; y < h; ++y) {
                for (int x = 0; x < w; ++x) {
                    int classIdx = ig.get();
                    if (classIdx >= this.colorsB.length) {
                        this.initColors(this.step.classColors(), this.colorsB.length + 32);
                    }
                    long idxB = 3 * w * y + 3 * x;
                    uByteRawIndexer.putRaw(idxB, this.colorsB[classIdx]);
                    uByteRawIndexer.putRaw(idxB + 1L, this.colorsG[classIdx]);
                    uByteRawIndexer.putRaw(idxB + 2L, this.colorsR[classIdx]);
                    if (backgroundMask == null) continue;
                    long idxMask = w * y + x;
                    uByteRawIndexer2.putRaw(idxMask, classIdx == backgroundClass ? 0 : 1);
                }
            }
        }
        if (resizeRequired) {
            Image im = data.getImage(imgName);
            BoundingBox bb = ImageToNDArray.getCropRegion(im, this.step.imageToNDArrayConfig());
            int oH = (int)(bb.height() * (double)im.height());
            int oW = (int)(bb.width() * (double)im.width());
            int x1 = (int)(bb.x1() * (double)im.width());
            int y1 = (int)(bb.y1() * (double)im.height());
            Mat resized = new Mat();
            opencv_imgproc.resize((Mat)drawOn, (Mat)resized, (Size)new Size(oW, oH));
            Mat resizedFloat = new Mat();
            resized.convertTo(resizedFloat, CvType.CV_32FC3);
            Mat asFloat = new Mat();
            ((Mat)im.getAs(Mat.class)).convertTo(asFloat, CvType.CV_32FC3);
            Mat subset = asFloat.apply(new Rect(x1, y1, oW, oH));
            double opacity = this.step.opacity();
            if (backgroundMask == null) {
                opencv_imgproc.accumulateWeighted((Mat)resized, (Mat)subset, (double)opacity);
            } else {
                Mat maskResized = new Mat();
                opencv_imgproc.resize((Mat)backgroundMask, (Mat)maskResized, (Size)new Size(oW, oH));
                opencv_imgproc.accumulateWeighted((Mat)resized, (Mat)subset, (double)opacity, (Mat)maskResized);
            }
            Mat out = new Mat();
            asFloat.convertTo(out, CvType.CV_8UC3);
            drawOn = out;
        }
        if ((outputName = this.step.outputName()) == null) {
            outputName = "image";
        }
        return Data.singleton((String)outputName, (Object)Image.create((Object)drawOn));
    }

    private void initColors(List<String> classColors, int max) {
        if (this.colorsB == null && classColors != null) {
            this.colorsB = new int[classColors.size()];
            this.colorsG = new int[classColors.size()];
            this.colorsR = new int[classColors.size()];
            for (int i = 0; i < this.colorsB.length; ++i) {
                Scalar c = ColorUtil.stringToColor(classColors.get(i));
                this.colorsB[i] = (int)c.blue();
                this.colorsG[i] = (int)c.green();
                this.colorsR[i] = (int)c.red();
            }
        }
        if (this.colorsB == null || this.colorsB.length < max) {
            int i;
            int start;
            if (this.colorsB == null) {
                this.colorsB = new int[max];
                this.colorsG = new int[max];
                this.colorsR = new int[max];
                start = 0;
            } else {
                start = this.colorsB.length;
                this.colorsB = Arrays.copyOf(this.colorsB, max);
                this.colorsG = Arrays.copyOf(this.colorsG, max);
                this.colorsR = Arrays.copyOf(this.colorsR, max);
            }
            Random rng = new Random(12345L);
            if (start > 0) {
                for (i = 0; i < start; ++i) {
                    rng.nextInt(255);
                    rng.nextInt(255);
                    rng.nextInt(255);
                }
            }
            for (i = start; i < max; ++i) {
                Scalar s = ColorUtil.randomColor(rng);
                this.colorsB[i] = (int)s.blue();
                this.colorsG[i] = (int)s.green();
                this.colorsR[i] = (int)s.red();
            }
        }
    }

    private static interface IntGetter {
        public int get();
    }
}

