/*
 * Decompiled with CFR 0.152.
 */
package smile.datasets;

import java.io.BufferedInputStream;
import java.io.DataInputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.file.Path;
import java.util.stream.IntStream;
import org.apache.commons.csv.CSVFormat;
import smile.data.CategoricalEncoder;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.data.type.DataTypes;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.IntVector;
import smile.io.Read;
import smile.util.Paths;

public record MNIST(DataFrame data, Formula formula) {
    public MNIST() throws IOException {
        this(Paths.getTestData("mnist/mnist2500_X.txt"), Paths.getTestData("mnist/mnist2500_labels.txt"));
    }

    public MNIST(Path dataFilePath, Path labelFilePath) throws IOException {
        this(dataFilePath.toString().endsWith(".txt") ? MNIST.loadText(dataFilePath, labelFilePath) : MNIST.loadBinary(dataFilePath, labelFilePath), Formula.lhs("class"));
    }

    private static DataFrame loadBinary(Path dataFilePath, Path labelFilePath) throws IOException {
        try (DataInputStream dataInputStream = new DataInputStream(new BufferedInputStream(new FileInputStream(dataFilePath.toFile())));){
            DataFrame dataFrame;
            try (DataInputStream labelInputStream = new DataInputStream(new BufferedInputStream(new FileInputStream(labelFilePath.toFile())));){
                int magicNumber = dataInputStream.readInt();
                if (magicNumber != 2051) {
                    throw new IOException("Invalid MNIST data file magic number: " + magicNumber);
                }
                int size = dataInputStream.readInt();
                int nrow = dataInputStream.readInt();
                int ncol = dataInputStream.readInt();
                int length = nrow * ncol;
                int labelMagicNumber = labelInputStream.readInt();
                if (labelMagicNumber != 2049) {
                    throw new IOException("Invalid MNIST label file magic number: " + labelMagicNumber);
                }
                int labelSize = labelInputStream.readInt();
                if (labelSize != size) {
                    throw new IOException("Data file and label file have different size: " + size + " vs " + labelSize);
                }
                float[][] data = new float[size][length];
                int[] y = new int[size];
                for (int i = 0; i < size; ++i) {
                    y[i] = labelInputStream.readUnsignedByte();
                    float[] x = data[i];
                    int j = 0;
                    for (int r = 0; r < nrow; ++r) {
                        int c = 0;
                        while (c < ncol) {
                            x[j] = (float)dataInputStream.readUnsignedByte() / 255.0f;
                            ++c;
                            ++j;
                        }
                    }
                }
                DataFrame df = DataFrame.of(data, new String[0]);
                dataFrame = df.add(new IntVector("class", y));
            }
            return dataFrame;
        }
    }

    private static DataFrame loadText(Path dataFilePath, Path labelFilePath) throws IOException {
        StructType schema = new StructType(IntStream.range(1, 785).mapToObj(i -> new StructField("V" + i, DataTypes.FloatType)).toList());
        CSVFormat format = CSVFormat.Builder.create().setDelimiter(' ').get();
        DataFrame data = Read.csv(dataFilePath, format, schema);
        int[] y = Read.csv(labelFilePath, format).column(0).toIntArray();
        return data.add(new IntVector("class", y));
    }

    public double[][] x() {
        return this.formula.x(this.data).toArray(false, CategoricalEncoder.DUMMY, new String[0]);
    }

    public int[] y() {
        return this.formula.y(this.data).toIntArray();
    }
}

