/*
 * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.modality.cv.util;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.types.Shape;

/**
 * {@code NDImageUtils} is an image processing utility to load, reshape, and convert images using
 * {@link NDArray} images.
 */
public final class NDImageUtils {

    private NDImageUtils() {}

    /**
     * Resizes an image to the given size.
     *
     * @param image the image to resize
     * @param size the new size to use for both height and width
     * @return the resized NDList
     */
    public static NDArray resize(NDArray image, int size) {
        return image.getNDArrayInternal().resize(size, size);
    }

    /**
     * Resizes an image to the given width and height.
     *
     * @param image the image to resize
     * @param width the desired width
     * @param height the desired height
     * @return the resized NDList
     */
    public static NDArray resize(NDArray image, int width, int height) {
        return image.getNDArrayInternal().resize(width, height);
    }

    /**
     * Normalizes an image NDArray of shape CHW or NCHW with a single mean and standard deviation to
     * apply to all channels.
     *
     * @param input the image to normalize
     * @param mean the mean to normalize with (for all channels)
     * @param std the standard deviation to normalize with (for all channels)
     * @return the normalized NDArray
     * @see NDImageUtils#normalize(NDArray, float[], float[])
     */
    public static NDArray normalize(NDArray input, float mean, float std) {
        return normalize(input, new float[] {mean, mean, mean}, new float[] {std, std, std});
    }

    /**
     * Normalizes an image NDArray of shape CHW or NCHW with mean and standard deviation.
     *
     * <p>Given mean {@code (m1, ..., mn)} and standard deviation {@code (s1, ..., sn} for {@code n}
     * channels, this transform normalizes each channel of the input tensor with: {@code output[i] =
     * (input[i] - m1) / (s1)}.
     *
     * @param input the image to normalize
     * @param mean the mean to normalize with for each channel
     * @param std the standard deviation to normalize with for each channel
     * @return the normalized NDArray
     */
    public static NDArray normalize(NDArray input, float[] mean, float[] std) {
        return input.getNDArrayInternal().normalize(mean, std);
    }

    /**
     * Converts an image NDArray from preprocessing format to Neural Network format.
     *
     * <p>Converts an image NDArray of shape HWC in the range {@code [0, 255]} to a {@link
     * ai.djl.ndarray.types.DataType#FLOAT32} tensor NDArray of shape CHW in the range {@code [0,
     * 1]}.
     *
     * @param image the image to convert
     * @return the converted image
     */
    public static NDArray toTensor(NDArray image) {
        return image.getNDArrayInternal().toTensor();
    }

    /**
     * Crops an image to a square of size {@code min(width, height)}.
     *
     * @param image the image to crop
     * @return the cropped image
     * @see NDImageUtils#centerCrop(NDArray, int, int)
     */
    public static NDArray centerCrop(NDArray image) {
        Shape shape = image.getShape();
        int w = (int) shape.get(1);
        int h = (int) shape.get(0);

        if (w == h) {
            return image;
        }

        if (w > h) {
            return centerCrop(image, h, h);
        }

        return centerCrop(image, w, w);
    }

    /**
     * Crops an image to a given width and height from the center of the image.
     *
     * @param image the image to crop
     * @param width the desired width of the cropped image
     * @param height the desired height of the cropped image
     * @return the cropped image
     */
    public static NDArray centerCrop(NDArray image, int width, int height) {
        Shape shape = image.getShape();
        int w = (int) shape.get(1);
        int h = (int) shape.get(0);

        int x;
        int y;
        int dw = (w - width) / 2;
        int dh = (h - height) / 2;
        if (dw > 0) {
            x = dw;
            w = width;
        } else {
            x = 0;
        }
        if (dh > 0) {
            y = dh;
            h = height;
        } else {
            y = 0;
        }

        return crop(image, x, y, w, h);
    }

    /**
     * Crops an image with a given location and size.
     *
     * @param image the image to crop
     * @param x the x coordinate of the top-left corner of the crop
     * @param y the y coordinate of the top-left corner of the crop
     * @param width the width of the cropped image
     * @param height the height of the cropped image
     * @return the cropped image
     */
    public static NDArray crop(NDArray image, int x, int y, int width, int height) {
        return image.getNDArrayInternal().crop(x, y, width, height);
    }
}
