001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.util.infotheory; 018 019import com.oracle.labs.mlrg.olcut.util.MutableLong; 020import org.tribuo.util.infotheory.impl.CachedPair; 021import org.tribuo.util.infotheory.impl.CachedTriple; 022import org.tribuo.util.infotheory.impl.PairDistribution; 023import org.tribuo.util.infotheory.impl.TripleDistribution; 024import org.tribuo.util.infotheory.impl.WeightCountTuple; 025import org.tribuo.util.infotheory.impl.WeightedPairDistribution; 026import org.tribuo.util.infotheory.impl.WeightedTripleDistribution; 027 028import java.util.ArrayList; 029import java.util.LinkedHashMap; 030import java.util.List; 031import java.util.Map; 032import java.util.Map.Entry; 033import java.util.logging.Level; 034import java.util.logging.Logger; 035 036/** 037 * A class of (discrete) weighted information theoretic functions. Gives warnings if 038 * there are insufficient samples to estimate the quantities accurately. 039 * <p> 040 * Defaults to log_2, so returns values in bits. 041 * <p> 042 * All functions expect that the element types have well defined equals and 043 * hashcode, and that equals is consistent with hashcode. The behaviour is undefined 044 * if this is not true. 045 */ 046public final class WeightedInformationTheory { 047 private static final Logger logger = Logger.getLogger(WeightedInformationTheory.class.getName()); 048 049 /** 050 * The ratio of samples to symbols before emitting a warning. 051 */ 052 public static final double SAMPLES_RATIO = 5.0; 053 /** 054 * The initial size of the various maps. 055 */ 056 public static final int DEFAULT_MAP_SIZE = 20; 057 /** 058 * Log base 2. 059 */ 060 public static final double LOG_2 = Math.log(2); 061 /** 062 * Log base e. 063 */ 064 public static final double LOG_E = Math.log(Math.E); 065 066 /** 067 * Sets the base of the logarithm used in the information theoretic calculations. 068 * For LOG_2 the unit is "bit", for LOG_E the unit is "nat". 069 */ 070 public static double LOG_BASE = LOG_2; 071 072 /** 073 * Chooses which variable is the one with associated weights. 074 */ 075 public enum VariableSelector { 076 /** 077 * The first variable is weighted. 078 */ 079 FIRST, 080 /** 081 * The second variable is weighted. 082 */ 083 SECOND, 084 /** 085 * The third variable is weighted. 086 */ 087 THIRD 088 } 089 090 /** 091 * Private constructor, only has static methods. 092 */ 093 private WeightedInformationTheory() {} 094 095 /** 096 * Calculates the discrete weighted joint mutual information, using 097 * histogram probability estimators. Arrays must be the same length. 098 * @param <T1> Type contained in the first array. 099 * @param <T2> Type contained in the second array. 100 * @param <T3> Type contained in the target array. 101 * @param first An array of values. 102 * @param second Another array of values. 103 * @param target Target array of values. 104 * @param weights Array of weight values. 105 * @return The weighted mutual information I_w(first,second;joint) 106 */ 107 public static <T1,T2,T3> double jointMI(List<T1> first, List<T2> second, List<T3> target, List<Double> weights) { 108 WeightedTripleDistribution<T1, T2, T3> tripleRV = WeightedTripleDistribution.constructFromLists(first, second, target, weights); 109 110 return jointMI(tripleRV); 111 } 112 113 /** 114 * Calculates the discrete weighted joint mutual information, using 115 * histogram probability estimators. 116 * @param tripleRV The weighted triple distribution. 117 * @param <T1> The first element type. 118 * @param <T2> The second element type. 119 * @param <T3> The third element type. 120 * @return The weighted mutual information I_w(first,second;joint) 121 */ 122 public static <T1,T2,T3> double jointMI(WeightedTripleDistribution<T1,T2,T3> tripleRV) { 123 Map<CachedTriple<T1,T2,T3>, WeightCountTuple> jointCount = tripleRV.getJointCount(); 124 Map<CachedPair<T1,T2>,WeightCountTuple> abCount = tripleRV.getABCount(); 125 Map<T3,WeightCountTuple> cCount = tripleRV.getCCount(); 126 127 double vectorLength = tripleRV.count; 128 double jmi = 0.0; 129 for (Entry<CachedTriple<T1,T2,T3>,WeightCountTuple> e : jointCount.entrySet()) { 130 double jointCurCount = e.getValue().count; 131 double jointCurWeight = e.getValue().weight; 132 double prob = jointCurCount / vectorLength; 133 CachedPair<T1,T2> pair = e.getKey().getAB(); 134 double abCurCount = abCount.get(pair).count; 135 double cCurCount = cCount.get(e.getKey().getC()).count; 136 137 jmi += jointCurWeight * prob * Math.log((vectorLength*jointCurCount)/(abCurCount*cCurCount)); 138 } 139 jmi /= LOG_BASE; 140 141 double stateRatio = vectorLength / jointCount.size(); 142 if (stateRatio < SAMPLES_RATIO) { 143 logger.log(Level.INFO, "Joint MI estimate of {0} had samples/state ratio of {1}", new Object[]{jmi, stateRatio}); 144 } 145 146 return jmi; 147 } 148 149 /** 150 * Calculates the discrete weighted joint mutual information, using 151 * histogram probability estimators. 152 * @param rv The triple distribution. 153 * @param weights The weights for one of the variables. 154 * @param vs The weighted variable id. 155 * @param <T1> The first element type. 156 * @param <T2> The second element type. 157 * @param <T3> The third element type. 158 * @return The weighted mutual information I_w(first,second;joint) 159 */ 160 public static <T1,T2,T3> double jointMI(TripleDistribution<T1,T2,T3> rv, Map<?,Double> weights, VariableSelector vs){ 161 Double boxedWeight; 162 double vecLength = rv.count; 163 Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount = rv.getJointCount(); 164 Map<CachedPair<T1,T2>,MutableLong> abCount = rv.getABCount(); 165 Map<T3,MutableLong> cCount = rv.getCCount(); 166 167 double jmi = 0.0; 168 for (Entry<CachedTriple<T1,T2,T3>,MutableLong> e : jointCount.entrySet()) { 169 double jointCurCount = e.getValue().doubleValue(); 170 double prob = jointCurCount / vecLength; 171 CachedPair<T1,T2> pair = new CachedPair<>(e.getKey().getA(),e.getKey().getB()); 172 double abCurCount = abCount.get(pair).doubleValue(); 173 double cCurCount = cCount.get(e.getKey().getC()).doubleValue(); 174 175 double weight = 1.0; 176 switch (vs) { 177 case FIRST: 178 boxedWeight = weights.get(e.getKey().getA()); 179 weight = boxedWeight == null ? 1.0 : boxedWeight; 180 break; 181 case SECOND: 182 boxedWeight = weights.get(e.getKey().getB()); 183 weight = boxedWeight == null ? 1.0 : boxedWeight; 184 break; 185 case THIRD: 186 boxedWeight = weights.get(e.getKey().getC()); 187 weight = boxedWeight == null ? 1.0 : boxedWeight; 188 break; 189 } 190 191 jmi += weight * prob * Math.log((vecLength*jointCurCount)/(abCurCount*cCurCount)); 192 } 193 jmi /= LOG_BASE; 194 195 double stateRatio = vecLength / jointCount.size(); 196 if (stateRatio < SAMPLES_RATIO) { 197 logger.log(Level.INFO, "Joint MI estimate of {0} had samples/state ratio of {1}, with {2} observations and {3} states", new Object[]{jmi, stateRatio, vecLength, jointCount.size()}); 198 } 199 200 return jmi; 201 } 202 203 /** 204 * Calculates the discrete weighted conditional mutual information, using 205 * histogram probability estimators. Arrays must be the same length. 206 * @param <T1> Type contained in the first array. 207 * @param <T2> Type contained in the second array. 208 * @param <T3> Type contained in the condition array. 209 * @param first An array of values. 210 * @param second Another array of values. 211 * @param condition Array to condition upon. 212 * @param weights Array of weight values. 213 * @return The weighted conditional mutual information I_w(first;second|condition) 214 */ 215 public static <T1,T2,T3> double conditionalMI(List<T1> first, List<T2> second, List<T3> condition, List<Double> weights) { 216 if ((first.size() == second.size()) && (first.size() == condition.size()) && (first.size() == weights.size())) { 217 WeightedTripleDistribution<T1,T2,T3> tripleRV = WeightedTripleDistribution.constructFromLists(first, second, condition, weights); 218 219 return conditionalMI(tripleRV); 220 } else { 221 throw new IllegalArgumentException("Weighted Conditional Mutual Information requires four vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", condition.size() = " + condition.size() + ", weights.size() = "+ weights.size()); 222 } 223 } 224 225 /** 226 * Calculates the discrete weighted conditional mutual information, using 227 * histogram probability estimators. 228 * @param tripleRV The weighted triple distribution. 229 * @param <T1> The first element type. 230 * @param <T2> The second element type. 231 * @param <T3> The condition element type. 232 * @return The weighted conditional mutual information I_w(first;second|condition) 233 */ 234 public static <T1,T2,T3> double conditionalMI(WeightedTripleDistribution<T1,T2,T3> tripleRV) { 235 Map<CachedTriple<T1,T2,T3>,WeightCountTuple> jointCount = tripleRV.getJointCount(); 236 Map<CachedPair<T1,T3>,WeightCountTuple> acCount = tripleRV.getACCount(); 237 Map<CachedPair<T2,T3>,WeightCountTuple> bcCount = tripleRV.getBCCount(); 238 Map<T3,WeightCountTuple> cCount = tripleRV.getCCount(); 239 240 double vectorLength = tripleRV.count; 241 double cmi = 0.0; 242 for (Entry<CachedTriple<T1,T2,T3>,WeightCountTuple> e : jointCount.entrySet()) { 243 double weight = e.getValue().weight; 244 double jointCurCount = e.getValue().count; 245 double prob = jointCurCount / vectorLength; 246 CachedPair<T1,T3> acPair = e.getKey().getAC(); 247 CachedPair<T2,T3> bcPair = e.getKey().getBC(); 248 double acCurCount = acCount.get(acPair).count; 249 double bcCurCount = bcCount.get(bcPair).count; 250 double cCurCount = cCount.get(e.getKey().getC()).count; 251 252 cmi += weight * prob * Math.log((cCurCount*jointCurCount)/(acCurCount*bcCurCount)); 253 } 254 cmi /= LOG_BASE; 255 256 double stateRatio = vectorLength / jointCount.size(); 257 if (stateRatio < SAMPLES_RATIO) { 258 logger.log(Level.INFO, "Conditional MI estimate of {0} had samples/state ratio of {1}", new Object[]{cmi, stateRatio}); 259 } 260 261 return cmi; 262 } 263 264 /** 265 * Calculates the discrete weighted conditional mutual information, using 266 * histogram probability estimators. 267 * @param rv The triple distribution. 268 * @param weights The element weights. 269 * @param vs The variable to apply the weights to. 270 * @param <T1> The first element type. 271 * @param <T2> The second element type. 272 * @param <T3> The condition element type. 273 * @return The weighted conditional mutual information I_w(first;second|condition) 274 */ 275 public static <T1,T2,T3> double conditionalMI(TripleDistribution<T1,T2,T3> rv, Map<?,Double> weights, VariableSelector vs) { 276 Double boxedWeight; 277 Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount = rv.getJointCount(); 278 Map<CachedPair<T1,T3>,MutableLong> acCount = rv.getACCount(); 279 Map<CachedPair<T2,T3>,MutableLong> bcCount = rv.getBCCount(); 280 Map<T3,MutableLong> cCount = rv.getCCount(); 281 282 double vectorLength = rv.count; 283 double cmi = 0.0; 284 for (Entry<CachedTriple<T1, T2, T3>, MutableLong> e : jointCount.entrySet()) { 285 double jointCurCount = e.getValue().doubleValue(); 286 double prob = jointCurCount / vectorLength; 287 CachedPair<T1, T3> acPair = new CachedPair<>(e.getKey().getA(), e.getKey().getC()); 288 CachedPair<T2, T3> bcPair = new CachedPair<>(e.getKey().getB(), e.getKey().getC()); 289 double acCurCount = acCount.get(acPair).doubleValue(); 290 double bcCurCount = bcCount.get(bcPair).doubleValue(); 291 double cCurCount = cCount.get(e.getKey().getC()).doubleValue(); 292 293 double weight = 1.0; 294 switch (vs) { 295 case FIRST: 296 boxedWeight = weights.get(e.getKey().getA()); 297 weight = boxedWeight == null ? 1.0 : boxedWeight; 298 break; 299 case SECOND: 300 boxedWeight = weights.get(e.getKey().getB()); 301 weight = boxedWeight == null ? 1.0 : boxedWeight; 302 break; 303 case THIRD: 304 boxedWeight = weights.get(e.getKey().getC()); 305 weight = boxedWeight == null ? 1.0 : boxedWeight; 306 break; 307 } 308 cmi += weight * prob * Math.log((cCurCount * jointCurCount) / (acCurCount * bcCurCount)); 309 } 310 cmi /= LOG_BASE; 311 312 double stateRatio = vectorLength / jointCount.size(); 313 if (stateRatio < SAMPLES_RATIO) { 314 logger.log(Level.INFO, "Conditional MI estimate of {0} had samples/state ratio of {1}", new Object[]{cmi, stateRatio}); 315 } 316 317 return cmi; 318 } 319 320 /** 321 * Calculates the discrete weighted mutual information, using histogram 322 * probability estimators. 323 * <p> 324 * Arrays must be the same length. 325 * @param <T1> Type of the first array 326 * @param <T2> Type of the second array 327 * @param first An array of values 328 * @param second Another array of values 329 * @param weights Array of weight values. 330 * @return The weighted mutual information I_w(first;Second) 331 */ 332 public static <T1,T2> double mi(ArrayList<T1> first, ArrayList<T2> second, ArrayList<Double> weights) { 333 if ((first.size() == second.size()) && (first.size() == weights.size())) { 334 WeightedPairDistribution<T1,T2> countPair = WeightedPairDistribution.constructFromLists(first,second,weights); 335 return mi(countPair); 336 } else { 337 throw new IllegalArgumentException("Weighted Mutual Information requires three vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", weights.size() = " + weights.size()); 338 } 339 } 340 341 /** 342 * Calculates the discrete weighted mutual information, using histogram 343 * probability estimators. 344 * @param jointDist The weighted joint distribution. 345 * @param <T1> Type of the first element. 346 * @param <T2> Type of the second element. 347 * @return The weighted mutual information I_w(first;Second) 348 */ 349 public static <T1,T2> double mi(WeightedPairDistribution<T1,T2> jointDist) { 350 double vectorLength = jointDist.count; 351 double mi = 0.0; 352 Map<CachedPair<T1,T2>,WeightCountTuple> countDist = jointDist.getJointCounts(); 353 Map<T1,WeightCountTuple> firstCountDist = jointDist.getFirstCount(); 354 Map<T2,WeightCountTuple> secondCountDist = jointDist.getSecondCount(); 355 356 for (Entry<CachedPair<T1,T2>,WeightCountTuple> e : countDist.entrySet()) { 357 double weight = e.getValue().weight; 358 double jointCount = e.getValue().count; 359 double prob = jointCount / vectorLength; 360 double firstCount = firstCountDist.get(e.getKey().getA()).count; 361 double secondCount = secondCountDist.get(e.getKey().getB()).count; 362 363 mi += weight * prob * Math.log((vectorLength*jointCount)/(firstCount*secondCount)); 364 } 365 mi /= LOG_BASE; 366 367 double stateRatio = vectorLength / countDist.size(); 368 if (stateRatio < SAMPLES_RATIO) { 369 logger.log(Level.INFO, "MI estimate of {0} had samples/state ratio of {1}", new Object[]{mi, stateRatio}); 370 } 371 372 return mi; 373 } 374 375 /** 376 * Calculates the discrete weighted mutual information, using histogram 377 * probability estimators. 378 * @param pairDist The joint distribution. 379 * @param weights The element weights. 380 * @param vs The variable to apply the weights to. 381 * @param <T1> Type of the first element. 382 * @param <T2> Type of the second element. 383 * @return The weighted mutual information I_w(first;Second) 384 */ 385 public static <T1,T2> double mi(PairDistribution<T1,T2> pairDist, Map<?,Double> weights, VariableSelector vs) { 386 if (vs == VariableSelector.THIRD) { 387 throw new IllegalArgumentException("MI only has two variables"); 388 } 389 Map<CachedPair<T1,T2>,MutableLong> countDist = pairDist.jointCounts; 390 Map<T1,MutableLong> firstCountDist = pairDist.firstCount; 391 Map<T2,MutableLong> secondCountDist = pairDist.secondCount; 392 393 Double boxedWeight; 394 double vectorLength = pairDist.count; 395 double mi = 0.0; 396 boolean error = false; 397 for (Entry<CachedPair<T1,T2>,MutableLong> e : countDist.entrySet()) { 398 double jointCount = e.getValue().doubleValue(); 399 double prob = jointCount / vectorLength; 400 double firstProb = firstCountDist.get(e.getKey().getA()).doubleValue(); 401 double secondProb = secondCountDist.get(e.getKey().getB()).doubleValue(); 402 403 double top = vectorLength * jointCount; 404 double bottom = firstProb * secondProb; 405 double ratio = top/bottom; 406 double logRatio = Math.log(ratio); 407 408 if (Double.isNaN(logRatio) || Double.isNaN(prob) || Double.isNaN(mi)) { 409 logger.log(Level.WARNING, "State = " + e.getKey().toString()); 410 logger.log(Level.WARNING, "mi = " + mi + " prob = " + prob + " top = " + top + " bottom = " + bottom + " ratio = " + ratio + " logRatio = " + logRatio); 411 error = true; 412 } 413 414 double weight = 1.0; 415 switch (vs) { 416 case FIRST: 417 boxedWeight = weights.get(e.getKey().getA()); 418 weight = boxedWeight == null ? 1.0 : boxedWeight; 419 break; 420 case SECOND: 421 boxedWeight = weights.get(e.getKey().getB()); 422 weight = boxedWeight == null ? 1.0 : boxedWeight; 423 break; 424 default: 425 throw new IllegalArgumentException("VariableSelector.THIRD not allowed in a two variable calculation."); 426 } 427 mi += weight * prob * logRatio; 428 //mi += prob * Math.log((vectorLength*jointCount)/(firstProb*secondProb)); 429 } 430 mi /= LOG_BASE; 431 432 double stateRatio = vectorLength / countDist.size(); 433 if (stateRatio < SAMPLES_RATIO) { 434 logger.log(Level.INFO, "MI estimate of {0} had samples/state ratio of {1}", new Object[]{mi, stateRatio}); 435 } 436 437 if (error) { 438 logger.log(Level.SEVERE, "NanFound ", new IllegalStateException("NaN found")); 439 } 440 441 return mi; 442 } 443 444 /** 445 * Calculates the Shannon/Guiasu weighted joint entropy of two arrays, 446 * using histogram probability estimators. 447 * <p> 448 * Arrays must be same length. 449 * @param <T1> Type of the first array. 450 * @param <T2> Type of the second array. 451 * @param first An array of values. 452 * @param second Another array of values. 453 * @param weights Array of weight values. 454 * @return The entropy H(first,second) 455 */ 456 public static <T1,T2> double jointEntropy(ArrayList<T1> first, ArrayList<T2> second, ArrayList<Double> weights) { 457 if ((first.size() == second.size()) && (first.size() == weights.size())) { 458 double vectorLength = first.size(); 459 double jointEntropy = 0.0; 460 461 WeightedPairDistribution<T1,T2> pairDist = WeightedPairDistribution.constructFromLists(first,second,weights); 462 Map<CachedPair<T1,T2>,WeightCountTuple> countDist = pairDist.getJointCounts(); 463 464 for (Entry<CachedPair<T1,T2>,WeightCountTuple> e : countDist.entrySet()) { 465 double prob = e.getValue().count / vectorLength; 466 double weight = e.getValue().weight; 467 468 jointEntropy -= weight * prob * Math.log(prob); 469 } 470 jointEntropy /= LOG_BASE; 471 472 double stateRatio = vectorLength / countDist.size(); 473 if (stateRatio < SAMPLES_RATIO) { 474 logger.log(Level.INFO, "Weighted Joint Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{jointEntropy, stateRatio}); 475 } 476 477 return jointEntropy; 478 } else { 479 throw new IllegalArgumentException("Weighted Joint Entropy requires three vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", weights.size() = " + weights.size()); 480 } 481 } 482 483 /** 484 * Calculates the discrete Shannon/Guiasu Weighted Conditional Entropy of 485 * two arrays, using histogram probability estimators. 486 * <p> 487 * Arrays must be the same length. 488 * @param <T1> Type of the first array. 489 * @param <T2> Type of the second array. 490 * @param vector The main array of values. 491 * @param condition The array to condition on. 492 * @param weights Array of weight values. 493 * @return The weighted conditional entropy H_w(vector|condition). 494 */ 495 public static <T1,T2> double weightedConditionalEntropy(ArrayList<T1> vector, ArrayList<T2> condition, ArrayList<Double> weights) { 496 if ((vector.size() == condition.size()) && (vector.size() == weights.size())) { 497 double vectorLength = vector.size(); 498 double condEntropy = 0.0; 499 500 WeightedPairDistribution<T1,T2> pairDist = WeightedPairDistribution.constructFromLists(vector,condition,weights); 501 Map<CachedPair<T1,T2>,WeightCountTuple> countDist = pairDist.getJointCounts(); 502 Map<T2,WeightCountTuple> conditionCountDist = pairDist.getSecondCount(); 503 504 for (Entry<CachedPair<T1,T2>,WeightCountTuple> e : countDist.entrySet()) { 505 double prob = e.getValue().count / vectorLength; 506 double condProb = conditionCountDist.get(e.getKey().getB()).count / vectorLength; 507 double weight = e.getValue().weight; 508 509 condEntropy -= weight * prob * Math.log(prob/condProb); 510 } 511 condEntropy /= LOG_BASE; 512 513 double stateRatio = vectorLength / countDist.size(); 514 if (stateRatio < SAMPLES_RATIO) { 515 logger.log(Level.INFO, "Weighted Conditional Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{condEntropy, stateRatio}); 516 } 517 518 return condEntropy; 519 } else { 520 throw new IllegalArgumentException("Weighted Conditional Entropy requires three vectors the same length. vector.size() = " + vector.size() + ", condition.size() = " + condition.size() + ", weights.size() = " + weights.size()); 521 } 522 } 523 524 /** 525 * Calculates the discrete Shannon/Guiasu Weighted Entropy, using histogram 526 * probability estimators. 527 * @param <T> Type of the array. 528 * @param vector The array of values. 529 * @param weights Array of weight values. 530 * @return The weighted entropy H_w(vector). 531 */ 532 public static <T> double weightedEntropy(ArrayList<T> vector, ArrayList<Double> weights) { 533 if (vector.size() == weights.size()) { 534 double vectorLength = vector.size(); 535 double entropy = 0.0; 536 537 Map<T,WeightCountTuple> countDist = calculateWeightedCountDist(vector,weights); 538 for (Entry<T,WeightCountTuple> e : countDist.entrySet()) { 539 long count = e.getValue().count; 540 double weight = e.getValue().weight; 541 double prob = count / vectorLength; 542 entropy -= weight * prob * Math.log(prob); 543 } 544 entropy /= LOG_BASE; 545 546 double stateRatio = vectorLength / countDist.size(); 547 if (stateRatio < SAMPLES_RATIO) { 548 logger.log(Level.INFO, "Weighted Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{entropy, stateRatio}); 549 } 550 551 return entropy; 552 } else { 553 throw new IllegalArgumentException("Weighted Entropy requires two vectors the same length. vector.size() = " + vector.size() + ",weights.size() = " + weights.size()); 554 } 555 } 556 557 /** 558 * Generate the counts for a single vector. 559 * @param <T> The type inside the vector. 560 * @param vector An array of values. 561 * @param weights The array of weight values. 562 * @return A HashMap from states of T to Pairs of count and total weight for that state. 563 */ 564 public static <T> Map<T,WeightCountTuple> calculateWeightedCountDist(ArrayList<T> vector, ArrayList<Double> weights) { 565 Map<T,WeightCountTuple> dist = new LinkedHashMap<>(DEFAULT_MAP_SIZE); 566 for (int i = 0; i < vector.size(); i++) { 567 T e = vector.get(i); 568 Double weight = weights.get(i); 569 WeightCountTuple curVal = dist.computeIfAbsent(e,(k) -> new WeightCountTuple()); 570 curVal.count += 1; 571 curVal.weight += weight; 572 } 573 574 normaliseWeights(dist); 575 576 return dist; 577 } 578 579 /** 580 * Normalizes the weights in the map, i.e., divides each weight by it's count. 581 * @param map The map to normalize. 582 * @param <T> The type of the variable that was counted. 583 */ 584 public static <T> void normaliseWeights(Map<T,WeightCountTuple> map) { 585 for (Entry<T,WeightCountTuple> e : map.entrySet()) { 586 WeightCountTuple tuple = e.getValue(); 587 tuple.weight /= tuple.count; 588 } 589 } 590 591}