/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.multilabel.sgd.linear;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import java.util.HashMap;
import java.util.HashSet;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.ONNXExportable;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.common.sgd.AbstractLinearSGDModel;
import org.tribuo.common.sgd.AbstractSGDModel;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.LinearParameters;
import org.tribuo.math.Parameters;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.protos.NormalizerProto;
import org.tribuo.math.protos.ParametersProto;
import org.tribuo.math.util.VectorNormalizer;
import org.tribuo.multilabel.MultiLabel;
import org.tribuo.multilabel.sgd.protos.MultiLabelLinearSGDProto;
import org.tribuo.protos.core.ModelDataProto;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.util.onnx.ONNXNode;

public class LinearSGDModel
extends AbstractLinearSGDModel<MultiLabel>
implements ONNXExportable {
    private static final long serialVersionUID = 2L;
    public static final int CURRENT_VERSION = 0;
    private final VectorNormalizer normalizer;
    private final double threshold;

    LinearSGDModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<MultiLabel> outputIDInfo, LinearParameters parameters, VectorNormalizer normalizer, boolean generatesProbabilities, double threshold) {
        super(name, provenance, featureIDMap, outputIDInfo, parameters, generatesProbabilities);
        this.normalizer = normalizer;
        this.threshold = threshold;
    }

    public static LinearSGDModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        MultiLabelLinearSGDProto proto = (MultiLabelLinearSGDProto)message.unpack(MultiLabelLinearSGDProto.class);
        ModelDataCarrier carrier = ModelDataCarrier.deserialize((ModelDataProto)proto.getMetadata());
        if (!carrier.outputDomain().getOutput(0).getClass().equals(MultiLabel.class)) {
            throw new IllegalStateException("Invalid protobuf, output domain is not a multi-label domain, found " + carrier.outputDomain().getClass());
        }
        ImmutableOutputInfo outputDomain = carrier.outputDomain();
        Parameters params = Parameters.deserialize((ParametersProto)proto.getParams());
        if (!(params instanceof LinearParameters)) {
            throw new IllegalStateException("Invalid protobuf, parameters must be LinearParameters, found " + params.getClass());
        }
        VectorNormalizer normalizer = VectorNormalizer.deserialize((NormalizerProto)proto.getNormalizer());
        return new LinearSGDModel(carrier.name(), carrier.provenance(), carrier.featureDomain(), (ImmutableOutputInfo<MultiLabel>)outputDomain, (LinearParameters)params, normalizer, carrier.generatesProbabilities(), proto.getThreshold());
    }

    public Prediction<MultiLabel> predict(Example<MultiLabel> example) {
        AbstractSGDModel.PredAndActive predTuple = this.predictSingle(example);
        DenseVector outputs = predTuple.prediction;
        outputs.normalize(this.normalizer);
        HashMap<String, MultiLabel> fullLabels = new HashMap<String, MultiLabel>();
        HashSet<Label> predictedLabels = new HashSet<Label>();
        for (int i = 0; i < outputs.size(); ++i) {
            String labelName = ((MultiLabel)this.outputIDInfo.getOutput(i)).getLabelString();
            double labelScore = outputs.get(i);
            Label score = new Label(((MultiLabel)this.outputIDInfo.getOutput(i)).getLabelString(), labelScore);
            if (labelScore > this.threshold) {
                predictedLabels.add(score);
            }
            fullLabels.put(labelName, new MultiLabel(score));
        }
        return new Prediction((Output)new MultiLabel(predictedLabels), fullLabels, predTuple.numActiveFeatures - 1, example, this.generatesProbabilities);
    }

    public ModelProto serialize() {
        ModelDataCarrier carrier = this.createDataCarrier();
        MultiLabelLinearSGDProto.Builder modelBuilder = MultiLabelLinearSGDProto.newBuilder();
        modelBuilder.setMetadata(carrier.serialize());
        modelBuilder.setParams((ParametersProto)this.modelParameters.serialize());
        modelBuilder.setNormalizer((NormalizerProto)this.normalizer.serialize());
        modelBuilder.setThreshold(this.threshold);
        ModelProto.Builder builder = ModelProto.newBuilder();
        builder.setVersion(0);
        builder.setClassName(LinearSGDModel.class.getName());
        builder.setSerializedData(Any.pack((Message)modelBuilder.build()));
        return builder.build();
    }

    protected String getDimensionName(int index) {
        return ((MultiLabel)this.outputIDInfo.getOutput(index)).getLabelString();
    }

    protected LinearSGDModel copy(String newName, ModelProvenance newProvenance) {
        return new LinearSGDModel(newName, newProvenance, this.featureIDMap, (ImmutableOutputInfo<MultiLabel>)this.outputIDInfo, (LinearParameters)this.modelParameters.copy(), this.normalizer, this.generatesProbabilities, this.threshold);
    }

    protected ONNXNode onnxOutput(ONNXNode input) {
        return this.normalizer.exportNormalizer(input);
    }

    protected String onnxModelName() {
        return "MultiLabel-LinearSGDModel";
    }
}

