/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.interop.tensorflow;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.function.BiFunction;
import java.util.logging.Logger;
import org.tensorflow.Operand;
import org.tensorflow.Tensor;
import org.tensorflow.framework.op.FrameworkOps;
import org.tensorflow.ndarray.FloatNdArray;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.index.Index;
import org.tensorflow.ndarray.index.Indices;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.math.Mean;
import org.tensorflow.types.TFloat16;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.family.TNumber;
import org.tribuo.Example;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.interop.tensorflow.OutputConverter;
import org.tribuo.interop.tensorflow.protos.OutputConverterProto;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.multilabel.MultiLabel;
import org.tribuo.protos.ProtoSerializable;
import org.tribuo.protos.ProtoSerializableClass;
import org.tribuo.protos.ProtoUtil;

@ProtoSerializableClass(version=0)
public class MultiLabelConverter
implements OutputConverter<MultiLabel> {
    private static final long serialVersionUID = 1L;
    private static final Logger logger = Logger.getLogger(MultiLabelConverter.class.getName());
    public static final int CURRENT_VERSION = 0;
    public static final double THRESHOLD = 0.5;

    public static MultiLabelConverter deserializeFromProto(int version, String className, Any message) {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        if (message.getValue() != ByteString.EMPTY) {
            throw new IllegalArgumentException("Invalid proto");
        }
        return new MultiLabelConverter();
    }

    public OutputConverterProto serialize() {
        return (OutputConverterProto)ProtoUtil.serialize((ProtoSerializable)this);
    }

    @Override
    public BiFunction<Ops, Pair<Placeholder<? extends TNumber>, Operand<TNumber>>, Operand<TNumber>> loss() {
        return (ops, pair) -> {
            FrameworkOps frameworkOps = FrameworkOps.create((Ops)ops);
            Placeholder placeholder = (Placeholder)pair.getA();
            return ops.math.mean(frameworkOps.nn.sigmoidCrossEntropyWithLogits((Operand)placeholder, (Operand)pair.getB()), (Operand)ops.constant(0), new Mean.Options[0]);
        };
    }

    @Override
    public <V extends TNumber> BiFunction<Ops, Operand<V>, Op> outputTransformFunction() {
        return (ops, logits) -> ops.math.sigmoid(logits);
    }

    @Override
    public Prediction<MultiLabel> convertToPrediction(Tensor tensor, ImmutableOutputInfo<MultiLabel> outputIDInfo, int numValidFeatures, Example<MultiLabel> example) {
        FloatNdArray predictions = this.getBatchPredictions(tensor, outputIDInfo);
        long batchSize = predictions.shape().asArray()[0];
        if (batchSize != 1L) {
            throw new IllegalArgumentException("Supplied tensor has too many results, batchSize = " + batchSize);
        }
        return this.generatePrediction(predictions.slice(new Index[]{Indices.at((long)0L), Indices.all()}), outputIDInfo, numValidFeatures, example);
    }

    private Prediction<MultiLabel> generatePrediction(FloatNdArray predictions, ImmutableOutputInfo<MultiLabel> outputIDInfo, int numUsed, Example<MultiLabel> example) {
        long[] shape = predictions.shape().asArray();
        if (shape.length != 1) {
            throw new IllegalArgumentException("Failed to get scalar predictions. Found " + Arrays.toString(shape));
        }
        if (shape[0] > Integer.MAX_VALUE) {
            throw new IllegalArgumentException("More than Integer.MAX_VALUE predictions. Found " + shape[0]);
        }
        int length = (int)shape[0];
        HashMap<String, MultiLabel> fullLabels = new HashMap<String, MultiLabel>(outputIDInfo.size());
        HashSet<Label> predictedLabels = new HashSet<Label>();
        for (int i = 0; i < length; ++i) {
            String labelName = ((MultiLabel)outputIDInfo.getOutput(i)).getLabelString();
            double labelScore = predictions.getFloat(new long[]{i});
            Label score = new Label(labelName, labelScore);
            if (labelScore > 0.5) {
                predictedLabels.add(score);
            }
            fullLabels.put(labelName, new MultiLabel(score));
        }
        return new Prediction((Output)new MultiLabel(predictedLabels), fullLabels, numUsed, example, true);
    }

    @Override
    public MultiLabel convertToOutput(Tensor tensor, ImmutableOutputInfo<MultiLabel> outputIDInfo) {
        FloatNdArray predictions = this.getBatchPredictions(tensor, outputIDInfo);
        long batchSize = predictions.shape().asArray()[0];
        if (batchSize != 1L) {
            throw new IllegalArgumentException("Supplied tensor has too many results, batchSize = " + batchSize);
        }
        return this.generateMultiLabel(predictions.slice(new Index[]{Indices.at((long)0L), Indices.all()}), outputIDInfo);
    }

    private MultiLabel generateMultiLabel(FloatNdArray predictions, ImmutableOutputInfo<MultiLabel> outputIDInfo) {
        long[] shape = predictions.shape().asArray();
        if (shape.length != 1) {
            throw new IllegalArgumentException("Failed to get scalar predictions. Found " + Arrays.toString(shape));
        }
        if (shape[0] > Integer.MAX_VALUE) {
            throw new IllegalArgumentException("More than Integer.MAX_VALUE predictions. Found " + shape[0]);
        }
        int length = (int)shape[0];
        HashSet<Label> predictedLabels = new HashSet<Label>();
        for (int i = 0; i < length; ++i) {
            double labelScore = predictions.getFloat(new long[]{i});
            Label score = new Label(((MultiLabel)outputIDInfo.getOutput(i)).getLabelString(), labelScore);
            if (!(labelScore > 0.5)) continue;
            predictedLabels.add(score);
        }
        return new MultiLabel(predictedLabels);
    }

    private FloatNdArray getBatchPredictions(Tensor tensor, ImmutableOutputInfo<MultiLabel> outputIDInfo) {
        long[] shape = tensor.shape().asArray();
        if (shape.length != 2) {
            throw new IllegalArgumentException("Supplied tensor has the wrong number of dimensions, shape = " + Arrays.toString(shape));
        }
        int numValues = (int)shape[1];
        if (numValues != outputIDInfo.size()) {
            throw new IllegalArgumentException("Supplied tensor has incorrect number of elements, tensor output dimension: " + numValues + ", outputInfo dimension: " + outputIDInfo.size());
        }
        if (tensor instanceof TFloat16) {
            return (TFloat16)tensor;
        }
        if (tensor instanceof TFloat32) {
            return (TFloat32)tensor;
        }
        throw new IllegalArgumentException("Tensor is not a probability distribution. Found type " + tensor.getClass().getName());
    }

    @Override
    public List<Prediction<MultiLabel>> convertToBatchPrediction(Tensor tensor, ImmutableOutputInfo<MultiLabel> outputIDInfo, int[] numValidFeatures, List<Example<MultiLabel>> examples) {
        FloatNdArray predictions = this.getBatchPredictions(tensor, outputIDInfo);
        ArrayList<Prediction<MultiLabel>> output = new ArrayList<Prediction<MultiLabel>>();
        int batchSize = (int)predictions.shape().asArray()[0];
        if (batchSize != examples.size() || batchSize != numValidFeatures.length) {
            throw new IllegalArgumentException("Invalid number of predictions received from Tensorflow, expected " + numValidFeatures.length + ", received " + batchSize);
        }
        for (int i = 0; i < batchSize; ++i) {
            FloatNdArray slice = predictions.slice(new Index[]{Indices.at((long)i), Indices.all()});
            output.add(this.generatePrediction(slice, outputIDInfo, numValidFeatures[i], examples.get(i)));
        }
        return output;
    }

    @Override
    public List<MultiLabel> convertToBatchOutput(Tensor tensor, ImmutableOutputInfo<MultiLabel> outputIDInfo) {
        FloatNdArray predictions = this.getBatchPredictions(tensor, outputIDInfo);
        ArrayList<MultiLabel> output = new ArrayList<MultiLabel>();
        int batchSize = (int)predictions.shape().asArray()[0];
        for (int i = 0; i < batchSize; ++i) {
            FloatNdArray slice = predictions.slice(new Index[]{Indices.at((long)i), Indices.all()});
            output.add(this.generateMultiLabel(slice, outputIDInfo));
        }
        return output;
    }

    @Override
    public Tensor convertToTensor(MultiLabel example, ImmutableOutputInfo<MultiLabel> outputIDInfo) {
        SparseVector vec = example.convertToSparseVector(outputIDInfo);
        TFloat32 returnVal = TFloat32.tensorOf((Shape)Shape.of((long[])new long[]{1L, outputIDInfo.size()}));
        for (int j = 0; j < outputIDInfo.size(); ++j) {
            returnVal.setFloat(0.0f, new long[]{0L, j});
        }
        for (VectorTuple v : vec) {
            returnVal.setFloat((float)v.value, new long[]{0L, v.index});
        }
        return returnVal;
    }

    @Override
    public Tensor convertToTensor(List<Example<MultiLabel>> examples, ImmutableOutputInfo<MultiLabel> outputIDInfo) {
        TFloat32 returnVal = TFloat32.tensorOf((Shape)Shape.of((long[])new long[]{examples.size(), outputIDInfo.size()}));
        int i = 0;
        for (Example<MultiLabel> e : examples) {
            SparseVector vec = ((MultiLabel)e.getOutput()).convertToSparseVector(outputIDInfo);
            for (int j = 0; j < outputIDInfo.size(); ++j) {
                returnVal.setFloat(0.0f, new long[]{i, j});
            }
            for (VectorTuple v : vec) {
                returnVal.setFloat((float)v.value, new long[]{i, v.index});
            }
            ++i;
        }
        return returnVal;
    }

    @Override
    public boolean generatesProbabilities() {
        return true;
    }

    public String toString() {
        return "MultiLabelConverter()";
    }

    public ConfiguredObjectProvenance getProvenance() {
        return new ConfiguredObjectProvenanceImpl((Configurable)this, "OutputConverter");
    }

    @Override
    public Class<MultiLabel> getTypeWitness() {
        return MultiLabel.class;
    }
}

