001/* 002 * Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016package org.tribuo.util.tokens.impl.wordpiece; 017 018import java.text.Normalizer; 019import java.util.ArrayList; 020import java.util.Collections; 021import java.util.List; 022import java.util.Set; 023import java.util.regex.Pattern; 024 025import org.tribuo.util.tokens.Token; 026import org.tribuo.util.tokens.Token.TokenType; 027import org.tribuo.util.tokens.Tokenizer; 028import org.tribuo.util.tokens.impl.WhitespaceTokenizer; 029 030import com.oracle.labs.mlrg.olcut.config.Config; 031import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 032import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; 033 034/** 035 * This Tokenizer is meant to be a reasonable approximation of the BertTokenizer 036 * defined <a href= 037 * "https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/tokenization_bert.py#L117">here</a>. 038 * Please see class definition for <code>BertTokenizer</code> (the line numbers 039 * may change.) Please see notes in WordpieceTokenizerTest for information about 040 * how we tested the similarity between this tokenizer and the referenced python 041 * implementation and for regression test examples that fail. In short, there 042 * are outstanding discrepancies for texts that include Arabic and other 043 * non-latin scripts that generate so many "[UNK]" tokens for an English-based 044 * BPE vocabulary as to render the discrepancies as practically meaningless. 045 * <p> 046 * As in the reference implementation, the input text is whitespace tokenized 047 * and each token is further tokenized to account for things like punctuation 048 * and Chinese characters. The resulting tokens are then applied to the 049 * wordpiece algorithm implemented in {@link Wordpiece} which is driven by an 050 * input vocabulary which matches tokens and token suffixes as it can. Any 051 * tokens that are not found in the input vocbulary are marked as "unknown". 052 */ 053public class WordpieceTokenizer implements Tokenizer { 054 055 private static final Pattern accentsPattern = Pattern.compile("\\p{Mn}"); 056 057 @Config(mandatory=true, description="an instance of Wordpiece which applies the 'wordpiece' algorithm") 058 private Wordpiece wordpiece; 059 @Config(description="determines whether or not to lowercase the input text") 060 private boolean toLowerCase = true; 061 @Config(description="performs whitespace tokenization before 'basic' tokenizer is applied (see basicTokenizer)") 062 private Tokenizer whitespaceTokenizer = new WhitespaceTokenizer(); 063 @Config(description="performs some tokenization work on the input text before the wordpiece algorithm is applied to each resulting token.") 064 private Tokenizer basicTokenizer = new WordpieceBasicTokenizer(); 065 @Config(description="determines whether or not to strip accents/diacritics from the input text") 066 private boolean stripAccents = true; 067 @Config(description="a set of 'token' strings that should never be split regardless of whether they have e.g., punctuation in the middle. No entries should have whitespace in them.") 068 private Set<String> neverSplitTokens = Collections.emptySet(); 069 070 // internal state member variables 071 private boolean reset; 072 073 private Token currentToken; 074 075 private List<Token> currentWordpieceTokens = new ArrayList<>(); 076 077 private int currentWordpieceIndex; 078 079 /** 080 * For OLCUT. 081 */ 082 @SuppressWarnings("unused") 083 private WordpieceTokenizer() { 084 } 085 086 /** 087 * Constructs a wordpiece tokenizer. 088 * @param wordpiece an instance of {@link Wordpiece} 089 * @param tokenizer Wordpiece is run on the tokens generated by the 090 * tokenizer provided here. 091 * @param toLowerCase determines whether or not to lowercase each token 092 * before running Wordpiece on it 093 * @param stripAccents determines whether or not to strip out accents from 094 * each token before running Wordpiece on it 095 * @param neverSplit a set of token values that we will not apply 096 * Wordpiece to. 097 */ 098 public WordpieceTokenizer(Wordpiece wordpiece, Tokenizer tokenizer, boolean toLowerCase, boolean stripAccents, 099 Set<String> neverSplit) { 100 this.wordpiece = wordpiece; 101 this.basicTokenizer = tokenizer; 102 this.toLowerCase = toLowerCase; 103 this.stripAccents = stripAccents; 104 this.neverSplitTokens = neverSplit; 105 } 106 107 @Override 108 public ConfiguredObjectProvenance getProvenance() { 109 return new ConfiguredObjectProvenanceImpl(this, "Tokenizer"); 110 } 111 112 @Override 113 public void reset(CharSequence cs) { 114 this.reset = true; 115 this.whitespaceTokenizer.reset(cs); 116 this.currentWordpieceTokens.clear(); 117 currentWordpieceIndex = -1; 118 if (this.whitespaceTokenizer.advance()) { 119 this.currentToken = this.whitespaceTokenizer.getToken(); 120 getWordpieceTokens(); 121 } 122 } 123 124 @Override 125 public boolean advance() { 126 if (!reset) { 127 throw new IllegalStateException("WordpieceTokenizer has not been reset."); 128 } 129 currentWordpieceIndex++; 130 if (currentWordpieceIndex < currentWordpieceTokens.size()) { 131 return true; 132 } else if (whitespaceTokenizer.advance()) { 133 currentToken = this.whitespaceTokenizer.getToken(); 134 getWordpieceTokens(); 135 currentWordpieceIndex = 0; 136 if (currentWordpieceTokens.size() == 0) { 137 return advance(); 138 } 139 return true; 140 } else { 141 return false; 142 } 143 } 144 145 /** 146 * Normalizes the text by converting it into the canonical unicode decomposition 147 * and then replacing accents. 148 * @param text The input text to normalize. 149 * @return A normalized form of the text. 150 */ 151 private static String normalize(String text) { 152 text = Normalizer.normalize(text, Normalizer.Form.NFD); 153 text = accentsPattern.matcher(text).replaceAll(""); 154 return text; 155 } 156 157 /** 158 * Generates the wordpiece tokens from the next token. 159 */ 160 private void getWordpieceTokens() { 161 this.currentWordpieceTokens.clear(); 162 163 String text = currentToken.text; 164 if(neverSplitTokens.contains(text)) { 165 currentWordpieceTokens.add(currentToken); 166 return; 167 } 168 169 List<Token> basicTokens = this.basicTokenizer.tokenize(text); 170 for(Token basicToken : basicTokens) { 171 172 text = basicToken.text; 173 174 if (toLowerCase) { 175 text = text.toLowerCase(); 176 } 177 178 if (this.stripAccents) { 179 text = normalize(text); 180 } 181 182 List<String> wordpieces = wordpiece.wordpiece(text); 183 184 if (wordpieces.size() == 0) { 185 return; 186 } else if (wordpieces.size() == 1) { 187 String wp = wordpieces.get(0); 188 int start = basicToken.start + currentToken.start; 189 int end = basicToken.end + currentToken.start; 190 if (wp.equals(this.wordpiece.getUnknownToken())) { 191 currentWordpieceTokens.add(new Token(wp, start, end, TokenType.UNKNOWN)); 192 } else { 193 currentWordpieceTokens.add(new Token(wp, start, end, TokenType.WORD)); 194 } 195 } else { 196 int begin = currentToken.start + basicToken.start; 197 for (String wp : wordpieces) { 198 TokenType type = TokenType.PREFIX; 199 int end = begin + wp.length(); 200 if (wp.startsWith("##")) { 201 end -= 2; 202 type = TokenType.SUFFIX; 203 } 204 currentWordpieceTokens.add(new Token(wp, begin, end, type)); 205 begin = end; 206 } 207 } 208 } 209 } 210 211 @Override 212 public Token getToken() { 213 if (currentWordpieceIndex < currentWordpieceTokens.size()) { 214 return currentWordpieceTokens.get(currentWordpieceIndex); 215 } else { 216 throw new IllegalStateException("WordpieceTokenizer is not ready."); 217 } 218 } 219 220 @Override 221 public String getText() { 222 return getToken().text; 223 } 224 225 @Override 226 public int getStart() { 227 return getToken().start; 228 } 229 230 @Override 231 public int getEnd() { 232 return getToken().end; 233 } 234 235 @Override 236 public TokenType getType() { 237 return getToken().type; 238 } 239 240 @Override 241 public WordpieceTokenizer clone() { 242 try { 243 WordpieceTokenizer copy = (WordpieceTokenizer) super.clone(); 244 copy.whitespaceTokenizer = whitespaceTokenizer.clone(); 245 copy.basicTokenizer = basicTokenizer.clone(); 246 copy.reset = false; 247 copy.currentToken = null; 248 copy.currentWordpieceTokens.clear(); 249 copy.currentWordpieceIndex = -1; 250 return copy; 251 } catch (CloneNotSupportedException e) { 252 throw new AssertionError("WordpieceTokenizer is Cloneable, but clone call failed"); 253 } 254 } 255}