001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.util.infotheory.impl;
018
019import org.tribuo.util.infotheory.WeightedInformationTheory;
020
021import java.util.HashMap;
022import java.util.List;
023import java.util.Map;
024import java.util.Map.Entry;
025
026/**
027 * Generates the counts for a triplet of vectors. Contains the joint
028 * count, the three pairwise counts, and the three marginal counts.
029 * @param <T1> Type of the first list.
030 * @param <T2> Type of the second list.
031 * @param <T3> Type of the third list.
032 */
033public class WeightedTripleDistribution<T1,T2,T3> {
034    /**
035     * The default map size.
036     */
037    public static final int DEFAULT_MAP_SIZE = 20;
038
039    /**
040     * The sample count.
041     */
042    public final long count;
043
044    private final Map<CachedTriple<T1,T2,T3>,WeightCountTuple> jointCount;
045    private final Map<CachedPair<T1,T2>,WeightCountTuple> abCount;
046    private final Map<CachedPair<T1,T3>,WeightCountTuple> acCount;
047    private final Map<CachedPair<T2,T3>,WeightCountTuple> bcCount;
048    private final Map<T1,WeightCountTuple> aCount;
049    private final Map<T2,WeightCountTuple> bCount;
050    private final Map<T3,WeightCountTuple> cCount;
051
052    /**
053     * Constructs a weighted triple distribution from the supplied values.
054     * @param count The sample count.
055     * @param jointCount The ABC joint distribution.
056     * @param abCount The AB joint distribution.
057     * @param acCount The AC joint distribution.
058     * @param bcCount The BC joint distribution.
059     * @param aCount The A marginal distribution.
060     * @param bCount The B marginal distribution.
061     * @param cCount The C marginal distribution.
062     */
063    public WeightedTripleDistribution(long count, Map<CachedTriple<T1,T2,T3>,WeightCountTuple> jointCount, Map<CachedPair<T1,T2>,WeightCountTuple> abCount, Map<CachedPair<T1,T3>,WeightCountTuple> acCount, Map<CachedPair<T2,T3>,WeightCountTuple> bcCount, Map<T1,WeightCountTuple> aCount, Map<T2,WeightCountTuple> bCount, Map<T3,WeightCountTuple> cCount) {
064        this.count = count;
065        this.jointCount = jointCount;
066        this.abCount = abCount;
067        this.acCount = acCount;
068        this.bcCount = bcCount;
069        this.aCount = aCount;
070        this.bCount = bCount;
071        this.cCount = cCount;
072    }
073
074    /**
075     * The joint distribution over the three variables.
076     * @return The joint distribution.
077     */
078    public Map<CachedTriple<T1,T2,T3>,WeightCountTuple> getJointCount() {
079        return jointCount;
080    }
081
082    /**
083     * The joint distribution over the first and second variables.
084     * @return The joint distribution over A and B.
085     */
086    public Map<CachedPair<T1,T2>,WeightCountTuple> getABCount() {
087        return abCount;
088    }
089
090    /**
091     * The joint distribution over the first and third variables.
092     * @return The joint distribution over A and C.
093     */
094    public Map<CachedPair<T1,T3>,WeightCountTuple> getACCount() {
095        return acCount;
096    }
097
098    /**
099     * The joint distribution over the second and third variables.
100     * @return The joint distribution over B and C.
101     */
102    public Map<CachedPair<T2,T3>,WeightCountTuple> getBCCount() {
103        return bcCount;
104    }
105
106    /**
107     * The marginal distribution over the first variable.
108     * @return The marginal distribution for A.
109     */
110    public Map<T1,WeightCountTuple> getACount() {
111        return aCount;
112    }
113
114    /**
115     * The marginal distribution over the second variable.
116     * @return The marginal distribution for B.
117     */
118    public Map<T2,WeightCountTuple> getBCount() {
119        return bCount;
120    }
121
122    /**
123     * The marginal distribution over the third variable.
124     * @return The marginal distribution for C.
125     */
126    public Map<T3,WeightCountTuple> getCCount() {
127        return cCount;
128    }
129
130    /**
131     * Constructs a WeightedTripleDistribution from three lists of the same length and a list of weights of the same length.
132     * <p>
133     * If they are not the same length it throws IllegalArgumentException.
134     * @param first The first list.
135     * @param second The second list.
136     * @param third The third list.
137     * @param weights The weight list.
138     * @param <T1> The first type.
139     * @param <T2> The second type.
140     * @param <T3> The third type.
141     * @return The WeightedTripleDistribution.
142     */
143    public static <T1,T2,T3> WeightedTripleDistribution<T1,T2,T3> constructFromLists(List<T1> first, List<T2> second, List<T3> third, List<Double> weights) {
144        Map<CachedTriple<T1,T2,T3>,WeightCountTuple> jointCount = new HashMap<>(DEFAULT_MAP_SIZE);
145        Map<CachedPair<T1,T2>,WeightCountTuple> abCount = new HashMap<>(DEFAULT_MAP_SIZE);
146        Map<CachedPair<T1,T3>,WeightCountTuple> acCount = new HashMap<>(DEFAULT_MAP_SIZE);
147        Map<CachedPair<T2,T3>,WeightCountTuple> bcCount = new HashMap<>(DEFAULT_MAP_SIZE);
148        Map<T1,WeightCountTuple> aCount = new HashMap<>(DEFAULT_MAP_SIZE);
149        Map<T2,WeightCountTuple> bCount = new HashMap<>(DEFAULT_MAP_SIZE);
150        Map<T3,WeightCountTuple> cCount = new HashMap<>(DEFAULT_MAP_SIZE);
151
152        long count = first.size();
153
154        if ((first.size() == second.size()) && (first.size() == third.size()) && (first.size() == weights.size())) {
155            for (int i = 0; i < first.size(); i++) {
156                double weight = weights.get(i);
157                T1 a = first.get(i);
158                T2 b = second.get(i);
159                T3 c = third.get(i);
160                CachedTriple<T1,T2,T3> triple = new CachedTriple<>(a,b,c);
161                CachedPair<T1,T2> abPair = triple.getAB();
162                CachedPair<T1,T3> acPair = triple.getAC();
163                CachedPair<T2,T3> bcPair = triple.getBC();
164
165                WeightCountTuple abcCurCount = jointCount.computeIfAbsent(triple,(k) -> new WeightCountTuple());
166                abcCurCount.weight += weight;
167                abcCurCount.count++;
168
169                WeightCountTuple abCurCount = abCount.computeIfAbsent(abPair,(k) -> new WeightCountTuple());
170                abCurCount.weight += weight;
171                abCurCount.count++;
172
173                WeightCountTuple acCurCount = acCount.computeIfAbsent(acPair,(k) -> new WeightCountTuple());
174                acCurCount.weight += weight;
175                acCurCount.count++;
176                
177                WeightCountTuple bcCurCount = bcCount.computeIfAbsent(bcPair,(k) -> new WeightCountTuple());
178                bcCurCount.weight += weight;
179                bcCurCount.count++;
180                
181                WeightCountTuple aCurCount = aCount.computeIfAbsent(a,(k) -> new WeightCountTuple());
182                aCurCount.weight += weight;
183                aCurCount.count++;
184
185                WeightCountTuple bCurCount = bCount.computeIfAbsent(b,(k) -> new WeightCountTuple());
186                bCurCount.weight += weight;
187                bCurCount.count++;
188
189                WeightCountTuple cCurCount = cCount.computeIfAbsent(c,(k) -> new WeightCountTuple());
190                cCurCount.weight += weight;
191                cCurCount.count++;
192            }
193
194            WeightedInformationTheory.normaliseWeights(jointCount);
195            WeightedInformationTheory.normaliseWeights(abCount);
196            WeightedInformationTheory.normaliseWeights(acCount);
197            WeightedInformationTheory.normaliseWeights(bcCount);
198            WeightedInformationTheory.normaliseWeights(aCount);
199            WeightedInformationTheory.normaliseWeights(bCount);
200            WeightedInformationTheory.normaliseWeights(cCount);
201
202            return new WeightedTripleDistribution<>(count,jointCount,abCount,acCount,bcCount,aCount,bCount,cCount);
203        } else {
204            throw new IllegalArgumentException("Counting requires lists of the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", third.size() = " + third.size() + ", weights.size() = " + weights.size());
205        }
206    }
207
208    /**
209     * Constructs a WeightedTripleDistribution by marginalising the supplied joint distribution.
210     * @param jointCount The joint distribution.
211     * @param <T1> The type of A.
212     * @param <T2> The type of B.
213     * @param <T3> The type of C.
214     * @return A WeightedTripleDistribution.
215     */
216    public static <T1,T2,T3> WeightedTripleDistribution<T1,T2,T3> constructFromMap(Map<CachedTriple<T1,T2,T3>,WeightCountTuple> jointCount) {
217        Map<CachedPair<T1,T2>,WeightCountTuple> abCount = new HashMap<>(DEFAULT_MAP_SIZE);
218        Map<CachedPair<T1,T3>,WeightCountTuple> acCount = new HashMap<>(DEFAULT_MAP_SIZE);
219        Map<CachedPair<T2,T3>,WeightCountTuple> bcCount = new HashMap<>(DEFAULT_MAP_SIZE);
220        Map<T1,WeightCountTuple> aCount = new HashMap<>(DEFAULT_MAP_SIZE);
221        Map<T2,WeightCountTuple> bCount = new HashMap<>(DEFAULT_MAP_SIZE);
222        Map<T3,WeightCountTuple> cCount = new HashMap<>(DEFAULT_MAP_SIZE);
223        
224        long count = 0L;
225
226        for (Entry<CachedTriple<T1,T2,T3>,WeightCountTuple> e : jointCount.entrySet()) {
227            CachedTriple<T1,T2,T3> triple = e.getKey();
228            WeightCountTuple tuple = e.getValue();
229            CachedPair<T1,T2> abPair = triple.getAB();
230            CachedPair<T1,T3> acPair = triple.getAC();
231            CachedPair<T2,T3> bcPair = triple.getBC();
232            T1 a = triple.getA();
233            T2 b = triple.getB();
234            T3 c = triple.getC();
235
236            count += tuple.count;
237
238            double weight = tuple.weight * tuple.count;
239
240            WeightCountTuple abCurCount = abCount.computeIfAbsent(abPair,(k) -> new WeightCountTuple());
241            abCurCount.weight += weight;
242            abCurCount.count += tuple.count;
243
244            WeightCountTuple acCurCount = acCount.computeIfAbsent(acPair,(k) -> new WeightCountTuple());
245            acCurCount.weight += weight;
246            acCurCount.count += tuple.count;
247
248            WeightCountTuple bcCurCount = bcCount.computeIfAbsent(bcPair,(k) -> new WeightCountTuple());
249            bcCurCount.weight += weight;
250            bcCurCount.count += tuple.count;
251
252            WeightCountTuple aCurCount = aCount.computeIfAbsent(a,(k) -> new WeightCountTuple());
253            aCurCount.weight += weight;
254            aCurCount.count += tuple.count;
255
256            WeightCountTuple bCurCount = bCount.computeIfAbsent(b,(k) -> new WeightCountTuple());
257            bCurCount.weight += weight;
258            bCurCount.count += tuple.count;
259
260            WeightCountTuple cCurCount = cCount.computeIfAbsent(c,(k) -> new WeightCountTuple());
261            cCurCount.weight += weight;
262            cCurCount.count += tuple.count;
263        }
264
265        WeightedInformationTheory.normaliseWeights(abCount);
266        WeightedInformationTheory.normaliseWeights(acCount);
267        WeightedInformationTheory.normaliseWeights(bcCount);
268        WeightedInformationTheory.normaliseWeights(aCount);
269        WeightedInformationTheory.normaliseWeights(bCount);
270        WeightedInformationTheory.normaliseWeights(cCount);
271
272        return new WeightedTripleDistribution<>(count,jointCount,abCount,acCount,bcCount,aCount,bCount,cCount);
273    }
274    
275}