001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.util.infotheory;
018
019import com.oracle.labs.mlrg.olcut.util.MutableLong;
020import org.tribuo.util.infotheory.impl.CachedPair;
021import org.tribuo.util.infotheory.impl.CachedTriple;
022import org.tribuo.util.infotheory.impl.PairDistribution;
023import org.tribuo.util.infotheory.impl.TripleDistribution;
024import org.tribuo.util.infotheory.impl.WeightCountTuple;
025import org.tribuo.util.infotheory.impl.WeightedPairDistribution;
026import org.tribuo.util.infotheory.impl.WeightedTripleDistribution;
027
028import java.util.ArrayList;
029import java.util.LinkedHashMap;
030import java.util.List;
031import java.util.Map;
032import java.util.Map.Entry;
033import java.util.logging.Level;
034import java.util.logging.Logger;
035
036/**
037 * A class of (discrete) weighted information theoretic functions. Gives warnings if
038 * there are insufficient samples to estimate the quantities accurately.
039 * <p>
040 * Defaults to log_2, so returns values in bits.
041 * <p>
042 * All functions expect that the element types have well defined equals and
043 * hashcode, and that equals is consistent with hashcode. The behaviour is undefined
044 * if this is not true.
045 */
046public final class WeightedInformationTheory {
047    private static final Logger logger = Logger.getLogger(WeightedInformationTheory.class.getName());
048
049    /**
050     * The ratio of samples to symbols before emitting a warning.
051     */
052    public static final double SAMPLES_RATIO = 5.0;
053    /**
054     * The initial size of the various maps.
055     */
056    public static final int DEFAULT_MAP_SIZE = 20;
057    /**
058     * Log base 2.
059     */
060    public static final double LOG_2 = Math.log(2);
061    /**
062     * Log base e.
063     */
064    public static final double LOG_E = Math.log(Math.E);
065
066    /**
067     * Sets the base of the logarithm used in the information theoretic calculations.
068     * For LOG_2 the unit is "bit", for LOG_E the unit is "nat".
069     */
070    public static double LOG_BASE = LOG_2;
071
072    /**
073     * Chooses which variable is the one with associated weights.
074     */
075    public enum VariableSelector {
076        /**
077         * The first variable is weighted.
078         */
079        FIRST,
080        /**
081         * The second variable is weighted.
082         */
083        SECOND,
084        /**
085         * The third variable is weighted.
086         */
087        THIRD
088    }
089
090    /**
091     * Private constructor, only has static methods.
092     */
093    private WeightedInformationTheory() {}
094
095    /**
096     * Calculates the discrete weighted joint mutual information, using
097     * histogram probability estimators. Arrays must be the same length.
098     * @param <T1> Type contained in the first array.
099     * @param <T2> Type contained in the second array.
100     * @param <T3> Type contained in the target array.
101     * @param first An array of values.
102     * @param second Another array of values.
103     * @param target Target array of values.
104     * @param weights Array of weight values.
105     * @return The weighted mutual information I_w(first,second;joint)
106     */
107    public static <T1,T2,T3> double jointMI(List<T1> first, List<T2> second, List<T3> target, List<Double> weights) {
108        WeightedTripleDistribution<T1, T2, T3> tripleRV = WeightedTripleDistribution.constructFromLists(first, second, target, weights);
109
110        return jointMI(tripleRV);
111    }
112
113    /**
114     * Calculates the discrete weighted joint mutual information, using
115     * histogram probability estimators.
116     * @param tripleRV The weighted triple distribution.
117     * @param <T1> The first element type.
118     * @param <T2> The second element type.
119     * @param <T3> The third element type.
120     * @return The weighted mutual information I_w(first,second;joint)
121     */
122    public static <T1,T2,T3> double jointMI(WeightedTripleDistribution<T1,T2,T3> tripleRV) {
123        Map<CachedTriple<T1,T2,T3>, WeightCountTuple> jointCount = tripleRV.getJointCount();
124        Map<CachedPair<T1,T2>,WeightCountTuple> abCount = tripleRV.getABCount();
125        Map<T3,WeightCountTuple> cCount = tripleRV.getCCount();
126
127        double vectorLength = tripleRV.count;
128        double jmi = 0.0;
129        for (Entry<CachedTriple<T1,T2,T3>,WeightCountTuple> e : jointCount.entrySet()) {
130            double jointCurCount = e.getValue().count;
131            double jointCurWeight = e.getValue().weight;
132            double prob = jointCurCount / vectorLength;
133            CachedPair<T1,T2> pair = e.getKey().getAB();
134            double abCurCount = abCount.get(pair).count;
135            double cCurCount = cCount.get(e.getKey().getC()).count;
136
137            jmi += jointCurWeight * prob * Math.log((vectorLength*jointCurCount)/(abCurCount*cCurCount));
138        }
139        jmi /= LOG_BASE;
140
141        double stateRatio = vectorLength / jointCount.size();
142        if (stateRatio < SAMPLES_RATIO) {
143            logger.log(Level.INFO, "Joint MI estimate of {0} had samples/state ratio of {1}", new Object[]{jmi, stateRatio});
144        }
145        
146        return jmi;
147    }
148
149    /**
150     * Calculates the discrete weighted joint mutual information, using
151     * histogram probability estimators.
152     * @param rv The triple distribution.
153     * @param weights The weights for one of the variables.
154     * @param vs The weighted variable id.
155     * @param <T1> The first element type.
156     * @param <T2> The second element type.
157     * @param <T3> The third element type.
158     * @return The weighted mutual information I_w(first,second;joint)
159     */
160    public static <T1,T2,T3> double jointMI(TripleDistribution<T1,T2,T3> rv, Map<?,Double> weights, VariableSelector vs){
161        Double boxedWeight;
162        double vecLength = rv.count;
163        Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount = rv.getJointCount();
164        Map<CachedPair<T1,T2>,MutableLong> abCount = rv.getABCount();
165        Map<T3,MutableLong> cCount = rv.getCCount();
166
167        double jmi = 0.0;
168        for (Entry<CachedTriple<T1,T2,T3>,MutableLong> e : jointCount.entrySet()) {
169            double jointCurCount = e.getValue().doubleValue();
170            double prob = jointCurCount / vecLength;
171            CachedPair<T1,T2> pair = new CachedPair<>(e.getKey().getA(),e.getKey().getB());
172            double abCurCount = abCount.get(pair).doubleValue();
173            double cCurCount = cCount.get(e.getKey().getC()).doubleValue();
174
175            double weight = 1.0;
176            switch (vs) {
177                case FIRST:
178                    boxedWeight = weights.get(e.getKey().getA());
179                    weight = boxedWeight == null ? 1.0 : boxedWeight;
180                    break;
181                case SECOND:
182                    boxedWeight = weights.get(e.getKey().getB());
183                    weight = boxedWeight == null ? 1.0 : boxedWeight;
184                    break;
185                case THIRD:
186                    boxedWeight = weights.get(e.getKey().getC());
187                    weight = boxedWeight == null ? 1.0 : boxedWeight;
188                    break;
189            }
190
191            jmi += weight * prob * Math.log((vecLength*jointCurCount)/(abCurCount*cCurCount));
192        }
193        jmi /= LOG_BASE;
194
195        double stateRatio = vecLength / jointCount.size();
196        if (stateRatio < SAMPLES_RATIO) {
197            logger.log(Level.INFO, "Joint MI estimate of {0} had samples/state ratio of {1}, with {2} observations and {3} states", new Object[]{jmi, stateRatio, vecLength, jointCount.size()});
198        }
199
200        return jmi;
201    }
202
203    /**
204     * Calculates the discrete weighted conditional mutual information, using
205     * histogram probability estimators. Arrays must be the same length.
206     * @param <T1> Type contained in the first array.
207     * @param <T2> Type contained in the second array.
208     * @param <T3> Type contained in the condition array.
209     * @param first An array of values.
210     * @param second Another array of values.
211     * @param condition Array to condition upon.
212     * @param weights Array of weight values.
213     * @return The weighted conditional mutual information I_w(first;second|condition)
214     */
215    public static <T1,T2,T3> double conditionalMI(List<T1> first, List<T2> second, List<T3> condition, List<Double> weights) {
216        if ((first.size() == second.size()) && (first.size() == condition.size()) && (first.size() == weights.size())) {
217            WeightedTripleDistribution<T1,T2,T3> tripleRV = WeightedTripleDistribution.constructFromLists(first, second, condition, weights);
218
219            return conditionalMI(tripleRV);
220        } else {
221            throw new IllegalArgumentException("Weighted Conditional Mutual Information requires four vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", condition.size() = " + condition.size() + ", weights.size() = "+ weights.size());
222        }
223    }
224
225    /**
226     * Calculates the discrete weighted conditional mutual information, using
227     * histogram probability estimators.
228     * @param tripleRV The weighted triple distribution.
229     * @param <T1> The first element type.
230     * @param <T2> The second element type.
231     * @param <T3> The condition element type.
232     * @return The weighted conditional mutual information I_w(first;second|condition)
233     */
234    public static <T1,T2,T3> double conditionalMI(WeightedTripleDistribution<T1,T2,T3> tripleRV) {
235        Map<CachedTriple<T1,T2,T3>,WeightCountTuple> jointCount = tripleRV.getJointCount();
236        Map<CachedPair<T1,T3>,WeightCountTuple> acCount = tripleRV.getACCount();
237        Map<CachedPair<T2,T3>,WeightCountTuple> bcCount = tripleRV.getBCCount();
238        Map<T3,WeightCountTuple> cCount = tripleRV.getCCount();
239
240        double vectorLength = tripleRV.count;
241        double cmi = 0.0;
242        for (Entry<CachedTriple<T1,T2,T3>,WeightCountTuple> e : jointCount.entrySet()) {
243            double weight = e.getValue().weight;
244            double jointCurCount = e.getValue().count;
245            double prob = jointCurCount / vectorLength;
246            CachedPair<T1,T3> acPair = e.getKey().getAC();
247            CachedPair<T2,T3> bcPair = e.getKey().getBC();
248            double acCurCount = acCount.get(acPair).count;
249            double bcCurCount = bcCount.get(bcPair).count;
250            double cCurCount = cCount.get(e.getKey().getC()).count;
251
252            cmi += weight * prob * Math.log((cCurCount*jointCurCount)/(acCurCount*bcCurCount));
253        }
254        cmi /= LOG_BASE;
255
256        double stateRatio = vectorLength / jointCount.size();
257        if (stateRatio < SAMPLES_RATIO) {
258            logger.log(Level.INFO, "Conditional MI estimate of {0} had samples/state ratio of {1}", new Object[]{cmi, stateRatio});
259        }
260
261        return cmi;
262    }
263
264    /**
265     * Calculates the discrete weighted conditional mutual information, using
266     * histogram probability estimators.
267     * @param rv The triple distribution.
268     * @param weights The element weights.
269     * @param vs The variable to apply the weights to.
270     * @param <T1> The first element type.
271     * @param <T2> The second element type.
272     * @param <T3> The condition element type.
273     * @return The weighted conditional mutual information I_w(first;second|condition)
274     */
275    public static <T1,T2,T3> double conditionalMI(TripleDistribution<T1,T2,T3> rv, Map<?,Double> weights, VariableSelector vs) {
276        Double boxedWeight;
277        Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount = rv.getJointCount();
278        Map<CachedPair<T1,T3>,MutableLong> acCount = rv.getACCount();
279        Map<CachedPair<T2,T3>,MutableLong> bcCount = rv.getBCCount();
280        Map<T3,MutableLong> cCount = rv.getCCount();
281
282        double vectorLength = rv.count;
283        double cmi = 0.0;
284        for (Entry<CachedTriple<T1, T2, T3>, MutableLong> e : jointCount.entrySet()) {
285            double jointCurCount = e.getValue().doubleValue();
286            double prob = jointCurCount / vectorLength;
287            CachedPair<T1, T3> acPair = new CachedPair<>(e.getKey().getA(), e.getKey().getC());
288            CachedPair<T2, T3> bcPair = new CachedPair<>(e.getKey().getB(), e.getKey().getC());
289            double acCurCount = acCount.get(acPair).doubleValue();
290            double bcCurCount = bcCount.get(bcPair).doubleValue();
291            double cCurCount = cCount.get(e.getKey().getC()).doubleValue();
292
293            double weight = 1.0;
294            switch (vs) {
295                case FIRST:
296                    boxedWeight = weights.get(e.getKey().getA());
297                    weight = boxedWeight == null ? 1.0 : boxedWeight;
298                    break;
299                case SECOND:
300                    boxedWeight = weights.get(e.getKey().getB());
301                    weight = boxedWeight == null ? 1.0 : boxedWeight;
302                    break;
303                case THIRD:
304                    boxedWeight = weights.get(e.getKey().getC());
305                    weight = boxedWeight == null ? 1.0 : boxedWeight;
306                    break;
307            }
308            cmi += weight * prob * Math.log((cCurCount * jointCurCount) / (acCurCount * bcCurCount));
309        }
310        cmi /= LOG_BASE;
311
312        double stateRatio = vectorLength / jointCount.size();
313        if (stateRatio < SAMPLES_RATIO) {
314            logger.log(Level.INFO, "Conditional MI estimate of {0} had samples/state ratio of {1}", new Object[]{cmi, stateRatio});
315        }
316
317        return cmi;
318    }
319
320    /**
321     * Calculates the discrete weighted mutual information, using histogram
322     * probability estimators.
323     * <p>
324     * Arrays must be the same length.
325     * @param <T1> Type of the first array
326     * @param <T2> Type of the second array
327     * @param first An array of values
328     * @param second Another array of values
329     * @param weights Array of weight values.
330     * @return The weighted mutual information I_w(first;Second)
331     */
332    public static <T1,T2> double mi(ArrayList<T1> first, ArrayList<T2> second, ArrayList<Double> weights) {
333        if ((first.size() == second.size()) && (first.size() == weights.size())) {
334            WeightedPairDistribution<T1,T2> countPair = WeightedPairDistribution.constructFromLists(first,second,weights);
335            return mi(countPair);
336        } else {
337            throw new IllegalArgumentException("Weighted Mutual Information requires three vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", weights.size() = " + weights.size());
338        }
339    }
340
341    /**
342     * Calculates the discrete weighted mutual information, using histogram
343     * probability estimators.
344     * @param jointDist The weighted joint distribution.
345     * @param <T1> Type of the first element.
346     * @param <T2> Type of the second element.
347     * @return The weighted mutual information I_w(first;Second)
348     */
349    public static <T1,T2> double mi(WeightedPairDistribution<T1,T2> jointDist) {
350        double vectorLength = jointDist.count;
351        double mi = 0.0;
352        Map<CachedPair<T1,T2>,WeightCountTuple> countDist = jointDist.getJointCounts();
353        Map<T1,WeightCountTuple> firstCountDist = jointDist.getFirstCount();
354        Map<T2,WeightCountTuple> secondCountDist = jointDist.getSecondCount();
355
356        for (Entry<CachedPair<T1,T2>,WeightCountTuple> e : countDist.entrySet()) {
357            double weight = e.getValue().weight;
358            double jointCount = e.getValue().count;
359            double prob = jointCount / vectorLength;
360            double firstCount = firstCountDist.get(e.getKey().getA()).count;
361            double secondCount = secondCountDist.get(e.getKey().getB()).count;
362
363            mi += weight * prob * Math.log((vectorLength*jointCount)/(firstCount*secondCount));
364        }
365        mi /= LOG_BASE;
366
367        double stateRatio = vectorLength / countDist.size();
368        if (stateRatio < SAMPLES_RATIO) {
369            logger.log(Level.INFO, "MI estimate of {0} had samples/state ratio of {1}", new Object[]{mi, stateRatio});
370        }
371
372        return mi;
373    }
374
375    /**
376     * Calculates the discrete weighted mutual information, using histogram
377     * probability estimators.
378     * @param pairDist The joint distribution.
379     * @param weights The element weights.
380     * @param vs The variable to apply the weights to.
381     * @param <T1> Type of the first element.
382     * @param <T2> Type of the second element.
383     * @return The weighted mutual information I_w(first;Second)
384     */
385    public static <T1,T2> double mi(PairDistribution<T1,T2> pairDist, Map<?,Double> weights, VariableSelector vs) {
386        if (vs == VariableSelector.THIRD) {
387            throw new IllegalArgumentException("MI only has two variables");
388        }
389        Map<CachedPair<T1,T2>,MutableLong> countDist = pairDist.jointCounts;
390        Map<T1,MutableLong> firstCountDist = pairDist.firstCount;
391        Map<T2,MutableLong> secondCountDist = pairDist.secondCount;
392
393        Double boxedWeight;
394        double vectorLength = pairDist.count;
395        double mi = 0.0;
396        boolean error = false;
397        for (Entry<CachedPair<T1,T2>,MutableLong> e : countDist.entrySet()) {
398            double jointCount = e.getValue().doubleValue();
399            double prob = jointCount / vectorLength;
400            double firstProb = firstCountDist.get(e.getKey().getA()).doubleValue();
401            double secondProb = secondCountDist.get(e.getKey().getB()).doubleValue();
402
403            double top = vectorLength * jointCount;
404            double bottom = firstProb * secondProb;
405            double ratio = top/bottom;
406            double logRatio = Math.log(ratio);
407
408            if (Double.isNaN(logRatio) || Double.isNaN(prob) || Double.isNaN(mi)) {
409                logger.log(Level.WARNING, "State = " + e.getKey().toString());
410                logger.log(Level.WARNING, "mi = " + mi + " prob = " + prob + " top = " + top + " bottom = " + bottom + " ratio = " + ratio + " logRatio = " + logRatio);
411                error = true;
412            }
413
414            double weight = 1.0;
415            switch (vs) {
416                case FIRST:
417                    boxedWeight = weights.get(e.getKey().getA());
418                    weight = boxedWeight == null ? 1.0 : boxedWeight;
419                    break;
420                case SECOND:
421                    boxedWeight = weights.get(e.getKey().getB());
422                    weight = boxedWeight == null ? 1.0 : boxedWeight;
423                    break;
424                default:
425                    throw new IllegalArgumentException("VariableSelector.THIRD not allowed in a two variable calculation.");
426            }
427            mi += weight * prob * logRatio;
428            //mi += prob * Math.log((vectorLength*jointCount)/(firstProb*secondProb));
429        }
430        mi /= LOG_BASE;
431
432        double stateRatio = vectorLength / countDist.size();
433        if (stateRatio < SAMPLES_RATIO) {
434            logger.log(Level.INFO, "MI estimate of {0} had samples/state ratio of {1}", new Object[]{mi, stateRatio});
435        }
436
437        if (error) {
438            logger.log(Level.SEVERE, "NanFound ", new IllegalStateException("NaN found"));
439        }
440
441        return mi;
442    }
443
444    /**
445     * Calculates the Shannon/Guiasu weighted joint entropy of two arrays, 
446     * using histogram probability estimators. 
447     * <p>
448     * Arrays must be same length.
449     * @param <T1> Type of the first array.
450     * @param <T2> Type of the second array.
451     * @param first An array of values.
452     * @param second Another array of values.
453     * @param weights Array of weight values.
454     * @return The entropy H(first,second)
455     */
456    public static <T1,T2> double jointEntropy(ArrayList<T1> first, ArrayList<T2> second, ArrayList<Double> weights) {
457        if ((first.size() == second.size()) && (first.size() == weights.size())) {
458            double vectorLength = first.size();
459            double jointEntropy = 0.0;
460            
461            WeightedPairDistribution<T1,T2> pairDist = WeightedPairDistribution.constructFromLists(first,second,weights);
462            Map<CachedPair<T1,T2>,WeightCountTuple> countDist = pairDist.getJointCounts();
463
464            for (Entry<CachedPair<T1,T2>,WeightCountTuple> e : countDist.entrySet()) {
465                double prob = e.getValue().count / vectorLength;
466                double weight = e.getValue().weight;
467
468                jointEntropy -= weight * prob * Math.log(prob);
469            }
470            jointEntropy /= LOG_BASE;
471
472            double stateRatio = vectorLength / countDist.size();
473            if (stateRatio < SAMPLES_RATIO) {
474                logger.log(Level.INFO, "Weighted Joint Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{jointEntropy, stateRatio});
475            }
476            
477            return jointEntropy;
478        } else {
479            throw new IllegalArgumentException("Weighted Joint Entropy requires three vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", weights.size() = " + weights.size());
480        }
481    }
482    
483    /**
484     * Calculates the discrete Shannon/Guiasu Weighted Conditional Entropy of 
485     * two arrays, using histogram probability estimators. 
486     * <p>
487     * Arrays must be the same length.
488     * @param <T1> Type of the first array.
489     * @param <T2> Type of the second array.
490     * @param vector The main array of values.
491     * @param condition The array to condition on.
492     * @param weights Array of weight values.
493     * @return The weighted conditional entropy H_w(vector|condition).
494     */
495    public static <T1,T2> double weightedConditionalEntropy(ArrayList<T1> vector, ArrayList<T2> condition, ArrayList<Double> weights) {
496        if ((vector.size() == condition.size()) && (vector.size() == weights.size())) {
497            double vectorLength = vector.size();
498            double condEntropy = 0.0;
499            
500            WeightedPairDistribution<T1,T2> pairDist = WeightedPairDistribution.constructFromLists(vector,condition,weights);
501            Map<CachedPair<T1,T2>,WeightCountTuple> countDist = pairDist.getJointCounts();
502            Map<T2,WeightCountTuple> conditionCountDist = pairDist.getSecondCount();
503
504            for (Entry<CachedPair<T1,T2>,WeightCountTuple> e : countDist.entrySet()) {
505                double prob = e.getValue().count / vectorLength;
506                double condProb = conditionCountDist.get(e.getKey().getB()).count / vectorLength;
507                double weight = e.getValue().weight;
508
509                condEntropy -= weight * prob * Math.log(prob/condProb);
510            }
511            condEntropy /= LOG_BASE;
512
513            double stateRatio = vectorLength / countDist.size();
514            if (stateRatio < SAMPLES_RATIO) {
515                logger.log(Level.INFO, "Weighted Conditional Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{condEntropy, stateRatio});
516            }
517            
518            return condEntropy;
519        } else {
520            throw new IllegalArgumentException("Weighted Conditional Entropy requires three vectors the same length. vector.size() = " + vector.size() + ", condition.size() = " + condition.size() + ", weights.size() = " + weights.size());
521        }
522    }
523
524    /**
525     * Calculates the discrete Shannon/Guiasu Weighted Entropy, using histogram 
526     * probability estimators.
527     * @param <T> Type of the array.
528     * @param vector The array of values.
529     * @param weights Array of weight values.
530     * @return The weighted entropy H_w(vector).
531     */
532    public static <T> double weightedEntropy(ArrayList<T> vector, ArrayList<Double> weights) {
533        if (vector.size() == weights.size()) {
534            double vectorLength = vector.size();
535            double entropy = 0.0;
536
537            Map<T,WeightCountTuple> countDist = calculateWeightedCountDist(vector,weights);
538            for (Entry<T,WeightCountTuple> e : countDist.entrySet()) {
539                long count = e.getValue().count;
540                double weight = e.getValue().weight;
541                double prob = count / vectorLength;
542                entropy -= weight * prob * Math.log(prob);
543            }
544            entropy /= LOG_BASE;
545
546            double stateRatio = vectorLength / countDist.size();
547            if (stateRatio < SAMPLES_RATIO) {
548                logger.log(Level.INFO, "Weighted Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{entropy, stateRatio});
549            }
550            
551            return entropy;
552        } else {
553            throw new IllegalArgumentException("Weighted Entropy requires two vectors the same length. vector.size() = " + vector.size() + ",weights.size() = " + weights.size());
554        }
555    }
556
557    /**
558     * Generate the counts for a single vector.
559     * @param <T> The type inside the vector.
560     * @param vector An array of values.
561     * @param weights The array of weight values.
562     * @return A HashMap from states of T to Pairs of count and total weight for that state.
563     */
564    public static <T> Map<T,WeightCountTuple> calculateWeightedCountDist(ArrayList<T> vector, ArrayList<Double> weights) {
565        Map<T,WeightCountTuple> dist = new LinkedHashMap<>(DEFAULT_MAP_SIZE);
566        for (int i = 0; i < vector.size(); i++) {
567            T e = vector.get(i);
568            Double weight = weights.get(i);
569            WeightCountTuple curVal = dist.computeIfAbsent(e,(k) -> new WeightCountTuple());
570            curVal.count += 1;
571            curVal.weight += weight;
572        }
573
574        normaliseWeights(dist);
575
576        return dist;
577    }
578
579    /**
580     * Normalizes the weights in the map, i.e., divides each weight by it's count.
581     * @param map The map to normalize.
582     * @param <T> The type of the variable that was counted.
583     */
584    public static <T> void normaliseWeights(Map<T,WeightCountTuple> map) {
585        for (Entry<T,WeightCountTuple> e : map.entrySet()) {
586            WeightCountTuple tuple = e.getValue();
587            tuple.weight /= tuple.count;
588        }
589    }
590    
591}