/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.interop.tensorflow;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.BufferedInputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.Instant;
import java.time.OffsetDateTime;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.tensorflow.ExecutionEnvironment;
import org.tensorflow.Graph;
import org.tensorflow.GraphOperation;
import org.tensorflow.Operand;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.exceptions.TensorFlowException;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.proto.framework.ConfigProto;
import org.tensorflow.proto.framework.GraphDef;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.family.TNumber;
import org.tribuo.Dataset;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.Trainer;
import org.tribuo.interop.tensorflow.FeatureConverter;
import org.tribuo.interop.tensorflow.GradientOptimiser;
import org.tribuo.interop.tensorflow.OutputConverter;
import org.tribuo.interop.tensorflow.TensorFlowCheckpointModel;
import org.tribuo.interop.tensorflow.TensorFlowModel;
import org.tribuo.interop.tensorflow.TensorFlowNativeModel;
import org.tribuo.interop.tensorflow.TensorFlowUtil;
import org.tribuo.interop.tensorflow.TensorMap;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.SkeletalTrainerProvenance;
import org.tribuo.provenance.TrainerProvenance;

public final class TensorFlowTrainer<T extends Output<T>>
implements Trainer<T> {
    private static final Logger logger = Logger.getLogger(TensorFlowTrainer.class.getName());
    @Config(mandatory=true, description="Path to the protobuf containing the graph.")
    private Path graphPath;
    private GraphDef graphDef;
    @Config(description="Test time batch size.")
    private int testBatchSize = 16;
    @Config(mandatory=true, description="Name of the output operation before the loss.")
    private String outputName;
    @Config(mandatory=true, description="Feature extractor.")
    private FeatureConverter featureConverter;
    @Config(mandatory=true, description="Response extractor.")
    private OutputConverter<T> outputConverter;
    @Config(description="Training time batch size.")
    private int trainBatchSize = 1;
    @Config(description="Number of SGD epochs to run.")
    private int epochs = 5;
    @Config(description="Logging interval to print out the loss.")
    private int loggingInterval = 100;
    @Config(mandatory=true, description="The gradient optimiser to use.")
    private GradientOptimiser optimiserEnum;
    @Config(mandatory=true, description="The gradient optimiser parameters.")
    private Map<String, Float> gradientParams;
    @Config(description="Saved model format.")
    private TFModelFormat modelFormat = TFModelFormat.TRIBUO_NATIVE;
    @Config(description="Checkpoint output directory.")
    private Path checkpointPath;
    @Config(description="Inter operation thread pool size. -1 uses the default TF value. Tribuo defaults to 1 for deterministic behaviour.")
    private int interOpParallelism = 1;
    @Config(description="Intra operation thread pool size. -1 uses the default TF value. Tribuo defaults to 1 for deterministic behaviour.")
    private int intraOpParallelism = 1;
    private int trainInvocationCounter = 0;

    private TensorFlowTrainer() {
    }

    public TensorFlowTrainer(Path graphPath, String outputName, GradientOptimiser optimiser, Map<String, Float> gradientParams, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int trainBatchSize, int epochs, int testBatchSize, int loggingInterval) throws IOException {
        this(graphPath, TensorFlowTrainer.loadGraphDef(graphPath), outputName, optimiser, gradientParams, featureConverter, outputConverter, trainBatchSize, epochs, testBatchSize, loggingInterval, null, TFModelFormat.TRIBUO_NATIVE);
    }

    public TensorFlowTrainer(Path graphPath, String outputName, GradientOptimiser optimiser, Map<String, Float> gradientParams, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int trainBatchSize, int epochs, int testBatchSize, int loggingInterval, Path checkpointPath) throws IOException {
        this(graphPath, TensorFlowTrainer.loadGraphDef(graphPath), outputName, optimiser, gradientParams, featureConverter, outputConverter, trainBatchSize, epochs, testBatchSize, loggingInterval, checkpointPath, TFModelFormat.CHECKPOINT);
    }

    public TensorFlowTrainer(GraphDef graphDef, String outputName, GradientOptimiser optimiser, Map<String, Float> gradientParams, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int trainBatchSize, int epochs, int testBatchSize, int loggingInterval) {
        this(null, graphDef, outputName, optimiser, gradientParams, featureConverter, outputConverter, trainBatchSize, epochs, testBatchSize, loggingInterval, null, TFModelFormat.TRIBUO_NATIVE);
    }

    public TensorFlowTrainer(GraphDef graphDef, String outputName, GradientOptimiser optimiser, Map<String, Float> gradientParams, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int trainBatchSize, int epochs, int testBatchSize, int loggingInterval, Path checkpointPath) {
        this(null, graphDef, outputName, optimiser, gradientParams, featureConverter, outputConverter, trainBatchSize, epochs, testBatchSize, loggingInterval, checkpointPath, TFModelFormat.CHECKPOINT);
    }

    public TensorFlowTrainer(Graph graph, String outputName, GradientOptimiser optimiser, Map<String, Float> gradientParams, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int trainBatchSize, int epochs, int testBatchSize, int loggingInterval) {
        this(null, graph.toGraphDef(), outputName, optimiser, gradientParams, featureConverter, outputConverter, trainBatchSize, epochs, testBatchSize, loggingInterval, null, TFModelFormat.TRIBUO_NATIVE);
    }

    public TensorFlowTrainer(Graph graph, String outputName, GradientOptimiser optimiser, Map<String, Float> gradientParams, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int trainBatchSize, int epochs, int testBatchSize, int loggingInterval, Path checkpointPath) {
        this(null, graph.toGraphDef(), outputName, optimiser, gradientParams, featureConverter, outputConverter, trainBatchSize, epochs, testBatchSize, loggingInterval, checkpointPath, TFModelFormat.CHECKPOINT);
    }

    private TensorFlowTrainer(Path graphPath, GraphDef graphDef, String outputName, GradientOptimiser optimiser, Map<String, Float> gradientParams, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int trainBatchSize, int epochs, int testBatchSize, int loggingInterval, Path checkpointPath, TFModelFormat modelFormat) {
        if (graphPath == null && graphDef == null) {
            throw new IllegalArgumentException("Must supply either a GraphDef or a path to a Graph");
        }
        this.graphPath = graphPath;
        this.graphDef = graphDef;
        this.outputName = outputName;
        this.optimiserEnum = optimiser;
        this.gradientParams = Collections.unmodifiableMap(new HashMap<String, Float>(gradientParams));
        this.featureConverter = featureConverter;
        this.outputConverter = outputConverter;
        this.trainBatchSize = trainBatchSize;
        this.epochs = epochs;
        this.testBatchSize = testBatchSize;
        this.loggingInterval = loggingInterval;
        this.checkpointPath = checkpointPath;
        this.modelFormat = modelFormat;
        this.validateGraph(false);
    }

    public void postConfig() throws IOException {
        this.graphDef = TensorFlowTrainer.loadGraphDef(this.graphPath);
        if (this.checkpointPath == null && this.modelFormat == TFModelFormat.CHECKPOINT) {
            throw new PropertyException("", "checkpointPath", "Must set 'checkpointPath' when using TFModelFormat.CHECKPOINT");
        }
        this.validateGraph(true);
    }

    private void validateGraph(boolean throwPropertyException) {
        try (Graph graph = new Graph();){
            graph.importGraphDef(this.graphDef);
            for (String inputName : this.featureConverter.inputNamesSet()) {
                if (graph.operation(inputName) != null) continue;
                String msg = "Unable to find an input operation, expected an op with name '" + inputName + "'";
                if (throwPropertyException) {
                    throw new PropertyException("", "featureConverter", msg);
                }
                throw new IllegalArgumentException(msg);
            }
            GraphOperation outputOp = graph.operation(this.outputName);
            if (outputOp == null) {
                String msg = "Unable to find the output operation, expected an op with name '" + this.outputName + "'";
                if (throwPropertyException) {
                    throw new PropertyException("", "outputName", msg);
                }
                throw new IllegalArgumentException(msg);
            }
            Shape outputShape = outputOp.output(0).shape();
            if (outputShape.numDimensions() != 2) {
                String msg = "Expected a 2 dimensional output, found " + Arrays.toString(outputShape.asArray());
                if (throwPropertyException) {
                    throw new PropertyException("", "outputName", msg);
                }
                throw new IllegalArgumentException(msg);
            }
        }
    }

    private static GraphDef loadGraphDef(Path path) throws IOException {
        try (BufferedInputStream stream = new BufferedInputStream(new FileInputStream(path.toFile()));){
            GraphDef graphDef = GraphDef.parseFrom((InputStream)stream);
            return graphDef;
        }
    }

    public TensorFlowModel<T> train(Dataset<T> examples) {
        return this.train((Dataset)examples, Collections.emptyMap());
    }

    public TensorFlowModel<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance) {
        return this.train((Dataset)examples, (Map)runProvenance, -1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     * Enabled aggressive exception aggregation
     */
    public TensorFlowModel<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance, int invocationCount) {
        Path curCheckpointPath;
        ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
        ImmutableOutputInfo outputInfo = examples.getOutputIDInfo();
        ArrayList batch = new ArrayList();
        TensorFlowTrainer tensorFlowTrainer = this;
        synchronized (tensorFlowTrainer) {
            if (invocationCount != -1) {
                this.setInvocationCount(invocationCount);
            }
            curCheckpointPath = this.checkpointPath != null ? Paths.get(this.checkpointPath.toString(), "invocation-" + this.trainInvocationCounter, "tribuo") : null;
            ++this.trainInvocationCounter;
        }
        ConfigProto.Builder configBuilder = ConfigProto.newBuilder();
        if (this.interOpParallelism > -1) {
            configBuilder.setInterOpParallelismThreads(this.interOpParallelism);
        }
        if (this.intraOpParallelism > -1) {
            configBuilder.setIntraOpParallelismThreads(this.intraOpParallelism);
        }
        ConfigProto config = configBuilder.build();
        try (Graph graph = new Graph();){
            TensorFlowModel tensorFlowModel;
            try (Session session = new Session(graph, config);){
                TensorFlowModel tfModel;
                graph.importGraphDef(this.graphDef);
                Ops tf = Ops.create((ExecutionEnvironment)graph).withName("tribuo-internal");
                org.tensorflow.Output intermediateOutputOp = graph.operation(this.outputName).output(0);
                Shape outputShape = intermediateOutputOp.shape();
                Shape expectedShape = Shape.of((long[])new long[]{this.trainBatchSize, outputInfo.size()});
                if (!outputShape.isCompatibleWith(expectedShape)) {
                    throw new IllegalArgumentException("Incompatible output shape, expected " + expectedShape.toString() + " found " + outputShape.toString());
                }
                Placeholder targetPlaceholder = tf.placeholder(TFloat32.class, new Placeholder.Options[]{Placeholder.shape((Shape)Shape.of((long[])new long[]{this.trainBatchSize, outputInfo.size()}))});
                Op outputOp = this.outputConverter.outputTransformFunction().apply(tf, (Operand<org.tensorflow.Output>)intermediateOutputOp);
                Operand<TNumber> lossOp = this.outputConverter.loss().apply(tf, (Pair<Placeholder<TNumber>, Operand<TNumber>>)new Pair((Object)targetPlaceholder, (Object)intermediateOutputOp));
                Op optimiser = this.optimiserEnum.applyOptimiser(graph, lossOp, this.gradientParams);
                session.initialize();
                logger.info("Initialised the model parameters");
                int interval = 0;
                for (int i = 0; i < this.epochs; ++i) {
                    logger.log(Level.INFO, "Starting epoch " + i);
                    for (int j = 0; j < examples.size(); j += this.trainBatchSize) {
                        batch.clear();
                        for (int k = j; k < j + this.trainBatchSize && k < examples.size(); ++k) {
                            batch.add(examples.getExample(k));
                        }
                        try (TensorMap input = this.featureConverter.convert(batch, featureMap);
                             Tensor target = this.outputConverter.convertToTensor(batch, outputInfo);
                             Tensor lossTensor = (Tensor)input.feedInto(session.runner()).feed((Operand)targetPlaceholder, target).addTarget(optimiser).fetch(lossOp).run().get(0);){
                            if (this.loggingInterval != -1 && interval % this.loggingInterval == 0) {
                                logger.log(Level.INFO, "Training loss at itr " + interval + " = " + ((TFloat32)lossTensor).getFloat(new long[0]));
                            }
                        }
                        ++interval;
                    }
                }
                TensorFlowUtil.annotateGraph(graph, session);
                if (this.modelFormat == TFModelFormat.CHECKPOINT) {
                    session.save(curCheckpointPath.toString());
                }
                GraphDef trainedGraphDef = graph.toGraphDef();
                ModelProvenance modelProvenance = new ModelProvenance(TensorFlowModel.class.getName(), OffsetDateTime.now(), (DatasetProvenance)examples.getProvenance(), this.getProvenance(), runProvenance);
                switch (this.modelFormat) {
                    case TRIBUO_NATIVE: {
                        Map<String, TensorFlowUtil.TensorTuple> tensorMap = TensorFlowUtil.extractMarshalledVariables(graph, session);
                        tfModel = new TensorFlowNativeModel<T>("tf-native-model", modelProvenance, featureMap, outputInfo, trainedGraphDef, tensorMap, this.testBatchSize, outputOp.op().name(), this.featureConverter, this.outputConverter);
                        break;
                    }
                    case CHECKPOINT: {
                        tfModel = new TensorFlowCheckpointModel<T>("tf-checkpoint-model", modelProvenance, featureMap, outputInfo, trainedGraphDef, curCheckpointPath.getParent().toString(), curCheckpointPath.getFileName().toString(), this.testBatchSize, outputOp.op().name(), this.featureConverter, this.outputConverter);
                        break;
                    }
                    default: {
                        throw new IllegalStateException("Unexpected enum constant " + (Object)((Object)this.modelFormat));
                    }
                }
                tensorFlowModel = tfModel;
            }
            return tensorFlowModel;
        }
        catch (TensorFlowException e) {
            logger.log(Level.SEVERE, "TensorFlow threw an error", e);
            throw new IllegalStateException(e);
        }
    }

    public String toString() {
        String path = this.graphPath == null ? "" : this.graphPath.toString();
        String output = "TFTrainer(graphPath=" + path + ",exampleConverter=" + this.featureConverter.toString() + ",outputConverter=" + this.outputConverter.toString() + ",minibatchSize=" + this.trainBatchSize + ",epochs=" + this.epochs + ",gradientOptimizer=" + (Object)((Object)this.optimiserEnum) + ",gradientParams=" + this.gradientParams.toString() + ",modelFormat=" + (Object)((Object)this.modelFormat);
        if (this.modelFormat == TFModelFormat.CHECKPOINT) {
            return output + ",checkpointPath=" + this.checkpointPath.toString() + ")";
        }
        return output + ")";
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    public void setInvocationCount(int invocationCount) {
        if (invocationCount < 0) {
            throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
        }
        this.trainInvocationCounter = invocationCount;
    }

    public TrainerProvenance getProvenance() {
        return new TensorFlowTrainerProvenance(this);
    }

    public static enum TFModelFormat {
        TRIBUO_NATIVE,
        CHECKPOINT;

    }

    public static final class TensorFlowTrainerProvenance
    extends SkeletalTrainerProvenance {
        private static final long serialVersionUID = 1L;
        public static final String GRAPH_HASH = "graph-hash";
        public static final String GRAPH_LAST_MOD = "graph-last-modified";
        private final HashProvenance graphHash;
        private final DateTimeProvenance graphLastModified;

        <T extends Output<T>> TensorFlowTrainerProvenance(TensorFlowTrainer<T> host) {
            super(host);
            if (((TensorFlowTrainer)host).graphPath != null) {
                this.graphHash = new HashProvenance(DEFAULT_HASH_TYPE, GRAPH_HASH, ProvenanceUtil.hashResource((ProvenanceUtil.HashType)DEFAULT_HASH_TYPE, (Path)((TensorFlowTrainer)host).graphPath));
                this.graphLastModified = new DateTimeProvenance(GRAPH_LAST_MOD, OffsetDateTime.ofInstant(Instant.ofEpochMilli(((TensorFlowTrainer)host).graphPath.toFile().lastModified()), ZoneId.systemDefault()));
            } else {
                this.graphHash = new HashProvenance(DEFAULT_HASH_TYPE, GRAPH_HASH, ProvenanceUtil.hashArray((ProvenanceUtil.HashType)DEFAULT_HASH_TYPE, (byte[])((TensorFlowTrainer)host).graphDef.toByteArray()));
                this.graphLastModified = new DateTimeProvenance(GRAPH_LAST_MOD, OffsetDateTime.now());
            }
        }

        public TensorFlowTrainerProvenance(Map<String, Provenance> map) {
            this(TensorFlowTrainerProvenance.extractTFProvenanceInfo(map));
        }

        private TensorFlowTrainerProvenance(SkeletalConfiguredObjectProvenance.ExtractedInfo info) {
            super(info);
            this.graphHash = (HashProvenance)info.instanceValues.get(GRAPH_HASH);
            this.graphLastModified = (DateTimeProvenance)info.instanceValues.get(GRAPH_LAST_MOD);
        }

        public Map<String, PrimitiveProvenance<?>> getInstanceValues() {
            Map map = super.getInstanceValues();
            map.put(this.graphHash.getKey(), this.graphHash);
            map.put(this.graphLastModified.getKey(), this.graphLastModified);
            return map;
        }

        protected static SkeletalConfiguredObjectProvenance.ExtractedInfo extractTFProvenanceInfo(Map<String, Provenance> map) {
            SkeletalConfiguredObjectProvenance.ExtractedInfo info = SkeletalTrainerProvenance.extractProvenanceInfo(map);
            info.instanceValues.put(GRAPH_HASH, (PrimitiveProvenance)ObjectProvenance.checkAndExtractProvenance((Map)info.configuredParameters, (String)GRAPH_HASH, HashProvenance.class, (String)TensorFlowTrainerProvenance.class.getSimpleName()));
            info.instanceValues.put(GRAPH_LAST_MOD, (PrimitiveProvenance)ObjectProvenance.checkAndExtractProvenance((Map)info.configuredParameters, (String)GRAPH_LAST_MOD, DateTimeProvenance.class, (String)TensorFlowTrainerProvenance.class.getSimpleName()));
            return info;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || ((Object)((Object)this)).getClass() != o.getClass()) {
                return false;
            }
            if (!super.equals(o)) {
                return false;
            }
            TensorFlowTrainerProvenance pairs = (TensorFlowTrainerProvenance)((Object)o);
            return this.graphHash.equals((Object)pairs.graphHash) && this.graphLastModified.equals((Object)pairs.graphLastModified);
        }

        public int hashCode() {
            return Objects.hash(super.hashCode(), this.graphHash, this.graphLastModified);
        }
    }
}

