/*
 * Decompiled with CFR 0.152.
 */
package org.datavec.image.loader;

import com.github.jaiimageio.impl.plugins.tiff.TIFFImageReaderSpi;
import com.github.jaiimageio.impl.plugins.tiff.TIFFImageWriterSpi;
import com.twelvemonkeys.imageio.plugins.bmp.BMPImageReaderSpi;
import com.twelvemonkeys.imageio.plugins.bmp.CURImageReaderSpi;
import com.twelvemonkeys.imageio.plugins.bmp.ICOImageReaderSpi;
import com.twelvemonkeys.imageio.plugins.jpeg.JPEGImageReaderSpi;
import com.twelvemonkeys.imageio.plugins.jpeg.JPEGImageWriterSpi;
import com.twelvemonkeys.imageio.plugins.psd.PSDImageReaderSpi;
import java.awt.Graphics2D;
import java.awt.Image;
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferByte;
import java.awt.image.WritableRaster;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import javax.imageio.ImageIO;
import javax.imageio.spi.IIORegistry;
import org.datavec.image.loader.BaseImageLoader;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.linalg.util.NDArrayUtil;

public class ImageLoader
extends BaseImageLoader {
    public ImageLoader() {
    }

    public ImageLoader(int height, int width) {
        this.height = height;
        this.width = width;
    }

    public ImageLoader(int height, int width, int channels) {
        this.height = height;
        this.width = width;
        this.channels = channels;
    }

    public ImageLoader(int height, int width, int channels, boolean centerCropIfNeeded) {
        this(height, width, channels);
        this.centerCropIfNeeded = centerCropIfNeeded;
    }

    @Override
    public INDArray asRowVector(File f) throws IOException {
        return this.asRowVector(ImageIO.read(f));
    }

    @Override
    public INDArray asRowVector(InputStream inputStream) throws IOException {
        return this.asRowVector(ImageIO.read(inputStream));
    }

    public INDArray asRowVector(BufferedImage image) {
        if (this.centerCropIfNeeded) {
            image = this.centerCropIfNeeded(image);
        }
        image = this.scalingIfNeed(image, true);
        if (this.channels == 3) {
            return this.toINDArrayBGR(image).ravel();
        }
        int[][] ret = this.toIntArrayArray(image);
        return NDArrayUtil.toNDArray((int[])ArrayUtil.flatten((int[][])ret));
    }

    public INDArray toRaveledTensor(File file) {
        try {
            BufferedInputStream bis = new BufferedInputStream(new FileInputStream(file));
            INDArray ret = this.toRaveledTensor(bis);
            bis.close();
            return ret.ravel();
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public INDArray toRaveledTensor(InputStream is) {
        return this.toBgr(is).ravel();
    }

    public INDArray toRaveledTensor(BufferedImage image) {
        try {
            image = this.scalingIfNeed(image, false);
            return this.toINDArrayBGR(image).ravel();
        }
        catch (Exception e) {
            throw new RuntimeException("Unable to load image", e);
        }
    }

    public INDArray toBgr(File file) {
        try {
            BufferedInputStream bis = new BufferedInputStream(new FileInputStream(file));
            INDArray ret = this.toBgr(bis);
            bis.close();
            return ret;
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public INDArray toBgr(InputStream inputStream) {
        try {
            BufferedImage image = ImageIO.read(inputStream);
            return this.toBgr(image);
        }
        catch (IOException e) {
            throw new RuntimeException("Unable to load image", e);
        }
    }

    private org.datavec.image.data.Image toBgrImage(InputStream inputStream) {
        try {
            BufferedImage image = ImageIO.read(inputStream);
            INDArray img = this.toBgr(image);
            return new org.datavec.image.data.Image(img, image.getData().getNumBands(), image.getHeight(), image.getWidth());
        }
        catch (IOException e) {
            throw new RuntimeException("Unable to load image", e);
        }
    }

    public INDArray toBgr(BufferedImage image) {
        if (image == null) {
            throw new IllegalStateException("Unable to load image");
        }
        image = this.scalingIfNeed(image, false);
        return this.toINDArrayBGR(image);
    }

    @Override
    public INDArray asMatrix(File f) throws IOException {
        return NDArrayUtil.toNDArray((int[][])this.fromFile(f));
    }

    @Override
    public INDArray asMatrix(InputStream inputStream) throws IOException {
        if (this.channels == 3) {
            return this.toBgr(inputStream);
        }
        try {
            BufferedImage image = ImageIO.read(inputStream);
            return this.asMatrix(image);
        }
        catch (IOException e) {
            throw new IOException("Unable to load image", e);
        }
    }

    @Override
    public org.datavec.image.data.Image asImageMatrix(File f) throws IOException {
        try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f));){
            org.datavec.image.data.Image image = this.asImageMatrix(bis);
            return image;
        }
    }

    @Override
    public org.datavec.image.data.Image asImageMatrix(InputStream inputStream) throws IOException {
        if (this.channels == 3) {
            return this.toBgrImage(inputStream);
        }
        try {
            BufferedImage image = ImageIO.read(inputStream);
            INDArray asMatrix = this.asMatrix(image);
            return new org.datavec.image.data.Image(asMatrix, image.getData().getNumBands(), image.getHeight(), image.getWidth());
        }
        catch (IOException e) {
            throw new IOException("Unable to load image", e);
        }
    }

    public INDArray asMatrix(BufferedImage image) {
        if (this.channels == 3) {
            return this.toBgr(image);
        }
        image = this.scalingIfNeed(image, true);
        int w = image.getWidth();
        int h = image.getHeight();
        INDArray ret = Nd4j.create((int)h, (int)w);
        for (int i = 0; i < h; ++i) {
            for (int j = 0; j < w; ++j) {
                ret.putScalar(new int[]{i, j}, image.getRGB(j, i));
            }
        }
        return ret;
    }

    public INDArray asImageMiniBatches(File f, int numMiniBatches, int numRowsPerSlice) {
        try {
            INDArray d = this.asMatrix(f);
            return Nd4j.create((int[])new int[]{numMiniBatches, numRowsPerSlice, d.columns()});
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public int[] flattenedImageFromFile(File f) throws IOException {
        return ArrayUtil.flatten((int[][])this.fromFile(f));
    }

    public int[][] fromFile(File file) throws IOException {
        BufferedImage image = ImageIO.read(file);
        image = this.scalingIfNeed(image, true);
        return this.toIntArrayArray(image);
    }

    public int[][][] fromFileMultipleChannels(File file) throws IOException {
        BufferedImage image = ImageIO.read(file);
        image = this.scalingIfNeed(image, this.channels > 3);
        int w = image.getWidth();
        int h = image.getHeight();
        int bands = image.getSampleModel().getNumBands();
        int[][][] ret = new int[this.channels][h][w];
        byte[] pixels = ((DataBufferByte)image.getRaster().getDataBuffer()).getData();
        for (int i = 0; i < h; ++i) {
            for (int j = 0; j < w; ++j) {
                for (int k = 0; k < this.channels && k < bands; ++k) {
                    ret[k][i][j] = pixels[this.channels * w * i + this.channels * j + k];
                }
            }
        }
        return ret;
    }

    public static BufferedImage toImage(INDArray matrix) {
        BufferedImage img = new BufferedImage(matrix.rows(), matrix.columns(), 2);
        WritableRaster r = img.getRaster();
        int[] equiv = new int[matrix.length()];
        for (int i = 0; i < equiv.length; ++i) {
            equiv[i] = (int)matrix.getDouble(i);
        }
        r.setDataElements(0, 0, matrix.rows(), matrix.columns(), equiv);
        return img;
    }

    private static int[] rasterData(INDArray matrix) {
        int[] ret = new int[matrix.length()];
        for (int i = 0; i < ret.length; ++i) {
            ret[i] = (int)Math.round((Double)matrix.getScalar(i).element());
        }
        return ret;
    }

    public void toBufferedImageRGB(INDArray arr, BufferedImage image) {
        if (arr.rank() < 3) {
            throw new IllegalArgumentException("Arr must be 3d");
        }
        image = this.scalingIfNeed(image, arr.size(-2), arr.size(-1), true);
        for (int i = 0; i < image.getHeight(); ++i) {
            for (int j = 0; j < image.getWidth(); ++j) {
                int r = arr.slice(2).getInt(new int[]{i, j});
                int g = arr.slice(1).getInt(new int[]{i, j});
                int b = arr.slice(0).getInt(new int[]{i, j});
                int a = 1;
                int col = a << 24 | r << 16 | g << 8 | b;
                image.setRGB(j, i, col);
            }
        }
    }

    public static BufferedImage toBufferedImage(Image img, int type) {
        if (img instanceof BufferedImage) {
            return (BufferedImage)img;
        }
        BufferedImage bimage = new BufferedImage(img.getWidth(null), img.getHeight(null), type);
        Graphics2D bGr = bimage.createGraphics();
        bGr.drawImage(img, 0, 0, null);
        bGr.dispose();
        return bimage;
    }

    protected int[][] toIntArrayArray(BufferedImage image) {
        int w = image.getWidth();
        int h = image.getHeight();
        int[][] ret = new int[h][w];
        if (image.getRaster().getNumDataElements() == 1) {
            WritableRaster raster = image.getRaster();
            for (int i = 0; i < h; ++i) {
                for (int j = 0; j < w; ++j) {
                    ret[i][j] = raster.getSample(j, i, 0);
                }
            }
        } else {
            for (int i = 0; i < h; ++i) {
                for (int j = 0; j < w; ++j) {
                    ret[i][j] = image.getRGB(j, i);
                }
            }
        }
        return ret;
    }

    protected INDArray toINDArrayBGR(BufferedImage image) {
        int height = image.getHeight();
        int width = image.getWidth();
        int bands = image.getSampleModel().getNumBands();
        byte[] pixels = ((DataBufferByte)image.getRaster().getDataBuffer()).getData();
        int[] shape = new int[]{height, width, bands};
        INDArray ret2 = Nd4j.create((int)1, (int)pixels.length);
        for (int i = 0; i < ret2.length(); ++i) {
            ret2.putScalar(i, pixels[i] & 0xFF);
        }
        return ret2.reshape(shape).permute(new int[]{2, 0, 1});
    }

    public BufferedImage centerCropIfNeeded(BufferedImage img) {
        int x = 0;
        int y = 0;
        int height = img.getHeight();
        int width = img.getWidth();
        int diff = Math.abs(width - height) / 2;
        if (width > height) {
            x = diff;
            width -= diff;
        } else if (height > width) {
            y = diff;
            height -= diff;
        }
        return img.getSubimage(x, y, width, height);
    }

    protected BufferedImage scalingIfNeed(BufferedImage image, boolean needAlpha) {
        return this.scalingIfNeed(image, this.height, this.width, needAlpha);
    }

    protected BufferedImage scalingIfNeed(BufferedImage image, int dstHeight, int dstWidth, boolean needAlpha) {
        if (dstHeight > 0 && dstWidth > 0 && (image.getHeight() != dstHeight || image.getWidth() != dstWidth)) {
            Image scaled = image.getScaledInstance(dstWidth, dstHeight, 4);
            if (needAlpha && image.getColorModel().hasAlpha() && this.channels == 6) {
                return ImageLoader.toBufferedImage(scaled, 6);
            }
            if (this.channels == 10) {
                return ImageLoader.toBufferedImage(scaled, 10);
            }
            return ImageLoader.toBufferedImage(scaled, 5);
        }
        if (image.getType() == 6 || image.getType() == 5) {
            return image;
        }
        if (needAlpha && image.getColorModel().hasAlpha() && this.channels == 6) {
            return ImageLoader.toBufferedImage(image, 6);
        }
        if (this.channels == 10) {
            return ImageLoader.toBufferedImage(image, 10);
        }
        return ImageLoader.toBufferedImage(image, 5);
    }

    static {
        ImageIO.scanForPlugins();
        IIORegistry registry = IIORegistry.getDefaultInstance();
        registry.registerServiceProvider(new TIFFImageWriterSpi());
        registry.registerServiceProvider(new TIFFImageReaderSpi());
        registry.registerServiceProvider(new JPEGImageReaderSpi());
        registry.registerServiceProvider(new JPEGImageWriterSpi());
        registry.registerServiceProvider(new PSDImageReaderSpi());
        registry.registerServiceProvider(Arrays.asList(new BMPImageReaderSpi(), new CURImageReaderSpi(), new ICOImageReaderSpi()));
    }
}

