001/* 002 * Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.util.tokens.impl.wordpiece; 018 019import java.io.IOException; 020import java.io.UncheckedIOException; 021import java.util.ArrayList; 022import java.util.Collections; 023import java.util.HashSet; 024import java.util.List; 025import java.util.Set; 026 027import com.oracle.labs.mlrg.olcut.config.Config; 028import com.oracle.labs.mlrg.olcut.config.Configurable; 029import com.oracle.labs.mlrg.olcut.util.IOUtil; 030 031/** 032 * This is vanilla implementation of the Wordpiece algorithm as found here: 033 * 034 * <a href= 035 * "https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/tokenization_bert.py"> 036 * https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/tokenization_bert.py</a> 037 * 038 * <p> 039 * Please refer to the class definition for <code>WordpieceTokenizer</code>. It 040 * does not include any of the tokenization work that is typically performed 041 * before wordpiece is called as is done in the above-referenced implementation. 042 * That functionality is provided by {@link WordpieceTokenizer} and 043 * {@link WordpieceBasicTokenizer}. 044 * 045 */ 046public class Wordpiece implements Configurable { 047 048 /** 049 * The default unknown token string. 050 */ 051 public static final String DEFAULT_UNKNOWN_TOKEN = "[UNK]"; 052 053 @Config(mandatory=true, description="path to a vocabulary data file.") 054 private String vocabPath; 055 @Config(mandatory=false, description="the value to use for 'UNKNOWN' tokens. Defaults to '[UNK]' which is a common default in BERT-based solutions.") 056 private String unknownToken = DEFAULT_UNKNOWN_TOKEN; 057 @Config(mandatory=false, description="the maximum number of characters per word to consider. This helps eliminate doing extra work on pathological cases.") 058 private int maxInputCharactersPerWord = 100; 059 060 private Set<String> vocab; 061 062 /** 063 * For OLCUT. 064 */ 065 @SuppressWarnings("unused") 066 private Wordpiece() { } 067 068 /** 069 * Constructs a Wordpiece using the supplied vocab. 070 * <p> 071 * Sets the unknown token to {@link #DEFAULT_UNKNOWN_TOKEN}. 072 * @param vocab The wordpiece vocabulary. 073 */ 074 public Wordpiece(Set<String> vocab) { 075 this(vocab, DEFAULT_UNKNOWN_TOKEN); 076 } 077 078 /** 079 * Constructs a Wordpiece using the supplied vocabulary and unknown token. 080 * @param vocab The wordpiece vocabulary. 081 * @param unknownToken The unknown token. 082 */ 083 public Wordpiece(Set<String> vocab, String unknownToken) { 084 this(vocab, unknownToken, 100); 085 } 086 087 /** 088 * Initializes an instance of Wordpiece with the given vocabulary, unknown 089 * token, and max word length. 090 * 091 * @param vocab the pre-trained wordpiece vocabulary. See 092 * the contents of e.g., 093 * <a href="https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt"> 094 * https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt</a> 095 * @param unknownToken a string used to indicate a token was not 096 * found in the vocabulary - typically "[UNK]" 097 * @param maxInputCharactersPerWord a maximum to shield against looping over 098 * character-by-character pathologically long 099 * "tokens" 100 */ 101 public Wordpiece(Set<String> vocab, String unknownToken, int maxInputCharactersPerWord) { 102 this.vocab = Collections.unmodifiableSet(vocab); 103 this.unknownToken = unknownToken; 104 this.maxInputCharactersPerWord = maxInputCharactersPerWord; 105 } 106 107 /** 108 * Constructs a wordpiece by reading the vocabulary from the supplied path. 109 * @param vocabPath The path to the wordpiece vocabulary. 110 */ 111 public Wordpiece(String vocabPath) { 112 this.vocabPath = vocabPath; 113 try { 114 this.postConfig(); 115 } catch (IOException e) { 116 throw new UncheckedIOException(e); 117 } 118 } 119 120 /** 121 * Initializes an instance of Wordpiece with the given vocabulary, unknown 122 * token, and max word length. 123 * 124 * @param vocabPath Path to the pre-trained wordpiece vocabulary. See 125 * the contents of e.g. 126 * https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt 127 * @param unknownToken a string used to indicate a token was not 128 * found in the vocabulary - typically "[UNK]" 129 * @param maxInputCharactersPerWord a maximum to shield against looping over 130 * character-by-character pathologically long 131 * "tokens" 132 */ 133 public Wordpiece(String vocabPath, String unknownToken, int maxInputCharactersPerWord) { 134 this.vocabPath = vocabPath; 135 this.unknownToken = unknownToken; 136 this.maxInputCharactersPerWord = maxInputCharactersPerWord; 137 try { 138 this.postConfig(); 139 } catch (IOException e) { 140 throw new UncheckedIOException(e); 141 } 142 } 143 144 /** 145 * Used by the OLCUT configuration system, and should not be called by external code. 146 */ 147 @Override 148 public void postConfig() throws IOException { 149 this.vocab = Collections.unmodifiableSet(new HashSet<>(IOUtil.getLines(this.vocabPath))); 150 } 151 152 /** 153 * Executes Wordpiece tokenization on the provided token. Note that tokens 154 * corresponding to word suffixes as indicated in the provided vocabulary with 155 * the sequence "##" prepended to the entry may be returned by this method. This 156 * method does not perform whitespace tokenization or any other preprocessing. 157 * This method does not lowercase the input token or otherwise modify it in any 158 * way. 159 * 160 * @param token the token to apply Wordpiece tokenization to. 161 * @return tokens corresponding to Wordpiece tokenization applied to the input 162 * text. Some tokens may have a prefix "##" as described above. Some 163 * tokens may correspond to an unknown token as specified during 164 * initialization (default "[UNK]") 165 */ 166 public List<String> wordpiece(String token) { 167 if (token.length() > this.maxInputCharactersPerWord) { 168 return Collections.singletonList(this.unknownToken); 169 } 170 171 List<String> subTokens = new ArrayList<>(); 172 173 boolean isBad = false; 174 int start = 0; 175 while (start < token.length()) { 176 int end = token.length(); 177 String currentSubstring = null; 178 while (start < end) { 179 String substring = token.substring(start, end); 180 if (start > 0) { 181 substring = "##" + substring; 182 } 183 if (this.vocab.contains(substring)) { 184 currentSubstring = substring; 185 break; 186 } 187 end--; 188 } 189 if (currentSubstring == null) { 190 isBad = true; 191 break; 192 } 193 subTokens.add(currentSubstring); 194 start = end; 195 } 196 if (isBad) { 197 return Collections.singletonList(this.unknownToken); 198 } else { 199 return subTokens; 200 } 201 } 202 203 /** 204 * a getter for the "unknown" token specified during initialization. 205 * 206 * @return the "unknown" token name - defaults to "[UNK]" 207 */ 208 public String getUnknownToken() { 209 return unknownToken; 210 } 211 212 /** 213 * a getter for the maximum character count for a token to consider when 214 * {@link #wordpiece(String)} is applied to a token. This value is set at 215 * initialization and defaults to 100. Token values passed to that method that 216 * are not tokenized and the result of {@link #getUnknownToken()} is returned 217 * instead. 218 * 219 * @return the maximum length of a token that will be analyzed by 220 * {@link #wordpiece(String)}. 221 */ 222 public int getMaxInputCharactersPerWord() { 223 return maxInputCharactersPerWord; 224 } 225 226}