001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.util.infotheory.impl;
018
019import com.oracle.labs.mlrg.olcut.util.MutableLong;
020import org.tribuo.util.infotheory.InformationTheory;
021
022import java.util.HashMap;
023import java.util.LinkedHashMap;
024import java.util.List;
025import java.util.Map;
026import java.util.Map.Entry;
027
028/**
029 * A count distribution over {@link CachedPair} objects.
030 * @param <T1> The type of the first element
031 * @param <T2> The type of the second element
032 */
033public class PairDistribution<T1,T2> {
034
035    /**
036     * The number of samples this distribution has seen.
037     */
038    public final long count;
039
040    /**
041     * The joint distribution.
042     */
043    public final Map<CachedPair<T1,T2>,MutableLong> jointCounts;
044    /**
045     * The first marginal distribution.
046     */
047    public final Map<T1,MutableLong> firstCount;
048    /**
049     * The second marginal distribution.
050     */
051    public final Map<T2,MutableLong> secondCount;
052
053    /**
054     * Constructs a pair distribution.
055     * @param count The total sample count.
056     * @param jointCounts The joint counts.
057     * @param firstCount The first variable count.
058     * @param secondCount The second variable count.
059     */
060    public PairDistribution(long count, Map<CachedPair<T1,T2>,MutableLong> jointCounts, Map<T1,MutableLong> firstCount, Map<T2,MutableLong> secondCount) {
061        this.count = count;
062        this.jointCounts = new LinkedHashMap<>(jointCounts);
063        this.firstCount = new LinkedHashMap<>(firstCount);
064        this.secondCount = new LinkedHashMap<>(secondCount);
065    }
066
067    /**
068     * Constructs a pair distribution.
069     * @param count The total sample count.
070     * @param jointCounts The joint counts.
071     * @param firstCount The first variable count.
072     * @param secondCount The second variable count.
073     */
074    public PairDistribution(long count, LinkedHashMap<CachedPair<T1,T2>,MutableLong> jointCounts, LinkedHashMap<T1,MutableLong> firstCount, LinkedHashMap<T2,MutableLong> secondCount) {
075        this.count = count;
076        this.jointCounts = jointCounts;
077        this.firstCount = firstCount;
078        this.secondCount = secondCount;
079    }
080    
081    /**
082     * Generates the counts for two vectors. Returns a PairDistribution containing the joint
083     * count, and the two marginal counts.
084     * @param <T1> Type of the first array.
085     * @param <T2> Type of the second array.
086     * @param first An array of values.
087     * @param second Another array of values.
088     * @return The joint counts and the two marginal counts.
089     */
090    public static <T1,T2> PairDistribution<T1,T2> constructFromLists(List<T1> first, List<T2> second) {
091        LinkedHashMap<CachedPair<T1,T2>,MutableLong> abCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE);
092        LinkedHashMap<T1,MutableLong> aCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE);
093        LinkedHashMap<T2,MutableLong> bCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE);
094
095        if (first.size() == second.size()) {
096            long count = 0;
097            for (int i = 0; i < first.size(); i++) {
098                T1 a = first.get(i);
099                T2 b = second.get(i);
100                CachedPair<T1,T2> pair = new CachedPair<>(a,b);
101
102                MutableLong abCount = abCountDist.computeIfAbsent(pair, k -> new MutableLong());
103                abCount.increment();
104
105                MutableLong aCount = aCountDist.computeIfAbsent(a, k -> new MutableLong());
106                aCount.increment();
107
108                MutableLong bCount = bCountDist.computeIfAbsent(b, k -> new MutableLong());
109                bCount.increment();
110
111                count++;
112            }
113
114            return new PairDistribution<>(count,abCountDist,aCountDist,bCountDist);
115        } else {
116            throw new IllegalArgumentException("Counting requires arrays of the same length. first.size() = " + first.size() + ", second.size() = " + second.size());
117        }
118    }
119
120    /**
121     * Constructs a distribution from a joint count.
122     * @param jointCount The joint count.
123     * @param <T1> The type of the first variable.
124     * @param <T2> The type of the second variable.
125     * @return A pair distribution.
126     */
127    public static <T1,T2> PairDistribution<T1,T2> constructFromMap(Map<CachedPair<T1,T2>,MutableLong> jointCount) {
128        Map<T1,MutableLong> aCount = new HashMap<>(InformationTheory.DEFAULT_MAP_SIZE);
129        Map<T2,MutableLong> bCount = new HashMap<>(InformationTheory.DEFAULT_MAP_SIZE);
130
131        return constructFromMap(jointCount,aCount,bCount);
132    }
133
134    /**
135     * Constructs a distribution from a joint count.
136     * @param jointCount The joint count.
137     * @param aSize The initial size of the first marginal hash map.
138     * @param bSize The initial size of the second marginal hash map.
139     * @param <T1> The type of the first variable.
140     * @param <T2> The type of the second variable.
141     * @return A pair distribution.
142     */
143    public static <T1,T2> PairDistribution<T1,T2> constructFromMap(Map<CachedPair<T1,T2>,MutableLong> jointCount, int aSize, int bSize) {
144        Map<T1,MutableLong> aCount = new HashMap<>(aSize);
145        Map<T2,MutableLong> bCount = new HashMap<>(bSize);
146
147        return constructFromMap(jointCount,aCount,bCount);
148    }
149
150    /**
151     * Constructs a joint distribution from the counts.
152     * @param jointCount The joint count.
153     * @param aCount The first marginal count.
154     * @param bCount The second marginal count.
155     * @param <T1> The type of the first variable.
156     * @param <T2> The type of the second variable.
157     * @return A pair distribution.
158     */
159    public static <T1,T2> PairDistribution<T1,T2> constructFromMap(Map<CachedPair<T1,T2>,MutableLong> jointCount,
160                                                                           Map<T1,MutableLong> aCount,
161                                                                           Map<T2,MutableLong> bCount) {
162        long count = 0L;
163        
164        for (Entry<CachedPair<T1,T2>,MutableLong> e : jointCount.entrySet()) {
165            CachedPair<T1,T2> pair = e.getKey();
166            long curCount = e.getValue().longValue();
167            T1 a = pair.getA();
168            T2 b = pair.getB();
169
170            MutableLong curACount = aCount.computeIfAbsent(a, k -> new MutableLong());
171            curACount.increment(curCount);
172
173            MutableLong curBCount = bCount.computeIfAbsent(b, k -> new MutableLong());
174            curBCount.increment(curCount);
175            count += curCount;
176        }
177
178        return new PairDistribution<>(count,jointCount,aCount,bCount);
179    }
180
181}