001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.util.infotheory.impl; 018 019import com.oracle.labs.mlrg.olcut.util.MutableLong; 020import org.tribuo.util.infotheory.InformationTheory; 021 022import java.util.HashMap; 023import java.util.LinkedHashMap; 024import java.util.List; 025import java.util.Map; 026import java.util.Map.Entry; 027 028/** 029 * A count distribution over {@link CachedPair} objects. 030 * @param <T1> The type of the first element 031 * @param <T2> The type of the second element 032 */ 033public class PairDistribution<T1,T2> { 034 035 /** 036 * The number of samples this distribution has seen. 037 */ 038 public final long count; 039 040 /** 041 * The joint distribution. 042 */ 043 public final Map<CachedPair<T1,T2>,MutableLong> jointCounts; 044 /** 045 * The first marginal distribution. 046 */ 047 public final Map<T1,MutableLong> firstCount; 048 /** 049 * The second marginal distribution. 050 */ 051 public final Map<T2,MutableLong> secondCount; 052 053 /** 054 * Constructs a pair distribution. 055 * @param count The total sample count. 056 * @param jointCounts The joint counts. 057 * @param firstCount The first variable count. 058 * @param secondCount The second variable count. 059 */ 060 public PairDistribution(long count, Map<CachedPair<T1,T2>,MutableLong> jointCounts, Map<T1,MutableLong> firstCount, Map<T2,MutableLong> secondCount) { 061 this.count = count; 062 this.jointCounts = new LinkedHashMap<>(jointCounts); 063 this.firstCount = new LinkedHashMap<>(firstCount); 064 this.secondCount = new LinkedHashMap<>(secondCount); 065 } 066 067 /** 068 * Constructs a pair distribution. 069 * @param count The total sample count. 070 * @param jointCounts The joint counts. 071 * @param firstCount The first variable count. 072 * @param secondCount The second variable count. 073 */ 074 public PairDistribution(long count, LinkedHashMap<CachedPair<T1,T2>,MutableLong> jointCounts, LinkedHashMap<T1,MutableLong> firstCount, LinkedHashMap<T2,MutableLong> secondCount) { 075 this.count = count; 076 this.jointCounts = jointCounts; 077 this.firstCount = firstCount; 078 this.secondCount = secondCount; 079 } 080 081 /** 082 * Generates the counts for two vectors. Returns a PairDistribution containing the joint 083 * count, and the two marginal counts. 084 * @param <T1> Type of the first array. 085 * @param <T2> Type of the second array. 086 * @param first An array of values. 087 * @param second Another array of values. 088 * @return The joint counts and the two marginal counts. 089 */ 090 public static <T1,T2> PairDistribution<T1,T2> constructFromLists(List<T1> first, List<T2> second) { 091 LinkedHashMap<CachedPair<T1,T2>,MutableLong> abCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE); 092 LinkedHashMap<T1,MutableLong> aCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE); 093 LinkedHashMap<T2,MutableLong> bCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE); 094 095 if (first.size() == second.size()) { 096 long count = 0; 097 for (int i = 0; i < first.size(); i++) { 098 T1 a = first.get(i); 099 T2 b = second.get(i); 100 CachedPair<T1,T2> pair = new CachedPair<>(a,b); 101 102 MutableLong abCount = abCountDist.computeIfAbsent(pair, k -> new MutableLong()); 103 abCount.increment(); 104 105 MutableLong aCount = aCountDist.computeIfAbsent(a, k -> new MutableLong()); 106 aCount.increment(); 107 108 MutableLong bCount = bCountDist.computeIfAbsent(b, k -> new MutableLong()); 109 bCount.increment(); 110 111 count++; 112 } 113 114 return new PairDistribution<>(count,abCountDist,aCountDist,bCountDist); 115 } else { 116 throw new IllegalArgumentException("Counting requires arrays of the same length. first.size() = " + first.size() + ", second.size() = " + second.size()); 117 } 118 } 119 120 /** 121 * Constructs a distribution from a joint count. 122 * @param jointCount The joint count. 123 * @param <T1> The type of the first variable. 124 * @param <T2> The type of the second variable. 125 * @return A pair distribution. 126 */ 127 public static <T1,T2> PairDistribution<T1,T2> constructFromMap(Map<CachedPair<T1,T2>,MutableLong> jointCount) { 128 Map<T1,MutableLong> aCount = new HashMap<>(InformationTheory.DEFAULT_MAP_SIZE); 129 Map<T2,MutableLong> bCount = new HashMap<>(InformationTheory.DEFAULT_MAP_SIZE); 130 131 return constructFromMap(jointCount,aCount,bCount); 132 } 133 134 /** 135 * Constructs a distribution from a joint count. 136 * @param jointCount The joint count. 137 * @param aSize The initial size of the first marginal hash map. 138 * @param bSize The initial size of the second marginal hash map. 139 * @param <T1> The type of the first variable. 140 * @param <T2> The type of the second variable. 141 * @return A pair distribution. 142 */ 143 public static <T1,T2> PairDistribution<T1,T2> constructFromMap(Map<CachedPair<T1,T2>,MutableLong> jointCount, int aSize, int bSize) { 144 Map<T1,MutableLong> aCount = new HashMap<>(aSize); 145 Map<T2,MutableLong> bCount = new HashMap<>(bSize); 146 147 return constructFromMap(jointCount,aCount,bCount); 148 } 149 150 /** 151 * Constructs a joint distribution from the counts. 152 * @param jointCount The joint count. 153 * @param aCount The first marginal count. 154 * @param bCount The second marginal count. 155 * @param <T1> The type of the first variable. 156 * @param <T2> The type of the second variable. 157 * @return A pair distribution. 158 */ 159 public static <T1,T2> PairDistribution<T1,T2> constructFromMap(Map<CachedPair<T1,T2>,MutableLong> jointCount, 160 Map<T1,MutableLong> aCount, 161 Map<T2,MutableLong> bCount) { 162 long count = 0L; 163 164 for (Entry<CachedPair<T1,T2>,MutableLong> e : jointCount.entrySet()) { 165 CachedPair<T1,T2> pair = e.getKey(); 166 long curCount = e.getValue().longValue(); 167 T1 a = pair.getA(); 168 T2 b = pair.getB(); 169 170 MutableLong curACount = aCount.computeIfAbsent(a, k -> new MutableLong()); 171 curACount.increment(curCount); 172 173 MutableLong curBCount = bCount.computeIfAbsent(b, k -> new MutableLong()); 174 curBCount.increment(curCount); 175 count += curCount; 176 } 177 178 return new PairDistribution<>(count,jointCount,aCount,bCount); 179 } 180 181}