/*
 * Decompiled with CFR 0.152.
 */
package deepnetts.util;

import deepnetts.net.ConvolutionalNetwork;
import deepnetts.net.FeedForwardNetwork;
import deepnetts.net.NetworkType;
import deepnetts.net.NeuralNetwork;
import deepnetts.net.layers.AbstractLayer;
import deepnetts.net.layers.ConvolutionalLayer;
import deepnetts.net.layers.FullyConnectedLayer;
import deepnetts.net.layers.InputLayer;
import deepnetts.net.layers.LayerType;
import deepnetts.net.layers.MaxPoolingLayer;
import deepnetts.net.layers.OutputLayer;
import deepnetts.net.layers.SoftmaxOutputLayer;
import deepnetts.net.layers.activation.ActivationType;
import deepnetts.net.loss.CrossEntropyLoss;
import deepnetts.net.loss.LossType;
import deepnetts.net.loss.MeanSquaredErrorLoss;
import deepnetts.util.Tensor;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import org.json.JSONArray;
import org.json.JSONObject;

public class FileIO {
    public static final String NETWORK_FILE_EXT = "dnet";

    private FileIO() {
    }

    public static void writeToFile(NeuralNetwork neuralNet, String fileName) throws IOException {
        try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(fileName));){
            oos.writeObject(neuralNet);
        }
    }

    public static void writeToFileAsJson(NeuralNetwork neuralNet, String fileName) throws IOException {
        String jsonStr = FileIO.toJson(neuralNet);
        try (PrintWriter pw = new PrintWriter(new File(fileName));){
            pw.print(jsonStr);
        }
    }

    public static <T> T createFromFile(String fileName, Class<T> clazz) throws IOException, ClassNotFoundException {
        T neuralNet;
        try (ObjectInputStream ois = new ObjectInputStream(new FileInputStream(fileName));){
            neuralNet = clazz.cast(ois.readObject());
        }
        return neuralNet;
    }

    public static NeuralNetwork createFromFile(File file) throws IOException, ClassNotFoundException {
        ConvolutionalNetwork nnet;
        try (ObjectInputStream ois = new ObjectInputStream(new FileInputStream(file));){
            nnet = (ConvolutionalNetwork)ois.readObject();
        }
        return nnet;
    }

    public static String toJson(NeuralNetwork<?> nnet) {
        JSONObject json = new JSONObject();
        JSONArray layers = new JSONArray();
        InputLayer inputLayer = nnet.getInputLayer();
        JSONObject inputLayerJson = new JSONObject();
        inputLayerJson.put("layerType", (Object)LayerType.INPUT.toString());
        inputLayerJson.put("width", inputLayer.getWidth());
        if (nnet instanceof ConvolutionalNetwork) {
            inputLayerJson.put("height", inputLayer.getHeight());
            inputLayerJson.put("channels", inputLayer.getDepth());
        }
        layers.put((Object)inputLayerJson);
        for (AbstractLayer layer : nnet.getLayers()) {
            JSONObject outputLayerJson;
            if (layer instanceof ConvolutionalLayer) {
                ConvolutionalLayer convLayer = (ConvolutionalLayer)layer;
                JSONObject convLayerJson = new JSONObject();
                convLayerJson.put("layerType", (Object)LayerType.CONVOLUTIONAL);
                convLayerJson.put("filterWidth", convLayer.getFilterWidth());
                convLayerJson.put("filterHeight", convLayer.getFilterHeight());
                convLayerJson.put("channels", convLayer.getDepth());
                convLayerJson.put("stride", convLayer.getStride());
                convLayerJson.put("activation", (Object)convLayer.getActivationType());
                JSONArray filters = new JSONArray();
                for (Tensor filter : convLayer.getFilters()) {
                    filters.put((Object)filter);
                }
                layers.put((Object)convLayerJson);
                continue;
            }
            if (layer instanceof MaxPoolingLayer) {
                MaxPoolingLayer maxPooling = (MaxPoolingLayer)layer;
                JSONObject poolLayerJson = new JSONObject();
                poolLayerJson.put("layerType", (Object)LayerType.MAXPOOLING);
                poolLayerJson.put("filterWidth", maxPooling.getFilterWidth());
                poolLayerJson.put("filterHeight", maxPooling.getFilterHeight());
                poolLayerJson.put("stride", maxPooling.getStride());
                layers.put((Object)poolLayerJson);
                continue;
            }
            if (layer instanceof FullyConnectedLayer) {
                JSONObject fullyConnLayerJson = new JSONObject();
                fullyConnLayerJson.put("layerType", (Object)LayerType.FULLY_CONNECTED);
                fullyConnLayerJson.put("width", layer.getWidth());
                fullyConnLayerJson.put("activation", (Object)layer.getActivationType());
                layers.put((Object)fullyConnLayerJson);
                continue;
            }
            if (layer instanceof SoftmaxOutputLayer) {
                outputLayerJson = new JSONObject();
                outputLayerJson.put("layerType", (Object)LayerType.OUTPUT);
                outputLayerJson.put("width", layer.getWidth());
                outputLayerJson.put("activation", (Object)layer.getActivationType());
                layers.put((Object)outputLayerJson);
                continue;
            }
            if (!(layer instanceof OutputLayer)) continue;
            outputLayerJson = new JSONObject();
            outputLayerJson.put("layerType", (Object)LayerType.OUTPUT);
            outputLayerJson.put("width", layer.getWidth());
            outputLayerJson.put("activation", (Object)layer.getActivationType());
            layers.put((Object)outputLayerJson);
        }
        json.put("networkType", (Object)NetworkType.Of(nnet.getClass()));
        json.put("layers", (Object)layers);
        json.put("lossFunction", (Object)LossType.of(nnet.getLossFunction().getClass()));
        return json.toString();
    }

    public static NeuralNetwork createFromJson(String jsonStr) {
        JSONObject obj = new JSONObject(jsonStr);
        return FileIO.createFromJson(obj);
    }

    public static NeuralNetwork createFromJson(File file) throws FileNotFoundException, IOException {
        String line;
        BufferedReader br = new BufferedReader(new FileReader(file));
        StringBuilder sb = new StringBuilder();
        while ((line = br.readLine()) != null) {
            sb.append(line).append(System.lineSeparator());
        }
        return FileIO.createFromJson(sb.toString());
    }

    public static NeuralNetwork createFromJson(JSONObject jsonObj) {
        String networkType = jsonObj.getString("networkType");
        if (networkType.equals(NetworkType.FEEDFORWARD.toString())) {
            return FileIO.createFeedForwardNetworkFromJson(jsonObj);
        }
        if (networkType.equals(NetworkType.CONVOLUTIONAL.toString())) {
            return FileIO.createConvolutionalNetworkFromJson(jsonObj);
        }
        throw new RuntimeException("Unknown network type: " + networkType);
    }

    public static ConvolutionalNetwork createConvolutionalNetworkFromJson(JSONObject jsonObj) {
        JSONArray jsonLayers = jsonObj.getJSONArray("layers");
        ConvolutionalNetwork.Builder builder = new ConvolutionalNetwork.Builder();
        for (Object jsonLayerObject : jsonLayers) {
            JSONObject layerObj = (JSONObject)jsonLayerObject;
            switch (LayerType.valueOf(layerObj.getString("layerType").toUpperCase())) {
                case INPUT: {
                    int width = layerObj.getInt("width");
                    int height = layerObj.getInt("height");
                    int channels = layerObj.getInt("channels");
                    builder.addInputLayer(width, height, channels);
                    break;
                }
                case CONVOLUTIONAL: {
                    int filterWidth = layerObj.getInt("filterWidth");
                    int filterHeight = layerObj.getInt("filterHeight");
                    int stride = layerObj.getInt("stride");
                    int channels = layerObj.getInt("channels");
                    String activation = layerObj.getString("activation").toUpperCase();
                    builder.addConvolutionalLayer(filterWidth, filterHeight, channels, stride, ActivationType.valueOf(activation));
                    break;
                }
                case MAXPOOLING: {
                    int filterWidth = layerObj.getInt("filterWidth");
                    int filterHeight = layerObj.getInt("filterHeight");
                    int stride = layerObj.getInt("stride");
                    builder.addMaxPoolingLayer(filterWidth, filterHeight, stride);
                    break;
                }
                case FULLY_CONNECTED: {
                    int width = layerObj.getInt("width");
                    String activation = layerObj.getString("activation").toUpperCase();
                    builder.addFullyConnectedLayer(width, ActivationType.valueOf(activation));
                    break;
                }
                case OUTPUT: {
                    int width = layerObj.getInt("width");
                    String activation = layerObj.getString("activation").toUpperCase();
                    if (activation.equals(ActivationType.SIGMOID.toString())) {
                        builder.addOutputLayer(width, OutputLayer.class);
                        builder.lossFunction(MeanSquaredErrorLoss.class);
                        break;
                    }
                    if (!activation.equals(ActivationType.SOFTMAX.toString())) break;
                    builder.addOutputLayer(width, SoftmaxOutputLayer.class);
                    builder.lossFunction(CrossEntropyLoss.class);
                }
            }
        }
        String lossFunction = jsonObj.getString("lossFunction");
        builder.lossFunction(LossType.valueOf(lossFunction));
        ConvolutionalNetwork neuralNet = builder.build();
        return neuralNet;
    }

    public static FeedForwardNetwork createFeedForwardNetworkFromJson(JSONObject jsonObj) {
        JSONArray jsonLayers = jsonObj.getJSONArray("layers");
        FeedForwardNetwork.Builder builder = new FeedForwardNetwork.Builder();
        ArrayList<String> allWeights = new ArrayList<String>();
        ArrayList allBiases = new ArrayList();
        for (Object jsonLayerObject : jsonLayers) {
            JSONObject layerObj = (JSONObject)jsonLayerObject;
            switch (LayerType.valueOf(layerObj.getString("layerType").toUpperCase())) {
                case INPUT: {
                    int width = layerObj.getInt("width");
                    builder.addInputLayer(width);
                    break;
                }
                case FULLY_CONNECTED: {
                    String weights;
                    int width = layerObj.getInt("width");
                    String activation = layerObj.getString("activation").toUpperCase();
                    if (layerObj.has("weights")) {
                        weights = layerObj.getString("weights");
                        allWeights.add(weights);
                    }
                    builder.addFullyConnectedLayer(width, ActivationType.valueOf(activation));
                    break;
                }
                case OUTPUT: {
                    int width = layerObj.getInt("width");
                    String activation = layerObj.getString("activation").toUpperCase();
                    builder.addOutputLayer(width, ActivationType.valueOf(activation));
                    if (!layerObj.has("weights")) break;
                    String weights = layerObj.getString("weights");
                    allWeights.add(weights);
                }
            }
        }
        String lossFunction = jsonObj.getString("lossFunction");
        builder.lossFunction(LossType.valueOf(lossFunction));
        FeedForwardNetwork neuralNet = builder.build();
        return neuralNet;
    }
}

