/*
 * This file is part of ELKI:
 * Environment for Developing KDD-Applications Supported by Index-Structures
 *
 * Copyright (C) 2022
 * ELKI Development Team
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with this program. If not, see <http://www.gnu.org/licenses/>.
 */
package elki.evaluation.clustering;

import java.util.Iterator;
import java.util.List;

import elki.data.Cluster;
import elki.data.Clustering;
import elki.database.ids.DBIDUtil;
import elki.database.ids.DBIDs;
import elki.math.MeanVariance;
import elki.utilities.datastructures.BitsUtil;

/**
 * Class storing the contingency table and related data on two clusterings.
 * 
 * @author Erich Schubert
 * @since 0.5.0
 * 
 * @opt nodefillcolor LemonChiffon
 *
 * @assoc - evaluates - Clustering
 * @composed - - - PairCounting
 * @composed - - - Entropy
 * @composed - - - EditDistance
 * @composed - - - BCubed
 * @composed - - - SetMatchingPurity
 */
public class ClusterContingencyTable {
  /**
   * Noise cluster handling
   */
  protected boolean breakNoiseClusters = false;

  /**
   * Self pairing
   */
  protected boolean selfPairing = true;

  /**
   * Number of clusters.
   */
  protected int size1 = -1, size2 = -1;

  /**
   * Contingency matrix
   */
  protected int[][] contingency = null;

  /**
   * Noise flags
   */
  protected long[] noise1 = null, noise2 = null;

  /**
   * Pair counting measures
   */
  protected PairCounting paircount = null;

  /**
   * Entropy-based measures
   */
  protected Entropy entropy = null;

  /**
   * Set matching purity measures
   */
  protected SetMatchingPurity smp = null;

  /**
   * Maximum Matching Accuracy
   */
  protected MaximumMatchingAccuracy mmacc = null;

  /**
   * Pair Sets Index Measures
   */
  protected PairSetsIndex psi = null;

  /**
   * Edit-Distance measures
   */
  protected EditDistance edit = null;

  /**
   * BCubed measures
   */
  protected BCubed bcubed = null;

  /**
   * Constructor.
   * 
   * @param selfPairing Build self-pairs
   * @param breakNoiseClusters Break noise clusters into individual objects
   * @param result1 First clustering
   * @param result2 Second clustering
   */
  public ClusterContingencyTable(boolean selfPairing, boolean breakNoiseClusters, Clustering<?> result1, Clustering<?> result2) {
    super();
    this.selfPairing = selfPairing;
    this.breakNoiseClusters = breakNoiseClusters;
    // Get the clusters
    final List<? extends Cluster<?>> cs1 = result1.getAllClusters();
    final List<? extends Cluster<?>> cs2 = result2.getAllClusters();

    // Initialize
    size1 = cs1.size();
    size2 = cs2.size();
    // +2: 1 for cluster sizes, 1 for intersection sums
    // these rows are expected to be equal for strict partitionings but may
    // deviate if we have partial, hierarchical, or overlapping clusterings, in
    // which case many measures will no longer work correctly!
    contingency = new int[size1 + 2][size2 + 2];
    noise1 = BitsUtil.zero(size1);
    noise2 = BitsUtil.zero(size2);

    // Fill main part of matrix
    {
      final Iterator<? extends Cluster<?>> it2 = cs2.iterator();
      for(int i2 = 0; it2.hasNext(); i2++) {
        final Cluster<?> c2 = it2.next();
        if(c2.isNoise()) {
          BitsUtil.setI(noise2, i2);
        }
        contingency[size1 + 1][i2] = c2.size();
        contingency[size1 + 1][size2] += c2.size();
      }
    }
    final Iterator<? extends Cluster<?>> it1 = cs1.iterator();
    for(int i1 = 0; it1.hasNext(); i1++) {
      final Cluster<?> c1 = it1.next();
      if(c1.isNoise()) {
        BitsUtil.setI(noise1, i1);
      }
      final DBIDs ids = DBIDUtil.ensureSet(c1.getIDs());
      contingency[i1][size2 + 1] = c1.size();
      contingency[size1][size2 + 1] += c1.size();

      final Iterator<? extends Cluster<?>> it2 = cs2.iterator();
      for(int i2 = 0; it2.hasNext(); i2++) {
        final Cluster<?> c2 = it2.next();
        int count = DBIDUtil.intersectionSize(ids, c2.getIDs());
        contingency[i1][i2] = count;
        contingency[i1][size2] += count;
        contingency[size1][i2] += count;
        contingency[size1][size2] += count;
      }
    }
  }

  /**
   * Check whether the marginal cluster sizes both sum to the total size.
   *
   * @return {@code true} when the clustering is a non-overlapping complete
   *         partitioning of the data set
   */
  public boolean isStrictPartitioning() {
    int expected = contingency[size1][size2];
    return contingency[size1][size2 + 1] == expected && contingency[size1 + 1][size2] == expected;
  }

  @Override
  public String toString() {
    StringBuilder buf = new StringBuilder(size1 * size2 * 10 + 10);
    if(contingency != null) {
      for(int i1 = 0; i1 <= size1; i1++) {
        for(int i2 = 0; i2 <= size2; i2++) {
          buf.append(contingency[i1][i2]).append(i2 < size2 ? " " : "| ");
        }
        buf.append(i1 < size1 ? "\n" : "------\n");
      }
    }
    return buf.toString();
  }

  /**
   * Get (compute) the pair counting measures.
   * 
   * @return Pair counting measures
   */
  public PairCounting getPaircount() {
    return paircount != null ? paircount : (paircount = new PairCounting(this));
  }

  /**
   * Get (compute) the entropy based measures
   * 
   * @return Entropy based measures
   */
  public Entropy getEntropy() {
    return entropy != null ? entropy : (entropy = new Entropy(this));
  }

  /**
   * Get (compute) the edit-distance based measures
   * 
   * @return Edit-distance based measures
   */
  public EditDistance getEdit() {
    return edit != null ? edit : (edit = new EditDistance(this));
  }

  /**
   * The BCubed based measures
   * 
   * @return BCubed measures
   */
  public BCubed getBCubed() {
    return bcubed != null ? bcubed : (bcubed = new BCubed(this));
  }

  /**
   * The set-matching purity measures
   * 
   * @return Set-Matching purity measures
   */
  public SetMatchingPurity getSetMatchingPurity() {
    return smp != null ? smp : (smp = new SetMatchingPurity(this));
  }

  /**
   * The Maximum Matching Accuracy
   * 
   * @return Maximum Matching Accuracy
   */
  public MaximumMatchingAccuracy getMaximumMatchingAccuracy() {
    return mmacc != null ? mmacc : (mmacc = new MaximumMatchingAccuracy(this));
  }

  /**
   * The Pair Sets Index measures
   * 
   * @return Pair Sets Index measures
   */
  public PairSetsIndex getPairSetsIndex() {
    return psi != null ? psi : (psi = new PairSetsIndex(this));
  }

  /**
   * Compute the average Gini for each cluster (in both clusterings -
   * symmetric).
   * 
   * @return Mean and variance of Gini
   */
  public MeanVariance averageSymmetricGini() {
    MeanVariance mv = new MeanVariance();
    for(int i1 = 0; i1 < size1; i1++) {
      double purity = 0.0;
      if(contingency[i1][size2] > 0) {
        final double cs = contingency[i1][size2]; // sum, as double.
        for(int i2 = 0; i2 < size2; i2++) {
          double rel = contingency[i1][i2] / cs;
          purity += rel * rel;
        }
        mv.put(purity, cs);
      }
    }
    for(int i2 = 0; i2 < size2; i2++) {
      double purity = 0.0;
      if(contingency[size1][i2] > 0) {
        final double cs = contingency[size1][i2]; // sum, as double.
        for(int i1 = 0; i1 < size1; i1++) {
          double rel = contingency[i1][i2] / cs;
          purity += rel * rel;
        }
        mv.put(purity, cs);
      }
    }
    return mv;
  }

  /**
   * Compute the adjusted average Gini for each cluster (in both clusterings -
   * symmetric).
   * 
   * @return Mean and variance of Gini
   */
  public MeanVariance adjustedSymmetricGini() {
    MeanVariance mv = new MeanVariance();
    final double total = contingency[size1][size2];
    for(int i1 = 0; i1 < size1; i1++) {
      double purity = 0.0, exp = 0.0;
      if(contingency[i1][size2] > 0) {
        final double cs = contingency[i1][size2]; // sum, as double.
        for(int i2 = 0; i2 < size2; i2++) {
          double rel = contingency[i1][i2] / cs;
          purity += rel * rel;
          double e = contingency[size1][i2] / total;
          exp += e * e;
        }
        mv.put((purity - exp) / (1 - exp), cs);
      }
    }
    for(int i2 = 0; i2 < size2; i2++) {
      double purity = 0.0, exp = 0.0;
      if(contingency[size1][i2] > 0) {
        final double cs = contingency[size1][i2]; // sum, as double.
        for(int i1 = 0; i1 < size1; i1++) {
          double rel = contingency[i1][i2] / cs;
          purity += rel * rel;
          double e = contingency[i1][size2] / total;
          exp += e * e;
        }
        mv.put((purity - exp) / (1 - exp), cs);
      }
    }
    return mv;
  }

  /**
   * Utility class.
   * 
   * @author Erich Schubert
   *
   * @hidden
   */
  public static final class Util {
    /**
     * Private constructor. Static methods only.
     */
    private Util() {
      // Do not use.
    }

    /**
     * F-Measure
     * 
     * @param precision Precision
     * @param recall Recall
     * @param beta Beta value
     * @return F-Measure
     */
    public static double fMeasure(double precision, double recall, double beta) {
      final double beta2 = beta * beta;
      return (1 + beta2) * precision * recall / (beta2 * precision + recall);
    }

    /**
     * F1-Measure (F-Measure with beta = 1)
     * 
     * @param precision Precision
     * @param recall Recall
     * @return F-Measure
     */
    public static double f1Measure(double precision, double recall) {
      return 2 * precision * recall / (precision + recall);
    }
  }
}
