/*
 * Decompiled with CFR 0.152.
 */
package hex.deepwater;

import deepwater.backends.BackendModel;
import deepwater.backends.BackendParams;
import deepwater.backends.BackendTrain;
import deepwater.backends.RuntimeOptions;
import deepwater.datasets.ImageDataSet;
import hex.DataInfo;
import hex.deepwater.DeepWaterParameters;
import hex.genmodel.algos.deepwater.DeepwaterMojoModel;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.Arrays;
import water.H2O;
import water.Iced;
import water.Key;
import water.exceptions.H2OIllegalArgumentException;
import water.util.Log;
import water.util.PrettyPrint;
import water.util.TwoDimTable;

public final class DeepWaterModelInfo
extends Iced {
    private int _classes;
    byte[] _network;
    byte[] _modelparams;
    private TwoDimTable summaryTable;
    transient BackendTrain _backend;
    transient BackendModel _model;
    int _height;
    int _width;
    int _channels;
    float[] _meanData;
    DataInfo _dataInfo;
    volatile boolean _unstable = false;
    public DeepWaterParameters parameters;
    private long processed_global;
    private long processed_local;
    private final boolean _classification;

    void nukeBackend() {
        if (this._backend != null && this._model != null) {
            this._backend.delete(this._model);
        }
        this._backend = null;
        this._model = null;
    }

    void saveNativeState(String path, int iteration) {
        assert (this._backend != null);
        assert (this._model != null);
        this._backend.saveModel(this._model, path + ".json");
        this._backend.saveParam(this._model, path + "." + iteration + ".params");
    }

    float[] predict(float[] data) {
        assert (this._backend != null);
        assert (this._model != null);
        return this._backend.predict(this._model, data);
    }

    public int hashCode() {
        return Arrays.hashCode(this._network) + Arrays.hashCode(this._modelparams);
    }

    public long size() {
        long res = 0L;
        if (this._network != null) {
            res += (long)this._network.length;
        }
        if (this._modelparams != null) {
            res += (long)this._modelparams.length;
        }
        return res;
    }

    public final DeepWaterParameters get_params() {
        return this.parameters;
    }

    synchronized long get_processed_global() {
        return this.processed_global;
    }

    synchronized void set_processed_global(long p) {
        this.processed_global = p;
    }

    synchronized void add_processed_global(long p) {
        this.processed_global += p;
    }

    synchronized long get_processed_local() {
        return this.processed_local;
    }

    synchronized void set_processed_local(long p) {
        this.processed_local = p;
    }

    synchronized void add_processed_local(long p) {
        this.processed_local += p;
    }

    synchronized long get_processed_total() {
        return this.processed_global + this.processed_local;
    }

    private RuntimeOptions getRuntimeOptions() {
        RuntimeOptions opts = new RuntimeOptions();
        opts.setSeed((long)((int)this.get_params().getOrMakeRealSeed()));
        opts.setUseGPU(this.get_params()._gpu);
        opts.setDeviceID(this.get_params()._device_id);
        return opts;
    }

    private BackendParams getBackendParams() {
        String network;
        BackendParams backendParams = new BackendParams();
        backendParams.set("mini_batch_size", (Object)this.get_params()._mini_batch_size);
        backendParams.set("clip_gradient", (Object)this.get_params()._clip_gradient);
        String string = network = this.parameters._network == null ? null : this.parameters._network.toString();
        if (network == null) {
            String acti;
            assert (this.parameters._activation != null);
            assert (this.parameters._hidden != null);
            Object[] acts = new String[this.parameters._hidden.length];
            if (this.parameters._activation.toString().startsWith("Rectifier")) {
                acti = "relu";
            } else if (this.parameters._activation.toString().startsWith("Tanh")) {
                acti = "tanh";
            } else {
                throw H2O.unimpl();
            }
            Arrays.fill(acts, acti);
            backendParams.set("activations", (Object)acts);
            backendParams.set("hidden", (Object)this.parameters._hidden);
            backendParams.set("input_dropout_ratio", (Object)this.parameters._input_dropout_ratio);
            backendParams.set("hidden_dropout_ratios", (Object)this.parameters._hidden_dropout_ratios);
        }
        return backendParams;
    }

    private ImageDataSet getImageDataSet() {
        return new ImageDataSet(this._width, this._height, this._channels, this._classes);
    }

    DeepWaterModelInfo(DeepWaterParameters origParams, int nClasses, int nFeatures) {
        this._classes = nClasses;
        this._classification = this._classes > 1;
        this.parameters = (DeepWaterParameters)origParams.clone();
        this._width = nFeatures;
        this._height = 0;
        this._channels = 0;
        if (this.parameters._problem_type == DeepWaterParameters.ProblemType.image) {
            this._width = this.parameters._image_shape[0];
            this._height = this.parameters._image_shape[1];
            this._channels = this.parameters._channels;
            if (this._width == 0 || this._height == 0) {
                switch (this.parameters._network) {
                    case lenet: {
                        this._width = 28;
                        this._height = 28;
                        break;
                    }
                    case auto: 
                    case alexnet: 
                    case inception_bn: 
                    case googlenet: 
                    case resnet: {
                        this._width = 224;
                        this._height = 224;
                        break;
                    }
                    case vgg: {
                        this._width = 320;
                        this._height = 320;
                        break;
                    }
                    case user: {
                        throw new H2OIllegalArgumentException("Please specify width and height for user-given model definition.");
                    }
                    default: {
                        throw H2O.unimpl((String)("Unknown network type: " + (Object)((Object)this.parameters._network)));
                    }
                }
            }
            assert (this._width > 0);
            assert (this._height > 0);
        } else if (this.parameters._problem_type == DeepWaterParameters.ProblemType.dataset) {
            if (this.parameters._image_shape != null) {
                if (this.parameters._image_shape[0] > 0) {
                    this._width = this.parameters._image_shape[0];
                }
                if (this.parameters._image_shape[1] > 0) {
                    this._height = this.parameters._image_shape[1];
                }
                this._channels = this._width > 0 && this._height > 0 ? this.parameters._channels : 0;
            }
        } else if (this.parameters._problem_type == DeepWaterParameters.ProblemType.text) {
            this._width = 100;
        } else {
            Log.warn((Object[])new Object[]{"unknown problem_type:", this.parameters._problem_type});
            throw H2O.unimpl();
        }
        this.setupNativeBackend();
    }

    private void setupNativeBackend() {
        try {
            String networkDef;
            this._backend = DeepwaterMojoModel.createDeepWaterBackend((String)this.parameters._backend.toString());
            if (this._backend == null) {
                throw new IllegalArgumentException("No backend found. Cannot build a Deep Water model.");
            }
            ImageDataSet imageDataSet = this.getImageDataSet();
            RuntimeOptions opts = this.getRuntimeOptions();
            BackendParams bparms = this.getBackendParams();
            if (this.parameters._network != DeepWaterParameters.Network.user) {
                String network;
                String string = network = this.parameters._network == null ? null : this.parameters._network.toString();
                if (network != null) {
                    Log.info((Object[])new Object[]{"Creating a fresh model of the following network type: " + network});
                    this._model = this._backend.buildNet(imageDataSet, opts, bparms, this._classes, network);
                } else {
                    Log.info((Object[])new Object[]{"Creating a fresh model of the following network type: MLP"});
                    this._model = this._backend.buildNet(imageDataSet, opts, bparms, this._classes, "MLP");
                }
            }
            if ((networkDef = this.parameters._network_definition_file) != null && !networkDef.isEmpty()) {
                File f = new File(networkDef);
                if (!f.exists() || f.isDirectory()) {
                    throw new RuntimeException("Network definition file " + f + " not found.");
                }
                Log.info((Object[])new Object[]{"Loading the network from: " + f.getAbsolutePath()});
                Log.info((Object[])new Object[]{"Setting the optimizer and initializing the first and last layer."});
                this._model = this._backend.buildNet(imageDataSet, opts, bparms, this._classes, f.getAbsolutePath());
            }
            if (this.parameters._mean_image_file != null && !this.parameters._mean_image_file.isEmpty()) {
                imageDataSet.setMeanData(this._backend.loadMeanImage(this._model, this.parameters._mean_image_file));
            }
            this._meanData = imageDataSet.getMeanData();
            String networkParms = this.parameters._network_parameters_file;
            if (networkParms != null && !networkParms.isEmpty()) {
                File f = new File(networkParms);
                if (!f.exists() || f.isDirectory()) {
                    throw new RuntimeException("Network parameter file " + f + " not found.");
                }
                Log.info((Object[])new Object[]{"Loading the parameters (weights/biases) from: " + f.getAbsolutePath()});
                assert (this._model != null);
                this._backend.loadParam(this._model, f.getAbsolutePath());
            } else {
                Log.warn((Object[])new Object[]{"No network parameters file specified. Starting from scratch."});
            }
            this.nativeToJava();
        }
        catch (Throwable t) {
            throw new RuntimeException("Unable to initialize the native Deep Learning backend: " + t.getMessage());
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    void nativeToJava() {
        FileInputStream is;
        if (this._backend == null) {
            return;
        }
        Log.info((Object[])new Object[]{"Native backend -> Java."});
        long now = System.currentTimeMillis();
        File file = null;
        if (this._network == null) {
            try {
                file = new File(System.getProperty("java.io.tmpdir"), Key.make().toString());
                this._backend.saveModel(this._model, file.toString());
                is = new FileInputStream(file);
                this._network = new byte[(int)file.length()];
                is.read(this._network);
                is.close();
            }
            catch (IOException e) {
                e.printStackTrace();
            }
            finally {
                if (file != null) {
                    file.delete();
                }
            }
        }
        try {
            file = new File(System.getProperty("java.io.tmpdir"), Key.make().toString());
            this._backend.saveParam(this._model, file.toString());
            is = new FileInputStream(file);
            this._modelparams = new byte[(int)file.length()];
            is.read(this._modelparams);
            is.close();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        finally {
            if (file != null) {
                file.delete();
            }
        }
        long time = System.currentTimeMillis() - now;
        Log.info((Object[])new Object[]{"Took: " + PrettyPrint.msecs((long)time, (boolean)true)});
    }

    void javaToNative() {
        this.javaToNative(null, null);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void javaToNative(byte[] network, byte[] parameters) {
        FileOutputStream os;
        long now = System.currentTimeMillis();
        if (this._backend != null && (network == null || Arrays.equals(network, this._network)) && (parameters == null || Arrays.equals(parameters, this._modelparams))) {
            Log.warn((Object[])new Object[]{"No need to move the state from Java to native."});
            return;
        }
        if (this._backend == null) {
            this._backend = DeepwaterMojoModel.createDeepWaterBackend((String)this.get_params()._backend.toString());
            if (this._backend == null) {
                throw new IllegalArgumentException("No backend found. Cannot build a Deep Water model.");
            }
        }
        if (network == null) {
            network = this._network;
        }
        if (parameters == null) {
            parameters = this._modelparams;
        }
        if (network == null || parameters == null) {
            return;
        }
        Log.info((Object[])new Object[]{"Java state -> native backend."});
        File file = null;
        try {
            file = new File(System.getProperty("java.io.tmpdir"), Key.make().toString() + ".json");
            os = new FileOutputStream(file);
            os.write(network);
            os.close();
            this._model = this._backend.buildNet(this.getImageDataSet(), this.getRuntimeOptions(), this.getBackendParams(), this._classes, file.toString());
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        finally {
            file.delete();
        }
        try {
            file = new File(System.getProperty("java.io.tmpdir"), Key.make().toString());
            os = new FileOutputStream(file);
            os.write(parameters);
            os.close();
            this._backend.loadParam(this._model, file.toString());
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        finally {
            file.delete();
        }
        long time = System.currentTimeMillis() - now;
        Log.info((Object[])new Object[]{"Took: " + PrettyPrint.msecs((long)time, (boolean)true)});
    }

    TwoDimTable createSummaryTable() {
        TwoDimTable table = new TwoDimTable("Status of Deep Learning Model", (this.get_params()._network == null ? "MLP: " + Arrays.toString(this.get_params()._hidden) : this.get_params()._network.toString()) + ", " + PrettyPrint.bytes((long)this.size()) + ", " + (!this.get_params()._autoencoder ? "predicting " + this.get_params()._response_column + ", " : "") + (this.get_params()._autoencoder ? "auto-encoder" : (this._classification ? this._classes + "-class classification" : "regression")) + ", " + String.format("%,d", this.get_processed_global()) + " training samples, " + "mini-batch size " + String.format("%,d", this.get_params()._mini_batch_size), new String[1], new String[]{"Input Neurons", "Rate", "Momentum"}, new String[]{"int", "double", "double"}, new String[]{"%d", "%5f", "%5f"}, "");
        table.set(0, 0, (Object)(this._dataInfo != null ? this._dataInfo.fullN() : this._width * this._height * this._channels));
        table.set(0, 1, (Object)Float.valueOf(this.get_params().learningRate(this.get_processed_global())));
        table.set(0, 2, (Object)Float.valueOf(this.get_params().momentum(this.get_processed_global())));
        this.summaryTable = table;
        return this.summaryTable;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        if (!this.get_params()._quiet_mode) {
            this.createSummaryTable();
            if (this.summaryTable != null) {
                sb.append(this.summaryTable.toString(1));
            }
        }
        return sb.toString();
    }

    public String toStringAll() {
        StringBuilder sb = new StringBuilder();
        sb.append(this.toString());
        sb.append("\nprocessed global: ").append(this.get_processed_global());
        sb.append("\nprocessed local:  ").append(this.get_processed_local());
        sb.append("\nprocessed total:  ").append(this.get_processed_total());
        sb.append("\n");
        return sb.toString();
    }

    public void add(DeepWaterModelInfo other) {
        throw H2O.unimpl();
    }

    public void mult(double N) {
        throw H2O.unimpl();
    }

    public void div(double N) {
        throw H2O.unimpl();
    }
}

