/*
 * Decompiled with CFR 0.152.
 */
package deepnetts.data;

import deepnetts.core.DeepNetts;
import deepnetts.data.ExampleImage;
import deepnetts.data.TabularDataSet;
import deepnetts.util.DeepNettsException;
import deepnetts.util.ImageSetUtils;
import deepnetts.util.ImageUtils;
import deepnetts.util.Tensor;
import java.awt.image.BufferedImage;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Callable;
import javax.imageio.ImageIO;
import javax.visrec.ml.data.Column;
import javax.visrec.ml.data.DataSet;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class ImageSet
extends TabularDataSet<ExampleImage> {
    private final int imageWidth;
    private final int imageHeight;
    private boolean scaleImages = true;
    private boolean invertImages = false;
    private Tensor mean;
    private String delimiter = " ";
    private static String NEGATIVE_LABEL = "negative";
    private static final Logger LOGGER = LogManager.getLogger((String)DeepNetts.class.getName());
    private final Object LOCK = new Object();

    public ImageSet(int imageWidth, int imageHeight) {
        this.imageWidth = imageWidth;
        this.imageHeight = imageHeight;
    }

    public ImageSet(int imageWidth, int imageHeight, String imageDirPath) throws IOException {
        this.imageWidth = imageWidth;
        this.imageHeight = imageHeight;
        ImageSetUtils.createImageIndex(imageDirPath);
        ImageSetUtils.createLabelsIndex(imageDirPath);
        this.setScaleImages(true);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public DataSet<ExampleImage> add(ExampleImage exImage) throws DeepNettsException {
        if (exImage == null) {
            throw new DeepNettsException("Example image cannot be null!");
        }
        Object object = this.LOCK;
        synchronized (object) {
            this.items.add(exImage);
        }
        return this;
    }

    public void loadImages(String imageIdxFile) throws FileNotFoundException {
        this.loadImages(new File(imageIdxFile));
    }

    public void loadImages(File imageIdxFile) throws FileNotFoundException {
        Objects.requireNonNull(imageIdxFile, "Index file cannot be null!");
        if (this.columnNames == null) {
            throw new DeepNettsException("Error: Labels are not loaded. In order to load images correctly you have to load labels first using ImageSet.loadLabels method.");
        }
        String rootPath = imageIdxFile.getPath().substring(0, imageIdxFile.getPath().lastIndexOf(File.separator));
        String imgFileName = null;
        String label = null;
        LinkedList<BufferedImage> images = new LinkedList<BufferedImage>();
        LinkedList<String> labels = new LinkedList<String>();
        try (BufferedReader br = new BufferedReader(new FileReader(imageIdxFile));){
            String line = null;
            int lineCount = 0;
            while ((line = br.readLine()) != null) {
                ++lineCount;
                if (line.isEmpty()) continue;
                String[] parts = line.split(this.delimiter);
                if (parts.length > 2) {
                    throw new DeepNettsException("Bad file format: image paths and labels should not contain spaces! At line " + lineCount);
                }
                imgFileName = parts[0];
                if (parts.length == 2) {
                    label = parts[1];
                } else if (parts.length == 1) {
                    int labelEndIdx = imgFileName.lastIndexOf(File.separator);
                    label = imgFileName.substring(0, labelEndIdx);
                }
                imgFileName = rootPath + File.separator + imgFileName;
                BufferedImage image = ImageIO.read(new File(imgFileName));
                images.add(image);
                labels.add(label);
            }
            this.processImages(images, labels);
            if (this.isEmpty()) {
                throw new DeepNettsException("Zero images loaded!");
            }
            LOGGER.info("Loaded " + this.size() + " images");
        }
        catch (FileNotFoundException ex) {
            LOGGER.error((Object)ex);
            throw new DeepNettsException("Could not find image file: " + imgFileName, ex);
        }
        catch (IOException ex) {
            LOGGER.error((Object)ex);
            throw new DeepNettsException("Error loading image file: " + imgFileName, ex);
        }
        catch (NullPointerException ex) {
            LOGGER.error((Object)ex);
            throw new DeepNettsException("Error loading image file: " + imgFileName, ex);
        }
    }

    public void loadImages(File imageIdxFile, int numOfImages) throws DeepNettsException {
        Objects.requireNonNull(imageIdxFile, "Index file cannot be null!");
        if (this.columnNames == null) {
            throw new DeepNettsException("Error: Labels are not loaded. In order to load images correctly you have to load labels first using ImageSet.loadLabels method.");
        }
        String rootPath = imageIdxFile.getPath().substring(0, imageIdxFile.getPath().lastIndexOf(File.separator));
        String imgFileName = null;
        String label = null;
        LinkedList<BufferedImage> images = new LinkedList<BufferedImage>();
        LinkedList<String> labels = new LinkedList<String>();
        try (BufferedReader br = new BufferedReader(new FileReader(imageIdxFile));){
            String line = null;
            for (int i = 0; i < numOfImages; ++i) {
                line = br.readLine();
                if (line.isEmpty()) continue;
                String[] parts = line.split(this.delimiter);
                if (parts.length > 2) {
                    throw new DeepNettsException("Bad file format: image paths and labels should not contain spaces! At line " + i);
                }
                imgFileName = parts[0];
                if (parts.length == 2) {
                    label = parts[1];
                } else if (parts.length == 1) {
                    int labelEndIdx = imgFileName.lastIndexOf(File.separator);
                    label = imgFileName.substring(0, labelEndIdx);
                }
                imgFileName = rootPath + File.separator + imgFileName;
                BufferedImage image = ImageIO.read(new File(imgFileName));
                images.add(image);
                labels.add(label);
            }
            this.processImages(images, labels);
        }
        catch (FileNotFoundException ex) {
            LOGGER.error((Object)ex);
            throw new DeepNettsException("Could not find image file: " + imgFileName, ex);
        }
        catch (IOException ex) {
            LOGGER.error((Object)ex);
            throw new DeepNettsException("Error loading image file: " + imgFileName, ex);
        }
        if (this.isEmpty()) {
            throw new DeepNettsException("Zero images loaded!");
        }
        LOGGER.info("Loaded " + this.size() + " images");
    }

    private void processImages(List<BufferedImage> images, List<String> labels) throws IOException {
        for (int i = 0; i < images.size(); ++i) {
            BufferedImage img = images.get(i);
            String lbl = labels.get(i);
            if (this.scaleImages) {
                img = ImageUtils.scaleImage(img, this.imageWidth, this.imageHeight);
            }
            ExampleImage exImg = new ExampleImage(img, lbl);
            exImg.setTargetOutput(new Tensor(this.oneHotEncode(lbl, this.columnNames)));
            if (this.invertImages) {
                exImg.invert();
            }
            this.add(exImg);
        }
    }

    public void invert() {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    private float[] oneHotEncode(String label, String[] labels) {
        float[] returnArr = new float[labels.length];
        if (label.equalsIgnoreCase(NEGATIVE_LABEL)) {
            return returnArr;
        }
        for (int i = 0; i < labels.length; ++i) {
            if (!labels[i].equals(label)) continue;
            returnArr[i] = 1.0f;
        }
        return returnArr;
    }

    public int getLabelsCount() {
        return this.columnNames.length;
    }

    public ImageSet[] split(double ... partSizes) {
        if (partSizes.length < 2) {
            throw new IllegalArgumentException("Must specify at least two parts");
        }
        int partsSum = 0;
        for (int i = 0; i < partSizes.length; ++i) {
            if (partSizes[i] <= 0.0) {
                throw new IllegalArgumentException("Value of the part cannot be zero or negative!");
            }
            partsSum = (int)((double)partsSum + partSizes[i]);
        }
        if (partsSum > 1) {
            throw new IllegalArgumentException("Sum of parts/percents cannot be larger than 1!");
        }
        LOGGER.info("Splitting data set: " + Arrays.toString(partSizes));
        ImageSet[] subSets = new ImageSet[partSizes.length];
        int itemIdx = 0;
        for (int p = 0; p < partSizes.length; ++p) {
            ImageSet subSet = new ImageSet(this.imageWidth, this.imageHeight);
            int itemsCount = (int)((double)this.size() * partSizes[p]);
            for (int j = 0; j < itemsCount; ++j) {
                subSet.add((ExampleImage)this.items.get(itemIdx));
                ++itemIdx;
            }
            subSets[p] = subSet;
            subSet.setColumnNames(this.columnNames);
            subSet.setColumns(this.getColumns());
            subSet.setColumns(this.getColumns());
        }
        return subSets;
    }

    public String[] loadLabels(String filePath) throws DeepNettsException {
        return this.loadLabels(new File(filePath));
    }

    public String[] loadLabels(File file) throws DeepNettsException {
        String[] stringArray;
        BufferedReader br = new BufferedReader(new FileReader(file));
        try {
            String line = null;
            ArrayList<String> labelsList = new ArrayList<String>();
            while ((line = br.readLine()) != null) {
                if (line.isEmpty()) continue;
                if ((line = line.trim()).contains(" ")) {
                    throw new DeepNettsException("Bad label format: Labels should not contain space characters! For label:" + line);
                }
                labelsList.add(line);
                this.getColumns().add(new Column(line, Column.Type.BINARY, true));
            }
            this.columnNames = labelsList.toArray(new String[labelsList.size()]);
            this.setAsTargetColumns(this.columnNames);
            LOGGER.info("Loaded " + labelsList.size() + " labels");
            stringArray = this.columnNames;
        }
        catch (Throwable throwable) {
            try {
                try {
                    br.close();
                }
                catch (Throwable throwable2) {
                    throwable.addSuppressed(throwable2);
                }
                throw throwable;
            }
            catch (FileNotFoundException ex) {
                LOGGER.error("Could not find labels file: " + file.getAbsolutePath(), (Throwable)ex);
                throw new DeepNettsException("Could not find labels file: " + file.getAbsolutePath(), ex);
            }
            catch (IOException ex) {
                LOGGER.error("Error reading labels file: " + file.getAbsolutePath(), (Throwable)ex);
                throw new DeepNettsException("Error reading labels file: " + file.getAbsolutePath(), ex);
            }
        }
        br.close();
        return stringArray;
    }

    public Tensor zeroMean() {
        this.mean = new Tensor(this.imageHeight, this.imageWidth, 3);
        this.items.forEach(img -> this.mean.add(img.getInput()));
        this.mean.div(this.items.size());
        for (ExampleImage image : this.items) {
            image.getInput().sub(this.mean);
        }
        return this.mean;
    }

    public boolean getScaleImages() {
        return this.scaleImages;
    }

    public final void setScaleImages(boolean scaleImages) {
        this.scaleImages = scaleImages;
    }

    public boolean getInvertImages() {
        return this.invertImages;
    }

    public void setInvertImages(boolean invertImages) {
        this.invertImages = invertImages;
    }

    public Map<String, Integer> countByClasses() {
        HashMap<String, Integer> map = new HashMap<String, Integer>();
        for (ExampleImage item : this.items) {
            if (map.containsKey(item.getLabel())) {
                String key = item.getLabel();
                map.put(key, map.get(key) + 1);
                continue;
            }
            map.put(item.getLabel(), 0);
        }
        LOGGER.info("Number of images by label/class");
        for (String key : map.keySet()) {
            LOGGER.info(key + " : " + map.get(key));
        }
        return map;
    }

    public String getDelimiter() {
        return this.delimiter;
    }

    public void setDelimiter(String delimiter) {
        this.delimiter = delimiter;
    }

    private class ImageProcessor
    implements Callable<Boolean> {
        private final List<BufferedImage> images;
        private final List<String> labels;
        private final int start;
        private final int end;

        public ImageProcessor(List<BufferedImage> images, List<String> labels, int start, int end) {
            this.images = images;
            this.labels = labels;
            this.start = start;
            this.end = end;
        }

        @Override
        public Boolean call() throws IOException {
            Iterator<String> li = this.labels.iterator();
            for (int i = this.start; i < this.end; ++i) {
                BufferedImage img = this.images.get(i);
                String lbl = li.next();
                if (ImageSet.this.scaleImages) {
                    img = ImageUtils.scaleImage(img, ImageSet.this.imageWidth, ImageSet.this.imageHeight);
                }
                ExampleImage exImg = new ExampleImage(img, lbl);
                exImg.setTargetOutput(new Tensor(ImageSet.this.oneHotEncode(lbl, ImageSet.this.columnNames)));
                if (ImageSet.this.invertImages) {
                    exImg.invert();
                }
                ImageSet.this.add(exImg);
            }
            return Boolean.TRUE;
        }
    }
}

