/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.sgd.fm;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import java.util.Arrays;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.ONNXExportable;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.common.sgd.AbstractFMModel;
import org.tribuo.common.sgd.AbstractSGDModel;
import org.tribuo.common.sgd.FMParameters;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.Parameters;
import org.tribuo.math.protos.ParametersProto;
import org.tribuo.protos.core.ModelDataProto;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.regression.ImmutableRegressionInfo;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.sgd.protos.FMRegressionModelProto;
import org.tribuo.util.onnx.ONNXInitializer;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperator;
import org.tribuo.util.onnx.ONNXOperators;
import org.tribuo.util.onnx.ONNXRef;

public class FMRegressionModel
extends AbstractFMModel<Regressor>
implements ONNXExportable {
    private static final long serialVersionUID = 3L;
    public static final int CURRENT_VERSION = 0;
    private final String[] dimensionNames;
    private final boolean standardise;

    FMRegressionModel(String name, String[] dimensionNames, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> outputIDInfo, FMParameters parameters, boolean standardise) {
        super(name, provenance, featureIDMap, outputIDInfo, parameters, false);
        this.dimensionNames = dimensionNames;
        this.standardise = standardise;
    }

    public static FMRegressionModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        FMRegressionModelProto proto = (FMRegressionModelProto)message.unpack(FMRegressionModelProto.class);
        ModelDataCarrier carrier = ModelDataCarrier.deserialize((ModelDataProto)proto.getMetadata());
        if (!carrier.outputDomain().getOutput(0).getClass().equals(Regressor.class)) {
            throw new IllegalStateException("Invalid protobuf, output domain is not a regression domain, found " + carrier.outputDomain().getClass());
        }
        ImmutableOutputInfo outputDomain = carrier.outputDomain();
        Parameters params = Parameters.deserialize((ParametersProto)proto.getParams());
        if (!(params instanceof FMParameters)) {
            throw new IllegalStateException("Invalid protobuf, parameters must be FMParameters, found " + params.getClass());
        }
        String[] dimensionNames = (String[])proto.getDimensionNamesList().toArray((Object[])new String[0]);
        if (dimensionNames.length != outputDomain.size()) {
            throw new IllegalStateException("Invalid protobuf, found a different number of dimension names to the output dimensions, found " + dimensionNames.length + " , expected " + outputDomain.size());
        }
        return new FMRegressionModel(carrier.name(), dimensionNames, carrier.provenance(), carrier.featureDomain(), (ImmutableOutputInfo<Regressor>)outputDomain, (FMParameters)params, proto.getStandardise());
    }

    public Prediction<Regressor> predict(Example<Regressor> example) {
        AbstractSGDModel.PredAndActive predTuple = this.predictSingle(example);
        double[] predictions = predTuple.prediction.toArray();
        if (this.standardise) {
            predictions = this.unstandardisePredictions(predictions);
        }
        return new Prediction((Output)new Regressor(this.dimensionNames, predictions), predTuple.numActiveFeatures, example);
    }

    public ModelProto serialize() {
        ModelDataCarrier carrier = this.createDataCarrier();
        FMRegressionModelProto.Builder modelBuilder = FMRegressionModelProto.newBuilder();
        modelBuilder.setMetadata(carrier.serialize());
        modelBuilder.setParams((ParametersProto)this.modelParameters.serialize());
        modelBuilder.addAllDimensionNames(Arrays.asList(this.dimensionNames));
        modelBuilder.setStandardise(this.standardise);
        ModelProto.Builder builder = ModelProto.newBuilder();
        builder.setVersion(0);
        builder.setClassName(FMRegressionModel.class.getName());
        builder.setSerializedData(Any.pack((Message)modelBuilder.build()));
        return builder.build();
    }

    private double[] unstandardisePredictions(double[] predictions) {
        ImmutableRegressionInfo info = (ImmutableRegressionInfo)this.outputIDInfo;
        for (int i = 0; i < predictions.length; ++i) {
            double mean = info.getMean(i);
            double variance = info.getVariance(i);
            predictions[i] = predictions[i] * variance + mean;
        }
        return predictions;
    }

    protected FMRegressionModel copy(String newName, ModelProvenance newProvenance) {
        return new FMRegressionModel(newName, Arrays.copyOf(this.dimensionNames, this.dimensionNames.length), newProvenance, this.featureIDMap, (ImmutableOutputInfo<Regressor>)this.outputIDInfo, (FMParameters)this.modelParameters.copy(), this.standardise);
    }

    protected String getDimensionName(int index) {
        return this.dimensionNames[index];
    }

    protected String onnxModelName() {
        return "FMRegressionModel";
    }

    protected ONNXNode onnxOutput(ONNXNode fmOutput) {
        if (this.standardise) {
            ImmutableRegressionInfo info = (ImmutableRegressionInfo)this.outputIDInfo;
            double[] means = new double[this.outputIDInfo.size()];
            double[] variances = new double[this.outputIDInfo.size()];
            for (int i = 0; i < means.length; ++i) {
                means[i] = info.getMean(i);
                variances[i] = info.getVariance(i);
            }
            ONNXInitializer outputMean = fmOutput.onnxContext().array("y_mean", means);
            ONNXInitializer outputVariance = fmOutput.onnxContext().array("y_var", variances);
            return fmOutput.apply((ONNXOperator)ONNXOperators.MUL, (ONNXRef)outputVariance).apply((ONNXOperator)ONNXOperators.ADD, (ONNXRef)outputMean);
        }
        return fmOutput;
    }
}

