001/* 002 * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.util.infotheory; 018 019import com.oracle.labs.mlrg.olcut.util.MutableLong; 020import org.tribuo.util.infotheory.impl.CachedPair; 021import org.tribuo.util.infotheory.impl.CachedTriple; 022import org.tribuo.util.infotheory.impl.PairDistribution; 023import org.tribuo.util.infotheory.impl.Row; 024import org.tribuo.util.infotheory.impl.RowList; 025import org.tribuo.util.infotheory.impl.TripleDistribution; 026 027import java.util.HashMap; 028import java.util.List; 029import java.util.Map; 030import java.util.Map.Entry; 031import java.util.Set; 032import java.util.logging.Level; 033import java.util.logging.Logger; 034import java.util.stream.DoubleStream; 035import java.util.stream.Stream; 036 037/** 038 * A class of (discrete) information theoretic functions. Gives warnings if 039 * there are insufficient samples to estimate the quantities accurately. 040 * <p> 041 * Defaults to log_2, so returns values in bits. 042 * <p> 043 * All functions expect that the element types have well defined equals and 044 * hashcode, and that equals is consistent with hashcode. The behaviour is undefined 045 * if this is not true. 046 */ 047public final class InformationTheory { 048 private static final Logger logger = Logger.getLogger(InformationTheory.class.getName()); 049 050 /** 051 * The ratio of samples to symbols before emitting a warning. 052 */ 053 public static final double SAMPLES_RATIO = 5.0; 054 /** 055 * The initial size of the various maps. 056 */ 057 public static final int DEFAULT_MAP_SIZE = 20; 058 /** 059 * Log base 2. 060 */ 061 public static final double LOG_2 = Math.log(2); 062 /** 063 * Log base e. 064 */ 065 public static final double LOG_E = Math.log(Math.E); 066 067 /** 068 * Sets the base of the logarithm used in the information theoretic calculations. 069 * For LOG_2 the unit is "bit", for LOG_E the unit is "nat". 070 */ 071 public static double LOG_BASE = LOG_2; 072 073 /** 074 * Private constructor, only has static methods. 075 */ 076 private InformationTheory() {} 077 078 /** 079 * Calculates the mutual information between the two sets of random variables. 080 * @param first The first set of random variables. 081 * @param second The second set of random variables. 082 * @param <T1> The first type. 083 * @param <T2> The second type. 084 * @return The mutual information I(first;second). 085 */ 086 public static <T1,T2> double mi(Set<List<T1>> first, Set<List<T2>> second) { 087 List<Row<T1>> firstList = new RowList<>(first); 088 List<Row<T2>> secondList = new RowList<>(second); 089 090 return mi(firstList,secondList); 091 } 092 093 /** 094 * Calculates the conditional mutual information between first and second conditioned on the set. 095 * @param first A sample from the first random variable. 096 * @param second A sample from the second random variable. 097 * @param condition A sample from the conditioning set of random variables. 098 * @param <T1> The first type. 099 * @param <T2> The second type. 100 * @param <T3> The third type. 101 * @return The conditional mutual information I(first;second|condition). 102 */ 103 public static <T1,T2,T3> double cmi(List<T1> first, List<T2> second, Set<List<T3>> condition) { 104 if (condition.isEmpty()) { 105 //logger.log(Level.INFO,"Empty conditioning set"); 106 return mi(first,second); 107 } else { 108 List<Row<T3>> conditionList = new RowList<>(condition); 109 110 return conditionalMI(first,second,conditionList); 111 } 112 } 113 114 /** 115 * Calculates the GTest statistics for the input variables conditioned on the set. 116 * @param first A sample from the first random variable. 117 * @param second A sample from the second random variable. 118 * @param condition A sample from the conditioning set of random variables. 119 * @param <T1> The first type. 120 * @param <T2> The second type. 121 * @param <T3> The third type. 122 * @return The GTest statistics. 123 */ 124 public static <T1,T2,T3> GTestStatistics gTest(List<T1> first, List<T2> second, Set<List<T3>> condition) { 125 ScoreStateCountTuple tuple; 126 if (condition == null) { 127 //logger.log(Level.INFO,"Null conditioning set"); 128 tuple = innerMI(first,second); 129 } else if (condition.isEmpty()) { 130 //logger.log(Level.INFO,"Empty conditioning set"); 131 tuple = innerMI(first,second); 132 } else { 133 List<Row<T3>> conditionList = new RowList<>(condition); 134 135 tuple = innerConditionalMI(first,second,conditionList); 136 } 137 double gMetric = 2 * second.size() * tuple.score; 138 double prob = computeChiSquaredProbability(tuple.stateCount, gMetric); 139 return new GTestStatistics(gMetric,tuple.stateCount,prob); 140 } 141 142 /** 143 * Computes the cumulative probability of the input value under a Chi-Squared distribution 144 * with the specified degrees of Freedom. 145 * @param degreesOfFreedom The degrees of freedom in the distribution. 146 * @param value The observed value. 147 * @return The cumulative probability of the observed value. 148 */ 149 private static double computeChiSquaredProbability(int degreesOfFreedom, double value) { 150 if (value <= 0) { 151 return 0.0; 152 } else { 153 int shape = degreesOfFreedom / 2; 154 int scale = 2; 155 return Gamma.regularizedGammaP(shape, value / scale, 1e-14, Integer.MAX_VALUE); 156 } 157 } 158 159 /** 160 * Calculates the discrete Shannon joint mutual information, using 161 * histogram probability estimators. Arrays must be the same length. 162 * @param <T1> Type contained in the first array. 163 * @param <T2> Type contained in the second array. 164 * @param <T3> Type contained in the target array. 165 * @param first An array of values. 166 * @param second Another array of values. 167 * @param target Target array of values. 168 * @return The mutual information I(first,second;joint) 169 */ 170 public static <T1,T2,T3> double jointMI(List<T1> first, List<T2> second, List<T3> target) { 171 if ((first.size() == second.size()) && (first.size() == target.size())) { 172 TripleDistribution<T1,T2,T3> tripleRV = TripleDistribution.constructFromLists(first,second,target); 173 return jointMI(tripleRV); 174 } else { 175 throw new IllegalArgumentException("Joint Mutual Information requires three vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", target.size() = " + target.size()); 176 } 177 } 178 179 /** 180 * Calculates the discrete Shannon joint mutual information, using 181 * histogram probability estimators. Arrays must be the same length. 182 * @param <T1> Type contained in the first array. 183 * @param <T2> Type contained in the second array. 184 * @param <T3> Type contained in the target array. 185 * @param rv The random variable to calculate the joint mi of 186 * @return The mutual information I(first,second;joint) 187 */ 188 public static <T1,T2,T3> double jointMI(TripleDistribution<T1,T2,T3> rv) { 189 double vecLength = rv.count; 190 Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount = rv.getJointCount(); 191 Map<CachedPair<T1,T2>,MutableLong> abCount = rv.getABCount(); 192 Map<T3,MutableLong> cCount = rv.getCCount(); 193 194 double jmi = 0.0; 195 for (Entry<CachedTriple<T1,T2,T3>,MutableLong> e : jointCount.entrySet()) { 196 double jointCurCount = e.getValue().doubleValue(); 197 double prob = jointCurCount / vecLength; 198 CachedPair<T1,T2> pair = e.getKey().getAB(); 199 double abCurCount = abCount.get(pair).doubleValue(); 200 double cCurCount = cCount.get(e.getKey().getC()).doubleValue(); 201 202 jmi += prob * Math.log((vecLength*jointCurCount)/(abCurCount*cCurCount)); 203 } 204 jmi /= LOG_BASE; 205 206 double stateRatio = vecLength / jointCount.size(); 207 if (stateRatio < SAMPLES_RATIO) { 208 logger.log(Level.INFO, "Joint MI estimate of {0} had samples/state ratio of {1}, with {2} observations and {3} states", new Object[]{jmi, stateRatio, vecLength, jointCount.size()}); 209 } 210 211 return jmi; 212 } 213 214 /** 215 * Calculates the conditional mutual information. If flipped == true, then calculates I(T1;T3|T2), otherwise calculates I(T1;T2|T3). 216 * @param <T1> The type of the first argument. 217 * @param <T2> The type of the second argument. 218 * @param <T3> The type of the third argument. 219 * @param rv The random variable. 220 * @param flipped If true then the second element is the conditional variable, otherwise the third element is. 221 * @return A ScoreStateCountTuple containing the conditional mutual information and the number of states in the joint random variable. 222 */ 223 private static <T1,T2,T3> ScoreStateCountTuple innerConditionalMI(TripleDistribution<T1,T2,T3> rv, boolean flipped) { 224 Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount = rv.getJointCount(); 225 Map<CachedPair<T1,T2>,MutableLong> abCount = rv.getABCount(); 226 Map<CachedPair<T1,T3>,MutableLong> acCount = rv.getACCount(); 227 Map<CachedPair<T2,T3>,MutableLong> bcCount = rv.getBCCount(); 228 Map<T2,MutableLong> bCount = rv.getBCount(); 229 Map<T3,MutableLong> cCount = rv.getCCount(); 230 231 double vectorLength = rv.count; 232 double cmi = 0.0; 233 if (flipped) { 234 for (Entry<CachedTriple<T1,T2,T3>, MutableLong> e : jointCount.entrySet()) { 235 double jointCurCount = e.getValue().doubleValue(); 236 double prob = jointCurCount / vectorLength; 237 CachedPair<T1,T2> abPair = e.getKey().getAB(); 238 CachedPair<T2,T3> bcPair = e.getKey().getBC(); 239 double abCurCount = abCount.get(abPair).doubleValue(); 240 double bcCurCount = bcCount.get(bcPair).doubleValue(); 241 double bCurCount = bCount.get(e.getKey().getB()).doubleValue(); 242 243 cmi += prob * Math.log((bCurCount * jointCurCount) / (abCurCount * bcCurCount)); 244 } 245 } else { 246 for (Entry<CachedTriple<T1, T2, T3>, MutableLong> e : jointCount.entrySet()) { 247 double jointCurCount = e.getValue().doubleValue(); 248 double prob = jointCurCount / vectorLength; 249 CachedPair<T1, T3> acPair = e.getKey().getAC(); 250 CachedPair<T2, T3> bcPair = e.getKey().getBC(); 251 double acCurCount = acCount.get(acPair).doubleValue(); 252 double bcCurCount = bcCount.get(bcPair).doubleValue(); 253 double cCurCount = cCount.get(e.getKey().getC()).doubleValue(); 254 255 cmi += prob * Math.log((cCurCount * jointCurCount) / (acCurCount * bcCurCount)); 256 } 257 } 258 cmi /= LOG_BASE; 259 260 double stateRatio = vectorLength / jointCount.size(); 261 if (stateRatio < SAMPLES_RATIO) { 262 logger.log(Level.INFO, "Conditional MI estimate of {0} had samples/state ratio of {1}", new Object[]{cmi, stateRatio}); 263 } 264 265 return new ScoreStateCountTuple(cmi,jointCount.size()); 266 } 267 268 /** 269 * Calculates the conditional mutual information, I(T1;T2|T3). 270 * @param <T1> The type of the first argument. 271 * @param <T2> The type of the second argument. 272 * @param <T3> The type of the third argument. 273 * @param first The first random variable. 274 * @param second The second random variable. 275 * @param condition The conditioning random variable. 276 * @return A ScoreStateCountTuple containing the conditional mutual information and the number of states in the joint random variable. 277 */ 278 private static <T1,T2,T3> ScoreStateCountTuple innerConditionalMI(List<T1> first, List<T2> second, List<T3> condition) { 279 if ((first.size() == second.size()) && (first.size() == condition.size())) { 280 TripleDistribution<T1,T2,T3> tripleRV = TripleDistribution.constructFromLists(first,second,condition); 281 282 return innerConditionalMI(tripleRV,false); 283 } else { 284 throw new IllegalArgumentException("Conditional Mutual Information requires three vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", condition.size() = " + condition.size()); 285 } 286 } 287 288 /** 289 * Calculates the discrete Shannon conditional mutual information, using 290 * histogram probability estimators. Arrays must be the same length. 291 * @param <T1> Type contained in the first array. 292 * @param <T2> Type contained in the second array. 293 * @param <T3> Type contained in the condition array. 294 * @param first An array of values. 295 * @param second Another array of values. 296 * @param condition Array to condition upon. 297 * @return The conditional mutual information I(first;second|condition) 298 */ 299 public static <T1,T2,T3> double conditionalMI(List<T1> first, List<T2> second, List<T3> condition) { 300 return innerConditionalMI(first,second,condition).score; 301 } 302 303 /** 304 * Calculates the discrete Shannon conditional mutual information, using 305 * histogram probability estimators. Note this calculates I(T1;T2|T3). 306 * @param <T1> Type of the first variable. 307 * @param <T2> Type of the second variable. 308 * @param <T3> Type of the condition variable. 309 * @param rv The triple random variable of the three inputs. 310 * @return The conditional mutual information I(first;second|condition) 311 */ 312 public static <T1,T2,T3> double conditionalMI(TripleDistribution<T1,T2,T3> rv) { 313 return innerConditionalMI(rv,false).score; 314 } 315 316 /** 317 * Calculates the discrete Shannon conditional mutual information, using 318 * histogram probability estimators. Note this calculates I(T1;T3|T2). 319 * @param <T1> Type of the first variable. 320 * @param <T2> Type of the condition variable. 321 * @param <T3> Type of the second variable. 322 * @param rv The triple random variable of the three inputs. 323 * @return The conditional mutual information I(first;second|condition) 324 */ 325 public static <T1,T2,T3> double conditionalMIFlipped(TripleDistribution<T1,T2,T3> rv) { 326 return innerConditionalMI(rv,true).score; 327 } 328 329 /** 330 * Calculates the mutual information from a joint random variable. 331 * @param pairDist The joint distribution. 332 * @param <T1> The first type. 333 * @param <T2> The second type. 334 * @return A ScoreStateCountTuple containing the mutual information and the number of states in the joint variable. 335 */ 336 private static <T1,T2> ScoreStateCountTuple innerMI(PairDistribution<T1,T2> pairDist) { 337 Map<CachedPair<T1,T2>,MutableLong> countDist = pairDist.jointCounts; 338 Map<T1,MutableLong> firstCountDist = pairDist.firstCount; 339 Map<T2,MutableLong> secondCountDist = pairDist.secondCount; 340 341 double vectorLength = pairDist.count; 342 double mi = 0.0; 343 boolean error = false; 344 for (Entry<CachedPair<T1,T2>,MutableLong> e : countDist.entrySet()) { 345 double jointCount = e.getValue().doubleValue(); 346 double prob = jointCount / vectorLength; 347 double firstProb = firstCountDist.get(e.getKey().getA()).doubleValue(); 348 double secondProb = secondCountDist.get(e.getKey().getB()).doubleValue(); 349 350 double top = vectorLength * jointCount; 351 double bottom = firstProb * secondProb; 352 double ratio = top/bottom; 353 double logRatio = Math.log(ratio); 354 355 if (Double.isNaN(logRatio) || Double.isNaN(prob) || Double.isNaN(mi)) { 356 logger.log(Level.WARNING, "State = " + e.getKey().toString()); 357 logger.log(Level.WARNING, "mi = " + mi + " prob = " + prob + " top = " + top + " bottom = " + bottom + " ratio = " + ratio + " logRatio = " + logRatio); 358 error = true; 359 } 360 mi += prob * logRatio; 361 //mi += prob * Math.log((vectorLength*jointCount)/(firstProb*secondProb)); 362 } 363 mi /= LOG_BASE; 364 365 double stateRatio = vectorLength / countDist.size(); 366 if (stateRatio < SAMPLES_RATIO) { 367 logger.log(Level.INFO, "MI estimate of {0} had samples/state ratio of {1}", new Object[]{mi, stateRatio}); 368 } 369 370 if (error) { 371 logger.log(Level.SEVERE, "NanFound ", new IllegalStateException("NaN found")); 372 } 373 374 return new ScoreStateCountTuple(mi,countDist.size()); 375 } 376 377 /** 378 * Calculates the mutual information between the two lists. 379 * @param first The first list. 380 * @param second The second list. 381 * @param <T1> The first type. 382 * @param <T2> The second type. 383 * @return A ScoreStateCountTuple containing the mutual information and the number of states in the joint variable. 384 */ 385 private static <T1,T2> ScoreStateCountTuple innerMI(List<T1> first, List<T2> second) { 386 if (first.size() == second.size()) { 387 PairDistribution<T1,T2> pairDist = PairDistribution.constructFromLists(first, second); 388 389 return innerMI(pairDist); 390 } else { 391 throw new IllegalArgumentException("Mutual Information requires two vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size()); 392 } 393 } 394 395 /** 396 * Calculates the discrete Shannon mutual information, using histogram 397 * probability estimators. Arrays must be the same length. 398 * @param <T1> Type of the first array 399 * @param <T2> Type of the second array 400 * @param first An array of values 401 * @param second Another array of values 402 * @return The mutual information I(first;second) 403 */ 404 public static <T1,T2> double mi(List<T1> first, List<T2> second) { 405 return innerMI(first,second).score; 406 } 407 408 /** 409 * Calculates the discrete Shannon mutual information, using histogram 410 * probability estimators. 411 * @param <T1> Type of the first variable 412 * @param <T2> Type of the second variable 413 * @param pairDist PairDistribution for the two variables. 414 * @return The mutual information I(first;second) 415 */ 416 public static <T1,T2> double mi(PairDistribution<T1,T2> pairDist) { 417 return innerMI(pairDist).score; 418 } 419 420 /** 421 * Calculates the Shannon joint entropy of two arrays, using histogram 422 * probability estimators. Arrays must be same length. 423 * @param <T1> Type of the first array. 424 * @param <T2> Type of the second array. 425 * @param first An array of values. 426 * @param second Another array of values. 427 * @return The entropy H(first,second) 428 */ 429 public static <T1,T2> double jointEntropy(List<T1> first, List<T2> second) { 430 if (first.size() == second.size()) { 431 double vectorLength = first.size(); 432 double jointEntropy = 0.0; 433 434 PairDistribution<T1,T2> countPair = PairDistribution.constructFromLists(first,second); 435 Map<CachedPair<T1,T2>,MutableLong> countDist = countPair.jointCounts; 436 437 for (Entry<CachedPair<T1,T2>,MutableLong> e : countDist.entrySet()) { 438 double prob = e.getValue().doubleValue() / vectorLength; 439 440 jointEntropy -= prob * Math.log(prob); 441 } 442 jointEntropy /= LOG_BASE; 443 444 double stateRatio = vectorLength / countDist.size(); 445 if (stateRatio < SAMPLES_RATIO) { 446 logger.log(Level.INFO, "Joint Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{jointEntropy, stateRatio}); 447 } 448 449 return jointEntropy; 450 } else { 451 throw new IllegalArgumentException("Joint Entropy requires two vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size()); 452 } 453 } 454 455 /** 456 * Calculates the discrete Shannon conditional entropy of two arrays, using 457 * histogram probability estimators. Arrays must be the same length. 458 * @param <T1> Type of the first array. 459 * @param <T2> Type of the second array. 460 * @param vector The main array of values. 461 * @param condition The array to condition on. 462 * @return The conditional entropy H(vector|condition). 463 */ 464 public static <T1,T2> double conditionalEntropy(List<T1> vector, List<T2> condition) { 465 if (vector.size() == condition.size()) { 466 double vectorLength = vector.size(); 467 double condEntropy = 0.0; 468 469 PairDistribution<T1,T2> countPair = PairDistribution.constructFromLists(vector,condition); 470 Map<CachedPair<T1,T2>,MutableLong> countDist = countPair.jointCounts; 471 Map<T2,MutableLong> conditionCountDist = countPair.secondCount; 472 473 for (Entry<CachedPair<T1,T2>,MutableLong> e : countDist.entrySet()) { 474 double prob = e.getValue().doubleValue() / vectorLength; 475 double condProb = conditionCountDist.get(e.getKey().getB()).doubleValue() / vectorLength; 476 477 condEntropy -= prob * Math.log(prob/condProb); 478 } 479 condEntropy /= LOG_BASE; 480 481 double stateRatio = vectorLength / countDist.size(); 482 if (stateRatio < SAMPLES_RATIO) { 483 logger.log(Level.INFO, "Conditional Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{condEntropy, stateRatio}); 484 } 485 486 return condEntropy; 487 } else { 488 throw new IllegalArgumentException("Conditional Entropy requires two vectors the same length. vector.size() = " + vector.size() + ", condition.size() = " + condition.size()); 489 } 490 } 491 492 /** 493 * Calculates the discrete Shannon entropy, using histogram probability 494 * estimators. 495 * @param <T> Type of the array. 496 * @param vector The array of values. 497 * @return The entropy H(vector). 498 */ 499 public static <T> double entropy(List<T> vector) { 500 double vectorLength = vector.size(); 501 double entropy = 0.0; 502 503 Map<T,Long> countDist = calculateCountDist(vector); 504 for (Entry<T,Long> e : countDist.entrySet()) { 505 double prob = e.getValue() / vectorLength; 506 entropy -= prob * Math.log(prob); 507 } 508 entropy /= LOG_BASE; 509 510 double stateRatio = vectorLength / countDist.size(); 511 if (stateRatio < SAMPLES_RATIO) { 512 logger.log(Level.INFO, "Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{entropy, stateRatio}); 513 } 514 515 return entropy; 516 } 517 518 /** 519 * Generate the counts for a single vector. 520 * @param <T> The type inside the vector. 521 * @param vector An array of values. 522 * @return A HashMap from states of T to counts. 523 */ 524 public static <T> Map<T,Long> calculateCountDist(List<T> vector) { 525 HashMap<T,Long> countDist = new HashMap<>(DEFAULT_MAP_SIZE); 526 for (T e : vector) { 527 Long curCount = countDist.getOrDefault(e,0L); 528 curCount += 1; 529 countDist.put(e, curCount); 530 } 531 532 return countDist; 533 } 534 535 /** 536 * Calculates the discrete Shannon entropy of a stream, assuming each element of the stream is 537 * an element of the same probability distribution. 538 * @param vector The probability distribution. 539 * @return The entropy. 540 */ 541 public static double calculateEntropy(Stream<Double> vector) { 542 return vector.map((p) -> (- p * Math.log(p) / LOG_BASE)).reduce(0.0, Double::sum); 543 } 544 545 /** 546 * Calculates the discrete Shannon entropy of a stream, assuming each element of the stream is 547 * an element of the same probability distribution. 548 * @param vector The probability distribution. 549 * @return The entropy. 550 */ 551 public static double calculateEntropy(DoubleStream vector) { 552 return vector.map((p) -> (- p * Math.log(p) / LOG_BASE)).sum(); 553 } 554 555 /** 556 * Compute the expected mutual information assuming randomized inputs. 557 * 558 * @param first The first vector. 559 * @param second The second vector. 560 * @param <T> The type inside the list. Must define equals and hashcode. 561 * @return The expected mutual information under a hypergeometric distribution. 562 */ 563 public static <T> double expectedMI(List<T> first, List<T> second) { 564 PairDistribution<T,T> pd = PairDistribution.constructFromLists(first,second); 565 566 Map<T, MutableLong> firstCount = pd.firstCount; 567 Map<T,MutableLong> secondCount = pd.secondCount; 568 long count = pd.count; 569 570 double output = 0.0; 571 572 for (Entry<T,MutableLong> f : firstCount.entrySet()) { 573 for (Entry<T,MutableLong> s : secondCount.entrySet()) { 574 long fVal = f.getValue().longValue(); 575 long sVal = s.getValue().longValue(); 576 long minCount = Math.min(fVal, sVal); 577 578 long threshold = fVal + sVal - count; 579 long start = threshold > 1 ? threshold : 1; 580 581 for (long nij = start; nij <= minCount; nij++) { 582 double acc = ((double) nij) / count; 583 acc *= Math.log(((double) (count * nij)) / (fVal * sVal)); 584 //numerator 585 double logSpace = Gamma.logGamma(fVal + 1); 586 logSpace += Gamma.logGamma(sVal + 1); 587 logSpace += Gamma.logGamma(count - fVal + 1); 588 logSpace += Gamma.logGamma(count - sVal + 1); 589 //denominator 590 logSpace -= Gamma.logGamma(count + 1); 591 logSpace -= Gamma.logGamma(nij + 1); 592 logSpace -= Gamma.logGamma(fVal - nij + 1); 593 logSpace -= Gamma.logGamma(sVal - nij + 1); 594 logSpace -= Gamma.logGamma(count - fVal - sVal + nij + 1); 595 acc *= Math.exp(logSpace); 596 output += acc; 597 } 598 } 599 } 600 return output; 601 } 602 603 /** 604 * A tuple of the information theoretic value, along with the number of 605 * states in the random variable. Will be a record one day. 606 */ 607 private static class ScoreStateCountTuple { 608 public final double score; 609 public final int stateCount; 610 611 /** 612 * Construct a score state tuple 613 * @param score The score. 614 * @param stateCount The number of states. 615 */ 616 ScoreStateCountTuple(double score, int stateCount) { 617 this.score = score; 618 this.stateCount = stateCount; 619 } 620 621 @Override 622 public String toString() { 623 return "ScoreStateCount(score=" + score + ",stateCount=" + stateCount + ")"; 624 } 625 } 626 627 /** 628 * An immutable named tuple containing the statistics from a G test. 629 * <p> 630 * Will be a record one day. 631 */ 632 public static final class GTestStatistics { 633 /** 634 * The G test statistic. 635 */ 636 public final double gStatistic; 637 /** 638 * The number of states. 639 */ 640 public final int numStates; 641 /** 642 * The probability of that statistic. 643 */ 644 public final double probability; 645 646 /** 647 * Constructs a GTestStatistics tuple with the supplied values. 648 * @param gStatistic The g test statistic. 649 * @param numStates The number of states. 650 * @param probability The probability of that statistic. 651 */ 652 // TODO should be package private. 653 public GTestStatistics(double gStatistic, int numStates, double probability) { 654 this.gStatistic = gStatistic; 655 this.numStates = numStates; 656 this.probability = probability; 657 } 658 659 @Override 660 public String toString() { 661 return "GTest(statistic="+gStatistic+",probability="+probability+",numStates="+numStates+")"; 662 } 663 } 664} 665