/*
 * Decompiled with CFR 0.152.
 */
package org.datavec.image.loader;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.SequenceInputStream;
import java.io.Serializable;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.bytedeco.javacpp.opencv_core;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.api.berkeley.Pair;
import org.datavec.image.data.ImageWritable;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.transform.ColorConversionTransform;
import org.datavec.image.transform.EqualizeHistTransform;
import org.datavec.image.transform.ImageTransform;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.impl.accum.Sum;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;

public class CifarLoader
extends NativeImageLoader
implements Serializable {
    public static final int NUM_TRAIN_IMAGES = 50000;
    public static final int NUM_TEST_IMAGES = 10000;
    public static final int NUM_LABELS = 10;
    public static final int HEIGHT = 32;
    public static final int WIDTH = 32;
    public static final int CHANNELS = 3;
    public static final int BYTEFILELEN = 3073;
    public static String dataBinUrl = "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz";
    public static String localDir = "cifar";
    public static String dataBinFile = "cifar-10-batches-bin";
    public static File fullDir = new File(BASE_DIR, FilenameUtils.concat((String)localDir, (String)dataBinFile));
    public static File meanVarPath = new File(fullDir, "meanVarPath.txt");
    protected static String labelFileName = "batches.meta.txt";
    protected static InputStream inputStream;
    protected static InputStream trainInputStream;
    protected static InputStream testInputStream;
    protected static List<DataSet> inputBatched;
    protected static List<String> labels;
    public static String[] TRAINFILENAMES;
    public static String TESTFILENAME;
    protected static String trainFilesSerialized;
    protected static String testFilesSerialized;
    protected static boolean train;
    public static boolean useSpecialPreProcessCifar;
    public static Map<String, String> cifarDataMap;
    protected static int height;
    protected static int width;
    protected static int channels;
    protected static long seed;
    protected static boolean shuffle;
    protected int numExamples = 0;
    protected static int numToConvertDS;
    protected double uMean = 0.0;
    protected double uStd = 0.0;
    protected double vMean = 0.0;
    protected double vStd = 0.0;
    protected boolean meanStdStored = false;
    protected int loadDSIndex = 0;
    protected DataSet loadDS = new DataSet();
    protected int fileNum = 0;

    public CifarLoader() {
        this(height, width, channels, null, train, useSpecialPreProcessCifar, fullDir, seed, shuffle);
    }

    public CifarLoader(boolean train) {
        this(height, width, channels, null, train, useSpecialPreProcessCifar, fullDir, seed, shuffle);
    }

    public CifarLoader(boolean train, File fullPath) {
        this(height, width, channels, null, train, useSpecialPreProcessCifar, fullPath, seed, shuffle);
    }

    public CifarLoader(int height, int width, int channels, boolean train, boolean useSpecialPreProcessCifar) {
        this(height, width, channels, null, train, useSpecialPreProcessCifar, fullDir, seed, shuffle);
    }

    public CifarLoader(int height, int width, int channels, ImageTransform imgTransform, boolean train, boolean useSpecialPreProcessCifar) {
        this(height, width, channels, imgTransform, train, useSpecialPreProcessCifar, fullDir, seed, shuffle);
    }

    public CifarLoader(int height, int width, int channels, ImageTransform imgTransform, boolean train, boolean useSpecialPreProcessCifar, boolean shuffle) {
        this(height, width, channels, imgTransform, train, useSpecialPreProcessCifar, fullDir, seed, shuffle);
    }

    public CifarLoader(int height, int width, int channels, ImageTransform imgTransform, boolean train, boolean useSpecialPreProcessCifar, File fullPath, long seed, boolean shuffle) {
        super(height, width, channels, imgTransform);
        CifarLoader.height = height;
        CifarLoader.width = width;
        CifarLoader.channels = channels;
        CifarLoader.train = train;
        CifarLoader.useSpecialPreProcessCifar = useSpecialPreProcessCifar;
        fullDir = fullPath;
        CifarLoader.seed = seed;
        CifarLoader.shuffle = shuffle;
        this.load();
    }

    @Override
    public INDArray asRowVector(File f) throws IOException {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray asRowVector(InputStream inputStream) throws IOException {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray asMatrix(File f) throws IOException {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray asMatrix(InputStream inputStream) throws IOException {
        throw new UnsupportedOperationException();
    }

    public void generateMaps() {
        cifarDataMap.put("filesFilename", new File(dataBinUrl).getName());
        cifarDataMap.put("filesURL", dataBinUrl);
        cifarDataMap.put("filesFilenameUnzipped", dataBinFile);
    }

    private void defineLabels() {
        try {
            String line;
            File path = new File(fullDir, labelFileName);
            BufferedReader br = new BufferedReader(new FileReader(path));
            while ((line = br.readLine()) != null) {
                labels.add(line);
            }
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void load() {
        if (!this.cifarRawFilesExist() && !fullDir.exists()) {
            this.generateMaps();
            fullDir.mkdir();
            log.info("Downloading {}...", (Object)localDir);
            CifarLoader.downloadAndUntar(cifarDataMap, new File(BASE_DIR, localDir));
        }
        try {
            Collection subFiles = FileUtils.listFiles((File)fullDir, (String[])new String[]{"bin"}, (boolean)true);
            Iterator trainIter = subFiles.iterator();
            trainInputStream = new SequenceInputStream(new FileInputStream((File)trainIter.next()), new FileInputStream((File)trainIter.next()));
            while (trainIter.hasNext()) {
                File nextFile = (File)trainIter.next();
                if (TESTFILENAME.equals(nextFile.getName())) continue;
                trainInputStream = new SequenceInputStream(trainInputStream, new FileInputStream(nextFile));
            }
            testInputStream = new FileInputStream(new File(fullDir, TESTFILENAME));
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        if (labels.isEmpty()) {
            this.defineLabels();
        }
        if (useSpecialPreProcessCifar && train && !this.cifarProcessedFilesExists()) {
            for (int i = this.fileNum + 1; i <= TRAINFILENAMES.length; ++i) {
                inputStream = trainInputStream;
                DataSet result = this.convertDataSet(numToConvertDS);
                result.save(new File(trainFilesSerialized + i + ".ser"));
            }
            inputStream = testInputStream;
            DataSet result = this.convertDataSet(numToConvertDS);
            result.save(new File(testFilesSerialized));
        }
        this.setInputStream();
    }

    public boolean cifarRawFilesExist() {
        File f = new File(fullDir, TESTFILENAME);
        if (!f.exists()) {
            return false;
        }
        for (String name : TRAINFILENAMES) {
            f = new File(fullDir, name);
            if (f.exists()) continue;
            return false;
        }
        return true;
    }

    private boolean cifarProcessedFilesExists() {
        File f;
        return !(train ? !(f = new File(trainFilesSerialized + 1 + ".ser")).exists() : !(f = new File(testFilesSerialized)).exists());
    }

    public opencv_core.Mat convertCifar(opencv_core.Mat orgImage) {
        ++this.numExamples;
        opencv_core.Mat resImage = new opencv_core.Mat();
        OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
        ColorConversionTransform yuvTransform = new ColorConversionTransform(new Random(seed), 36);
        EqualizeHistTransform histEqualization = new EqualizeHistTransform(new Random(seed), 36);
        if (converter != null) {
            ImageWritable writable = new ImageWritable(converter.convert(orgImage));
            writable = yuvTransform.transform(writable);
            writable = histEqualization.transform(writable);
            resImage = converter.convert(writable.getFrame());
        }
        return resImage;
    }

    public void normalizeCifar(File fileName) {
        DataSet result = new DataSet();
        result.load(fileName);
        if (!this.meanStdStored && train) {
            this.uMean = Math.abs(this.uMean / (double)this.numExamples);
            this.uStd = Math.sqrt(this.uStd);
            this.vMean = Math.abs(this.vMean / (double)this.numExamples);
            this.vStd = Math.sqrt(this.vStd);
            try {
                FileUtils.write((File)meanVarPath, (CharSequence)(this.uMean + "," + this.uStd + "," + this.vMean + "," + this.vStd));
            }
            catch (IOException e) {
                e.printStackTrace();
            }
            this.meanStdStored = true;
        } else if (this.uMean == 0.0 && this.meanStdStored) {
            try {
                String[] values = FileUtils.readFileToString((File)meanVarPath).split(",");
                this.uMean = Double.parseDouble(values[0]);
                this.uStd = Double.parseDouble(values[1]);
                this.vMean = Double.parseDouble(values[2]);
                this.vStd = Double.parseDouble(values[3]);
            }
            catch (IOException e) {
                e.printStackTrace();
            }
        }
        for (int i = 0; i < result.numExamples(); ++i) {
            INDArray newFeatures = result.get(i).getFeatureMatrix();
            newFeatures.tensorAlongDimension(0, new int[]{0, 2, 3}).divi((Number)255);
            newFeatures.tensorAlongDimension(1, new int[]{0, 2, 3}).subi((Number)this.uMean).divi((Number)this.uStd);
            newFeatures.tensorAlongDimension(2, new int[]{0, 2, 3}).subi((Number)this.vMean).divi((Number)this.vStd);
            result.get(i).setFeatures(newFeatures);
        }
        result.save(fileName);
    }

    public Pair<INDArray, opencv_core.Mat> convertMat(byte[] byteFeature) {
        INDArray label = FeatureUtil.toOutcomeVector((int)byteFeature[0], (int)10);
        opencv_core.Mat image = new opencv_core.Mat(32, 32, opencv_core.CV_8UC((int)3));
        ByteBuffer imageData = (ByteBuffer)image.createBuffer();
        for (int i = 0; i < 1024; ++i) {
            imageData.put(3 * i, byteFeature[i + 1 + 2 * height * width]);
            imageData.put(3 * i + 1, byteFeature[i + 1 + height * width]);
            imageData.put(3 * i + 2, byteFeature[i + 1]);
        }
        return new Pair((Object)label, (Object)image);
    }

    public DataSet convertDataSet(int num) {
        ArrayList<DataSet> dataSets = new ArrayList<DataSet>();
        byte[] byteFeature = new byte[3073];
        try {
            for (int batchNumCount = 0; inputStream.read(byteFeature) != -1 && batchNumCount != num; ++batchNumCount) {
                Pair<INDArray, opencv_core.Mat> matConversion = this.convertMat(byteFeature);
                try {
                    dataSets.add(new DataSet(this.asMatrix((opencv_core.Mat)matConversion.getSecond()), (INDArray)matConversion.getFirst()));
                    continue;
                }
                catch (Exception e) {
                    break;
                }
            }
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        DataSet result = new DataSet();
        try {
            result = DataSet.merge(dataSets);
        }
        catch (IllegalArgumentException e) {
            return result;
        }
        for (DataSet data : result) {
            try {
                if (useSpecialPreProcessCifar) {
                    INDArray uChannel = data.getFeatures().tensorAlongDimension(1, new int[]{0, 2, 3});
                    INDArray vChannel = data.getFeatures().tensorAlongDimension(2, new int[]{0, 2, 3});
                    double uTempMean = uChannel.meanNumber().doubleValue();
                    this.uStd += this.varManual(uChannel, uTempMean);
                    this.uMean += uTempMean;
                    double vTempMean = vChannel.meanNumber().doubleValue();
                    this.vStd += this.varManual(vChannel, vTempMean);
                    this.vMean += vTempMean;
                    data.setFeatures(data.getFeatureMatrix().div((Number)255));
                    continue;
                }
                data.setFeatures(data.getFeatureMatrix().div((Number)255));
            }
            catch (IllegalArgumentException e) {
                throw new IllegalStateException("The number of channels must be 3 to special preProcess Cifar with.");
            }
        }
        if (shuffle && num > 1) {
            result.shuffle(seed);
        }
        return result;
    }

    public double varManual(INDArray x, double mean) {
        INDArray xSubMean = x.sub((Number)mean);
        INDArray squared = xSubMean.muli(xSubMean);
        double accum = Nd4j.getExecutioner().execAndReturn((Accumulation)new Sum(squared)).getFinalResult().doubleValue();
        return accum / (double)x.ravel().length();
    }

    public DataSet next(int batchSize) {
        return this.next(batchSize, 0);
    }

    public DataSet next(int batchSize, int exampleNum) {
        DataSet result;
        ArrayList<DataSet> temp = new ArrayList<DataSet>();
        if (this.cifarProcessedFilesExists() && useSpecialPreProcessCifar) {
            if (exampleNum == 0 || exampleNum / this.fileNum == numToConvertDS && train) {
                ++this.fileNum;
                if (train) {
                    this.loadDS.load(new File(trainFilesSerialized + this.fileNum + ".ser"));
                }
                this.loadDS.load(new File(testFilesSerialized));
                if (shuffle && batchSize > 1) {
                    this.loadDS.shuffle(seed);
                }
                this.loadDSIndex = 0;
            }
            for (int i = 0; i < batchSize && this.loadDS.get(this.loadDSIndex) != null; ++i) {
                temp.add(this.loadDS.get(this.loadDSIndex));
                ++this.loadDSIndex;
            }
            result = temp.size() > 1 ? DataSet.merge(temp) : (DataSet)temp.get(0);
        } else {
            result = this.convertDataSet(batchSize);
        }
        return result;
    }

    public InputStream getInputStream() {
        return inputStream;
    }

    public void setInputStream() {
        inputStream = train ? trainInputStream : testInputStream;
    }

    public List<String> getLabels() {
        return labels;
    }

    public void reset() {
        this.numExamples = 0;
        this.fileNum = 0;
        this.load();
    }

    public void train() {
        train = true;
        this.setInputStream();
    }

    public void test() {
        train = false;
        this.setInputStream();
        shuffle = false;
        this.numExamples = 0;
        this.fileNum = 0;
    }

    static {
        labels = new ArrayList<String>();
        TRAINFILENAMES = new String[]{"data_batch_1.bin", "data_batch_2.bin", "data_batch_3.bin", "data_batch_4.bin", "data_batch5.bin"};
        TESTFILENAME = "test_batch.bin";
        trainFilesSerialized = FilenameUtils.concat((String)fullDir.toString(), (String)"cifar_train_serialized");
        testFilesSerialized = FilenameUtils.concat((String)fullDir.toString(), (String)"cifar_test_serialized.ser");
        train = true;
        useSpecialPreProcessCifar = false;
        cifarDataMap = new HashMap<String, String>();
        height = 32;
        width = 32;
        channels = 3;
        seed = System.currentTimeMillis();
        shuffle = true;
        numToConvertDS = 10000;
    }
}

