/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.liblinear;

import ai.onnx.proto.OnnxMl;
import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.provenance.Provenancable;
import com.oracle.labs.mlrg.olcut.util.Pair;
import de.bwaldvogel.liblinear.FeatureNode;
import de.bwaldvogel.liblinear.Linear;
import de.bwaldvogel.liblinear.Model;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.logging.Logger;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.ONNXExportable;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.common.liblinear.LibLinearModel;
import org.tribuo.common.liblinear.LibLinearTrainer;
import org.tribuo.common.liblinear.protos.LibLinearModelProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.protos.core.ModelDataProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.regression.ImmutableRegressionInfo;
import org.tribuo.regression.Regressor;
import org.tribuo.util.onnx.ONNXContext;
import org.tribuo.util.onnx.ONNXInitializer;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperator;
import org.tribuo.util.onnx.ONNXOperators;
import org.tribuo.util.onnx.ONNXPlaceholder;
import org.tribuo.util.onnx.ONNXRef;

public class LibLinearRegressionModel
extends LibLinearModel<Regressor>
implements ONNXExportable {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = Logger.getLogger(LibLinearRegressionModel.class.getName());
    private final String[] dimensionNames;
    private int[] mapping;

    LibLinearRegressionModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> outputInfo, List<Model> models) {
        super(name, description, featureIDMap, outputInfo, false, models);
        this.dimensionNames = Regressor.extractNames(outputInfo);
        this.mapping = ((ImmutableRegressionInfo)outputInfo).getIDtoNaturalOrderMapping();
    }

    public static LibLinearRegressionModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        if (!"org.tribuo.regression.liblinear.LibLinearRegressionModel".equals(className)) {
            throw new IllegalStateException("Invalid protobuf, this class can only deserialize LibLinearRegressionModel");
        }
        LibLinearModelProto proto = (LibLinearModelProto)message.unpack(LibLinearModelProto.class);
        ModelDataCarrier carrier = ModelDataCarrier.deserialize((ModelDataProto)proto.getMetadata());
        if (!carrier.outputDomain().getOutput(0).getClass().equals(Regressor.class)) {
            throw new IllegalStateException("Invalid protobuf, output domain is not a regression domain, found " + carrier.outputDomain().getClass());
        }
        ImmutableOutputInfo outputDomain = carrier.outputDomain();
        if (proto.getModelsCount() != outputDomain.size()) {
            throw new IllegalStateException("Invalid protobuf, expected " + outputDomain.size() + " model, found " + proto.getModelsCount());
        }
        try {
            ArrayList<Model> models = new ArrayList<Model>();
            for (ByteString modelArray : proto.getModelsList()) {
                ByteArrayInputStream bais = new ByteArrayInputStream(modelArray.toByteArray());
                ObjectInputStream ois = new ObjectInputStream(bais);
                Model model = (Model)ois.readObject();
                ois.close();
                models.add(model);
            }
            return new LibLinearRegressionModel(carrier.name(), carrier.provenance(), carrier.featureDomain(), (ImmutableOutputInfo<Regressor>)outputDomain, Collections.unmodifiableList(models));
        }
        catch (IOException | ClassNotFoundException e) {
            throw new IllegalStateException("Invalid protobuf, failed to deserialize liblinear model", e);
        }
    }

    public Prediction<Regressor> predict(Example<Regressor> example) {
        FeatureNode[] features = LibLinearTrainer.exampleToNodes(example, (ImmutableFeatureMap)this.featureIDMap, null);
        if (features.length == 1) {
            throw new IllegalArgumentException("No features found in Example " + example.toString());
        }
        double[] scores = new double[((Model)this.models.get(0)).getNrClass()];
        double[] regressedValues = new double[this.models.size()];
        for (int i = 0; i < regressedValues.length; ++i) {
            regressedValues[this.mapping[i]] = Linear.predictValues((Model)((Model)this.models.get(i)), (de.bwaldvogel.liblinear.Feature[])features, (double[])scores);
        }
        Regressor regressor = new Regressor(this.dimensionNames, regressedValues);
        return new Prediction((Output)regressor, features.length - 1, example);
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
        int maxFeatures = n < 0 ? this.featureIDMap.size() : n;
        double[][] featureWeights = this.getFeatureWeights();
        Comparator<Pair> comparator = Comparator.comparingDouble(p -> Math.abs((Double)p.getB()));
        HashMap<String, List<Pair<String, Double>>> map = new HashMap<String, List<Pair<String, Double>>>();
        PriorityQueue<Pair> q = new PriorityQueue<Pair>(maxFeatures, comparator);
        for (int i = 0; i < featureWeights.length; ++i) {
            int numFeatures = featureWeights[i].length - 1;
            for (int j = 0; j < numFeatures; ++j) {
                Pair cur = new Pair((Object)this.featureIDMap.get(j).getName(), (Object)featureWeights[i][j]);
                if (maxFeatures < 0 || q.size() < maxFeatures) {
                    q.offer(cur);
                    continue;
                }
                if (comparator.compare(cur, q.peek()) <= 0) continue;
                q.poll();
                q.offer(cur);
            }
            ArrayList<Pair> list = new ArrayList<Pair>();
            while (q.size() > 0) {
                list.add(q.poll());
            }
            Collections.reverse(list);
            map.put(this.dimensionNames[this.mapping[i]], list);
        }
        return map;
    }

    protected LibLinearRegressionModel copy(String newName, ModelProvenance newProvenance) {
        ArrayList<Model> newModels = new ArrayList<Model>();
        for (Model m : this.models) {
            newModels.add(LibLinearRegressionModel.copyModel((Model)m));
        }
        return new LibLinearRegressionModel(newName, newProvenance, this.featureIDMap, (ImmutableOutputInfo<Regressor>)this.outputIDInfo, newModels);
    }

    protected double[][] getFeatureWeights() {
        double[][] featureWeights = new double[this.models.size()][];
        for (int i = 0; i < this.models.size(); ++i) {
            featureWeights[i] = ((Model)this.models.get(i)).getFeatureWeights();
        }
        return featureWeights;
    }

    protected Excuse<Regressor> innerGetExcuse(Example<Regressor> e, double[][] allFeatureWeights) {
        Prediction<Regressor> prediction = this.predict(e);
        HashMap weightMap = new HashMap();
        for (int i = 0; i < allFeatureWeights.length; ++i) {
            ArrayList<Pair> scores = new ArrayList<Pair>();
            for (Feature f : e) {
                int id = this.featureIDMap.getID(f.getName());
                if (id <= -1) continue;
                double score = allFeatureWeights[i][id] * f.getValue();
                scores.add(new Pair((Object)f.getName(), (Object)score));
            }
            scores.sort((o1, o2) -> ((Double)o2.getB()).compareTo((Double)o1.getB()));
            weightMap.put(this.dimensionNames[this.mapping[i]], scores);
        }
        return new Excuse(e, prediction, weightMap);
    }

    public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) {
        ONNXContext onnx = new ONNXContext();
        ONNXPlaceholder input = onnx.floatInput(this.featureIDMap.size());
        ONNXPlaceholder output = onnx.floatOutput(this.outputIDInfo.size());
        onnx.setName("Regression-LibLinear");
        return ONNXExportable.buildModel((ONNXContext)((ONNXPlaceholder)this.writeONNXGraph((ONNXRef<?>)input).assignTo((ONNXRef)output)).onnxContext(), (String)domain, (long)modelVersion, (Provenancable)this);
    }

    public ONNXNode writeONNXGraph(ONNXRef<?> input) {
        ONNXContext onnx = input.onnxContext();
        double[][] weights = new double[this.models.size()][];
        for (int i = 0; i < this.models.size(); ++i) {
            weights[i] = ((Model)this.models.get(i)).getFeatureWeights();
        }
        int numFeatures = this.featureIDMap.size();
        int numOutputs = this.outputIDInfo.size();
        ONNXInitializer onnxWeights = onnx.floatTensor("liblinear-weights", Arrays.asList(numFeatures, numOutputs), fb -> {
            for (int j = 0; j < numFeatures; ++j) {
                for (int i = 0; i < weights.length; ++i) {
                    fb.put((float)weights[i][j]);
                }
            }
        });
        ONNXInitializer onnxBiases = onnx.floatTensor("liblinear-bias", Collections.singletonList(numOutputs), fb -> {
            for (int i = 0; i < weights.length; ++i) {
                fb.put((float)weights[i][numFeatures]);
            }
        });
        return input.apply((ONNXOperator)ONNXOperators.GEMM, Arrays.asList(onnxWeights, onnxBiases));
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
        if (this.mapping == null) {
            this.mapping = ((ImmutableRegressionInfo)this.outputIDInfo).getIDtoNaturalOrderMapping();
            ArrayList<Model> newModels = new ArrayList<Model>(this.models);
            for (int i = 0; i < this.mapping.length; ++i) {
                newModels.set(i, (Model)this.models.get(this.mapping[i]));
            }
            this.models = Collections.unmodifiableList(newModels);
        }
    }
}

