001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.util.infotheory.impl; 018 019import org.tribuo.util.infotheory.InformationTheory; 020import org.tribuo.util.infotheory.WeightedInformationTheory; 021 022import java.util.LinkedHashMap; 023import java.util.List; 024import java.util.Map; 025import java.util.Map.Entry; 026 027/** 028 * Generates the counts for a pair of vectors. Contains the joint 029 * count and the two marginal counts. 030 * @param <T1> Type of the first list. 031 * @param <T2> Type of the second list. 032 */ 033public class WeightedPairDistribution<T1,T2> { 034 035 /** 036 * The sample count. 037 */ 038 public final long count; 039 040 private final Map<CachedPair<T1,T2>,WeightCountTuple> jointCounts; 041 private final Map<T1,WeightCountTuple> firstCount; 042 private final Map<T2,WeightCountTuple> secondCount; 043 044 /** 045 * Constructs a weighted pair distribution from the supplied values. 046 * <p> 047 * Copies the maps out into LinkedHashMaps for iteration speed. 048 * @param count The sample count. 049 * @param jointCounts The joint distribution. 050 * @param firstCount The first marginal distribution. 051 * @param secondCount The second marginal distribution. 052 */ 053 public WeightedPairDistribution(long count, Map<CachedPair<T1,T2>,WeightCountTuple> jointCounts, Map<T1,WeightCountTuple> firstCount, Map<T2,WeightCountTuple> secondCount) { 054 this.count = count; 055 this.jointCounts = new LinkedHashMap<>(jointCounts); 056 this.firstCount = new LinkedHashMap<>(firstCount); 057 this.secondCount = new LinkedHashMap<>(secondCount); 058 } 059 060 /** 061 * Constructs a weighted pair distribution from the supplied values. 062 * @param count The sample count. 063 * @param jointCounts The joint distribution. 064 * @param firstCount The first marginal distribution. 065 * @param secondCount The second marginal distribution. 066 */ 067 public WeightedPairDistribution(long count, LinkedHashMap<CachedPair<T1,T2>,WeightCountTuple> jointCounts, LinkedHashMap<T1,WeightCountTuple> firstCount, LinkedHashMap<T2,WeightCountTuple> secondCount) { 068 this.count = count; 069 this.jointCounts = jointCounts; 070 this.firstCount = firstCount; 071 this.secondCount = secondCount; 072 } 073 074 /** 075 * Gets the joint distribution. 076 * @return The joint distribution. 077 */ 078 public Map<CachedPair<T1,T2>,WeightCountTuple> getJointCounts() { 079 return jointCounts; 080 } 081 082 /** 083 * Gets the first marginal distribution. 084 * @return The first marginal distribution. 085 */ 086 public Map<T1,WeightCountTuple> getFirstCount() { 087 return firstCount; 088 } 089 090 /** 091 * Gets the second marginal distribution. 092 * @return The second marginal distribution. 093 */ 094 public Map<T2,WeightCountTuple> getSecondCount() { 095 return secondCount; 096 } 097 098 /** 099 * Generates the counts for two vectors. Returns a pair containing the joint 100 * count, and a pair of the two marginal counts. 101 * @param <T1> Type of the first list. 102 * @param <T2> Type of the second list. 103 * @param first An list of values. 104 * @param second Another list of values. 105 * @param weights An list of per example weights. 106 * @return A WeightedPairDistribution. 107 */ 108 public static <T1,T2> WeightedPairDistribution<T1,T2> constructFromLists(List<T1> first, List<T2> second, List<Double> weights) { 109 LinkedHashMap<CachedPair<T1,T2>,WeightCountTuple> countDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE); 110 LinkedHashMap<T1,WeightCountTuple> aCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE); 111 LinkedHashMap<T2,WeightCountTuple> bCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE); 112 113 if ((first.size() == second.size()) && (first.size() == weights.size())) { 114 long count = 0; 115 for (int i = 0; i < first.size(); i++) { 116 T1 a = first.get(i); 117 T2 b = second.get(i); 118 double weight = weights.get(i); 119 CachedPair<T1,T2> pair = new CachedPair<>(a,b); 120 121 WeightCountTuple abCurCount = countDist.computeIfAbsent(pair,(k) -> new WeightCountTuple()); 122 abCurCount.weight += weight; 123 abCurCount.count++; 124 125 WeightCountTuple aCurCount = aCountDist.computeIfAbsent(a,(k) -> new WeightCountTuple()); 126 aCurCount.weight += weight; 127 aCurCount.count++; 128 129 WeightCountTuple bCurCount = bCountDist.computeIfAbsent(b,(k) -> new WeightCountTuple()); 130 bCurCount.weight += weight; 131 bCurCount.count++; 132 133 count++; 134 } 135 136 WeightedInformationTheory.normaliseWeights(countDist); 137 WeightedInformationTheory.normaliseWeights(aCountDist); 138 WeightedInformationTheory.normaliseWeights(bCountDist); 139 140 return new WeightedPairDistribution<>(count,countDist,aCountDist,bCountDist); 141 } else { 142 throw new IllegalArgumentException("Counting requires lists of the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", weights.size() = " + weights.size()); 143 } 144 } 145 146 /** 147 * Generates a WeightedPairDistribution by generating the marginal distributions for the first and second elements. 148 * This assumes the weights have already been normalised. 149 * @param <T1> Type of the first element. 150 * @param <T2> Type of the second element. 151 * @param jointCount The (normalised) input map. 152 * @return A WeightedPairDistribution 153 */ 154 public static <T1,T2> WeightedPairDistribution<T1,T2> constructFromMap(Map<CachedPair<T1,T2>,WeightCountTuple> jointCount) { 155 LinkedHashMap<CachedPair<T1,T2>,WeightCountTuple> countDist = new LinkedHashMap<>(jointCount); 156 LinkedHashMap<T1,WeightCountTuple> aCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE); 157 LinkedHashMap<T2,WeightCountTuple> bCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE); 158 159 long count = 0L; 160 161 for (Entry<CachedPair<T1,T2>,WeightCountTuple> e : countDist.entrySet()) { 162 CachedPair<T1,T2> pair = e.getKey(); 163 WeightCountTuple tuple = e.getValue(); 164 T1 a = pair.getA(); 165 T2 b = pair.getB(); 166 double weight = tuple.weight * tuple.count; 167 168 WeightCountTuple aCurCount = aCountDist.computeIfAbsent(a,(k) -> new WeightCountTuple()); 169 aCurCount.weight += weight; 170 aCurCount.count += tuple.count; 171 172 WeightCountTuple bCurCount = bCountDist.computeIfAbsent(b,(k) -> new WeightCountTuple()); 173 bCurCount.weight += weight; 174 bCurCount.count += tuple.count; 175 176 count += tuple.count; 177 } 178 179 WeightedInformationTheory.normaliseWeights(aCountDist); 180 WeightedInformationTheory.normaliseWeights(bCountDist); 181 182 return new WeightedPairDistribution<>(count,countDist,aCountDist,bCountDist); 183 } 184 185}