001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.util.infotheory.impl;
018
019import com.oracle.labs.mlrg.olcut.util.MutableLong;
020
021import java.util.HashMap;
022import java.util.LinkedHashMap;
023import java.util.List;
024import java.util.Map;
025import java.util.Map.Entry;
026
027/**
028 * Generates the counts for a triplet of vectors. Contains the joint
029 * count, the three pairwise counts, and the three marginal counts.
030 * <p>
031 * Relies upon hashCode and equals to determine element equality for counting.
032 * @param <T1> Type of the first list.
033 * @param <T2> Type of the second list.
034 * @param <T3> Type of the third list.
035 */
036public class TripleDistribution<T1,T2,T3> {
037    /**
038     * The default map size to initialise the marginalised count maps with.
039     */
040    public static final int DEFAULT_MAP_SIZE = 20;
041
042    /**
043     * The number of samples in this distribution.
044     */
045    public final long count;
046
047    private final Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount;
048    private final Map<CachedPair<T1,T2>,MutableLong> abCount;
049    private final Map<CachedPair<T1,T3>,MutableLong> acCount;
050    private final Map<CachedPair<T2,T3>,MutableLong> bcCount;
051    private final Map<T1,MutableLong> aCount;
052    private final Map<T2,MutableLong> bCount;
053    private final Map<T3,MutableLong> cCount;
054
055    /**
056     * Constructs a triple distribution from the supplied distributions.
057     * @param count The sample count.
058     * @param jointCount The joint distribution over the three variables.
059     * @param abCount The joint distribution over A and B.
060     * @param acCount The joint distribution over A and C.
061     * @param bcCount The joint distribution over B and C.
062     * @param aCount The marginal distribution over A.
063     * @param bCount The marginal distribution over B.
064     * @param cCount The marginal distribution over C.
065     */
066    public TripleDistribution(long count, Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount, Map<CachedPair<T1,T2>,MutableLong> abCount, Map<CachedPair<T1,T3>,MutableLong> acCount, Map<CachedPair<T2,T3>,MutableLong> bcCount, Map<T1,MutableLong> aCount, Map<T2,MutableLong> bCount, Map<T3,MutableLong> cCount) {
067        this.count = count;
068        this.jointCount = jointCount;
069        this.abCount = abCount;
070        this.acCount = acCount;
071        this.bcCount = bcCount;
072        this.aCount = aCount;
073        this.bCount = bCount;
074        this.cCount = cCount;
075    }
076
077    /**
078     * The joint distribution over the three variables.
079     * @return The joint distribution.
080     */
081    public Map<CachedTriple<T1,T2,T3>,MutableLong> getJointCount() {
082        return jointCount;
083    }
084
085    /**
086     * The joint distribution over the first and second variables.
087     * @return The joint distribution over A and B.
088     */
089    public Map<CachedPair<T1,T2>,MutableLong> getABCount() {
090        return abCount;
091    }
092
093    /**
094     * The joint distribution over the first and third variables.
095     * @return The joint distribution over A and C.
096     */
097    public Map<CachedPair<T1,T3>,MutableLong> getACCount() {
098        return acCount;
099    }
100
101    /**
102     * The joint distribution over the second and third variables.
103     * @return The joint distribution over B and C.
104     */
105    public Map<CachedPair<T2,T3>,MutableLong> getBCCount() {
106        return bcCount;
107    }
108
109    /**
110     * The marginal distribution over the first variable.
111     * @return The marginal distribution for A.
112     */
113    public Map<T1,MutableLong> getACount() {
114        return aCount;
115    }
116
117    /**
118     * The marginal distribution over the second variable.
119     * @return The marginal distribution for B.
120     */
121    public Map<T2,MutableLong> getBCount() {
122        return bCount;
123    }
124
125    /**
126     * The marginal distribution over the third variable.
127     * @return The marginal distribution for C.
128     */
129    public Map<T3,MutableLong> getCCount() {
130        return cCount;
131    }
132
133    /**
134     * Constructs a TripleDistribution from three lists of the same length.
135     * <p>
136     * If they are not the same length it throws IllegalArgumentException.
137     * @param first The first list.
138     * @param second The second list.
139     * @param third The third list.
140     * @param <T1> The first type.
141     * @param <T2> The second type.
142     * @param <T3> The third type.
143     * @return The TripleDistribution.
144     */
145    public static <T1,T2,T3> TripleDistribution<T1,T2,T3> constructFromLists(List<T1> first, List<T2> second, List<T3> third) {
146        Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount = new LinkedHashMap<>(DEFAULT_MAP_SIZE);
147        Map<CachedPair<T1,T2>,MutableLong> abCount = new HashMap<>(DEFAULT_MAP_SIZE);
148        Map<CachedPair<T1,T3>,MutableLong> acCount = new HashMap<>(DEFAULT_MAP_SIZE);
149        Map<CachedPair<T2,T3>,MutableLong> bcCount = new HashMap<>(DEFAULT_MAP_SIZE);
150        Map<T1,MutableLong> aCount = new HashMap<>(DEFAULT_MAP_SIZE);
151        Map<T2,MutableLong> bCount = new HashMap<>(DEFAULT_MAP_SIZE);
152        Map<T3,MutableLong> cCount = new HashMap<>(DEFAULT_MAP_SIZE);
153
154        long count = first.size();
155
156        if ((first.size() == second.size()) && (first.size() == third.size())) {
157            for (int i = 0; i < first.size(); i++) {
158                T1 a = first.get(i);
159                T2 b = second.get(i);
160                T3 c = third.get(i);
161                CachedTriple<T1,T2,T3> abc = new CachedTriple<>(a,b,c);
162                CachedPair<T1,T2> ab = abc.getAB();
163                CachedPair<T1,T3> ac = abc.getAC();
164                CachedPair<T2,T3> bc = abc.getBC();
165
166                MutableLong abcCurCount = jointCount.computeIfAbsent(abc, k -> new MutableLong());
167                abcCurCount.increment();
168
169                MutableLong abCurCount = abCount.computeIfAbsent(ab, k -> new MutableLong());
170                abCurCount.increment();
171                MutableLong acCurCount = acCount.computeIfAbsent(ac, k -> new MutableLong());
172                acCurCount.increment();
173                MutableLong bcCurCount = bcCount.computeIfAbsent(bc, k -> new MutableLong());
174                bcCurCount.increment();
175
176                MutableLong aCurCount = aCount.computeIfAbsent(a, k -> new MutableLong());
177                aCurCount.increment();
178                MutableLong bCurCount = bCount.computeIfAbsent(b, k -> new MutableLong());
179                bCurCount.increment();
180                MutableLong cCurCount = cCount.computeIfAbsent(c, k -> new MutableLong());
181                cCurCount.increment();
182            }
183
184            return new TripleDistribution<>(count,jointCount,abCount,acCount,bcCount,aCount,bCount,cCount);
185        } else {
186            throw new IllegalArgumentException("Counting requires lists of the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", third.size() = " + third.size());
187        }
188    }
189
190    /**
191     * Constructs a TripleDistribution by marginalising the supplied joint distribution.
192     * @param jointCount The joint distribution.
193     * @param <T1> The type of A.
194     * @param <T2> The type of B.
195     * @param <T3> The type of C.
196     * @return A TripleDistribution.
197     */
198    public static <T1,T2,T3> TripleDistribution<T1,T2,T3> constructFromMap(Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount) {
199        Map<CachedPair<T1,T2>,MutableLong> abCount = new HashMap<>(DEFAULT_MAP_SIZE);
200        Map<CachedPair<T1,T3>,MutableLong> acCount = new HashMap<>(DEFAULT_MAP_SIZE);
201        Map<CachedPair<T2,T3>,MutableLong> bcCount = new HashMap<>(DEFAULT_MAP_SIZE);
202        Map<T1,MutableLong> aCount = new HashMap<>(DEFAULT_MAP_SIZE);
203        Map<T2,MutableLong> bCount = new HashMap<>(DEFAULT_MAP_SIZE);
204        Map<T3,MutableLong> cCount = new HashMap<>(DEFAULT_MAP_SIZE);
205
206        return constructFromMap(jointCount,abCount,acCount,bcCount,aCount,bCount,cCount);
207    }
208
209    /**
210     * Constructs a TripleDistribution by marginalising the supplied joint distribution.
211     * <p>
212     * Sizes are used to preallocate the HashMaps.
213     * @param jointCount The joint distribution.
214     * @param abSize The number of unique AB states.
215     * @param acSize The number of unique AC states.
216     * @param bcSize The number of unique BC states.
217     * @param aSize The number of unique A states.
218     * @param bSize The number of unique B states.
219     * @param cSize The number of unique C states.
220     * @param <T1> The type of A.
221     * @param <T2> The type of B.
222     * @param <T3> The type of C.
223     * @return A TripleDistribution.
224     */
225    public static <T1,T2,T3> TripleDistribution<T1,T2,T3> constructFromMap(Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount,
226                                                                           int abSize, int acSize, int bcSize,
227                                                                           int aSize, int bSize, int cSize) {
228        Map<CachedPair<T1,T2>,MutableLong> abCount = new HashMap<>(abSize);
229        Map<CachedPair<T1,T3>,MutableLong> acCount = new HashMap<>(acSize);
230        Map<CachedPair<T2,T3>,MutableLong> bcCount = new HashMap<>(bcSize);
231        Map<T1,MutableLong> aCount = new HashMap<>(aSize);
232        Map<T2,MutableLong> bCount = new HashMap<>(bSize);
233        Map<T3,MutableLong> cCount = new HashMap<>(cSize);
234
235        return constructFromMap(jointCount,abCount,acCount,bcCount,aCount,bCount,cCount);
236    }
237
238    /**
239     * Constructs a TripleDistribution by marginalising the supplied joint distribution.
240     * <p>
241     * Sizes are used to preallocate the HashMaps.
242     * @param jointCount The joint distribution.
243     * @param abCount An empty hashmap for AB.
244     * @param acCount An empty hashmap for AC.
245     * @param bcCount An empty hashmap for BC.
246     * @param aCount An empty hashmap for A.
247     * @param bCount An empty hashmap for B.
248     * @param cCount An empty hashmap for C.
249     * @param <T1> The type of A.
250     * @param <T2> The type of B.
251     * @param <T3> The type of C.
252     * @return A TripleDistribution.
253     */
254    public static <T1,T2,T3> TripleDistribution<T1,T2,T3> constructFromMap(Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount,
255                                                                           Map<CachedPair<T1,T2>,MutableLong> abCount,
256                                                                           Map<CachedPair<T1,T3>,MutableLong> acCount,
257                                                                           Map<CachedPair<T2,T3>,MutableLong> bcCount,
258                                                                           Map<T1,MutableLong> aCount,
259                                                                           Map<T2,MutableLong> bCount,
260                                                                           Map<T3,MutableLong> cCount) {
261        long count = 0L;
262
263        for (Entry<CachedTriple<T1,T2,T3>,MutableLong> e : jointCount.entrySet()) {
264            CachedTriple<T1,T2,T3> abc = e.getKey();
265            long curCount = e.getValue().longValue();
266            CachedPair<T1,T2> ab = abc.getAB();
267            CachedPair<T1,T3> ac = abc.getAC();
268            CachedPair<T2,T3> bc = abc.getBC();
269            T1 a = abc.getA();
270            T2 b = abc.getB();
271            T3 c = abc.getC();
272
273            count += curCount;
274
275            MutableLong abCurCount = abCount.computeIfAbsent(ab, k -> new MutableLong());
276            abCurCount.increment(curCount);
277            MutableLong acCurCount = acCount.computeIfAbsent(ac, k -> new MutableLong());
278            acCurCount.increment(curCount);
279            MutableLong bcCurCount = bcCount.computeIfAbsent(bc, k -> new MutableLong());
280            bcCurCount.increment(curCount);
281
282            MutableLong aCurCount = aCount.computeIfAbsent(a, k -> new MutableLong());
283            aCurCount.increment(curCount);
284            MutableLong bCurCount = bCount.computeIfAbsent(b, k -> new MutableLong());
285            bCurCount.increment(curCount);
286            MutableLong cCurCount = cCount.computeIfAbsent(c, k -> new MutableLong());
287            cCurCount.increment(curCount);
288        }
289
290        return new TripleDistribution<>(count,jointCount,abCount,acCount,bcCount,aCount,bCount,cCount);
291    }
292    
293}