/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.slm;

import ai.onnx.proto.OnnxMl;
import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenancable;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.ONNXExportable;
import org.tribuo.Prediction;
import org.tribuo.VariableIDInfo;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.math.protos.TensorProto;
import org.tribuo.protos.core.ModelDataProto;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.regression.ImmutableRegressionInfo;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.impl.SkeletalIndependentRegressionSparseModel;
import org.tribuo.regression.slm.protos.SparseLinearModelProto;
import org.tribuo.util.Util;
import org.tribuo.util.onnx.ONNXContext;
import org.tribuo.util.onnx.ONNXInitializer;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperator;
import org.tribuo.util.onnx.ONNXOperators;
import org.tribuo.util.onnx.ONNXPlaceholder;
import org.tribuo.util.onnx.ONNXRef;

public class SparseLinearModel
extends SkeletalIndependentRegressionSparseModel
implements ONNXExportable {
    private static final long serialVersionUID = 3L;
    private static final Logger logger = Logger.getLogger(SparseLinearModel.class.getName());
    public static final int CURRENT_VERSION = 0;
    private SparseVector[] weights;
    private final DenseVector featureMeans;
    private final DenseVector featureVariance;
    private final boolean bias;
    private double[] yMean;
    private double[] yVariance;
    private boolean enet41MappingFix;

    SparseLinearModel(String name, String[] dimensionNames, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> labelIDMap, SparseVector[] weights, DenseVector featureMeans, DenseVector featureNorms, double[] yMean, double[] yNorms, boolean bias) {
        super(name, dimensionNames, description, featureIDMap, labelIDMap, SparseLinearModel.generateActiveFeatures(dimensionNames, featureIDMap, weights));
        this.weights = weights;
        this.featureMeans = featureMeans;
        this.featureVariance = featureNorms;
        this.bias = bias;
        this.yVariance = yNorms;
        this.yMean = yMean;
        this.enet41MappingFix = true;
    }

    public static SparseLinearModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        SparseLinearModelProto proto = (SparseLinearModelProto)message.unpack(SparseLinearModelProto.class);
        ModelDataCarrier carrier = ModelDataCarrier.deserialize((ModelDataProto)proto.getMetadata());
        if (!carrier.outputDomain().getOutput(0).getClass().equals(Regressor.class)) {
            throw new IllegalStateException("Invalid protobuf, output domain is not a regression domain, found " + carrier.outputDomain().getClass());
        }
        ImmutableOutputInfo outputDomain = carrier.outputDomain();
        String[] dimensions = new String[proto.getDimensionsCount()];
        if (dimensions.length != outputDomain.size()) {
            throw new IllegalStateException("Invalid protobuf, found insufficient dimension names, expected " + outputDomain.size() + ", found " + dimensions.length);
        }
        for (int i = 0; i < dimensions.length; ++i) {
            dimensions[i] = proto.getDimensions(i);
        }
        SparseVector[] weights = new SparseVector[outputDomain.size()];
        if (weights.length != proto.getWeightsCount()) {
            throw new IllegalStateException("Invalid protobuf, expected same weight dimension as output domain size, found " + proto.getWeightsCount() + " weights and " + outputDomain.size() + " output dimensions");
        }
        int featureSize = proto.getBias() ? carrier.featureDomain().size() + 1 : carrier.featureDomain().size();
        for (int i = 0; i < weights.length; ++i) {
            SparseVector v;
            Tensor deser = Tensor.deserialize((TensorProto)proto.getWeights(i));
            if (deser instanceof SparseVector) {
                v = (SparseVector)deser;
                if (v.size() != featureSize) {
                    throw new IllegalStateException("Invalid protobuf, weights size and feature domain do not match, expected " + featureSize + ", found " + v.size());
                }
            } else {
                throw new IllegalStateException("Invalid protobuf, expected a SparseVector, found " + deser.getClass());
            }
            weights[i] = v;
        }
        Tensor featureMeansTensor = Tensor.deserialize((TensorProto)proto.getFeatureMeans());
        if (!(featureMeansTensor instanceof DenseVector)) {
            throw new IllegalStateException("Invalid protobuf, feature means must be a dense vector, found " + featureMeansTensor.getClass());
        }
        DenseVector featureMeans = (DenseVector)featureMeansTensor;
        if (featureMeans.size() != featureSize) {
            throw new IllegalStateException("Invalid protobuf, feature means not the right size, expected " + featureSize + ", found " + featureMeans.size());
        }
        Tensor featureNormsTensor = Tensor.deserialize((TensorProto)proto.getFeatureNorms());
        if (!(featureNormsTensor instanceof DenseVector)) {
            throw new IllegalStateException("Invalid protobuf, feature means must be a dense vector, found " + featureNormsTensor.getClass());
        }
        DenseVector featureNorms = (DenseVector)featureNormsTensor;
        if (featureNorms.size() != featureSize) {
            throw new IllegalStateException("Invalid protobuf, feature means not the right size, expected " + featureSize + ", found " + featureNorms.size());
        }
        double[] yMean = Util.toPrimitiveDouble(proto.getYMeanList());
        if (yMean.length != outputDomain.size()) {
            throw new IllegalStateException("Invalid protobuf, y means not the right size, expected " + carrier.outputDomain().size() + " found " + yMean.length);
        }
        double[] yNorm = Util.toPrimitiveDouble(proto.getYNormList());
        if (yNorm.length != outputDomain.size()) {
            throw new IllegalStateException("Invalid protobuf, y norms not the right size, expected " + carrier.outputDomain().size() + " found " + yNorm.length);
        }
        return new SparseLinearModel(carrier.name(), dimensions, carrier.provenance(), carrier.featureDomain(), (ImmutableOutputInfo<Regressor>)outputDomain, weights, featureMeans, featureNorms, yMean, yNorm, proto.getBias());
    }

    private static Map<String, List<String>> generateActiveFeatures(String[] dimensionNames, ImmutableFeatureMap featureMap, SparseVector[] weightsArray) {
        HashMap<String, List<String>> map = new HashMap<String, List<String>>();
        for (int i = 0; i < dimensionNames.length; ++i) {
            ArrayList<String> featureNames = new ArrayList<String>();
            for (VectorTuple v : weightsArray[i]) {
                if (v.index == featureMap.size()) {
                    featureNames.add("BIAS");
                    continue;
                }
                VariableIDInfo info = featureMap.get(v.index);
                featureNames.add(info.getName());
            }
            map.put(dimensionNames[i], featureNames);
        }
        return map;
    }

    protected SparseVector createFeatures(Example<Regressor> example) {
        SparseVector features = SparseVector.createSparseVector(example, (ImmutableFeatureMap)this.featureIDMap, (boolean)this.bias);
        features.intersectAndAddInPlace((Tensor)this.featureMeans, a -> -a);
        features.hadamardProductInPlace((Tensor)this.featureVariance, a -> 1.0 / a);
        return features;
    }

    protected Regressor.DimensionTuple scoreDimension(int dimensionIdx, SparseVector features) {
        double prediction = this.weights[dimensionIdx].numActiveElements() > 0 ? this.weights[dimensionIdx].dot((SGDVector)features) : 1.0;
        prediction *= this.yVariance[dimensionIdx];
        return new Regressor.DimensionTuple(this.dimensions[dimensionIdx], prediction += this.yMean[dimensionIdx]);
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
        int maxFeatures = n < 0 ? this.featureIDMap.size() + 1 : n;
        Comparator<Pair> comparator = Comparator.comparingDouble(p -> Math.abs((Double)p.getB()));
        HashMap<String, List<Pair<String, Double>>> map = new HashMap<String, List<Pair<String, Double>>>();
        PriorityQueue<Pair> q = new PriorityQueue<Pair>(maxFeatures, comparator);
        for (int i = 0; i < this.dimensions.length; ++i) {
            q.clear();
            for (VectorTuple v : this.weights[i]) {
                VariableIDInfo info = this.featureIDMap.get(v.index);
                String name = info == null ? "BIAS" : info.getName();
                Pair curr = new Pair((Object)name, (Object)v.value);
                if (q.size() < maxFeatures) {
                    q.offer(curr);
                    continue;
                }
                if (comparator.compare(curr, q.peek()) <= 0) continue;
                q.poll();
                q.offer(curr);
            }
            ArrayList<Pair> b = new ArrayList<Pair>();
            while (q.size() > 0) {
                b.add(q.poll());
            }
            Collections.reverse(b);
            map.put(this.dimensions[i], b);
        }
        return map;
    }

    public Optional<Excuse<Regressor>> getExcuse(Example<Regressor> example) {
        Prediction prediction = this.predict(example);
        HashMap weightMap = new HashMap();
        SparseVector features = this.createFeatures(example);
        for (int i = 0; i < this.dimensions.length; ++i) {
            ArrayList<Pair> classScores = new ArrayList<Pair>();
            for (VectorTuple f : features) {
                double score = this.weights[i].get(f.index) * f.value;
                classScores.add(new Pair((Object)this.featureIDMap.get(f.index).getName(), (Object)score));
            }
            classScores.sort((o1, o2) -> ((Double)o2.getB()).compareTo((Double)o1.getB()));
            weightMap.put(this.dimensions[i], classScores);
        }
        return Optional.of(new Excuse(example, prediction, weightMap));
    }

    protected Model<Regressor> copy(String newName, ModelProvenance newProvenance) {
        return new SparseLinearModel(newName, Arrays.copyOf(this.dimensions, this.dimensions.length), newProvenance, this.featureIDMap, (ImmutableOutputInfo<Regressor>)this.outputIDInfo, this.copyWeights(), this.featureMeans.copy(), this.featureVariance.copy(), Arrays.copyOf(this.yMean, this.yMean.length), Arrays.copyOf(this.yVariance, this.yVariance.length), this.bias);
    }

    private SparseVector[] copyWeights() {
        SparseVector[] newWeights = new SparseVector[this.weights.length];
        for (int i = 0; i < this.weights.length; ++i) {
            newWeights[i] = this.weights[i].copy();
        }
        return newWeights;
    }

    public Map<String, SparseVector> getWeights() {
        SparseVector[] newWeights = this.copyWeights();
        HashMap<String, SparseVector> output = new HashMap<String, SparseVector>();
        for (int i = 0; i < this.dimensions.length; ++i) {
            output.put(this.dimensions[i], newWeights[i]);
        }
        return output;
    }

    public ModelProto serialize() {
        ModelDataCarrier carrier = this.createDataCarrier();
        SparseLinearModelProto.Builder modelBuilder = SparseLinearModelProto.newBuilder();
        modelBuilder.setMetadata(carrier.serialize());
        modelBuilder.addAllDimensions(Arrays.asList(this.dimensions));
        for (SparseVector v : this.weights) {
            modelBuilder.addWeights(v.serialize());
        }
        modelBuilder.setFeatureMeans(this.featureMeans.serialize());
        modelBuilder.setFeatureNorms(this.featureVariance.serialize());
        modelBuilder.setBias(this.bias);
        modelBuilder.addAllYMean(Arrays.stream(this.yMean).boxed().collect(Collectors.toList()));
        modelBuilder.addAllYNorm(Arrays.stream(this.yVariance).boxed().collect(Collectors.toList()));
        ModelProto.Builder builder = ModelProto.newBuilder();
        builder.setSerializedData(Any.pack((Message)modelBuilder.build()));
        builder.setClassName(SparseLinearModel.class.getName());
        builder.setVersion(0);
        return builder.build();
    }

    public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) {
        ONNXContext onnx = new ONNXContext();
        ONNXPlaceholder input = onnx.floatInput(this.featureIDMap.size());
        ONNXPlaceholder output = onnx.floatOutput(this.outputIDInfo.size());
        onnx.setName("Regression-SparseLinearModel");
        return ONNXExportable.buildModel((ONNXContext)((ONNXPlaceholder)this.writeONNXGraph((ONNXRef<?>)input).assignTo((ONNXRef)output)).onnxContext(), (String)domain, (long)modelVersion, (Provenancable)this);
    }

    public ONNXNode writeONNXGraph(ONNXRef<?> input) {
        ONNXContext onnx = input.onnxContext();
        ONNXInitializer onnxWeights = onnx.floatTensor("slm_weights", Arrays.asList(this.featureIDMap.size(), this.outputIDInfo.size()), fb -> {
            for (int j = 0; j < this.featureIDMap.size(); ++j) {
                for (int i = 0; i < this.weights.length; ++i) {
                    fb.put((float)this.weights[i].get(j));
                }
            }
        });
        ONNXInitializer onnxBiases = onnx.floatTensor("slm_biases", Collections.singletonList(this.outputIDInfo.size()), fb -> Arrays.stream(this.weights).forEachOrdered(sv -> fb.put((float)sv.get(this.featureIDMap.size()))));
        double[] xMean = this.bias ? Arrays.copyOf(this.featureMeans.toArray(), this.featureIDMap.size()) : this.featureMeans.toArray();
        ONNXInitializer featureMean = onnx.array("feature_mean", xMean);
        ONNXInitializer outputMean = onnx.array("y_mean", this.yMean);
        double[] xVariance = this.bias ? Arrays.copyOf(this.featureVariance.toArray(), this.featureIDMap.size()) : this.featureVariance.toArray();
        ONNXInitializer featureVariance = onnx.array("feature_variance", xVariance);
        ONNXInitializer outputVariance = onnx.array("y_variance", this.yVariance);
        ONNXNode scaledFeatures = input.apply((ONNXOperator)ONNXOperators.SUB, (ONNXRef)featureMean).apply((ONNXOperator)ONNXOperators.DIV, (ONNXRef)featureVariance);
        ONNXNode gemm = scaledFeatures.apply((ONNXOperator)ONNXOperators.GEMM, Arrays.asList(onnxWeights, onnxBiases));
        return gemm.apply((ONNXOperator)ONNXOperators.MUL, (ONNXRef)outputVariance).apply((ONNXOperator)ONNXOperators.ADD, (ONNXRef)outputMean);
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
        String tribuoVersion = (String)((PrimitiveProvenance)this.provenance.getTrainerProvenance().getInstanceValues().get("tribuo-version")).getValue();
        if (this.provenance.getTrainerProvenance().getClassName().equals("org.tribuo.regression.slm.ElasticNetCDTrainer") && !this.enet41MappingFix && (tribuoVersion.startsWith("4.0.0") || tribuoVersion.startsWith("4.0.1") || tribuoVersion.startsWith("4.0.2") || tribuoVersion.startsWith("4.1.0") || tribuoVersion.equals("4.1.1-SNAPSHOT"))) {
            this.enet41MappingFix = true;
            int[] mapping = ((ImmutableRegressionInfo)this.outputIDInfo).getIDtoNaturalOrderMapping();
            SparseVector[] newWeights = new SparseVector[this.weights.length];
            double[] newYMeans = new double[this.weights.length];
            double[] newYVariances = new double[this.weights.length];
            for (int i = 0; i < mapping.length; ++i) {
                newWeights[i] = this.weights[mapping[i]];
                newYMeans[i] = this.yMean[mapping[i]];
                newYVariances[i] = this.yVariance[mapping[i]];
            }
            this.yMean = newYMeans;
            this.yVariance = newYVariances;
            this.weights = newWeights;
        }
    }
}

