/*
 * Decompiled with CFR 0.152.
 */
package ai.konduit.serving.data.image.step.point.perspective.convert;

import ai.konduit.serving.annotation.runner.CanRun;
import ai.konduit.serving.data.image.step.point.perspective.convert.PerspectiveTransformStep;
import ai.konduit.serving.pipeline.api.context.Context;
import ai.konduit.serving.pipeline.api.data.BoundingBox;
import ai.konduit.serving.pipeline.api.data.Data;
import ai.konduit.serving.pipeline.api.data.Image;
import ai.konduit.serving.pipeline.api.data.Point;
import ai.konduit.serving.pipeline.api.data.ValueType;
import ai.konduit.serving.pipeline.api.step.PipelineStep;
import ai.konduit.serving.pipeline.api.step.PipelineStepRunner;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import lombok.NonNull;
import org.bytedeco.javacpp.indexer.DoubleIndexer;
import org.bytedeco.javacpp.indexer.DoubleRawIndexer;
import org.bytedeco.javacpp.indexer.FloatIndexer;
import org.bytedeco.opencv.global.opencv_core;
import org.bytedeco.opencv.global.opencv_imgproc;
import org.bytedeco.opencv.opencv_core.Mat;
import org.bytedeco.opencv.opencv_core.Size;
import org.opencv.core.CvType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@CanRun(value={PerspectiveTransformStep.class})
public class PerspectiveTransformRunner
implements PipelineStepRunner {
    private static final Logger log = LoggerFactory.getLogger(PerspectiveTransformRunner.class);
    protected final PerspectiveTransformStep step;

    public PerspectiveTransformRunner(@NonNull PerspectiveTransformStep step) {
        if (step == null) {
            throw new NullPointerException("step is marked non-null but is null");
        }
        this.step = step;
    }

    public void close() {
    }

    public PipelineStep getPipelineStep() {
        return this.step;
    }

    public Data exec(Context ctx, Data data) {
        List source = null;
        if (this.step.sourcePoints() != null && this.step.sourcePointsName() != null) {
            throw new IllegalStateException("You must not define both sourcePoints and sourcePointsName simultaneously on PerspectiveTransformStep!");
        }
        if (this.step.sourcePoints() == null && this.step.sourcePointsName() == null) {
            throw new IllegalStateException("You have to define either sourcePoints or sourcePointsName on PerspectiveTransformStep!");
        }
        if (this.step.sourcePoints() != null) {
            source = this.step.sourcePoints();
        } else {
            ValueType sourceType = data.type(this.step.sourcePointsName());
            if (sourceType == ValueType.LIST && data.listType(this.step.sourcePointsName()) == ValueType.POINT) {
                List points = data.getListPoint(this.step.sourcePointsName());
                if (points.size() != 4) {
                    throw new IllegalArgumentException("field " + this.step.sourcePointsName() + " for source points in PerspectiveTransformStep does not contain exactly 4 points (found: " + points.size() + ")");
                }
                source = points;
            }
        }
        List target = null;
        if (this.step.targetPoints() != null && this.step.targetPointsName() != null) {
            throw new IllegalStateException("You must not define both targetPoints and targetPointsName simultaneously on PerspectiveTransformStep!");
        }
        if (this.step.targetPoints() == null && this.step.targetPointsName() == null) {
            target = this.calculateTargetPoints(source);
        } else if (this.step.targetPoints() != null) {
            target = this.step.targetPoints();
        } else {
            ValueType targetType = data.type(this.step.targetPointsName());
            if (targetType == ValueType.LIST && data.listType(this.step.targetPointsName()) == ValueType.POINT) {
                List points = data.getListPoint(this.step.targetPointsName());
                if (points.size() != 4) {
                    throw new IllegalArgumentException("field " + this.step.targetPointsName() + " for target points in PerspectiveTransformStep does not contain exactly 4 points (found: " + points.size() + ")");
                }
                source = points;
            }
        }
        int refWidth = -1;
        int refHeight = -1;
        if (this.step.referenceImage() != null) {
            Image refImg;
            ValueType type = data.type(this.step.referenceImage());
            if (type == ValueType.IMAGE) {
                refImg = data.getImage(this.step.referenceImage());
            } else if (type == ValueType.LIST && data.listType(this.step.referenceImage()) == ValueType.IMAGE) {
                List images = data.getListImage(this.step.referenceImage());
                if (images.size() == 0) {
                    throw new IllegalArgumentException("fild " + this.step.referenceImage() + " is an empty list");
                }
                refImg = (Image)images.get(0);
            } else {
                throw new IllegalArgumentException("field " + this.step.referenceImage() + " is neither an image nor a list of images");
            }
            refWidth = refImg.width();
            refHeight = refImg.height();
        } else if (this.step.referenceWidth() != null && this.step.referenceHeight() != null) {
            refWidth = this.step.referenceWidth();
            refHeight = this.step.referenceHeight();
        }
        Mat sourceMat = this.pointsToMat(source);
        Mat targetMat = this.pointsToMat(target);
        Mat transMat = this.getPerspectiveTransform(sourceMat, targetMat, refWidth, refHeight);
        LinkedList<String> fields = this.step.inputNames();
        if (fields == null) {
            fields = new LinkedList<String>();
            for (String key : data.keys()) {
                if (key.equals(this.step.targetPointsName()) || key.equals(this.step.sourcePointsName())) continue;
                ValueType keyType = data.type(key);
                if (keyType == ValueType.LIST) {
                    keyType = data.listType(key);
                }
                if (keyType != ValueType.IMAGE && keyType != ValueType.BOUNDING_BOX && keyType != ValueType.POINT) continue;
                fields.add(key);
            }
        }
        if (fields.size() == 0) {
            throw new IllegalStateException("No fields found where PerspectiveTransformRunner could be applied.");
        }
        LinkedList<String> outNames = this.step.outputNames();
        if (outNames == null || outNames.size() == 0) {
            outNames = fields;
        } else if (outNames.size() != fields.size()) {
            throw new IllegalStateException("You must provide only as many outputNames as there are fields to be transformed! outputNames.size = " + this.step.outputNames().size() + " fields.size = " + fields.size());
        }
        Data out = Data.empty();
        if (this.step.keepOtherFields()) {
            for (String key : data.keys()) {
                out.copyFrom(key, data);
            }
        }
        int rW = refWidth;
        int rH = refHeight;
        block12: for (int i = 0; i < fields.size(); ++i) {
            String key = (String)fields.get(i);
            ValueType keyType = data.type(key);
            String outKey = (String)outNames.get(i);
            if (keyType == ValueType.LIST) {
                keyType = data.listType(key);
                switch (keyType) {
                    case POINT: {
                        out.putListPoint(outKey, data.getListPoint(key).stream().map(it -> this.transform(transMat, (Point)it, rW, rH)).collect(Collectors.toList()));
                        continue block12;
                    }
                    case IMAGE: {
                        out.putListImage(outKey, data.getListImage(key).stream().map(it -> this.transform(transMat, (Image)it)).collect(Collectors.toList()));
                        continue block12;
                    }
                    case BOUNDING_BOX: {
                        out.putListBoundingBox(outKey, data.getListBoundingBox(key).stream().map(it -> this.transform(transMat, (BoundingBox)it, rW, rH)).collect(Collectors.toList()));
                        continue block12;
                    }
                    default: {
                        throw new IllegalStateException("Field " + key + " with data type " + keyType + " is not supported for perspective transform!");
                    }
                }
            }
            switch (keyType) {
                case POINT: {
                    out.put(outKey, this.transform(transMat, data.getPoint(key), rW, rH));
                    continue block12;
                }
                case IMAGE: {
                    out.put(outKey, this.transform(transMat, data.getImage(key)));
                    continue block12;
                }
                case BOUNDING_BOX: {
                    out.put(outKey, this.transform(transMat, data.getBoundingBox(key), rW, rH));
                    continue block12;
                }
                default: {
                    throw new IllegalStateException("Field " + key + " with data type " + keyType + " is not supported for perspective transform!");
                }
            }
        }
        return out;
    }

    private Point transform(Mat transform, Point it, int refW, int refH) {
        it = it.toAbsolute(new double[]{refW, refH});
        Mat dst = new Mat();
        Mat src = new Mat(1, 1, CvType.CV_64FC((int)it.dimensions()));
        DoubleIndexer idx = (DoubleIndexer)src.createIndexer();
        DoubleRawIndexer doubleRawIndexer = (DoubleRawIndexer)idx;
        for (int i = 0; i < it.dimensions(); ++i) {
            doubleRawIndexer.putRaw((long)i, it.get(i));
        }
        opencv_core.perspectiveTransform((Mat)src, (Mat)dst, (Mat)transform);
        idx = (DoubleIndexer)dst.createIndexer();
        double[] coords = new double[it.dimensions()];
        idx.get(0L, coords);
        return Point.create((double[])coords, (String)it.label(), (Double)it.probability());
    }

    private BoundingBox transform(Mat transform, BoundingBox it, int refW, int refH) {
        Point transformedCenter = this.transform(transform, Point.create((double)it.cx(), (double)it.cy()), refW, refH);
        return BoundingBox.create((double)transformedCenter.x(), (double)transformedCenter.y(), (double)it.width(), (double)it.height(), (String)it.label(), (Double)it.probability());
    }

    private Image transform(Mat transform, Image it) {
        Mat dst = new Mat();
        Mat src = (Mat)it.getAs(Mat.class);
        Size outputSize = this.calculateOutputSize(transform, it.width(), it.height());
        opencv_imgproc.warpPerspective((Mat)src, (Mat)dst, (Mat)transform, (Size)outputSize);
        return Image.create((Object)dst);
    }

    private Mat getPerspectiveTransform(Mat sourceMat, Mat targetMat, int refWidth, int refHeight) {
        Mat initialTransform = opencv_imgproc.getPerspectiveTransform((Mat)sourceMat, (Mat)targetMat);
        if (refWidth == -1 || refHeight == -1) {
            return initialTransform;
        }
        double[] extremes = this.calculateExtremes(initialTransform, refWidth, refHeight);
        FloatIndexer tIdx = (FloatIndexer)targetMat.createIndexer();
        long rows = tIdx.size(0);
        for (long i = 0L; i < rows; ++i) {
            tIdx.put(i, 0L, (float)((double)tIdx.get(i, 0L) - extremes[0]));
            tIdx.put(i, 1L, (float)((double)tIdx.get(i, 1L) - extremes[1]));
        }
        return opencv_imgproc.getPerspectiveTransform((Mat)sourceMat, (Mat)targetMat);
    }

    private double[] calculateExtremes(Mat transform, int width, int height) {
        Mat src = new Mat(4, 1, CvType.CV_64FC2);
        DoubleIndexer idx = (DoubleIndexer)src.createIndexer();
        DoubleRawIndexer idxRaw = (DoubleRawIndexer)src.createIndexer();
        idxRaw.putRaw(0L, 0.0);
        idxRaw.putRaw(1L, 0.0);
        idxRaw.putRaw(2L, (double)width);
        idxRaw.putRaw(3L, 0.0);
        idxRaw.putRaw(4L, 0.0);
        idxRaw.putRaw(5L, (double)height);
        idxRaw.putRaw(6L, (double)width);
        idxRaw.putRaw(7L, (double)height);
        Mat dst = new Mat();
        opencv_core.perspectiveTransform((Mat)src, (Mat)dst, (Mat)transform);
        idx = (DoubleIndexer)dst.createIndexer();
        idxRaw = (DoubleRawIndexer)idx;
        double[] xValues = new double[]{idxRaw.getRaw(0L), idxRaw.getRaw(2L), idxRaw.getRaw(4L), idxRaw.getRaw(6L)};
        double[] yValues = new double[]{idxRaw.getRaw(1L), idxRaw.getRaw(3L), idxRaw.getRaw(5L), idxRaw.getRaw(7L)};
        double minX = DoubleStream.of(xValues).min().getAsDouble();
        double maxX = DoubleStream.of(xValues).max().getAsDouble();
        double minY = DoubleStream.of(yValues).min().getAsDouble();
        double maxY = DoubleStream.of(yValues).max().getAsDouble();
        return new double[]{minX, minY, maxX, maxY};
    }

    private Size calculateOutputSize(Mat transform, int width, int height) {
        double[] extremes = this.calculateExtremes(transform, width, height);
        double minX = extremes[0];
        double minY = extremes[1];
        double maxX = extremes[2];
        double maxY = extremes[3];
        int outputWidth = (int)Math.round(maxX - minX);
        int outputHeight = (int)Math.round(maxY - minY);
        if (outputWidth > 4096 || outputHeight > 4096) {
            log.warn("Selected transform would create a too large output image ({}, {}})", (Object)outputWidth, (Object)outputHeight);
            outputWidth = Math.min(outputWidth, 4096);
            outputHeight = Math.min(outputHeight, 4096);
        }
        return new Size(outputWidth, outputHeight);
    }

    private List<Point> calculateTargetPoints(List<Point> source) {
        Point topLeft = source.get(0);
        Point topRight = source.get(1);
        Point bottomLeft = source.get(2);
        Point bottomRight = source.get(3);
        double width = Math.max(Math.sqrt(Math.pow(topLeft.x() - bottomLeft.x(), 2.0) + Math.pow(topLeft.y() - bottomLeft.y(), 2.0)), Math.sqrt(Math.pow(topRight.x() - bottomRight.x(), 2.0) + Math.pow(topRight.y() - bottomRight.y(), 2.0)));
        double height = Math.max(Math.sqrt(Math.pow(topLeft.x() - topRight.x(), 2.0) + Math.pow(topLeft.y() - topRight.y(), 2.0)), Math.sqrt(Math.pow(bottomLeft.x() - bottomRight.x(), 2.0) + Math.pow(bottomLeft.y() - bottomRight.y(), 2.0)));
        double originX = topLeft.x() <= width / 2.0 ? topLeft.x() : width - topLeft.x();
        double originY = topLeft.y() <= height / 2.0 ? topLeft.y() : height - topLeft.y();
        return Arrays.asList(Point.create((double)originX, (double)originY), Point.create((double)(originX + width), (double)originY), Point.create((double)originX, (double)(originY + height)), Point.create((double)(originX + width), (double)(originY + height)));
    }

    private Mat pointsToMat(List<Point> points) {
        int rows = points.size();
        int cols = points.get(1).dimensions();
        Mat mat = new Mat(rows, cols, 5);
        FloatIndexer idx = (FloatIndexer)mat.createIndexer();
        for (int i = 0; i < rows; ++i) {
            for (int j = 0; j < cols; ++j) {
                idx.put((long)i, (long)j, (float)points.get(i).get(j));
            }
        }
        return mat;
    }
}

