001/*
002 * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.util.infotheory;
018
019import com.oracle.labs.mlrg.olcut.util.MutableLong;
020import org.tribuo.util.infotheory.impl.CachedPair;
021import org.tribuo.util.infotheory.impl.CachedTriple;
022import org.tribuo.util.infotheory.impl.PairDistribution;
023import org.tribuo.util.infotheory.impl.Row;
024import org.tribuo.util.infotheory.impl.RowList;
025import org.tribuo.util.infotheory.impl.TripleDistribution;
026
027import java.util.HashMap;
028import java.util.List;
029import java.util.Map;
030import java.util.Map.Entry;
031import java.util.Set;
032import java.util.logging.Level;
033import java.util.logging.Logger;
034import java.util.stream.DoubleStream;
035import java.util.stream.Stream;
036
037/**
038 * A class of (discrete) information theoretic functions. Gives warnings if
039 * there are insufficient samples to estimate the quantities accurately.
040 * <p>
041 * Defaults to log_2, so returns values in bits.
042 * <p>
043 * All functions expect that the element types have well defined equals and
044 * hashcode, and that equals is consistent with hashcode. The behaviour is undefined
045 * if this is not true.
046 */
047public final class InformationTheory {
048    private static final Logger logger = Logger.getLogger(InformationTheory.class.getName());
049
050    /**
051     * The ratio of samples to symbols before emitting a warning.
052     */
053    public static final double SAMPLES_RATIO = 5.0;
054    /**
055     * The initial size of the various maps.
056     */
057    public static final int DEFAULT_MAP_SIZE = 20;
058    /**
059     * Log base 2.
060     */
061    public static final double LOG_2 = Math.log(2);
062    /**
063     * Log base e.
064     */
065    public static final double LOG_E = Math.log(Math.E);
066
067    /**
068     * Sets the base of the logarithm used in the information theoretic calculations.
069     * For LOG_2 the unit is "bit", for LOG_E the unit is "nat".
070     */
071    public static double LOG_BASE = LOG_2;
072
073    /**
074     * Private constructor, only has static methods.
075     */
076    private InformationTheory() {}
077
078    /**
079     * Calculates the mutual information between the two sets of random variables.
080     * @param first The first set of random variables.
081     * @param second The second set of random variables.
082     * @param <T1> The first type.
083     * @param <T2> The second type.
084     * @return The mutual information I(first;second).
085     */
086    public static <T1,T2> double mi(Set<List<T1>> first, Set<List<T2>> second) {
087        List<Row<T1>> firstList = new RowList<>(first);
088        List<Row<T2>> secondList = new RowList<>(second);
089
090        return mi(firstList,secondList);
091    }
092
093    /**
094     * Calculates the conditional mutual information between first and second conditioned on the set.
095     * @param first A sample from the first random variable.
096     * @param second A sample from the second random variable.
097     * @param condition A sample from the conditioning set of random variables.
098     * @param <T1> The first type.
099     * @param <T2> The second type.
100     * @param <T3> The third type.
101     * @return The conditional mutual information I(first;second|condition).
102     */
103    public static <T1,T2,T3> double cmi(List<T1> first, List<T2> second, Set<List<T3>> condition) {
104        if (condition.isEmpty()) {
105            //logger.log(Level.INFO,"Empty conditioning set");
106            return mi(first,second);
107        } else {
108            List<Row<T3>> conditionList = new RowList<>(condition);
109        
110            return conditionalMI(first,second,conditionList);
111        }
112    }
113
114    /**
115     * Calculates the GTest statistics for the input variables conditioned on the set.
116     * @param first A sample from the first random variable.
117     * @param second A sample from the second random variable.
118     * @param condition A sample from the conditioning set of random variables.
119     * @param <T1> The first type.
120     * @param <T2> The second type.
121     * @param <T3> The third type.
122     * @return The GTest statistics.
123     */
124    public static <T1,T2,T3> GTestStatistics gTest(List<T1> first, List<T2> second, Set<List<T3>> condition) {
125        ScoreStateCountTuple tuple;
126        if (condition == null) {
127            //logger.log(Level.INFO,"Null conditioning set");
128            tuple = innerMI(first,second);
129        } else if (condition.isEmpty()) {
130            //logger.log(Level.INFO,"Empty conditioning set");
131            tuple = innerMI(first,second);
132        } else {
133            List<Row<T3>> conditionList = new RowList<>(condition);
134        
135            tuple = innerConditionalMI(first,second,conditionList);
136        }
137        double gMetric = 2 * second.size() * tuple.score;
138        double prob = computeChiSquaredProbability(tuple.stateCount, gMetric);
139        return new GTestStatistics(gMetric,tuple.stateCount,prob);
140    }
141
142    /**
143     * Computes the cumulative probability of the input value under a Chi-Squared distribution
144     * with the specified degrees of Freedom.
145     * @param degreesOfFreedom The degrees of freedom in the distribution.
146     * @param value The observed value.
147     * @return The cumulative probability of the observed value.
148     */
149    private static double computeChiSquaredProbability(int degreesOfFreedom, double value) {
150        if (value <= 0) {
151            return 0.0;
152        } else {
153            int shape = degreesOfFreedom / 2;
154            int scale = 2;
155            return Gamma.regularizedGammaP(shape, value / scale, 1e-14, Integer.MAX_VALUE);
156        }
157    }
158
159    /**
160     * Calculates the discrete Shannon joint mutual information, using
161     * histogram probability estimators. Arrays must be the same length.
162     * @param <T1> Type contained in the first array.
163     * @param <T2> Type contained in the second array.
164     * @param <T3> Type contained in the target array.
165     * @param first An array of values.
166     * @param second Another array of values.
167     * @param target Target array of values.
168     * @return The mutual information I(first,second;joint)
169     */
170    public static <T1,T2,T3> double jointMI(List<T1> first, List<T2> second, List<T3> target) {
171        if ((first.size() == second.size()) && (first.size() == target.size())) {
172            TripleDistribution<T1,T2,T3> tripleRV = TripleDistribution.constructFromLists(first,second,target);
173            return jointMI(tripleRV);
174        } else {
175            throw new IllegalArgumentException("Joint Mutual Information requires three vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", target.size() = " + target.size());
176        }
177    }
178
179    /**
180     * Calculates the discrete Shannon joint mutual information, using
181     * histogram probability estimators. Arrays must be the same length.
182     * @param <T1> Type contained in the first array.
183     * @param <T2> Type contained in the second array.
184     * @param <T3> Type contained in the target array.
185     * @param rv The random variable to calculate the joint mi of
186     * @return The mutual information I(first,second;joint)
187     */
188    public static <T1,T2,T3> double jointMI(TripleDistribution<T1,T2,T3> rv) {
189        double vecLength = rv.count;
190        Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount = rv.getJointCount();
191        Map<CachedPair<T1,T2>,MutableLong> abCount = rv.getABCount();
192        Map<T3,MutableLong> cCount = rv.getCCount();
193
194        double jmi = 0.0;
195        for (Entry<CachedTriple<T1,T2,T3>,MutableLong> e : jointCount.entrySet()) {
196            double jointCurCount = e.getValue().doubleValue();
197            double prob = jointCurCount / vecLength;
198            CachedPair<T1,T2> pair = e.getKey().getAB();
199            double abCurCount = abCount.get(pair).doubleValue();
200            double cCurCount = cCount.get(e.getKey().getC()).doubleValue();
201
202            jmi += prob * Math.log((vecLength*jointCurCount)/(abCurCount*cCurCount));
203        }
204        jmi /= LOG_BASE;
205        
206        double stateRatio = vecLength / jointCount.size();
207        if (stateRatio < SAMPLES_RATIO) {
208            logger.log(Level.INFO, "Joint MI estimate of {0} had samples/state ratio of {1}, with {2} observations and {3} states", new Object[]{jmi, stateRatio, vecLength, jointCount.size()});
209        }
210
211        return jmi;
212    }
213
214    /**
215     * Calculates the conditional mutual information. If flipped == true, then calculates I(T1;T3|T2), otherwise calculates I(T1;T2|T3).
216     * @param <T1> The type of the first argument.
217     * @param <T2> The type of the second argument.
218     * @param <T3> The type of the third argument.
219     * @param rv The random variable.
220     * @param flipped If true then the second element is the conditional variable, otherwise the third element is.
221     * @return A ScoreStateCountTuple containing the conditional mutual information and the number of states in the joint random variable.
222     */
223    private static <T1,T2,T3> ScoreStateCountTuple innerConditionalMI(TripleDistribution<T1,T2,T3> rv, boolean flipped) {
224        Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount = rv.getJointCount();
225        Map<CachedPair<T1,T2>,MutableLong> abCount = rv.getABCount();
226        Map<CachedPair<T1,T3>,MutableLong> acCount = rv.getACCount();
227        Map<CachedPair<T2,T3>,MutableLong> bcCount = rv.getBCCount();
228        Map<T2,MutableLong> bCount = rv.getBCount();
229        Map<T3,MutableLong> cCount = rv.getCCount();
230
231        double vectorLength = rv.count;
232        double cmi = 0.0;
233        if (flipped) {
234            for (Entry<CachedTriple<T1,T2,T3>, MutableLong> e : jointCount.entrySet()) {
235                double jointCurCount = e.getValue().doubleValue();
236                double prob = jointCurCount / vectorLength;
237                CachedPair<T1,T2> abPair = e.getKey().getAB();
238                CachedPair<T2,T3> bcPair = e.getKey().getBC();
239                double abCurCount = abCount.get(abPair).doubleValue();
240                double bcCurCount = bcCount.get(bcPair).doubleValue();
241                double bCurCount = bCount.get(e.getKey().getB()).doubleValue();
242
243                cmi += prob * Math.log((bCurCount * jointCurCount) / (abCurCount * bcCurCount));
244            }
245        } else {
246            for (Entry<CachedTriple<T1, T2, T3>, MutableLong> e : jointCount.entrySet()) {
247                double jointCurCount = e.getValue().doubleValue();
248                double prob = jointCurCount / vectorLength;
249                CachedPair<T1, T3> acPair = e.getKey().getAC();
250                CachedPair<T2, T3> bcPair = e.getKey().getBC();
251                double acCurCount = acCount.get(acPair).doubleValue();
252                double bcCurCount = bcCount.get(bcPair).doubleValue();
253                double cCurCount = cCount.get(e.getKey().getC()).doubleValue();
254
255                cmi += prob * Math.log((cCurCount * jointCurCount) / (acCurCount * bcCurCount));
256            }
257        }
258        cmi /= LOG_BASE;
259
260        double stateRatio = vectorLength / jointCount.size();
261        if (stateRatio < SAMPLES_RATIO) {
262            logger.log(Level.INFO, "Conditional MI estimate of {0} had samples/state ratio of {1}", new Object[]{cmi, stateRatio});
263        }
264        
265        return new ScoreStateCountTuple(cmi,jointCount.size());
266    }
267
268    /**
269     * Calculates the conditional mutual information, I(T1;T2|T3).
270     * @param <T1> The type of the first argument.
271     * @param <T2> The type of the second argument.
272     * @param <T3> The type of the third argument.
273     * @param first The first random variable.
274     * @param second The second random variable.
275     * @param condition The conditioning random variable.
276     * @return A ScoreStateCountTuple containing the conditional mutual information and the number of states in the joint random variable.
277     */
278    private static <T1,T2,T3> ScoreStateCountTuple innerConditionalMI(List<T1> first, List<T2> second, List<T3> condition) {
279        if ((first.size() == second.size()) && (first.size() == condition.size())) {
280            TripleDistribution<T1,T2,T3> tripleRV = TripleDistribution.constructFromLists(first,second,condition);
281
282            return innerConditionalMI(tripleRV,false);
283        } else {
284            throw new IllegalArgumentException("Conditional Mutual Information requires three vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", condition.size() = " + condition.size());
285        }
286    }
287    
288    /**
289     * Calculates the discrete Shannon conditional mutual information, using
290     * histogram probability estimators. Arrays must be the same length.
291     * @param <T1> Type contained in the first array.
292     * @param <T2> Type contained in the second array.
293     * @param <T3> Type contained in the condition array.
294     * @param first An array of values.
295     * @param second Another array of values.
296     * @param condition Array to condition upon.
297     * @return The conditional mutual information I(first;second|condition)
298     */
299    public static <T1,T2,T3> double conditionalMI(List<T1> first, List<T2> second, List<T3> condition) {
300        return innerConditionalMI(first,second,condition).score;
301    }
302
303    /**
304     * Calculates the discrete Shannon conditional mutual information, using
305     * histogram probability estimators. Note this calculates I(T1;T2|T3).
306     * @param <T1> Type of the first variable.
307     * @param <T2> Type of the second variable.
308     * @param <T3> Type of the condition variable.
309     * @param rv The triple random variable of the three inputs.
310     * @return The conditional mutual information I(first;second|condition)
311     */
312    public static <T1,T2,T3> double conditionalMI(TripleDistribution<T1,T2,T3> rv) {
313        return innerConditionalMI(rv,false).score;
314    }
315
316    /**
317     * Calculates the discrete Shannon conditional mutual information, using
318     * histogram probability estimators. Note this calculates I(T1;T3|T2).
319     * @param <T1> Type of the first variable.
320     * @param <T2> Type of the condition variable.
321     * @param <T3> Type of the second variable.
322     * @param rv The triple random variable of the three inputs.
323     * @return The conditional mutual information I(first;second|condition)
324     */
325    public static <T1,T2,T3> double conditionalMIFlipped(TripleDistribution<T1,T2,T3> rv) {
326        return innerConditionalMI(rv,true).score;
327    }
328
329    /**
330     * Calculates the mutual information from a joint random variable.
331     * @param pairDist The joint distribution.
332     * @param <T1> The first type.
333     * @param <T2> The second type.
334     * @return A ScoreStateCountTuple containing the mutual information and the number of states in the joint variable.
335     */
336    private static <T1,T2> ScoreStateCountTuple innerMI(PairDistribution<T1,T2> pairDist) {
337        Map<CachedPair<T1,T2>,MutableLong> countDist = pairDist.jointCounts;
338        Map<T1,MutableLong> firstCountDist = pairDist.firstCount;
339        Map<T2,MutableLong> secondCountDist = pairDist.secondCount;
340
341        double vectorLength = pairDist.count;
342        double mi = 0.0;
343        boolean error = false;
344        for (Entry<CachedPair<T1,T2>,MutableLong> e : countDist.entrySet()) {
345            double jointCount = e.getValue().doubleValue();
346            double prob = jointCount / vectorLength;
347            double firstProb = firstCountDist.get(e.getKey().getA()).doubleValue();
348            double secondProb = secondCountDist.get(e.getKey().getB()).doubleValue();
349
350            double top = vectorLength * jointCount;
351            double bottom = firstProb * secondProb;
352            double ratio = top/bottom;
353            double logRatio = Math.log(ratio);
354
355            if (Double.isNaN(logRatio) || Double.isNaN(prob) || Double.isNaN(mi)) {
356                logger.log(Level.WARNING, "State = " + e.getKey().toString());
357                logger.log(Level.WARNING, "mi = " + mi + " prob = " + prob + " top = " + top + " bottom = " + bottom + " ratio = " + ratio + " logRatio = " + logRatio);
358                error = true;
359            }
360            mi += prob * logRatio;
361            //mi += prob * Math.log((vectorLength*jointCount)/(firstProb*secondProb));
362        }
363        mi /= LOG_BASE;
364
365        double stateRatio = vectorLength / countDist.size();
366        if (stateRatio < SAMPLES_RATIO) {
367            logger.log(Level.INFO, "MI estimate of {0} had samples/state ratio of {1}", new Object[]{mi, stateRatio});
368        }
369        
370        if (error) {
371            logger.log(Level.SEVERE, "NanFound ", new IllegalStateException("NaN found"));
372        }
373        
374        return new ScoreStateCountTuple(mi,countDist.size());
375    }
376
377    /**
378     * Calculates the mutual information between the two lists.
379     * @param first The first list.
380     * @param second The second list.
381     * @param <T1> The first type.
382     * @param <T2> The second type.
383     * @return A ScoreStateCountTuple containing the mutual information and the number of states in the joint variable.
384     */
385    private static <T1,T2> ScoreStateCountTuple innerMI(List<T1> first, List<T2> second) {
386        if (first.size() == second.size()) {
387            PairDistribution<T1,T2> pairDist = PairDistribution.constructFromLists(first, second);
388            
389            return innerMI(pairDist);
390        } else {
391            throw new IllegalArgumentException("Mutual Information requires two vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size());
392        }
393    }
394    
395    /**
396     * Calculates the discrete Shannon mutual information, using histogram 
397     * probability estimators. Arrays must be the same length.
398     * @param <T1> Type of the first array
399     * @param <T2> Type of the second array
400     * @param first An array of values
401     * @param second Another array of values
402     * @return The mutual information I(first;second)
403     */
404    public static <T1,T2> double mi(List<T1> first, List<T2> second) {
405        return innerMI(first,second).score;
406    }
407
408    /**
409     * Calculates the discrete Shannon mutual information, using histogram 
410     * probability estimators.
411     * @param <T1> Type of the first variable
412     * @param <T2> Type of the second variable
413     * @param pairDist PairDistribution for the two variables.
414     * @return The mutual information I(first;second)
415     */
416    public static <T1,T2> double mi(PairDistribution<T1,T2> pairDist) {
417        return innerMI(pairDist).score;
418    }
419
420    /**
421     * Calculates the Shannon joint entropy of two arrays, using histogram 
422     * probability estimators. Arrays must be same length.
423     * @param <T1> Type of the first array.
424     * @param <T2> Type of the second array.
425     * @param first An array of values.
426     * @param second Another array of values.
427     * @return The entropy H(first,second)
428     */
429    public static <T1,T2> double jointEntropy(List<T1> first, List<T2> second) {
430        if (first.size() == second.size()) {
431            double vectorLength = first.size();
432            double jointEntropy = 0.0;
433            
434            PairDistribution<T1,T2> countPair = PairDistribution.constructFromLists(first,second); 
435            Map<CachedPair<T1,T2>,MutableLong> countDist = countPair.jointCounts;
436
437            for (Entry<CachedPair<T1,T2>,MutableLong> e : countDist.entrySet()) {
438                double prob = e.getValue().doubleValue() / vectorLength;
439
440                jointEntropy -= prob * Math.log(prob);
441            }
442            jointEntropy /= LOG_BASE;
443
444            double stateRatio = vectorLength / countDist.size();
445            if (stateRatio < SAMPLES_RATIO) {
446                logger.log(Level.INFO, "Joint Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{jointEntropy, stateRatio});
447            }
448            
449            return jointEntropy;
450        } else {
451            throw new IllegalArgumentException("Joint Entropy requires two vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size());
452        }
453    }
454    
455    /**
456     * Calculates the discrete Shannon conditional entropy of two arrays, using
457     * histogram probability estimators. Arrays must be the same length.
458     * @param <T1> Type of the first array.
459     * @param <T2> Type of the second array.
460     * @param vector The main array of values.
461     * @param condition The array to condition on.
462     * @return The conditional entropy H(vector|condition).
463     */
464    public static <T1,T2> double conditionalEntropy(List<T1> vector, List<T2> condition) {
465        if (vector.size() == condition.size()) {
466            double vectorLength = vector.size();
467            double condEntropy = 0.0;
468            
469            PairDistribution<T1,T2> countPair = PairDistribution.constructFromLists(vector,condition); 
470            Map<CachedPair<T1,T2>,MutableLong> countDist = countPair.jointCounts;
471            Map<T2,MutableLong> conditionCountDist = countPair.secondCount;
472
473            for (Entry<CachedPair<T1,T2>,MutableLong> e : countDist.entrySet()) {
474                double prob = e.getValue().doubleValue() / vectorLength;
475                double condProb = conditionCountDist.get(e.getKey().getB()).doubleValue() / vectorLength;
476
477                condEntropy -= prob * Math.log(prob/condProb);
478            }
479            condEntropy /= LOG_BASE;
480
481            double stateRatio = vectorLength / countDist.size();
482            if (stateRatio < SAMPLES_RATIO) {
483                logger.log(Level.INFO, "Conditional Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{condEntropy, stateRatio});
484            }
485            
486            return condEntropy;
487        } else {
488            throw new IllegalArgumentException("Conditional Entropy requires two vectors the same length. vector.size() = " + vector.size() + ", condition.size() = " + condition.size());
489        }
490    }
491
492    /**
493     * Calculates the discrete Shannon entropy, using histogram probability 
494     * estimators.
495     * @param <T> Type of the array.
496     * @param vector The array of values.
497     * @return The entropy H(vector).
498     */
499    public static <T> double entropy(List<T> vector) {
500        double vectorLength = vector.size();
501        double entropy = 0.0;
502
503        Map<T,Long> countDist = calculateCountDist(vector);
504        for (Entry<T,Long> e : countDist.entrySet()) {
505            double prob = e.getValue() / vectorLength;
506            entropy -= prob * Math.log(prob);
507        }
508        entropy /= LOG_BASE;
509
510        double stateRatio = vectorLength / countDist.size();
511        if (stateRatio < SAMPLES_RATIO) {
512            logger.log(Level.INFO, "Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{entropy, stateRatio});
513        }
514        
515        return entropy;
516    }
517
518    /**
519     * Generate the counts for a single vector.
520     * @param <T> The type inside the vector.
521     * @param vector An array of values.
522     * @return A HashMap from states of T to counts.
523     */
524    public static <T> Map<T,Long> calculateCountDist(List<T> vector) {
525        HashMap<T,Long> countDist = new HashMap<>(DEFAULT_MAP_SIZE);
526        for (T e : vector) {
527            Long curCount = countDist.getOrDefault(e,0L);
528            curCount += 1;
529            countDist.put(e, curCount);
530        }
531
532        return countDist;
533    }
534
535    /**
536     * Calculates the discrete Shannon entropy of a stream, assuming each element of the stream is
537     * an element of the same probability distribution.
538     * @param vector The probability distribution.
539     * @return The entropy.
540     */
541    public static double calculateEntropy(Stream<Double> vector) {
542        return vector.map((p) -> (- p * Math.log(p) / LOG_BASE)).reduce(0.0, Double::sum);
543    }
544
545    /**
546     * Calculates the discrete Shannon entropy of a stream, assuming each element of the stream is
547     * an element of the same probability distribution.
548     * @param vector The probability distribution.
549     * @return The entropy.
550     */
551    public static double calculateEntropy(DoubleStream vector) {
552        return vector.map((p) -> (- p * Math.log(p) / LOG_BASE)).sum();
553    }
554
555    /**
556     * Compute the expected mutual information assuming randomized inputs.
557     *
558     * @param first The first vector.
559     * @param second The second vector.
560     * @param <T> The type inside the list. Must define equals and hashcode.
561     * @return The expected mutual information under a hypergeometric distribution.
562     */
563    public static <T> double expectedMI(List<T> first, List<T> second) {
564        PairDistribution<T,T> pd = PairDistribution.constructFromLists(first,second);
565
566        Map<T, MutableLong> firstCount = pd.firstCount;
567        Map<T,MutableLong> secondCount = pd.secondCount;
568        long count = pd.count;
569
570        double output = 0.0;
571
572        for (Entry<T,MutableLong> f : firstCount.entrySet()) {
573            for (Entry<T,MutableLong> s : secondCount.entrySet()) {
574                long fVal = f.getValue().longValue();
575                long sVal = s.getValue().longValue();
576                long minCount = Math.min(fVal, sVal);
577
578                long threshold = fVal + sVal - count;
579                long start = threshold > 1 ? threshold : 1;
580
581                for (long nij = start; nij <= minCount; nij++) {
582                    double acc = ((double) nij) / count;
583                    acc *= Math.log(((double) (count * nij)) / (fVal * sVal));
584                    //numerator
585                    double logSpace = Gamma.logGamma(fVal + 1);
586                    logSpace += Gamma.logGamma(sVal + 1);
587                    logSpace += Gamma.logGamma(count - fVal + 1);
588                    logSpace += Gamma.logGamma(count - sVal + 1);
589                    //denominator
590                    logSpace -= Gamma.logGamma(count + 1);
591                    logSpace -= Gamma.logGamma(nij + 1);
592                    logSpace -= Gamma.logGamma(fVal - nij + 1);
593                    logSpace -= Gamma.logGamma(sVal - nij + 1);
594                    logSpace -= Gamma.logGamma(count - fVal - sVal + nij + 1);
595                    acc *= Math.exp(logSpace);
596                    output += acc;
597                }
598            }
599        }
600        return output;
601    }
602
603    /**
604     * A tuple of the information theoretic value, along with the number of
605     * states in the random variable. Will be a record one day.
606     */
607    private static class ScoreStateCountTuple {
608        public final double score;
609        public final int stateCount;
610
611        /**
612         * Construct a score state tuple
613         * @param score The score.
614         * @param stateCount The number of states.
615         */
616        ScoreStateCountTuple(double score, int stateCount) {
617            this.score = score;
618            this.stateCount = stateCount;
619        }
620
621        @Override
622        public String toString() {
623            return "ScoreStateCount(score=" + score + ",stateCount=" + stateCount + ")";
624        }
625    }
626
627    /**
628     * An immutable named tuple containing the statistics from a G test.
629     * <p>
630     * Will be a record one day.
631     */
632    public static final class GTestStatistics {
633        /**
634         * The G test statistic.
635         */
636        public final double gStatistic;
637        /**
638         * The number of states.
639         */
640        public final int numStates;
641        /**
642         * The probability of that statistic.
643         */
644        public final double probability;
645
646        /**
647         * Constructs a GTestStatistics tuple with the supplied values.
648         * @param gStatistic The g test statistic.
649         * @param numStates The number of states.
650         * @param probability The probability of that statistic.
651         */
652        // TODO should be package private.
653        public GTestStatistics(double gStatistic, int numStates, double probability) {
654            this.gStatistic = gStatistic;
655            this.numStates = numStates;
656            this.probability = probability;
657        }
658
659        @Override
660        public String toString() {
661            return "GTest(statistic="+gStatistic+",probability="+probability+",numStates="+numStates+")";
662        }
663    }
664}
665