001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.util.infotheory.example; 018 019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager; 020import com.oracle.labs.mlrg.olcut.config.Option; 021import com.oracle.labs.mlrg.olcut.config.Options; 022import com.oracle.labs.mlrg.olcut.config.UsageException; 023import org.tribuo.util.infotheory.InformationTheory; 024import org.tribuo.util.infotheory.impl.CachedTriple; 025 026import java.util.ArrayList; 027import java.util.List; 028import java.util.Random; 029import java.util.logging.Level; 030import java.util.logging.Logger; 031 032/** 033 * Demo showing how to calculate various mutual informations and entropies. 034 */ 035public class InformationTheoryDemo { 036 037 private static final Logger logger = Logger.getLogger(InformationTheoryDemo.class.getName()); 038 039 private static final Random rng = new Random(1); 040 041 /** 042 * Generates a sample from a uniform distribution over the integers. 043 * @param length The number of samples. 044 * @param alphabetSize The alphabet size (i.e., the number of unique values). 045 * @return A sample from a uniform distribution. 046 */ 047 public static List<Integer> generateUniform(int length, int alphabetSize) { 048 List<Integer> vector = new ArrayList<>(length); 049 050 for (int i = 0; i < length; i++) { 051 vector.add(i,rng.nextInt(alphabetSize)); 052 } 053 054 return vector; 055 } 056 057 /** 058 * Generates a sample from a three variable XOR function. 059 * <p> 060 * Each list is a binary variable, and the third is the XOR of the first two. 061 * @param length The number of samples. 062 * @return A sample from an XOR function. 063 */ 064 public static CachedTriple<List<Integer>,List<Integer>,List<Integer>> generateXOR(int length) { 065 List<Integer> first = new ArrayList<>(length); 066 List<Integer> second = new ArrayList<>(length); 067 List<Integer> xor = new ArrayList<>(length); 068 069 for (int i = 0; i < length; i++) { 070 int firstVal = rng.nextInt(2); 071 int secondVal = rng.nextInt(2); 072 int xorVal = firstVal ^ secondVal; 073 first.add(i,firstVal); 074 second.add(i,secondVal); 075 xor.add(i,xorVal); 076 } 077 078 return new CachedTriple<>(first,second,xor); 079 } 080 081 /** 082 * These correlations don't map to mutual information values, as if xyDraw is above xyCorrelation then the draw is completely random. 083 * <p> 084 * To make it generate correlations of a specific mutual information then it needs to specify the full joint distribution and draw from that. 085 * @param length The number of samples. 086 * @param alphabetSize The alphabet size (i.e., the number of unique values). 087 * @param xyCorrelation Value between 0.0 and 1.0 specifying how likely it is that Y has the same value as X. 088 * @param xzCorrelation Value between 0.0 and 1.0 specifying how likely it is that Z has the same value as X. 089 * @return A triple of samples drawn from correlated random variables. 090 */ 091 public static CachedTriple<List<Integer>,List<Integer>,List<Integer>> generateCorrelated(int length, int alphabetSize, double xyCorrelation, double xzCorrelation) { 092 List<Integer> first = new ArrayList<>(length); 093 List<Integer> second = new ArrayList<>(length); 094 List<Integer> third = new ArrayList<>(length); 095 096 for (int i = 0; i < length; i++) { 097 int firstVal = rng.nextInt(alphabetSize); 098 first.add(firstVal); 099 100 double xyDraw = rng.nextDouble(); 101 if (xyDraw < xyCorrelation) { 102 second.add(firstVal); 103 } else { 104 second.add(rng.nextInt(alphabetSize)); 105 } 106 107 double xzDraw = rng.nextDouble(); 108 if (xzDraw < xzCorrelation) { 109 third.add(firstVal); 110 } else { 111 third.add(rng.nextInt(alphabetSize)); 112 } 113 } 114 115 return new CachedTriple<>(first,second,third); 116 } 117 118 /** 119 * Type of data distribution. 120 */ 121 public enum DistributionType { 122 /** 123 * Uniformly randomly generated data. 124 */ 125 RANDOM, 126 /** 127 * Data generated from an XOR function. 128 */ 129 XOR, 130 /** 131 * Correlated data. 132 */ 133 CORRELATED 134 } 135 136 /** 137 * Command line options. 138 */ 139 public static class DemoOptions implements Options { 140 @Override 141 public String getOptionsDescription() { 142 return "A demo class showing how to calculate various mutual informations from different inputs."; 143 } 144 145 /** 146 * The type of the input distribution. 147 */ 148 @Option(charName = 't', longName = "type", usage = "The type of the input distribution.") 149 public DistributionType type = DistributionType.RANDOM; 150 } 151 152 /** 153 * Runs a simple demo of the information theory functions. 154 * @param args The CLI arguments. 155 */ 156 public static void main(String[] args) { 157 158 DemoOptions options = new DemoOptions(); 159 160 try { 161 ConfigurationManager cm = new ConfigurationManager(args, options, false); 162 } catch (UsageException e) { 163 System.out.println(e.getUsage()); 164 } 165 166 List<Integer> x; 167 List<Integer> y; 168 List<Integer> z; 169 170 switch (options.type) { 171 case RANDOM: 172 x = generateUniform(1000, 5); 173 y = generateUniform(1000, 5); 174 z = generateUniform(1000, 5); 175 break; 176 case XOR: 177 CachedTriple<List<Integer>,List<Integer>,List<Integer>> trip = generateXOR(1000); 178 x = trip.getA(); 179 y = trip.getB(); 180 z = trip.getC(); 181 break; 182 case CORRELATED: 183 CachedTriple<List<Integer>,List<Integer>,List<Integer>> tripC = generateCorrelated(1000,5,0.7,0.5); 184 x = tripC.getA(); 185 y = tripC.getB(); 186 z = tripC.getC(); 187 break; 188 default: 189 logger.log(Level.WARNING, "Unknown test case, exiting"); 190 return; 191 } 192 193 double hx = InformationTheory.entropy(x); 194 double hy = InformationTheory.entropy(y); 195 double hz = InformationTheory.entropy(z); 196 197 double hxy = InformationTheory.jointEntropy(x,y); 198 double hxz = InformationTheory.jointEntropy(x,z); 199 double hyz = InformationTheory.jointEntropy(y,z); 200 201 double ixy = InformationTheory.mi(x,y); 202 double ixz = InformationTheory.mi(x,z); 203 double iyz = InformationTheory.mi(y,z); 204 205 InformationTheory.GTestStatistics gxy = InformationTheory.gTest(x,y,null); 206 InformationTheory.GTestStatistics gxz = InformationTheory.gTest(x,z,null); 207 InformationTheory.GTestStatistics gyz = InformationTheory.gTest(y,z,null); 208 209 if (InformationTheory.LOG_BASE == InformationTheory.LOG_2) { 210 logger.log(Level.INFO, "Using log_2"); 211 } else if (InformationTheory.LOG_BASE == InformationTheory.LOG_E) { 212 logger.log(Level.INFO, "Using log_e"); 213 } else { 214 logger.log(Level.INFO, "Using unexpected log base, LOG_BASE = " + InformationTheory.LOG_BASE); 215 } 216 217 logger.log(Level.INFO, "The entropy of X, H(X) is " + hx); 218 logger.log(Level.INFO, "The entropy of Y, H(Y) is " + hy); 219 logger.log(Level.INFO, "The entropy of Z, H(Z) is " + hz); 220 221 logger.log(Level.INFO, "The joint entropy of X and Y, H(X,Y) is " + hxy); 222 logger.log(Level.INFO, "The joint entropy of X and Z, H(X,Z) is " + hxz); 223 logger.log(Level.INFO, "The joint entropy of Y and Z, H(Y,Z) is " + hyz); 224 225 logger.log(Level.INFO, "The mutual information between X and Y, I(X;Y) is " + ixy); 226 logger.log(Level.INFO, "The mutual information between X and Z, I(X;Z) is " + ixz); 227 logger.log(Level.INFO, "The mutual information between Y and Z, I(Y;Z) is " + iyz); 228 229 logger.log(Level.INFO, "The G-Test between X and Y, G(X;Y) is " + gxy); 230 logger.log(Level.INFO, "The G-Test between X and Z, G(X;Z) is " + gxz); 231 logger.log(Level.INFO, "The G-Test between Y and Z, G(Y;Z) is " + gyz); 232 } 233 234}