001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.util.infotheory.impl;
018
019import org.tribuo.util.infotheory.InformationTheory;
020import org.tribuo.util.infotheory.WeightedInformationTheory;
021
022import java.util.LinkedHashMap;
023import java.util.List;
024import java.util.Map;
025import java.util.Map.Entry;
026
027/**
028 * Generates the counts for a pair of vectors. Contains the joint
029 * count and the two marginal counts.
030 * @param <T1> Type of the first list.
031 * @param <T2> Type of the second list.
032 */
033public class WeightedPairDistribution<T1,T2> {
034
035    /**
036     * The sample count.
037     */
038    public final long count;
039
040    private final Map<CachedPair<T1,T2>,WeightCountTuple> jointCounts;
041    private final Map<T1,WeightCountTuple> firstCount;
042    private final Map<T2,WeightCountTuple> secondCount;
043
044    /**
045     * Constructs a weighted pair distribution from the supplied values.
046     * <p>
047     * Copies the maps out into LinkedHashMaps for iteration speed.
048     * @param count The sample count.
049     * @param jointCounts The joint distribution.
050     * @param firstCount The first marginal distribution.
051     * @param secondCount The second marginal distribution.
052     */
053    public WeightedPairDistribution(long count, Map<CachedPair<T1,T2>,WeightCountTuple> jointCounts, Map<T1,WeightCountTuple> firstCount, Map<T2,WeightCountTuple> secondCount) {
054        this.count = count;
055        this.jointCounts = new LinkedHashMap<>(jointCounts);
056        this.firstCount = new LinkedHashMap<>(firstCount);
057        this.secondCount = new LinkedHashMap<>(secondCount);
058    }
059
060    /**
061     * Constructs a weighted pair distribution from the supplied values.
062     * @param count The sample count.
063     * @param jointCounts The joint distribution.
064     * @param firstCount The first marginal distribution.
065     * @param secondCount The second marginal distribution.
066     */
067    public WeightedPairDistribution(long count, LinkedHashMap<CachedPair<T1,T2>,WeightCountTuple> jointCounts, LinkedHashMap<T1,WeightCountTuple> firstCount, LinkedHashMap<T2,WeightCountTuple> secondCount) {
068        this.count = count;
069        this.jointCounts = jointCounts;
070        this.firstCount = firstCount;
071        this.secondCount = secondCount;
072    }
073
074    /**
075     * Gets the joint distribution.
076     * @return The joint distribution.
077     */
078    public Map<CachedPair<T1,T2>,WeightCountTuple> getJointCounts() {
079        return jointCounts;
080    }
081
082    /**
083     * Gets the first marginal distribution.
084     * @return The first marginal distribution.
085     */
086    public Map<T1,WeightCountTuple> getFirstCount() {
087        return firstCount;
088    }
089
090    /**
091     * Gets the second marginal distribution.
092     * @return The second marginal distribution.
093     */
094    public Map<T2,WeightCountTuple> getSecondCount() {
095        return secondCount;
096    }
097    
098    /**
099     * Generates the counts for two vectors. Returns a pair containing the joint
100     * count, and a pair of the two marginal counts.
101     * @param <T1> Type of the first list.
102     * @param <T2> Type of the second list.
103     * @param first An list of values.
104     * @param second Another list of values.
105     * @param weights An list of per example weights.
106     * @return A WeightedPairDistribution.
107     */
108    public static <T1,T2> WeightedPairDistribution<T1,T2> constructFromLists(List<T1> first, List<T2> second, List<Double> weights) {
109        LinkedHashMap<CachedPair<T1,T2>,WeightCountTuple> countDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE);
110        LinkedHashMap<T1,WeightCountTuple> aCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE);
111        LinkedHashMap<T2,WeightCountTuple> bCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE);
112
113        if ((first.size() == second.size()) && (first.size() == weights.size())) {
114            long count = 0;
115            for (int i = 0; i < first.size(); i++) {
116                T1 a = first.get(i);
117                T2 b = second.get(i);
118                double weight = weights.get(i);
119                CachedPair<T1,T2> pair = new CachedPair<>(a,b);
120
121                WeightCountTuple abCurCount = countDist.computeIfAbsent(pair,(k) -> new WeightCountTuple());
122                abCurCount.weight += weight;
123                abCurCount.count++;
124
125                WeightCountTuple aCurCount = aCountDist.computeIfAbsent(a,(k) -> new WeightCountTuple());
126                aCurCount.weight += weight;
127                aCurCount.count++;
128
129                WeightCountTuple bCurCount = bCountDist.computeIfAbsent(b,(k) -> new WeightCountTuple());
130                bCurCount.weight += weight;
131                bCurCount.count++;
132
133                count++;
134            }
135
136            WeightedInformationTheory.normaliseWeights(countDist);
137            WeightedInformationTheory.normaliseWeights(aCountDist);
138            WeightedInformationTheory.normaliseWeights(bCountDist);
139
140            return new WeightedPairDistribution<>(count,countDist,aCountDist,bCountDist);
141        } else {
142            throw new IllegalArgumentException("Counting requires lists of the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", weights.size() = " + weights.size());
143        }
144    }
145
146    /**
147     * Generates a WeightedPairDistribution by generating the marginal distributions for the first and second elements.
148     * This assumes the weights have already been normalised.
149     * @param <T1> Type of the first element.
150     * @param <T2> Type of the second element.
151     * @param jointCount The (normalised) input map.
152     * @return A WeightedPairDistribution
153     */
154    public static <T1,T2> WeightedPairDistribution<T1,T2> constructFromMap(Map<CachedPair<T1,T2>,WeightCountTuple> jointCount) {
155        LinkedHashMap<CachedPair<T1,T2>,WeightCountTuple> countDist = new LinkedHashMap<>(jointCount);
156        LinkedHashMap<T1,WeightCountTuple> aCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE);
157        LinkedHashMap<T2,WeightCountTuple> bCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE);
158
159        long count = 0L;
160        
161        for (Entry<CachedPair<T1,T2>,WeightCountTuple> e : countDist.entrySet()) {
162            CachedPair<T1,T2> pair = e.getKey();
163            WeightCountTuple tuple = e.getValue();
164            T1 a = pair.getA();
165            T2 b = pair.getB();
166            double weight = tuple.weight * tuple.count;
167
168            WeightCountTuple aCurCount = aCountDist.computeIfAbsent(a,(k) -> new WeightCountTuple());
169            aCurCount.weight += weight;
170            aCurCount.count += tuple.count;
171
172            WeightCountTuple bCurCount = bCountDist.computeIfAbsent(b,(k) -> new WeightCountTuple());
173            bCurCount.weight += weight;
174            bCurCount.count += tuple.count;
175
176            count += tuple.count;
177        }
178
179        WeightedInformationTheory.normaliseWeights(aCountDist);
180        WeightedInformationTheory.normaliseWeights(bCountDist);
181
182        return new WeightedPairDistribution<>(count,countDist,aCountDist,bCountDist);
183    }
184
185}