/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.interop.tensorflow.sequence;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceException;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
import java.io.BufferedInputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Path;
import java.time.Instant;
import java.time.OffsetDateTime;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Map;
import java.util.SplittableRandom;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.exceptions.TensorFlowException;
import org.tensorflow.proto.framework.GraphDef;
import org.tensorflow.types.TFloat32;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.interop.tensorflow.TensorFlowUtil;
import org.tribuo.interop.tensorflow.TensorMap;
import org.tribuo.interop.tensorflow.sequence.SequenceFeatureConverter;
import org.tribuo.interop.tensorflow.sequence.SequenceOutputConverter;
import org.tribuo.interop.tensorflow.sequence.TensorFlowSequenceModel;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.SkeletalTrainerProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.sequence.SequenceDataset;
import org.tribuo.sequence.SequenceModel;
import org.tribuo.sequence.SequenceTrainer;
import org.tribuo.util.Util;

public class TensorFlowSequenceTrainer<T extends Output<T>>
implements SequenceTrainer<T> {
    private static final Logger log = Logger.getLogger(TensorFlowSequenceTrainer.class.getName());
    @Config(mandatory=true, description="Path to the protobuf containing the TensorFlow graph.")
    protected Path graphPath;
    private GraphDef graphDef;
    @Config(mandatory=true, description="Sequence feature extractor.")
    protected SequenceFeatureConverter featureConverter;
    @Config(mandatory=true, description="Sequence output extractor.")
    protected SequenceOutputConverter<T> outputConverter;
    @Config(description="Minibatch size")
    protected int minibatchSize = 1;
    @Config(description="Number of SGD epochs to run.")
    protected int epochs = 5;
    @Config(description="Logging interval to print the loss.")
    protected int loggingInterval = 100;
    @Config(description="Seed for the RNG.")
    protected long seed = 1L;
    @Config(mandatory=true, description="Name of the training operation.")
    protected String trainOp;
    @Config(mandatory=true, description="Name of the loss operation (to inspect the loss).")
    protected String getLossOp;
    @Config(mandatory=true, description="Name of the prediction operation.")
    protected String predictOp;
    protected SplittableRandom rng;
    protected int trainInvocationCounter;

    public TensorFlowSequenceTrainer(Path graphPath, SequenceFeatureConverter featureConverter, SequenceOutputConverter<T> outputConverter, int minibatchSize, int epochs, int loggingInterval, long seed, String trainOp, String getLossOp, String predictOp) throws IOException {
        this.graphPath = graphPath;
        this.featureConverter = featureConverter;
        this.outputConverter = outputConverter;
        this.minibatchSize = minibatchSize;
        this.epochs = epochs;
        this.loggingInterval = loggingInterval;
        this.seed = seed;
        this.trainOp = trainOp;
        this.getLossOp = getLossOp;
        this.predictOp = predictOp;
        this.postConfig();
    }

    private TensorFlowSequenceTrainer() {
    }

    public synchronized void postConfig() throws IOException {
        this.rng = new SplittableRandom(this.seed);
        this.graphDef = GraphDef.parseFrom((InputStream)new BufferedInputStream(new FileInputStream(this.graphPath.toFile())));
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     * Enabled aggressive exception aggregation
     */
    public SequenceModel<T> train(SequenceDataset<T> examples, Map<String, Provenance> runProvenance) {
        TrainerProvenance provenance;
        SplittableRandom localRNG;
        TensorFlowSequenceTrainer tensorFlowSequenceTrainer = this;
        synchronized (tensorFlowSequenceTrainer) {
            localRNG = this.rng.split();
            provenance = this.getProvenance();
            ++this.trainInvocationCounter;
        }
        ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
        ImmutableOutputInfo labelMap = examples.getOutputIDInfo();
        ArrayList batch = new ArrayList();
        int[] indices = Util.randperm((int)examples.size(), (SplittableRandom)localRNG);
        try (Graph graph = new Graph();){
            TensorFlowSequenceModel<T> tensorFlowSequenceModel;
            try (Session session = new Session(graph);){
                graph.importGraphDef(this.graphDef);
                this.preTrainingHook(session, examples);
                int interval = 0;
                for (int i = 0; i < this.epochs; ++i) {
                    log.log(Level.INFO, "Starting epoch " + i);
                    Util.randpermInPlace((int[])indices, (SplittableRandom)localRNG);
                    for (int j = 0; j < examples.size(); j += this.minibatchSize) {
                        batch.clear();
                        for (int k = j; k < j + this.minibatchSize && k < examples.size(); ++k) {
                            int ix = indices[k];
                            batch.add(examples.getExample(ix));
                        }
                        TensorMap featureTensors = this.featureConverter.encode(batch, featureMap);
                        TensorMap supervisionTensors = this.outputConverter.encode(batch, labelMap);
                        TensorMap parameterTensors = this.getHyperparameterFeed();
                        Session.Runner runner = session.runner();
                        featureTensors.feedInto(runner);
                        supervisionTensors.feedInto(runner);
                        parameterTensors.feedInto(runner);
                        try (Tensor loss = (Tensor)runner.addTarget(this.trainOp).fetch(this.getLossOp).run().get(0);){
                            if (interval % this.loggingInterval == 0) {
                                float lossVal = ((TFloat32)loss).getFloat(new long[]{0L});
                                log.info(String.format("loss %-5.6f [epoch %-2d batch %-4d #(%d - %d)/%d]", Float.valueOf(lossVal), i, interval, j, Math.min(examples.size(), j + this.minibatchSize), examples.size()));
                            }
                            ++interval;
                        }
                        featureTensors.close();
                        supervisionTensors.close();
                        parameterTensors.close();
                    }
                }
                TensorFlowUtil.annotateGraph(graph, session);
                GraphDef trainedGraphDef = graph.toGraphDef();
                Map<String, TensorFlowUtil.TensorTuple> tensorMap = TensorFlowUtil.extractMarshalledVariables(graph, session);
                ModelProvenance modelProvenance = new ModelProvenance(TensorFlowSequenceModel.class.getName(), OffsetDateTime.now(), (DatasetProvenance)examples.getProvenance(), provenance, runProvenance);
                tensorFlowSequenceModel = new TensorFlowSequenceModel<T>("tf-sequence-model", modelProvenance, featureMap, labelMap, trainedGraphDef, this.featureConverter, this.outputConverter, this.predictOp, tensorMap);
            }
            return tensorFlowSequenceModel;
        }
        catch (TensorFlowException e) {
            log.log(Level.SEVERE, "TensorFlow threw an error", e);
            throw new IllegalStateException(e);
        }
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    public String toString() {
        return "TensorflowSequenceTrainer(graphPath=" + this.graphPath.toString() + ",exampleConverter=" + this.featureConverter.toString() + ",outputConverter=" + this.outputConverter.toString() + ",minibatchSize=" + this.minibatchSize + ",epochs=" + this.epochs + ",seed=" + this.seed + ")";
    }

    protected void preTrainingHook(Session session, SequenceDataset<T> examples) {
    }

    protected TensorMap getHyperparameterFeed() {
        return new TensorMap(Collections.emptyMap());
    }

    public TrainerProvenance getProvenance() {
        return new TensorFlowSequenceTrainerProvenance(this);
    }

    public static class TensorFlowSequenceTrainerProvenance
    extends SkeletalTrainerProvenance {
        private static final long serialVersionUID = 1L;
        public static final String GRAPH_HASH = "graph-hash";
        public static final String GRAPH_LAST_MOD = "graph-last-modified";
        private final StringProvenance graphHash;
        private final DateTimeProvenance graphLastModified;

        <T extends Output<T>> TensorFlowSequenceTrainerProvenance(TensorFlowSequenceTrainer<T> host) {
            super(host);
            this.graphHash = new StringProvenance(GRAPH_HASH, ProvenanceUtil.hashResource((ProvenanceUtil.HashType)DEFAULT_HASH_TYPE, (Path)host.graphPath));
            this.graphLastModified = new DateTimeProvenance(GRAPH_LAST_MOD, OffsetDateTime.ofInstant(Instant.ofEpochMilli(host.graphPath.toFile().lastModified()), ZoneId.systemDefault()));
        }

        public TensorFlowSequenceTrainerProvenance(Map<String, Provenance> map) {
            this(TensorFlowSequenceTrainerProvenance.extractTFProvenanceInfo(map));
        }

        private TensorFlowSequenceTrainerProvenance(SkeletalConfiguredObjectProvenance.ExtractedInfo info) {
            super(info);
            this.graphHash = (StringProvenance)info.instanceValues.get(GRAPH_HASH);
            this.graphLastModified = (DateTimeProvenance)info.instanceValues.get(GRAPH_LAST_MOD);
        }

        public Map<String, PrimitiveProvenance<?>> getInstanceValues() {
            Map map = super.getInstanceValues();
            map.put(this.graphHash.getKey(), this.graphHash);
            map.put(this.graphLastModified.getKey(), this.graphLastModified);
            return map;
        }

        protected static SkeletalConfiguredObjectProvenance.ExtractedInfo extractTFProvenanceInfo(Map<String, Provenance> map) {
            Provenance tmpProv;
            SkeletalConfiguredObjectProvenance.ExtractedInfo info = SkeletalTrainerProvenance.extractProvenanceInfo(map);
            if (info.configuredParameters.containsKey(GRAPH_HASH)) {
                tmpProv = (Provenance)info.configuredParameters.remove(GRAPH_HASH);
                if (!(tmpProv instanceof HashProvenance)) {
                    throw new ProvenanceException("graph-hash was not of type HashProvenance in class " + info.className);
                }
            } else {
                throw new ProvenanceException("Failed to find graph-hash when constructing SkeletalTrainerProvenance");
            }
            info.instanceValues.put(GRAPH_HASH, (HashProvenance)tmpProv);
            if (info.configuredParameters.containsKey(GRAPH_LAST_MOD)) {
                tmpProv = (Provenance)info.configuredParameters.remove(GRAPH_LAST_MOD);
                if (!(tmpProv instanceof DateTimeProvenance)) {
                    throw new ProvenanceException("graph-last-modified was not of type DateTimeProvenance in class " + info.className);
                }
            } else {
                throw new ProvenanceException("Failed to find graph-last-modified when constructing SkeletalTrainerProvenance");
            }
            info.instanceValues.put(GRAPH_LAST_MOD, (DateTimeProvenance)tmpProv);
            return info;
        }
    }
}

