/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.clustering.hdbscan;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.hdbscan.HdbscanTrainer;
import org.tribuo.clustering.hdbscan.protos.ClusterExemplarProto;
import org.tribuo.clustering.hdbscan.protos.HdbscanModelProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.distance.Distance;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.math.protos.DistanceProto;
import org.tribuo.math.protos.TensorProto;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.ModelDataProto;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;

public final class HdbscanModel
extends Model<ClusterID> {
    private static final long serialVersionUID = 1L;
    public static final int CURRENT_VERSION = 0;
    private final List<Integer> clusterLabels;
    private final DenseVector outlierScoresVector;
    @Deprecated
    private HdbscanTrainer.Distance distanceType;
    private Distance dist;
    private final List<HdbscanTrainer.ClusterExemplar> clusterExemplars;
    private final double noisePointsOutlierScore;

    HdbscanModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<ClusterID> outputIDInfo, List<Integer> clusterLabels, DenseVector outlierScoresVector, List<HdbscanTrainer.ClusterExemplar> clusterExemplars, Distance dist, double noisePointsOutlierScore) {
        super(name, description, featureIDMap, outputIDInfo, false);
        this.clusterLabels = Collections.unmodifiableList(clusterLabels);
        this.outlierScoresVector = outlierScoresVector;
        this.clusterExemplars = Collections.unmodifiableList(clusterExemplars);
        this.dist = dist;
        this.noisePointsOutlierScore = noisePointsOutlierScore;
    }

    public static HdbscanModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        HdbscanModelProto proto = (HdbscanModelProto)message.unpack(HdbscanModelProto.class);
        ModelDataCarrier carrier = ModelDataCarrier.deserialize((ModelDataProto)proto.getMetadata());
        if (!carrier.outputDomain().getOutput(0).getClass().equals(ClusterID.class)) {
            throw new IllegalStateException("Invalid protobuf, output domain is not a clustering domain, found " + carrier.outputDomain().getClass());
        }
        ImmutableOutputInfo outputDomain = carrier.outputDomain();
        Tensor outlierScoresTensor = Tensor.deserialize((TensorProto)proto.getOutlierScoresVector());
        if (!(outlierScoresTensor instanceof DenseVector)) {
            throw new IllegalStateException("Invalid protobuf, outlier scores must be a dense vector, found " + outlierScoresTensor.getClass());
        }
        DenseVector outlierScoresVector = (DenseVector)outlierScoresTensor;
        ArrayList<Integer> clusterLabels = new ArrayList<Integer>(proto.getClusterLabelsList());
        for (Integer n : clusterLabels) {
            if (outputDomain.getOutput(n.intValue()) != null || n == -1) continue;
            throw new IllegalStateException("Invalid protobuf, found cluster id " + n + " which is not present in the domain " + outputDomain);
        }
        if (clusterLabels.size() != outlierScoresVector.size()) {
            throw new IllegalStateException("Invalid protobuf, expected the same number of outlier scores as cluster labels, found " + outlierScoresVector.size() + " scores and " + clusterLabels.size() + " labels");
        }
        ArrayList<HdbscanTrainer.ClusterExemplar> exemplars = new ArrayList<HdbscanTrainer.ClusterExemplar>();
        for (ClusterExemplarProto p : proto.getClusterExemplarsList()) {
            exemplars.add(HdbscanTrainer.ClusterExemplar.deserialize(p));
        }
        Distance distance = (Distance)ProtoUtil.deserialize((Message)proto.getDistance());
        return new HdbscanModel(carrier.name(), carrier.provenance(), carrier.featureDomain(), (ImmutableOutputInfo<ClusterID>)outputDomain, clusterLabels, outlierScoresVector, exemplars, distance, proto.getNoisePointsOutlierScore());
    }

    public List<Integer> getClusterLabels() {
        return this.clusterLabels;
    }

    public List<Double> getOutlierScores() {
        ArrayList<Double> outlierScores = new ArrayList<Double>(this.outlierScoresVector.size());
        for (double outlierScore : this.outlierScoresVector.toArray()) {
            outlierScores.add(outlierScore);
        }
        return outlierScores;
    }

    public List<HdbscanTrainer.ClusterExemplar> getClusterExemplars() {
        ArrayList<HdbscanTrainer.ClusterExemplar> list = new ArrayList<HdbscanTrainer.ClusterExemplar>(this.clusterExemplars.size());
        for (HdbscanTrainer.ClusterExemplar e : this.clusterExemplars) {
            list.add(e.copy());
        }
        return list;
    }

    public List<Pair<Integer, List<Feature>>> getClusters() {
        ArrayList<Pair<Integer, List<Feature>>> list = new ArrayList<Pair<Integer, List<Feature>>>(this.clusterExemplars.size());
        for (HdbscanTrainer.ClusterExemplar e : this.clusterExemplars) {
            ArrayList<Feature> features = new ArrayList<Feature>(e.getFeatures().numActiveElements());
            for (VectorTuple v : e.getFeatures()) {
                Feature f = new Feature(this.featureIDMap.get(v.index).getName(), v.value);
                features.add(f);
            }
            list.add((Pair<Integer, List<Feature>>)new Pair((Object)e.getLabel(), features));
        }
        return list;
    }

    public Prediction<ClusterID> predict(Example<ClusterID> example) {
        Object vector = example.size() == this.featureIDMap.size() ? DenseVector.createDenseVector(example, (ImmutableFeatureMap)this.featureIDMap, (boolean)false) : SparseVector.createSparseVector(example, (ImmutableFeatureMap)this.featureIDMap, (boolean)false);
        if (vector.numActiveElements() == 0) {
            throw new IllegalArgumentException("No features found in Example " + example);
        }
        double minDistance = Double.POSITIVE_INFINITY;
        int clusterLabel = 0;
        double outlierScore = 0.0;
        if (Double.compare(this.noisePointsOutlierScore, 0.0) > 0) {
            boolean isNoisePoint = true;
            for (HdbscanTrainer.ClusterExemplar clusterExemplar : this.clusterExemplars) {
                double distance = this.dist.computeDistance(clusterExemplar.getFeatures(), (SGDVector)vector);
                if (isNoisePoint && distance <= clusterExemplar.getMaxDistToEdge()) {
                    isNoisePoint = false;
                }
                if (!(distance < minDistance)) continue;
                minDistance = distance;
                clusterLabel = clusterExemplar.getLabel();
                outlierScore = clusterExemplar.getOutlierScore();
            }
            if (isNoisePoint) {
                clusterLabel = 0;
                outlierScore = this.noisePointsOutlierScore;
            }
        } else {
            for (HdbscanTrainer.ClusterExemplar clusterExemplar : this.clusterExemplars) {
                double distance = this.dist.computeDistance(clusterExemplar.getFeatures(), (SGDVector)vector);
                if (!(distance < minDistance)) continue;
                minDistance = distance;
                clusterLabel = clusterExemplar.getLabel();
                outlierScore = clusterExemplar.getOutlierScore();
            }
        }
        return new Prediction((Output)new ClusterID(clusterLabel, outlierScore), vector.size(), example);
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
        return Collections.emptyMap();
    }

    public Optional<Excuse<ClusterID>> getExcuse(Example<ClusterID> example) {
        return Optional.empty();
    }

    public ModelProto serialize() {
        ModelDataCarrier carrier = this.createDataCarrier();
        HdbscanModelProto.Builder modelBuilder = HdbscanModelProto.newBuilder();
        modelBuilder.setMetadata(carrier.serialize());
        modelBuilder.addAllClusterLabels(this.clusterLabels);
        modelBuilder.setOutlierScoresVector(this.outlierScoresVector.serialize());
        modelBuilder.setDistance((DistanceProto)this.dist.serialize());
        for (HdbscanTrainer.ClusterExemplar e : this.clusterExemplars) {
            modelBuilder.addClusterExemplars(e.serialize());
        }
        modelBuilder.setNoisePointsOutlierScore(this.noisePointsOutlierScore);
        ModelProto.Builder builder = ModelProto.newBuilder();
        builder.setSerializedData(Any.pack((Message)modelBuilder.build()));
        builder.setClassName(HdbscanModel.class.getName());
        builder.setVersion(0);
        return builder.build();
    }

    protected HdbscanModel copy(String newName, ModelProvenance newProvenance) {
        DenseVector copyOutlierScoresVector = this.outlierScoresVector.copy();
        ArrayList<Integer> copyClusterLabels = new ArrayList<Integer>(this.clusterLabels);
        ArrayList<HdbscanTrainer.ClusterExemplar> copyExemplars = new ArrayList<HdbscanTrainer.ClusterExemplar>(this.clusterExemplars);
        return new HdbscanModel(newName, newProvenance, this.featureIDMap, (ImmutableOutputInfo<ClusterID>)this.outputIDInfo, copyClusterLabels, copyOutlierScoresVector, copyExemplars, this.dist, this.noisePointsOutlierScore);
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
        if (this.dist == null) {
            this.dist = this.distanceType.getDistanceType().getDistance();
        }
    }
}

