/*
 * Decompiled with CFR 0.152.
 */
package deepnetts.data;

import deepnetts.data.MLDataItem;
import deepnetts.util.ImageUtils;
import deepnetts.util.Tensor;
import java.awt.image.BufferedImage;
import java.awt.image.WritableRaster;
import java.io.File;
import java.io.IOException;
import javax.imageio.ImageIO;

public class ExampleImage
implements MLDataItem {
    private final int width;
    private final int height;
    private final String label;
    private Tensor targetOutput;
    private float[] rgbVector;
    private Tensor rgbTensor;
    private File file;

    public ExampleImage(File imgFile, String label) throws IOException {
        this.label = label;
        this.file = imgFile;
        BufferedImage image = ImageIO.read(imgFile);
        this.width = image.getWidth();
        this.height = image.getHeight();
        this.createInputFromPixels(image);
    }

    public ExampleImage(BufferedImage image, String label) {
        this.label = label;
        this.width = image.getWidth();
        this.height = image.getHeight();
        this.createInputFromPixels(image);
    }

    public ExampleImage(BufferedImage image) {
        this(image, null);
    }

    public ExampleImage(BufferedImage image, String label, int targetWidth, int targetHeight) throws IOException {
        this.label = label;
        this.width = targetWidth;
        this.height = targetHeight;
        if (image.getWidth() != targetWidth || image.getHeight() != targetHeight) {
            image = ImageUtils.scaleImage(image, targetWidth, targetHeight);
        }
        this.createInputFromPixels(image);
    }

    private void createInputFromPixels(BufferedImage image) {
        this.rgbVector = new float[this.width * this.height * 3];
        if (image.getType() != 2) {
            BufferedImage imageCopy = new BufferedImage(image.getWidth(), image.getHeight(), 2);
            imageCopy.getGraphics().drawImage(image, 0, 0, null);
            image = imageCopy;
        }
        WritableRaster raster = image.getRaster();
        float[] pixel = null;
        for (int y = 0; y < this.height; ++y) {
            for (int x = 0; x < this.width; ++x) {
                pixel = raster.getPixel(x, y, pixel);
                this.rgbVector[y * this.width + x] = pixel[0] / 255.0f;
                this.rgbVector[this.width * this.height + y * this.width + x] = pixel[1] / 255.0f;
                this.rgbVector[2 * this.width * this.height + y * this.width + x] = pixel[2] / 255.0f;
            }
        }
        this.rgbTensor = new Tensor(this.height, this.width, 3, this.rgbVector);
    }

    public void invert() {
        for (int i = 0; i < this.rgbVector.length; ++i) {
            this.rgbVector[i] = 1.0f - this.rgbVector[i];
        }
    }

    @Override
    public Tensor getTargetOutput() {
        return this.targetOutput;
    }

    public float[] getRgbVector() {
        return this.rgbVector;
    }

    public final void setTargetOutput(Tensor targetOutput) {
        this.targetOutput = targetOutput;
    }

    public int getWidth() {
        return this.width;
    }

    public int getHeight() {
        return this.height;
    }

    public String getLabel() {
        return this.label;
    }

    @Override
    public Tensor getInput() {
        return this.rgbTensor;
    }

    public File getFile() {
        return this.file;
    }
}

