001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.util.infotheory.impl; 018 019import com.oracle.labs.mlrg.olcut.util.MutableLong; 020 021import java.util.HashMap; 022import java.util.LinkedHashMap; 023import java.util.List; 024import java.util.Map; 025import java.util.Map.Entry; 026 027/** 028 * Generates the counts for a triplet of vectors. Contains the joint 029 * count, the three pairwise counts, and the three marginal counts. 030 * <p> 031 * Relies upon hashCode and equals to determine element equality for counting. 032 * @param <T1> Type of the first list. 033 * @param <T2> Type of the second list. 034 * @param <T3> Type of the third list. 035 */ 036public class TripleDistribution<T1,T2,T3> { 037 /** 038 * The default map size to initialise the marginalised count maps with. 039 */ 040 public static final int DEFAULT_MAP_SIZE = 20; 041 042 /** 043 * The number of samples in this distribution. 044 */ 045 public final long count; 046 047 private final Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount; 048 private final Map<CachedPair<T1,T2>,MutableLong> abCount; 049 private final Map<CachedPair<T1,T3>,MutableLong> acCount; 050 private final Map<CachedPair<T2,T3>,MutableLong> bcCount; 051 private final Map<T1,MutableLong> aCount; 052 private final Map<T2,MutableLong> bCount; 053 private final Map<T3,MutableLong> cCount; 054 055 /** 056 * Constructs a triple distribution from the supplied distributions. 057 * @param count The sample count. 058 * @param jointCount The joint distribution over the three variables. 059 * @param abCount The joint distribution over A and B. 060 * @param acCount The joint distribution over A and C. 061 * @param bcCount The joint distribution over B and C. 062 * @param aCount The marginal distribution over A. 063 * @param bCount The marginal distribution over B. 064 * @param cCount The marginal distribution over C. 065 */ 066 public TripleDistribution(long count, Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount, Map<CachedPair<T1,T2>,MutableLong> abCount, Map<CachedPair<T1,T3>,MutableLong> acCount, Map<CachedPair<T2,T3>,MutableLong> bcCount, Map<T1,MutableLong> aCount, Map<T2,MutableLong> bCount, Map<T3,MutableLong> cCount) { 067 this.count = count; 068 this.jointCount = jointCount; 069 this.abCount = abCount; 070 this.acCount = acCount; 071 this.bcCount = bcCount; 072 this.aCount = aCount; 073 this.bCount = bCount; 074 this.cCount = cCount; 075 } 076 077 /** 078 * The joint distribution over the three variables. 079 * @return The joint distribution. 080 */ 081 public Map<CachedTriple<T1,T2,T3>,MutableLong> getJointCount() { 082 return jointCount; 083 } 084 085 /** 086 * The joint distribution over the first and second variables. 087 * @return The joint distribution over A and B. 088 */ 089 public Map<CachedPair<T1,T2>,MutableLong> getABCount() { 090 return abCount; 091 } 092 093 /** 094 * The joint distribution over the first and third variables. 095 * @return The joint distribution over A and C. 096 */ 097 public Map<CachedPair<T1,T3>,MutableLong> getACCount() { 098 return acCount; 099 } 100 101 /** 102 * The joint distribution over the second and third variables. 103 * @return The joint distribution over B and C. 104 */ 105 public Map<CachedPair<T2,T3>,MutableLong> getBCCount() { 106 return bcCount; 107 } 108 109 /** 110 * The marginal distribution over the first variable. 111 * @return The marginal distribution for A. 112 */ 113 public Map<T1,MutableLong> getACount() { 114 return aCount; 115 } 116 117 /** 118 * The marginal distribution over the second variable. 119 * @return The marginal distribution for B. 120 */ 121 public Map<T2,MutableLong> getBCount() { 122 return bCount; 123 } 124 125 /** 126 * The marginal distribution over the third variable. 127 * @return The marginal distribution for C. 128 */ 129 public Map<T3,MutableLong> getCCount() { 130 return cCount; 131 } 132 133 /** 134 * Constructs a TripleDistribution from three lists of the same length. 135 * <p> 136 * If they are not the same length it throws IllegalArgumentException. 137 * @param first The first list. 138 * @param second The second list. 139 * @param third The third list. 140 * @param <T1> The first type. 141 * @param <T2> The second type. 142 * @param <T3> The third type. 143 * @return The TripleDistribution. 144 */ 145 public static <T1,T2,T3> TripleDistribution<T1,T2,T3> constructFromLists(List<T1> first, List<T2> second, List<T3> third) { 146 Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount = new LinkedHashMap<>(DEFAULT_MAP_SIZE); 147 Map<CachedPair<T1,T2>,MutableLong> abCount = new HashMap<>(DEFAULT_MAP_SIZE); 148 Map<CachedPair<T1,T3>,MutableLong> acCount = new HashMap<>(DEFAULT_MAP_SIZE); 149 Map<CachedPair<T2,T3>,MutableLong> bcCount = new HashMap<>(DEFAULT_MAP_SIZE); 150 Map<T1,MutableLong> aCount = new HashMap<>(DEFAULT_MAP_SIZE); 151 Map<T2,MutableLong> bCount = new HashMap<>(DEFAULT_MAP_SIZE); 152 Map<T3,MutableLong> cCount = new HashMap<>(DEFAULT_MAP_SIZE); 153 154 long count = first.size(); 155 156 if ((first.size() == second.size()) && (first.size() == third.size())) { 157 for (int i = 0; i < first.size(); i++) { 158 T1 a = first.get(i); 159 T2 b = second.get(i); 160 T3 c = third.get(i); 161 CachedTriple<T1,T2,T3> abc = new CachedTriple<>(a,b,c); 162 CachedPair<T1,T2> ab = abc.getAB(); 163 CachedPair<T1,T3> ac = abc.getAC(); 164 CachedPair<T2,T3> bc = abc.getBC(); 165 166 MutableLong abcCurCount = jointCount.computeIfAbsent(abc, k -> new MutableLong()); 167 abcCurCount.increment(); 168 169 MutableLong abCurCount = abCount.computeIfAbsent(ab, k -> new MutableLong()); 170 abCurCount.increment(); 171 MutableLong acCurCount = acCount.computeIfAbsent(ac, k -> new MutableLong()); 172 acCurCount.increment(); 173 MutableLong bcCurCount = bcCount.computeIfAbsent(bc, k -> new MutableLong()); 174 bcCurCount.increment(); 175 176 MutableLong aCurCount = aCount.computeIfAbsent(a, k -> new MutableLong()); 177 aCurCount.increment(); 178 MutableLong bCurCount = bCount.computeIfAbsent(b, k -> new MutableLong()); 179 bCurCount.increment(); 180 MutableLong cCurCount = cCount.computeIfAbsent(c, k -> new MutableLong()); 181 cCurCount.increment(); 182 } 183 184 return new TripleDistribution<>(count,jointCount,abCount,acCount,bcCount,aCount,bCount,cCount); 185 } else { 186 throw new IllegalArgumentException("Counting requires lists of the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", third.size() = " + third.size()); 187 } 188 } 189 190 /** 191 * Constructs a TripleDistribution by marginalising the supplied joint distribution. 192 * @param jointCount The joint distribution. 193 * @param <T1> The type of A. 194 * @param <T2> The type of B. 195 * @param <T3> The type of C. 196 * @return A TripleDistribution. 197 */ 198 public static <T1,T2,T3> TripleDistribution<T1,T2,T3> constructFromMap(Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount) { 199 Map<CachedPair<T1,T2>,MutableLong> abCount = new HashMap<>(DEFAULT_MAP_SIZE); 200 Map<CachedPair<T1,T3>,MutableLong> acCount = new HashMap<>(DEFAULT_MAP_SIZE); 201 Map<CachedPair<T2,T3>,MutableLong> bcCount = new HashMap<>(DEFAULT_MAP_SIZE); 202 Map<T1,MutableLong> aCount = new HashMap<>(DEFAULT_MAP_SIZE); 203 Map<T2,MutableLong> bCount = new HashMap<>(DEFAULT_MAP_SIZE); 204 Map<T3,MutableLong> cCount = new HashMap<>(DEFAULT_MAP_SIZE); 205 206 return constructFromMap(jointCount,abCount,acCount,bcCount,aCount,bCount,cCount); 207 } 208 209 /** 210 * Constructs a TripleDistribution by marginalising the supplied joint distribution. 211 * <p> 212 * Sizes are used to preallocate the HashMaps. 213 * @param jointCount The joint distribution. 214 * @param abSize The number of unique AB states. 215 * @param acSize The number of unique AC states. 216 * @param bcSize The number of unique BC states. 217 * @param aSize The number of unique A states. 218 * @param bSize The number of unique B states. 219 * @param cSize The number of unique C states. 220 * @param <T1> The type of A. 221 * @param <T2> The type of B. 222 * @param <T3> The type of C. 223 * @return A TripleDistribution. 224 */ 225 public static <T1,T2,T3> TripleDistribution<T1,T2,T3> constructFromMap(Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount, 226 int abSize, int acSize, int bcSize, 227 int aSize, int bSize, int cSize) { 228 Map<CachedPair<T1,T2>,MutableLong> abCount = new HashMap<>(abSize); 229 Map<CachedPair<T1,T3>,MutableLong> acCount = new HashMap<>(acSize); 230 Map<CachedPair<T2,T3>,MutableLong> bcCount = new HashMap<>(bcSize); 231 Map<T1,MutableLong> aCount = new HashMap<>(aSize); 232 Map<T2,MutableLong> bCount = new HashMap<>(bSize); 233 Map<T3,MutableLong> cCount = new HashMap<>(cSize); 234 235 return constructFromMap(jointCount,abCount,acCount,bcCount,aCount,bCount,cCount); 236 } 237 238 /** 239 * Constructs a TripleDistribution by marginalising the supplied joint distribution. 240 * <p> 241 * Sizes are used to preallocate the HashMaps. 242 * @param jointCount The joint distribution. 243 * @param abCount An empty hashmap for AB. 244 * @param acCount An empty hashmap for AC. 245 * @param bcCount An empty hashmap for BC. 246 * @param aCount An empty hashmap for A. 247 * @param bCount An empty hashmap for B. 248 * @param cCount An empty hashmap for C. 249 * @param <T1> The type of A. 250 * @param <T2> The type of B. 251 * @param <T3> The type of C. 252 * @return A TripleDistribution. 253 */ 254 public static <T1,T2,T3> TripleDistribution<T1,T2,T3> constructFromMap(Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount, 255 Map<CachedPair<T1,T2>,MutableLong> abCount, 256 Map<CachedPair<T1,T3>,MutableLong> acCount, 257 Map<CachedPair<T2,T3>,MutableLong> bcCount, 258 Map<T1,MutableLong> aCount, 259 Map<T2,MutableLong> bCount, 260 Map<T3,MutableLong> cCount) { 261 long count = 0L; 262 263 for (Entry<CachedTriple<T1,T2,T3>,MutableLong> e : jointCount.entrySet()) { 264 CachedTriple<T1,T2,T3> abc = e.getKey(); 265 long curCount = e.getValue().longValue(); 266 CachedPair<T1,T2> ab = abc.getAB(); 267 CachedPair<T1,T3> ac = abc.getAC(); 268 CachedPair<T2,T3> bc = abc.getBC(); 269 T1 a = abc.getA(); 270 T2 b = abc.getB(); 271 T3 c = abc.getC(); 272 273 count += curCount; 274 275 MutableLong abCurCount = abCount.computeIfAbsent(ab, k -> new MutableLong()); 276 abCurCount.increment(curCount); 277 MutableLong acCurCount = acCount.computeIfAbsent(ac, k -> new MutableLong()); 278 acCurCount.increment(curCount); 279 MutableLong bcCurCount = bcCount.computeIfAbsent(bc, k -> new MutableLong()); 280 bcCurCount.increment(curCount); 281 282 MutableLong aCurCount = aCount.computeIfAbsent(a, k -> new MutableLong()); 283 aCurCount.increment(curCount); 284 MutableLong bCurCount = bCount.computeIfAbsent(b, k -> new MutableLong()); 285 bCurCount.increment(curCount); 286 MutableLong cCurCount = cCount.computeIfAbsent(c, k -> new MutableLong()); 287 cCurCount.increment(curCount); 288 } 289 290 return new TripleDistribution<>(count,jointCount,abCount,acCount,bcCount,aCount,bCount,cCount); 291 } 292 293}