/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.clustering.hdbscan;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.util.MutableLong;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.Serializable;
import java.time.OffsetDateTime;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Trainer;
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.ImmutableClusteringInfo;
import org.tribuo.clustering.hdbscan.ExtendedMinimumSpanningTree;
import org.tribuo.clustering.hdbscan.HdbscanCluster;
import org.tribuo.clustering.hdbscan.HdbscanModel;
import org.tribuo.clustering.hdbscan.protos.ClusterExemplarProto;
import org.tribuo.math.distance.DistanceType;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.neighbour.NeighboursQuery;
import org.tribuo.math.neighbour.NeighboursQueryFactory;
import org.tribuo.math.neighbour.NeighboursQueryFactoryType;
import org.tribuo.math.neighbour.bruteforce.NeighboursBruteForceFactory;
import org.tribuo.math.protos.TensorProto;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;

public final class HdbscanTrainer
implements Trainer<ClusterID> {
    private static final Logger logger = Logger.getLogger(HdbscanTrainer.class.getName());
    static final int OUTLIER_NOISE_CLUSTER_LABEL = 0;
    private static final double MAX_OUTLIER_SCORE = 0.9999;
    @Config(mandatory=true, description="The minimum number of points required to form a cluster.")
    private int minClusterSize;
    @Deprecated
    @Config(description="The distance function to use. This is now deprecated.")
    private Distance distanceType;
    @Config(description="The distance function to use.")
    private org.tribuo.math.distance.Distance dist;
    @Config(mandatory=true, description="The number of nearest-neighbors to use in the initial density approximation. This includes the point itself.")
    private int k;
    @Deprecated
    @Config(description="The number of threads to use for training. This is now deprecated since it is a field on the NeighboursQueryFactory object.")
    private int numThreads = 1;
    @Config(description="The nearest neighbour implementation factory to use.")
    private NeighboursQueryFactory neighboursQueryFactory;
    private int trainInvocationCounter;

    private HdbscanTrainer() {
    }

    public HdbscanTrainer(int minClusterSize) {
        this(minClusterSize, DistanceType.L2.getDistance(), minClusterSize, 1, NeighboursQueryFactoryType.BRUTE_FORCE);
    }

    @Deprecated
    public HdbscanTrainer(int minClusterSize, Distance distanceType, int k, int numThreads) {
        this(minClusterSize, distanceType.getDistanceType().getDistance(), k, numThreads, NeighboursQueryFactoryType.BRUTE_FORCE);
    }

    public HdbscanTrainer(int minClusterSize, org.tribuo.math.distance.Distance dist, int k, int numThreads, NeighboursQueryFactoryType nqFactoryType) {
        this.minClusterSize = minClusterSize;
        this.dist = dist;
        this.k = k;
        this.numThreads = numThreads;
        this.neighboursQueryFactory = NeighboursQueryFactoryType.getNeighboursQueryFactory((NeighboursQueryFactoryType)nqFactoryType, (org.tribuo.math.distance.Distance)dist, (int)numThreads);
    }

    public HdbscanTrainer(int minClusterSize, int k, NeighboursQueryFactory neighboursQueryFactory) {
        this.minClusterSize = minClusterSize;
        this.dist = neighboursQueryFactory.getDistance();
        this.k = k;
        this.neighboursQueryFactory = neighboursQueryFactory;
    }

    public synchronized void postConfig() {
        if (this.distanceType != null) {
            if (this.dist != null) {
                throw new PropertyException("distType", "Both distType and distanceType must not both be set.");
            }
            this.dist = this.distanceType.getDistanceType().getDistance();
            this.distanceType = null;
        }
        if (this.neighboursQueryFactory == null) {
            int numberThreads = this.numThreads <= 0 ? 1 : this.numThreads;
            this.neighboursQueryFactory = new NeighboursBruteForceFactory(this.dist, numberThreads);
        } else if (!this.dist.equals(this.neighboursQueryFactory.getDistance())) {
            throw new PropertyException("neighboursQueryFactory", "distType and its field on the NeighboursQueryFactory must be equal.");
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public HdbscanModel train(Dataset<ClusterID> examples, Map<String, Provenance> runProvenance) {
        TrainerProvenance trainerProvenance;
        HdbscanTrainer hdbscanTrainer = this;
        synchronized (hdbscanTrainer) {
            trainerProvenance = this.getProvenance();
            ++this.trainInvocationCounter;
        }
        ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
        SGDVector[] data = new SGDVector[examples.size()];
        int n = 0;
        for (Example example : examples) {
            data[n] = example.size() == featureMap.size() ? DenseVector.createDenseVector((Example)example, (ImmutableFeatureMap)featureMap, (boolean)false) : SparseVector.createSparseVector((Example)example, (ImmutableFeatureMap)featureMap, (boolean)false);
            ++n;
        }
        DenseVector coreDistances = HdbscanTrainer.calculateCoreDistances(data, this.k, this.neighboursQueryFactory);
        ExtendedMinimumSpanningTree emst = HdbscanTrainer.constructEMST(data, coreDistances, this.dist);
        double[] pointNoiseLevels = new double[data.length];
        int[] pointLastClusters = new int[data.length];
        HashMap<Integer, int[]> hierarchy = new HashMap<Integer, int[]>();
        List<HdbscanCluster> clusters = HdbscanTrainer.computeHierarchyAndClusterTree(emst, this.minClusterSize, pointNoiseLevels, pointLastClusters, hierarchy);
        HdbscanTrainer.propagateTree(clusters);
        List<Integer> clusterLabels = HdbscanTrainer.findProminentClusters(hierarchy, clusters, data.length);
        DenseVector outlierScoresVector = HdbscanTrainer.calculateOutlierScores(pointNoiseLevels, pointLastClusters, clusters);
        Map<Integer, List<Pair<Double, Integer>>> clusterAssignments = HdbscanTrainer.generateClusterAssignments(clusterLabels, outlierScoresVector);
        HashMap<Integer, MutableLong> counts = new HashMap<Integer, MutableLong>();
        for (Map.Entry<Integer, List<Pair<Double, Integer>>> e : clusterAssignments.entrySet()) {
            counts.put(e.getKey(), new MutableLong((long)e.getValue().size()));
        }
        ImmutableClusteringInfo outputMap = new ImmutableClusteringInfo(counts);
        List<ClusterExemplar> clusterExemplars = HdbscanTrainer.computeExemplars(data, clusterAssignments, this.dist);
        double noisePointsOutlierScore = HdbscanTrainer.getNoisePointsOutlierScore(clusterAssignments);
        logger.log(Level.INFO, "Hdbscan is done.");
        ModelProvenance provenance = new ModelProvenance(HdbscanModel.class.getName(), OffsetDateTime.now(), (DatasetProvenance)examples.getProvenance(), trainerProvenance, runProvenance);
        return new HdbscanModel("hdbscan-model", provenance, featureMap, (ImmutableOutputInfo<ClusterID>)outputMap, clusterLabels, outlierScoresVector, clusterExemplars, this.dist, noisePointsOutlierScore);
    }

    public HdbscanModel train(Dataset<ClusterID> dataset) {
        return this.train((Dataset)dataset, Collections.emptyMap());
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    public void setInvocationCount(int newInvocationCount) {
        if (newInvocationCount < 0) {
            throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
        }
        this.trainInvocationCounter = newInvocationCount;
    }

    private static DenseVector calculateCoreDistances(SGDVector[] data, int k, NeighboursQueryFactory neighboursQueryFactory) {
        DenseVector coreDistances = new DenseVector(data.length);
        if (k == 1) {
            return coreDistances;
        }
        NeighboursQuery nq = neighboursQueryFactory.createNeighboursQuery(data);
        List indexDistancePairListOfLists = nq.queryAll(k);
        for (int point = 0; point < data.length; ++point) {
            coreDistances.set(point, ((Double)((Pair)((List)indexDistancePairListOfLists.get(point)).get(k - 1)).getB()).doubleValue());
        }
        return coreDistances;
    }

    private static ExtendedMinimumSpanningTree constructEMST(SGDVector[] data, DenseVector coreDistances, org.tribuo.math.distance.Distance dist) {
        int i;
        BitSet attachedPoints = new BitSet(data.length);
        int[] nearestMRDNeighbors = new int[2 * data.length - 1];
        double[] nearestMRDDistances = new double[2 * data.length - 1];
        for (int i2 = 0; i2 < data.length - 1; ++i2) {
            nearestMRDDistances[i2] = Double.MAX_VALUE;
        }
        int currentPoint = data.length - 1;
        attachedPoints.set(data.length - 1);
        for (int numAttachedPoints = 1; numAttachedPoints < data.length; ++numAttachedPoints) {
            int nearestMRDPoint = -1;
            double nearestMRDDistance = Double.MAX_VALUE;
            for (int neighbor = 0; neighbor < data.length; ++neighbor) {
                if (currentPoint == neighbor || attachedPoints.get(neighbor)) continue;
                double mutualReachabilityDistance = dist.computeDistance(data[currentPoint], data[neighbor]);
                if (coreDistances.get(currentPoint) > mutualReachabilityDistance) {
                    mutualReachabilityDistance = coreDistances.get(currentPoint);
                }
                if (coreDistances.get(neighbor) > mutualReachabilityDistance) {
                    mutualReachabilityDistance = coreDistances.get(neighbor);
                }
                if (mutualReachabilityDistance < nearestMRDDistances[neighbor]) {
                    nearestMRDDistances[neighbor] = mutualReachabilityDistance;
                    nearestMRDNeighbors[neighbor] = currentPoint;
                }
                if (!(nearestMRDDistances[neighbor] <= nearestMRDDistance)) continue;
                nearestMRDDistance = nearestMRDDistances[neighbor];
                nearestMRDPoint = neighbor;
            }
            attachedPoints.set(nearestMRDPoint);
            currentPoint = nearestMRDPoint;
        }
        int[] otherVertexIndices = new int[2 * data.length - 1];
        for (i = 0; i < data.length - 1; ++i) {
            otherVertexIndices[i] = i;
        }
        for (i = data.length - 1; i < data.length * 2 - 1; ++i) {
            int vertex;
            nearestMRDNeighbors[i] = vertex = i - (data.length - 1);
            otherVertexIndices[i] = vertex;
            nearestMRDDistances[i] = coreDistances.get(vertex);
        }
        ExtendedMinimumSpanningTree emst = new ExtendedMinimumSpanningTree(data.length, nearestMRDNeighbors, otherVertexIndices, nearestMRDDistances);
        return emst;
    }

    private static List<HdbscanCluster> computeHierarchyAndClusterTree(ExtendedMinimumSpanningTree emst, int minClusterSize, double[] pointNoiseLevels, int[] pointLastClusters, Map<Integer, int[]> hierarchy) {
        int lineCount = 0;
        int currentEdgeIndex = emst.getNumEdges() - 1;
        int nextClusterLabel = 2;
        boolean nextLevelSignificant = true;
        int[] currentClusterLabels = new int[emst.getNumVertices()];
        Arrays.fill(currentClusterLabels, 1);
        ArrayList<HdbscanCluster> clusters = new ArrayList<HdbscanCluster>();
        clusters.add(HdbscanCluster.NOT_A_CLUSTER);
        clusters.add(new HdbscanCluster(1, HdbscanCluster.NOT_A_CLUSTER, Double.NaN, emst.getNumVertices()));
        TreeSet<Integer> affectedClusterLabels = new TreeSet<Integer>();
        HashSet<Integer> affectedVertices = new HashSet<Integer>();
        while (currentEdgeIndex >= 0) {
            double currentEdgeWeight = emst.getEdgeWeightAtIndex(currentEdgeIndex);
            ArrayList<HdbscanCluster> newClusters = new ArrayList<HdbscanCluster>();
            while (currentEdgeIndex >= 0 && emst.getEdgeWeightAtIndex(currentEdgeIndex) == currentEdgeWeight) {
                int firstVertex = emst.getFirstVertexAtIndex(currentEdgeIndex);
                int secondVertex = emst.getSecondVertexAtIndex(currentEdgeIndex);
                emst.getEdgeListForVertex(firstVertex).remove((Object)secondVertex);
                emst.getEdgeListForVertex(secondVertex).remove((Object)firstVertex);
                if (currentClusterLabels[firstVertex] != 0) {
                    affectedVertices.add(firstVertex);
                    affectedVertices.add(secondVertex);
                    affectedClusterLabels.add(currentClusterLabels[firstVertex]);
                }
                --currentEdgeIndex;
            }
            if (affectedClusterLabels.isEmpty()) continue;
            while (!affectedClusterLabels.isEmpty()) {
                Object unexploredSubClusterPoints;
                int examinedClusterLabel = (Integer)affectedClusterLabels.last();
                affectedClusterLabels.remove(examinedClusterLabel);
                TreeSet<Integer> examinedVertices = new TreeSet<Integer>();
                Iterator vertexIterator = affectedVertices.iterator();
                while (vertexIterator.hasNext()) {
                    int vertex = (Integer)vertexIterator.next();
                    if (currentClusterLabels[vertex] != examinedClusterLabel) continue;
                    examinedVertices.add(vertex);
                    vertexIterator.remove();
                }
                TreeSet<Integer> firstChildCluster = null;
                Object unexploredFirstChildClusterPoints = null;
                int numChildClusters = 0;
                while (!examinedVertices.isEmpty()) {
                    TreeSet<Integer> constructingSubCluster = new TreeSet<Integer>();
                    unexploredSubClusterPoints = new ArrayDeque<Integer>();
                    boolean anyEdges = false;
                    boolean incrementedChildCount = false;
                    int rootVertex = (Integer)examinedVertices.last();
                    constructingSubCluster.add(rootVertex);
                    ((ArrayDeque)unexploredSubClusterPoints).add(rootVertex);
                    examinedVertices.remove(rootVertex);
                    while (!((ArrayDeque)unexploredSubClusterPoints).isEmpty()) {
                        int vertexToExplore = (Integer)((ArrayDeque)unexploredSubClusterPoints).poll();
                        for (int neighbor : emst.getEdgeListForVertex(vertexToExplore)) {
                            anyEdges = true;
                            if (!constructingSubCluster.add(neighbor)) continue;
                            ((ArrayDeque)unexploredSubClusterPoints).add(neighbor);
                            examinedVertices.remove(neighbor);
                        }
                        if (incrementedChildCount || constructingSubCluster.size() < minClusterSize || !anyEdges) continue;
                        incrementedChildCount = true;
                        ++numChildClusters;
                        if (firstChildCluster != null) continue;
                        firstChildCluster = constructingSubCluster;
                        unexploredFirstChildClusterPoints = unexploredSubClusterPoints;
                        break;
                    }
                    if (numChildClusters >= 2 && constructingSubCluster.size() >= minClusterSize && anyEdges) {
                        int firstChildClusterMember = (Integer)firstChildCluster.last();
                        if (constructingSubCluster.contains(firstChildClusterMember)) {
                            --numChildClusters;
                            continue;
                        }
                        HdbscanCluster parentCluster = clusters.get(examinedClusterLabel);
                        HdbscanCluster newCluster = parentCluster.createNewCluster(constructingSubCluster, currentClusterLabels, nextClusterLabel, currentEdgeWeight);
                        newClusters.add(newCluster);
                        clusters.add(newCluster);
                        ++nextClusterLabel;
                        continue;
                    }
                    if (constructingSubCluster.size() >= minClusterSize && anyEdges) continue;
                    HdbscanCluster parentCluster = clusters.get(examinedClusterLabel);
                    parentCluster.createNewCluster(constructingSubCluster, currentClusterLabels, 0, currentEdgeWeight);
                    for (int point : constructingSubCluster) {
                        pointNoiseLevels[point] = currentEdgeWeight;
                        pointLastClusters[point] = examinedClusterLabel;
                    }
                }
                if (numChildClusters < 2 || currentClusterLabels[(Integer)firstChildCluster.first()] != examinedClusterLabel) continue;
                while (!((ArrayDeque)unexploredFirstChildClusterPoints).isEmpty()) {
                    int vertexToExplore = (Integer)((ArrayDeque)unexploredFirstChildClusterPoints).poll();
                    unexploredSubClusterPoints = emst.getEdgeListForVertex(vertexToExplore).iterator();
                    while (unexploredSubClusterPoints.hasNext()) {
                        int neighbor = (Integer)unexploredSubClusterPoints.next();
                        if (!firstChildCluster.add(neighbor)) continue;
                        ((ArrayDeque)unexploredFirstChildClusterPoints).add(neighbor);
                    }
                }
                HdbscanCluster parentCluster = clusters.get(examinedClusterLabel);
                HdbscanCluster newCluster = parentCluster.createNewCluster(firstChildCluster, currentClusterLabels, nextClusterLabel, currentEdgeWeight);
                newClusters.add(newCluster);
                clusters.add(newCluster);
                ++nextClusterLabel;
            }
            if (nextLevelSignificant || !newClusters.isEmpty()) {
                ++lineCount;
            }
            for (HdbscanCluster newCluster : newClusters) {
                int[] hierarchyLevelLabels = new int[emst.getNumVertices()];
                System.arraycopy(currentClusterLabels, 0, hierarchyLevelLabels, 0, currentClusterLabels.length);
                newCluster.setHierarchyLevel(lineCount);
                hierarchy.put(lineCount, hierarchyLevelLabels);
            }
            if (newClusters.isEmpty()) {
                nextLevelSignificant = false;
                continue;
            }
            nextLevelSignificant = true;
        }
        return clusters;
    }

    private static void propagateTree(List<HdbscanCluster> clusters) {
        PriorityQueue<HdbscanCluster> clustersToExamine = new PriorityQueue<HdbscanCluster>();
        BitSet addedToExaminationList = new BitSet(clusters.size());
        for (HdbscanCluster cluster : clusters) {
            if (cluster == HdbscanCluster.NOT_A_CLUSTER || cluster.hasChildren()) continue;
            clustersToExamine.add(cluster);
            addedToExaminationList.set(cluster.getLabel());
        }
        while (!clustersToExamine.isEmpty()) {
            HdbscanCluster parent;
            HdbscanCluster currentCluster = (HdbscanCluster)clustersToExamine.poll();
            currentCluster.propagate();
            if (currentCluster.getParent() == HdbscanCluster.NOT_A_CLUSTER || addedToExaminationList.get((parent = currentCluster.getParent()).getLabel())) continue;
            clustersToExamine.add(parent);
            addedToExaminationList.set(parent.getLabel());
        }
    }

    private static List<Integer> findProminentClusters(Map<Integer, int[]> hierarchy, List<HdbscanCluster> clusters, int numPoints) {
        List<HdbscanCluster> solution = clusters.get(1).getPropagatedDescendants();
        ArrayList<Integer> clusterLabels = new ArrayList<Integer>(Collections.nCopies(numPoints, 0));
        TreeMap<Integer, List> significantLevels = new TreeMap<Integer, List>();
        for (HdbscanCluster cluster : solution) {
            List clusterList = significantLevels.computeIfAbsent(cluster.getHierarchyLevel(), p -> new ArrayList());
            clusterList.add(cluster.getLabel());
        }
        while (!significantLevels.isEmpty()) {
            Map.Entry entry = significantLevels.pollFirstEntry();
            List clusterList = (List)entry.getValue();
            Integer hierarchyLevel = (Integer)entry.getKey();
            int[] hierarchyLevelLabels = hierarchy.get(hierarchyLevel);
            for (int i = 0; i < hierarchyLevelLabels.length; ++i) {
                int label = hierarchyLevelLabels[i];
                if (!clusterList.contains(label)) continue;
                clusterLabels.set(i, label);
            }
        }
        return Collections.unmodifiableList(clusterLabels);
    }

    private static DenseVector calculateOutlierScores(double[] pointNoiseLevels, int[] pointLastClusters, List<HdbscanCluster> clusters) {
        int numPoints = pointNoiseLevels.length;
        DenseVector outlierScores = new DenseVector(numPoints);
        for (int i = 0; i < numPoints; ++i) {
            double epsilonMax = clusters.get(pointLastClusters[i]).getPropagatedLowestChildSplitLevel();
            double epsilon = pointNoiseLevels[i];
            double score = 0.0;
            if (epsilon != 0.0) {
                score = 1.0 - epsilonMax / epsilon;
            }
            outlierScores.set(i, score);
        }
        return outlierScores;
    }

    private static Map<Integer, List<Pair<Double, Integer>>> generateClusterAssignments(List<Integer> clusterLabels, DenseVector outlierScoresVector) {
        HashMap<Integer, List<Pair<Double, Integer>>> clusterAssignments = new HashMap<Integer, List<Pair<Double, Integer>>>();
        for (int i = 0; i < clusterLabels.size(); ++i) {
            Integer clusterLabel = clusterLabels.get(i);
            Double outlierScore = outlierScoresVector.get(i);
            List outlierScoreIndexList = clusterAssignments.computeIfAbsent(clusterLabel, j -> new ArrayList());
            outlierScoreIndexList.add(new Pair((Object)outlierScore, (Object)i));
        }
        return clusterAssignments;
    }

    private static List<ClusterExemplar> computeExemplars(SGDVector[] data, Map<Integer, List<Pair<Double, Integer>>> clusterAssignments, org.tribuo.math.distance.Distance dist) {
        ArrayList<ClusterExemplar> clusterExemplars = new ArrayList<ClusterExemplar>();
        int numExemplars = (int)Math.sqrt((double)data.length / 2.0) + clusterAssignments.size();
        for (Map.Entry<Integer, List<Pair<Double, Integer>>> e : clusterAssignments.entrySet()) {
            int clusterLabel = e.getKey();
            if (clusterLabel == 0) continue;
            List<Pair<Double, Integer>> outlierScoreIndexList = clusterAssignments.get(clusterLabel);
            TreeMap outlierScoreIndexTree = new TreeMap();
            outlierScoreIndexList.forEach(p -> outlierScoreIndexTree.put((Double)p.getA(), (Integer)p.getB()));
            int numExemplarsThisCluster = e.getValue().size() * numExemplars / data.length;
            if (numExemplarsThisCluster == 0) {
                numExemplarsThisCluster = 1;
            } else if (numExemplarsThisCluster > outlierScoreIndexTree.size()) {
                numExemplarsThisCluster = outlierScoreIndexTree.size();
            }
            ArrayList partialClusterExemplars = new ArrayList();
            Stream<Integer> intStream = IntStream.range(0, numExemplarsThisCluster).boxed();
            intStream.forEach(i -> partialClusterExemplars.add(outlierScoreIndexTree.pollFirstEntry()));
            for (Map.Entry partialClusterExemplar : partialClusterExemplars) {
                SGDVector features = data[(Integer)partialClusterExemplar.getValue()];
                double maxInnerDist = Double.NEGATIVE_INFINITY;
                for (Map.Entry entry : outlierScoreIndexTree.entrySet()) {
                    double distance = dist.computeDistance(features, data[(Integer)entry.getValue()]);
                    if (!(distance > maxInnerDist)) continue;
                    maxInnerDist = distance;
                }
                clusterExemplars.add(new ClusterExemplar(clusterLabel, (Double)partialClusterExemplar.getKey(), features, maxInnerDist));
            }
        }
        return clusterExemplars;
    }

    private static double getNoisePointsOutlierScore(Map<Integer, List<Pair<Double, Integer>>> clusterAssignments) {
        List<Pair<Double, Integer>> outlierScoreIndexList = clusterAssignments.get(0);
        if (outlierScoreIndexList == null || outlierScoreIndexList.isEmpty()) {
            return 0.9999;
        }
        double upperOutlierScoreBound = Double.NEGATIVE_INFINITY;
        for (Pair<Double, Integer> outlierScoreIndex : outlierScoreIndexList) {
            if (!((Double)outlierScoreIndex.getA() > upperOutlierScoreBound)) continue;
            upperOutlierScoreBound = (Double)outlierScoreIndex.getA();
        }
        return upperOutlierScoreBound;
    }

    public String toString() {
        return "HdbscanTrainer(minClusterSize=" + this.minClusterSize + ",distanceType=" + this.dist + ",k=" + this.k + ",numThreads=" + this.numThreads + ")";
    }

    public TrainerProvenance getProvenance() {
        return new TrainerProvenanceImpl((Trainer)this);
    }

    @Deprecated
    public static enum Distance {
        EUCLIDEAN(DistanceType.L2),
        COSINE(DistanceType.COSINE),
        L1(DistanceType.L1);

        private final DistanceType distanceType;

        private Distance(DistanceType distanceType) {
            this.distanceType = distanceType;
        }

        public DistanceType getDistanceType() {
            return this.distanceType;
        }
    }

    public static final class ClusterExemplar
    implements Serializable {
        private static final long serialVersionUID = 1L;
        private final Integer label;
        private final Double outlierScore;
        private final SGDVector features;
        private final Double maxDistToEdge;

        ClusterExemplar(Integer label, Double outlierScore, SGDVector features, Double maxDistToEdge) {
            this.label = label;
            this.outlierScore = outlierScore;
            this.features = features;
            this.maxDistToEdge = maxDistToEdge;
        }

        public Integer getLabel() {
            return this.label;
        }

        public Double getOutlierScore() {
            return this.outlierScore;
        }

        public SGDVector getFeatures() {
            return this.features;
        }

        public Double getMaxDistToEdge() {
            if (this.maxDistToEdge != null) {
                return this.maxDistToEdge;
            }
            return Double.NEGATIVE_INFINITY;
        }

        public ClusterExemplar copy() {
            return new ClusterExemplar(this.label, this.outlierScore, this.features.copy(), this.maxDistToEdge);
        }

        public String toString() {
            double dist = this.maxDistToEdge == null ? Double.NEGATIVE_INFINITY : this.maxDistToEdge;
            return "ClusterExemplar(label=" + this.label + ",outlierScore=" + this.outlierScore + ",vector=" + this.features + ",maxDistToEdge=" + dist + ")";
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            ClusterExemplar that = (ClusterExemplar)o;
            return this.label.equals(that.label) && this.outlierScore.equals(that.outlierScore) && this.features.equals(that.features) && Objects.equals(this.maxDistToEdge, that.maxDistToEdge);
        }

        public int hashCode() {
            return Objects.hash(this.label, this.outlierScore, this.features, this.maxDistToEdge);
        }

        ClusterExemplarProto serialize() {
            ClusterExemplarProto.Builder builder = ClusterExemplarProto.newBuilder();
            builder.setLabel(this.label);
            builder.setOutlierScore(this.outlierScore);
            builder.setFeatures((TensorProto)this.features.serialize());
            builder.setMaxDistToEdge(this.maxDistToEdge);
            return builder.build();
        }

        static ClusterExemplar deserialize(ClusterExemplarProto proto) {
            Tensor tensor = Tensor.deserialize((TensorProto)proto.getFeatures());
            if (!(tensor instanceof SGDVector)) {
                throw new IllegalStateException("Invalid protobuf, features must be an SGDVector, found " + tensor.getClass());
            }
            SGDVector vector = (SGDVector)tensor;
            return new ClusterExemplar(proto.getLabel(), proto.getOutlierScore(), vector, proto.getMaxDistToEdge());
        }
    }
}

