/*
 * This file is part of ELKI:
 * Environment for Developing KDD-Applications Supported by Index-Structures
 *
 * Copyright (C) 2022
 * ELKI Development Team
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with this program. If not, see <http://www.gnu.org/licenses/>.
 */
package elki.clustering.kmedoids;

import java.util.Random;

import elki.Algorithm;
import elki.clustering.ClusteringAlgorithmUtil;
import elki.clustering.kmeans.KMeans;
import elki.data.Cluster;
import elki.data.Clustering;
import elki.data.model.MedoidModel;
import elki.data.type.TypeInformation;
import elki.data.type.TypeUtil;
import elki.database.datastore.DataStoreFactory;
import elki.database.datastore.DataStoreUtil;
import elki.database.datastore.WritableDoubleDataStore;
import elki.database.datastore.WritableIntegerDataStore;
import elki.database.ids.*;
import elki.database.query.QueryBuilder;
import elki.database.query.distance.DistanceQuery;
import elki.database.relation.Relation;
import elki.distance.Distance;
import elki.distance.minkowski.EuclideanDistance;
import elki.logging.Logging;
import elki.logging.progress.FiniteProgress;
import elki.logging.statistics.DoubleStatistic;
import elki.result.Metadata;
import elki.utilities.documentation.Reference;
import elki.utilities.exceptions.AbortException;
import elki.utilities.optionhandling.OptionID;
import elki.utilities.optionhandling.Parameterizer;
import elki.utilities.optionhandling.constraints.CommonConstraints;
import elki.utilities.optionhandling.parameterization.Parameterization;
import elki.utilities.optionhandling.parameters.DoubleParameter;
import elki.utilities.optionhandling.parameters.IntParameter;
import elki.utilities.optionhandling.parameters.ObjectParameter;
import elki.utilities.optionhandling.parameters.RandomParameter;
import elki.utilities.random.RandomFactory;

/**
 * CLARANS: a method for clustering objects for spatial data mining
 * is inspired by PAM (partitioning around medoids, {@link PAM})
 * and CLARA and also based on sampling.
 * <p>
 * This implementation tries to balance memory and computation time.
 * By caching the distances to the two nearest medoids, we usually only need
 * O(n) instead of O(nk) distance computations for one iteration, at
 * the cost of needing O(2n) memory to store them.
 * <p>
 * The implementation is fairly ugly, because we have three solutions (the best
 * found so far, the current solution, and a neighbor candidate); and for each
 * point in each solution we need the best and second best assignments. But with
 * Java 11, we may be able to switch to value types that would clean this code
 * significantly, without the overhead of O(n) objects.
 * <p>
 * Reference:
 * <p>
 * R. T. Ng, J. Han<br>
 * CLARANS: a method for clustering objects for spatial data mining<br>
 * IEEE Transactions on Knowledge and Data Engineering 14(5)
 *
 * @author Erich Schubert
 * @since 0.7.5
 *
 * @navassoc - - - elki.data.model.MedoidModel
 * @has - - - Assignment
 *
 * @param <O> Input data type
 */
@Reference(authors = "R. T. Ng, J. Han", //
    title = "CLARANS: a method for clustering objects for spatial data mining", //
    booktitle = "IEEE Transactions on Knowledge and Data Engineering 14(5)", //
    url = "https://doi.org/10.1109/TKDE.2002.1033770", //
    bibkey = "DBLP:journals/tkde/NgH02")
public class CLARANS<O> implements KMedoidsClustering<O> {
  /**
   * Class logger.
   */
  private static final Logging LOG = Logging.getLogger(CLARANS.class);

  /**
   * Distance function used.
   */
  protected Distance<? super O> distance;

  /**
   * Number of clusters to find.
   */
  protected int k;

  /**
   * Number of samples to draw (i.e. restarts).
   */
  protected int numlocal;

  /**
   * Sampling rate. If less than 1, it is considered to be a relative value.
   */
  protected double maxneighbor;

  /**
   * Random factory for initialization.
   */
  protected RandomFactory random;

  /**
   * Constructor.
   *
   * @param distance Distance function to use
   * @param k Number of clusters to produce
   * @param numlocal Number of samples (restarts)
   * @param maxneighbor Neighbor sampling rate (absolute or relative)
   * @param random Random generator
   */
  public CLARANS(Distance<? super O> distance, int k, int numlocal, double maxneighbor, RandomFactory random) {
    super();
    this.distance = distance;
    this.k = k;
    this.numlocal = numlocal;
    this.maxneighbor = maxneighbor;
    this.random = random;
  }

  /**
   * Run CLARANS clustering.
   *
   * @param relation Data relation
   * @return Clustering
   */
  @Override
  public Clustering<MedoidModel> run(Relation<O> relation) {
    return run(relation, k, new QueryBuilder<>(relation, distance).distanceQuery());
  }

  @Override
  public Clustering<MedoidModel> run(Relation<O> relation, int k, DistanceQuery<? super O> distQ) {
    if(relation.size() <= 0) {
      Clustering<MedoidModel> empty = new Clustering<>();
      Metadata.of(empty).setLongName("CLARANS Clustering");
      return empty;
    }
    if(k * 2 >= relation.size()) {
      // Random sampling of non-medoids will be slow for huge k
      LOG.warning("A very large k was chosen. This implementation is not optimized for this case.");
    }
    DBIDs ids = relation.getDBIDs();
    final boolean metric = distance.isMetric();

    // Number of retries, relative rate, or absolute count:
    final int retries = (int) Math.ceil(maxneighbor < 1 ? maxneighbor * k * (ids.size() - k) : maxneighbor);
    Random rnd = random.getSingleThreadedRandom();
    DBIDArrayIter cand = DBIDUtil.ensureArray(ids).iter(); // Might copy!

    // Setup cluster assignment store
    Assignment best = new Assignment(distQ, ids, k);
    Assignment curr = new Assignment(distQ, ids, k);
    Assignment scratch = new Assignment(distQ, ids, k);

    // 1. initialize
    double bestscore = Double.POSITIVE_INFINITY;
    FiniteProgress prog = LOG.isVerbose() ? new FiniteProgress("CLARANS sampling restarts", numlocal, LOG) : null;
    for(int i = 0; i < numlocal; i++) {
      // 2. choose random initial medoids
      curr.medoids.clear().addDBIDs(DBIDUtil.randomSample(ids, k, rnd));
      // Cost of initial solution:
      double total = curr.assignToNearestCluster();

      // 3. Set j to 1.
      int j = 1;
      step: while(j < retries) {
        // 4 part a. choose a random non-medoid (~ neighbor in G):
        for(int r = 0;; r++) {
          cand.seek(rnd.nextInt(ids.size())); // Random point
          if(curr.nearest.doubleValue(cand) > 0) {
            break; // Good: not a medoid.
          }
          // We may have many duplicate points
          if(metric && curr.second.doubleValue(cand) == 0) {
            ++j; // Cannot yield an improvement if we are metric.
            continue step;
          }
          else if(!metric && !curr.medoids.contains(cand)) {
            // Probably not a good candidate, but try nevertheless
            break;
          }
          if(r >= 1000) {
            throw new AbortException("Failed to choose a non-medoid in 1000 attempts. Choose k << N.");
          }
          // else: this must be the medoid.
        }
        // 4 part b. choose a random medoid to replace:
        final int otherm = rnd.nextInt(k);
        // 5. check lower cost
        double cost = curr.computeCostDifferential(cand, otherm, scratch);
        if(!(cost < -1e-12 * total)) {
          ++j; // 6. try again
          continue;
        }
        total += cost; // cost is negative!
        // Swap:
        Assignment tmp = curr;
        curr = scratch;
        scratch = tmp;
        j = 1;
      }
      if(LOG.isStatistics()) {
        LOG.statistics(new DoubleStatistic(getClass().getName() + ".sample-" + i + ".cost", total));
      }
      // New best:
      if(total < bestscore) {
        // Swap:
        Assignment tmp = curr;
        curr = best;
        best = tmp;
        bestscore = total;
      }
      LOG.incrementProcessed(prog);
    }
    LOG.ensureCompleted(prog);
    if(LOG.isStatistics()) {
      LOG.statistics(new DoubleStatistic(getClass().getName() + ".final-cost", bestscore));
    }

    ArrayModifiableDBIDs[] clusters = ClusteringAlgorithmUtil.partitionsFromIntegerLabels(ids, best.assignment, k);

    // Wrap result
    Clustering<MedoidModel> result = new Clustering<>();
    for(DBIDArrayIter it = best.medoids.iter(); it.valid(); it.advance()) {
      result.addToplevelCluster(new Cluster<>(clusters[it.getOffset()], new MedoidModel(DBIDUtil.deref(it))));
    }
    Metadata.of(result).setLongName("CLARANS Clustering");
    return result;
  }

  /**
   * Assignment state.
   * 
   * @author Erich Schubert
   */
  protected static class Assignment {
    /**
     * Ids to process.
     */
    DBIDs ids;

    /**
     * Distance function to use.
     */
    DistanceQuery<?> distQ;

    /**
     * Distance to the nearest medoid of each point.
     */
    WritableDoubleDataStore nearest;

    /**
     * Distance to the second nearest medoid.
     */
    WritableDoubleDataStore second;

    /**
     * Cluster mapping.
     */
    WritableIntegerDataStore assignment;

    /**
     * Medoid id of the second closest. Needs some more memory, but saves
     * recomputations in the common case where not much changed.
     */
    WritableIntegerDataStore secondid;

    /**
     * Medoids
     */
    ArrayModifiableDBIDs medoids;

    /**
     * Medoid iterator
     */
    DBIDArrayMIter miter;

    /**
     * Constructor.
     *
     * @param distQ Distance query
     * @param ids IDs to process
     * @param k Number of medoids
     */
    public Assignment(DistanceQuery<?> distQ, DBIDs ids, int k) {
      this.distQ = distQ;
      this.ids = ids;
      this.medoids = DBIDUtil.newArray(k);
      this.miter = medoids.iter();
      this.assignment = DataStoreUtil.makeIntegerStorage(ids, DataStoreFactory.HINT_HOT | DataStoreFactory.HINT_TEMP, -1);
      this.nearest = DataStoreUtil.makeDoubleStorage(ids, DataStoreFactory.HINT_HOT | DataStoreFactory.HINT_TEMP);
      this.secondid = DataStoreUtil.makeIntegerStorage(ids, DataStoreFactory.HINT_HOT | DataStoreFactory.HINT_TEMP, -1);
      this.second = DataStoreUtil.makeDoubleStorage(ids, DataStoreFactory.HINT_HOT | DataStoreFactory.HINT_TEMP);
    }

    /**
     * Compute the reassignment cost, for one swap.
     *
     * @param h Current object to swap with any medoid.
     * @param mnum Medoid number to swap with h.
     * @param scratch Scratch assignment to fill.
     * @return Cost change
     */
    protected double computeCostDifferential(DBIDRef h, int mnum, Assignment scratch) {
      // Update medoids of scratch copy.
      scratch.medoids.clear().addDBIDs(medoids);
      scratch.medoids.set(mnum, h);
      double cost = 0;
      // Compute costs of reassigning other objects j:
      for(DBIDIter j = ids.iter(); j.valid(); j.advance()) {
        if(DBIDUtil.equal(h, j)) {
          scratch.recompute(j, mnum, 0., -1, Double.POSITIVE_INFINITY);
          continue;
        }
        // distance(j, i) to nearest medoid
        final double distcur = nearest.doubleValue(j);
        // distance(j, h) to new medoid
        final double dist_h = distQ.distance(h, j);
        // current assignment of j
        final int jcur = assignment.intValue(j);
        // Check if current medoid of j is removed:
        if(jcur == mnum) {
          // distance(j, o) to second nearest / possible reassignment
          final double distsec = second.doubleValue(j);
          // Case 1b: j switches to new medoid, or to the second nearest:
          if(dist_h < distsec) {
            cost += dist_h - distcur;
            scratch.assignment.putInt(j, mnum);
            scratch.nearest.putDouble(j, dist_h);
            scratch.second.putDouble(j, distsec);
            scratch.secondid.putInt(j, jcur);
          }
          else {
            // Second nearest is the new assignment.
            cost += distsec - distcur;
            // We have to recompute, because we do not know the true new second
            // nearest.
            scratch.recompute(j, mnum, dist_h, jcur, distsec);
          }
        }
        else if(dist_h < distcur) {
          // Case 1c: j is closer to h than its current medoid
          // and the current medoid is not removed (jcur != mnum).
          cost += dist_h - distcur;
          // Second nearest is the previous assignment
          scratch.assignment.putInt(j, mnum);
          scratch.nearest.putDouble(j, dist_h);
          scratch.second.putDouble(j, distcur);
          scratch.secondid.putInt(j, jcur);
        }
        else { // else Case 1a): j is closer to i than h and m, so no change.
          final int jsec = secondid.intValue(j);
          final double distsec = second.doubleValue(j);
          // Second nearest is still valid.
          if(jsec != mnum && distsec <= dist_h) {
            scratch.assignment.putInt(j, jcur);
            scratch.nearest.putDouble(j, distcur);
            scratch.secondid.putInt(j, jsec);
            scratch.second.putDouble(j, distsec);
          }
          else {
            scratch.recompute(j, jcur, distcur, mnum, dist_h);
          }
        }
      }
      return cost;
    }

    /**
     * Recompute the assignment of one point.
     *
     * @param id Point id
     * @param mnum Medoid number for known distance
     * @param known Known distance
     * @return cost
     */
    protected double recompute(DBIDRef id, int mnum, double known, int snum, double sknown) {
      double mindist = mnum >= 0 ? known : Double.POSITIVE_INFINITY,
          mindist2 = Double.POSITIVE_INFINITY;
      int minIndex = mnum, minIndex2 = -1;
      for(int i = 0; miter.seek(i).valid(); i++) {
        if(i == mnum) {
          continue;
        }
        final double dist = i == snum ? sknown : distQ.distance(id, miter);
        if(DBIDUtil.equal(id, miter) || dist < mindist) {
          minIndex2 = minIndex;
          mindist2 = mindist;
          minIndex = i;
          mindist = dist;
        }
        else if(dist < mindist2) {
          minIndex2 = i;
          mindist2 = dist;
        }
      }
      if(minIndex < 0) {
        throw new AbortException("Too many infinite distances. Cannot assign objects.");
      }
      assignment.putInt(id, minIndex);
      nearest.putDouble(id, mindist);
      secondid.putInt(id, minIndex2);
      second.putDouble(id, mindist2);
      return mindist;
    }

    /**
     * Assign each point to the nearest medoid.
     *
     * @return Assignment cost
     */
    protected double assignToNearestCluster() {
      double cost = 0.;
      for(DBIDIter iditer = ids.iter(); iditer.valid(); iditer.advance()) {
        cost += recompute(iditer, -1, Double.POSITIVE_INFINITY, -1, Double.POSITIVE_INFINITY);
      }
      return cost;
    }
  }

  @Override
  public TypeInformation[] getInputTypeRestriction() {
    return TypeUtil.array(distance.getInputTypeRestriction());
  }

  /**
   * Parameterization class.
   *
   * @author Erich Schubert
   */
  public static class Par<V> implements Parameterizer {
    /**
     * The number of restarts to run.
     */
    public static final OptionID RESTARTS_ID = new OptionID("clara.numlocal", "Number of samples (restarts) to run.");

    /**
     * The number of neighbors to explore.
     */
    public static final OptionID NEIGHBORS_ID = new OptionID("clara.numneighbor", "Number of tries to find a neighbor.");

    /**
     * Random generator.
     */
    public static final OptionID RANDOM_ID = new OptionID("clarans.random", "Random generator seed.");

    /**
     * Maximum neighbors to explore. If less than 1, it is considered to be a
     * relative value.
     */
    double maxneighbor;

    /**
     * Number of restarts to do.
     */
    int numlocal;

    /**
     * Number of cluster centers to find.
     */
    int k;

    /**
     * Random factory for initialization.
     */
    RandomFactory random;

    /**
     * The distance function to use.
     */
    protected Distance<? super V> distance;

    /**
     * Default sampling rate.
     *
     * @return Default sampling rate.
     */
    protected double defaultRate() {
      return 0.0125;
    }

    @Override
    public void configure(Parameterization config) {
      new ObjectParameter<Distance<? super V>>(Algorithm.Utils.DISTANCE_FUNCTION_ID, Distance.class, EuclideanDistance.class) //
          .grab(config, x -> distance = x);
      new IntParameter(KMeans.K_ID) //
          .addConstraint(CommonConstraints.GREATER_EQUAL_ONE_INT) //
          .grab(config, x -> k = x);
      new IntParameter(RESTARTS_ID, 2) //
          .addConstraint(CommonConstraints.GREATER_EQUAL_ONE_INT) //
          .grab(config, x -> numlocal = x);
      new DoubleParameter(NEIGHBORS_ID, defaultRate()) //
          .addConstraint(CommonConstraints.GREATER_THAN_ZERO_DOUBLE) //
          .grab(config, x -> maxneighbor = x);
      new RandomParameter(RANDOM_ID).grab(config, x -> random = x);
    }

    @Override
    public CLARANS<V> make() {
      return new CLARANS<>(distance, k, numlocal, maxneighbor, random);
    }
  }
}
