001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.util.infotheory.impl; 018 019import org.tribuo.util.infotheory.WeightedInformationTheory; 020 021import java.util.HashMap; 022import java.util.List; 023import java.util.Map; 024import java.util.Map.Entry; 025 026/** 027 * Generates the counts for a triplet of vectors. Contains the joint 028 * count, the three pairwise counts, and the three marginal counts. 029 * @param <T1> Type of the first list. 030 * @param <T2> Type of the second list. 031 * @param <T3> Type of the third list. 032 */ 033public class WeightedTripleDistribution<T1,T2,T3> { 034 /** 035 * The default map size. 036 */ 037 public static final int DEFAULT_MAP_SIZE = 20; 038 039 /** 040 * The sample count. 041 */ 042 public final long count; 043 044 private final Map<CachedTriple<T1,T2,T3>,WeightCountTuple> jointCount; 045 private final Map<CachedPair<T1,T2>,WeightCountTuple> abCount; 046 private final Map<CachedPair<T1,T3>,WeightCountTuple> acCount; 047 private final Map<CachedPair<T2,T3>,WeightCountTuple> bcCount; 048 private final Map<T1,WeightCountTuple> aCount; 049 private final Map<T2,WeightCountTuple> bCount; 050 private final Map<T3,WeightCountTuple> cCount; 051 052 /** 053 * Constructs a weighted triple distribution from the supplied values. 054 * @param count The sample count. 055 * @param jointCount The ABC joint distribution. 056 * @param abCount The AB joint distribution. 057 * @param acCount The AC joint distribution. 058 * @param bcCount The BC joint distribution. 059 * @param aCount The A marginal distribution. 060 * @param bCount The B marginal distribution. 061 * @param cCount The C marginal distribution. 062 */ 063 public WeightedTripleDistribution(long count, Map<CachedTriple<T1,T2,T3>,WeightCountTuple> jointCount, Map<CachedPair<T1,T2>,WeightCountTuple> abCount, Map<CachedPair<T1,T3>,WeightCountTuple> acCount, Map<CachedPair<T2,T3>,WeightCountTuple> bcCount, Map<T1,WeightCountTuple> aCount, Map<T2,WeightCountTuple> bCount, Map<T3,WeightCountTuple> cCount) { 064 this.count = count; 065 this.jointCount = jointCount; 066 this.abCount = abCount; 067 this.acCount = acCount; 068 this.bcCount = bcCount; 069 this.aCount = aCount; 070 this.bCount = bCount; 071 this.cCount = cCount; 072 } 073 074 /** 075 * The joint distribution over the three variables. 076 * @return The joint distribution. 077 */ 078 public Map<CachedTriple<T1,T2,T3>,WeightCountTuple> getJointCount() { 079 return jointCount; 080 } 081 082 /** 083 * The joint distribution over the first and second variables. 084 * @return The joint distribution over A and B. 085 */ 086 public Map<CachedPair<T1,T2>,WeightCountTuple> getABCount() { 087 return abCount; 088 } 089 090 /** 091 * The joint distribution over the first and third variables. 092 * @return The joint distribution over A and C. 093 */ 094 public Map<CachedPair<T1,T3>,WeightCountTuple> getACCount() { 095 return acCount; 096 } 097 098 /** 099 * The joint distribution over the second and third variables. 100 * @return The joint distribution over B and C. 101 */ 102 public Map<CachedPair<T2,T3>,WeightCountTuple> getBCCount() { 103 return bcCount; 104 } 105 106 /** 107 * The marginal distribution over the first variable. 108 * @return The marginal distribution for A. 109 */ 110 public Map<T1,WeightCountTuple> getACount() { 111 return aCount; 112 } 113 114 /** 115 * The marginal distribution over the second variable. 116 * @return The marginal distribution for B. 117 */ 118 public Map<T2,WeightCountTuple> getBCount() { 119 return bCount; 120 } 121 122 /** 123 * The marginal distribution over the third variable. 124 * @return The marginal distribution for C. 125 */ 126 public Map<T3,WeightCountTuple> getCCount() { 127 return cCount; 128 } 129 130 /** 131 * Constructs a WeightedTripleDistribution from three lists of the same length and a list of weights of the same length. 132 * <p> 133 * If they are not the same length it throws IllegalArgumentException. 134 * @param first The first list. 135 * @param second The second list. 136 * @param third The third list. 137 * @param weights The weight list. 138 * @param <T1> The first type. 139 * @param <T2> The second type. 140 * @param <T3> The third type. 141 * @return The WeightedTripleDistribution. 142 */ 143 public static <T1,T2,T3> WeightedTripleDistribution<T1,T2,T3> constructFromLists(List<T1> first, List<T2> second, List<T3> third, List<Double> weights) { 144 Map<CachedTriple<T1,T2,T3>,WeightCountTuple> jointCount = new HashMap<>(DEFAULT_MAP_SIZE); 145 Map<CachedPair<T1,T2>,WeightCountTuple> abCount = new HashMap<>(DEFAULT_MAP_SIZE); 146 Map<CachedPair<T1,T3>,WeightCountTuple> acCount = new HashMap<>(DEFAULT_MAP_SIZE); 147 Map<CachedPair<T2,T3>,WeightCountTuple> bcCount = new HashMap<>(DEFAULT_MAP_SIZE); 148 Map<T1,WeightCountTuple> aCount = new HashMap<>(DEFAULT_MAP_SIZE); 149 Map<T2,WeightCountTuple> bCount = new HashMap<>(DEFAULT_MAP_SIZE); 150 Map<T3,WeightCountTuple> cCount = new HashMap<>(DEFAULT_MAP_SIZE); 151 152 long count = first.size(); 153 154 if ((first.size() == second.size()) && (first.size() == third.size()) && (first.size() == weights.size())) { 155 for (int i = 0; i < first.size(); i++) { 156 double weight = weights.get(i); 157 T1 a = first.get(i); 158 T2 b = second.get(i); 159 T3 c = third.get(i); 160 CachedTriple<T1,T2,T3> triple = new CachedTriple<>(a,b,c); 161 CachedPair<T1,T2> abPair = triple.getAB(); 162 CachedPair<T1,T3> acPair = triple.getAC(); 163 CachedPair<T2,T3> bcPair = triple.getBC(); 164 165 WeightCountTuple abcCurCount = jointCount.computeIfAbsent(triple,(k) -> new WeightCountTuple()); 166 abcCurCount.weight += weight; 167 abcCurCount.count++; 168 169 WeightCountTuple abCurCount = abCount.computeIfAbsent(abPair,(k) -> new WeightCountTuple()); 170 abCurCount.weight += weight; 171 abCurCount.count++; 172 173 WeightCountTuple acCurCount = acCount.computeIfAbsent(acPair,(k) -> new WeightCountTuple()); 174 acCurCount.weight += weight; 175 acCurCount.count++; 176 177 WeightCountTuple bcCurCount = bcCount.computeIfAbsent(bcPair,(k) -> new WeightCountTuple()); 178 bcCurCount.weight += weight; 179 bcCurCount.count++; 180 181 WeightCountTuple aCurCount = aCount.computeIfAbsent(a,(k) -> new WeightCountTuple()); 182 aCurCount.weight += weight; 183 aCurCount.count++; 184 185 WeightCountTuple bCurCount = bCount.computeIfAbsent(b,(k) -> new WeightCountTuple()); 186 bCurCount.weight += weight; 187 bCurCount.count++; 188 189 WeightCountTuple cCurCount = cCount.computeIfAbsent(c,(k) -> new WeightCountTuple()); 190 cCurCount.weight += weight; 191 cCurCount.count++; 192 } 193 194 WeightedInformationTheory.normaliseWeights(jointCount); 195 WeightedInformationTheory.normaliseWeights(abCount); 196 WeightedInformationTheory.normaliseWeights(acCount); 197 WeightedInformationTheory.normaliseWeights(bcCount); 198 WeightedInformationTheory.normaliseWeights(aCount); 199 WeightedInformationTheory.normaliseWeights(bCount); 200 WeightedInformationTheory.normaliseWeights(cCount); 201 202 return new WeightedTripleDistribution<>(count,jointCount,abCount,acCount,bcCount,aCount,bCount,cCount); 203 } else { 204 throw new IllegalArgumentException("Counting requires lists of the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", third.size() = " + third.size() + ", weights.size() = " + weights.size()); 205 } 206 } 207 208 /** 209 * Constructs a WeightedTripleDistribution by marginalising the supplied joint distribution. 210 * @param jointCount The joint distribution. 211 * @param <T1> The type of A. 212 * @param <T2> The type of B. 213 * @param <T3> The type of C. 214 * @return A WeightedTripleDistribution. 215 */ 216 public static <T1,T2,T3> WeightedTripleDistribution<T1,T2,T3> constructFromMap(Map<CachedTriple<T1,T2,T3>,WeightCountTuple> jointCount) { 217 Map<CachedPair<T1,T2>,WeightCountTuple> abCount = new HashMap<>(DEFAULT_MAP_SIZE); 218 Map<CachedPair<T1,T3>,WeightCountTuple> acCount = new HashMap<>(DEFAULT_MAP_SIZE); 219 Map<CachedPair<T2,T3>,WeightCountTuple> bcCount = new HashMap<>(DEFAULT_MAP_SIZE); 220 Map<T1,WeightCountTuple> aCount = new HashMap<>(DEFAULT_MAP_SIZE); 221 Map<T2,WeightCountTuple> bCount = new HashMap<>(DEFAULT_MAP_SIZE); 222 Map<T3,WeightCountTuple> cCount = new HashMap<>(DEFAULT_MAP_SIZE); 223 224 long count = 0L; 225 226 for (Entry<CachedTriple<T1,T2,T3>,WeightCountTuple> e : jointCount.entrySet()) { 227 CachedTriple<T1,T2,T3> triple = e.getKey(); 228 WeightCountTuple tuple = e.getValue(); 229 CachedPair<T1,T2> abPair = triple.getAB(); 230 CachedPair<T1,T3> acPair = triple.getAC(); 231 CachedPair<T2,T3> bcPair = triple.getBC(); 232 T1 a = triple.getA(); 233 T2 b = triple.getB(); 234 T3 c = triple.getC(); 235 236 count += tuple.count; 237 238 double weight = tuple.weight * tuple.count; 239 240 WeightCountTuple abCurCount = abCount.computeIfAbsent(abPair,(k) -> new WeightCountTuple()); 241 abCurCount.weight += weight; 242 abCurCount.count += tuple.count; 243 244 WeightCountTuple acCurCount = acCount.computeIfAbsent(acPair,(k) -> new WeightCountTuple()); 245 acCurCount.weight += weight; 246 acCurCount.count += tuple.count; 247 248 WeightCountTuple bcCurCount = bcCount.computeIfAbsent(bcPair,(k) -> new WeightCountTuple()); 249 bcCurCount.weight += weight; 250 bcCurCount.count += tuple.count; 251 252 WeightCountTuple aCurCount = aCount.computeIfAbsent(a,(k) -> new WeightCountTuple()); 253 aCurCount.weight += weight; 254 aCurCount.count += tuple.count; 255 256 WeightCountTuple bCurCount = bCount.computeIfAbsent(b,(k) -> new WeightCountTuple()); 257 bCurCount.weight += weight; 258 bCurCount.count += tuple.count; 259 260 WeightCountTuple cCurCount = cCount.computeIfAbsent(c,(k) -> new WeightCountTuple()); 261 cCurCount.weight += weight; 262 cCurCount.count += tuple.count; 263 } 264 265 WeightedInformationTheory.normaliseWeights(abCount); 266 WeightedInformationTheory.normaliseWeights(acCount); 267 WeightedInformationTheory.normaliseWeights(bcCount); 268 WeightedInformationTheory.normaliseWeights(aCount); 269 WeightedInformationTheory.normaliseWeights(bCount); 270 WeightedInformationTheory.normaliseWeights(cCount); 271 272 return new WeightedTripleDistribution<>(count,jointCount,abCount,acCount,bcCount,aCount,bCount,cCount); 273 } 274 275}