001/*
002 * Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016package org.tribuo.util.tokens.impl.wordpiece;
017
018import java.text.Normalizer;
019import java.util.ArrayList;
020import java.util.Collections;
021import java.util.List;
022import java.util.Set;
023import java.util.regex.Pattern;
024
025import org.tribuo.util.tokens.Token;
026import org.tribuo.util.tokens.Token.TokenType;
027import org.tribuo.util.tokens.Tokenizer;
028import org.tribuo.util.tokens.impl.WhitespaceTokenizer;
029
030import com.oracle.labs.mlrg.olcut.config.Config;
031import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
032import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
033
034/**
035 * This Tokenizer is meant to be a reasonable approximation of the BertTokenizer
036 * defined <a href=
037 * "https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/tokenization_bert.py#L117">here</a>.
038 * Please see class definition for <code>BertTokenizer</code> (the line numbers
039 * may change.) Please see notes in WordpieceTokenizerTest for information about
040 * how we tested the similarity between this tokenizer and the referenced python
041 * implementation and for regression test examples that fail. In short, there
042 * are outstanding discrepancies for texts that include Arabic and other
043 * non-latin scripts that generate so many "[UNK]" tokens for an English-based
044 * BPE vocabulary as to render the discrepancies as practically meaningless.
045 * <p>
046 * As in the reference implementation, the input text is whitespace tokenized
047 * and each token is further tokenized to account for things like punctuation
048 * and Chinese characters. The resulting tokens are then applied to the
049 * wordpiece algorithm implemented in {@link Wordpiece} which is driven by an
050 * input vocabulary which matches tokens and token suffixes as it can. Any
051 * tokens that are not found in the input vocbulary are marked as "unknown".
052 */
053public class WordpieceTokenizer implements Tokenizer {
054
055    private static final Pattern accentsPattern = Pattern.compile("\\p{Mn}");
056
057    @Config(mandatory=true, description="an instance of Wordpiece which applies the 'wordpiece' algorithm")
058    private Wordpiece wordpiece;
059    @Config(description="determines whether or not to lowercase the input text")
060    private boolean toLowerCase = true;
061    @Config(description="performs whitespace tokenization before 'basic' tokenizer is applied (see basicTokenizer)")
062    private Tokenizer whitespaceTokenizer = new WhitespaceTokenizer();
063    @Config(description="performs some tokenization work on the input text before the wordpiece algorithm is applied to each resulting token.")
064    private Tokenizer basicTokenizer = new WordpieceBasicTokenizer();
065    @Config(description="determines whether or not to strip accents/diacritics from the input text")
066    private boolean stripAccents = true;
067    @Config(description="a set of 'token' strings that should never be split regardless of whether they have e.g., punctuation in the middle.  No entries should have whitespace in them.")
068    private Set<String> neverSplitTokens = Collections.emptySet();
069
070    // internal state member variables
071    private boolean reset;
072
073    private Token currentToken;
074
075    private List<Token> currentWordpieceTokens = new ArrayList<>();
076
077    private int currentWordpieceIndex;
078
079    /**
080     * For OLCUT.
081     */
082    @SuppressWarnings("unused")
083    private WordpieceTokenizer() {
084    }
085
086    /**
087     * Constructs a wordpiece tokenizer.
088     * @param wordpiece        an instance of {@link Wordpiece}
089     * @param tokenizer        Wordpiece is run on the tokens generated by the
090     *                         tokenizer provided here.
091     * @param toLowerCase      determines whether or not to lowercase each token
092     *                         before running Wordpiece on it
093     * @param stripAccents     determines whether or not to strip out accents from
094     *                         each token before running Wordpiece on it
095     * @param neverSplit       a set of token values that we will not apply
096     *                         Wordpiece to. 
097     */
098    public WordpieceTokenizer(Wordpiece wordpiece, Tokenizer tokenizer, boolean toLowerCase, boolean stripAccents,
099            Set<String> neverSplit) {
100        this.wordpiece = wordpiece;
101        this.basicTokenizer = tokenizer;
102        this.toLowerCase = toLowerCase;
103        this.stripAccents = stripAccents;
104        this.neverSplitTokens = neverSplit;
105    }
106
107    @Override
108    public ConfiguredObjectProvenance getProvenance() {
109        return new ConfiguredObjectProvenanceImpl(this, "Tokenizer");
110    }
111
112    @Override
113    public void reset(CharSequence cs) {
114        this.reset = true;
115        this.whitespaceTokenizer.reset(cs);
116        this.currentWordpieceTokens.clear();
117        currentWordpieceIndex = -1;
118        if (this.whitespaceTokenizer.advance()) {
119            this.currentToken = this.whitespaceTokenizer.getToken();
120            getWordpieceTokens();
121        }
122    }
123
124    @Override
125    public boolean advance() {
126        if (!reset) {
127            throw new IllegalStateException("WordpieceTokenizer has not been reset.");
128        }
129        currentWordpieceIndex++;
130        if (currentWordpieceIndex < currentWordpieceTokens.size()) {
131            return true;
132        } else if (whitespaceTokenizer.advance()) {
133            currentToken = this.whitespaceTokenizer.getToken();
134            getWordpieceTokens();
135            currentWordpieceIndex = 0;
136            if (currentWordpieceTokens.size() == 0) {
137                return advance();
138            }
139            return true;
140        } else {
141            return false;
142        }
143    }
144
145    /**
146     * Normalizes the text by converting it into the canonical unicode decomposition
147     * and then replacing accents.
148     * @param text The input text to normalize.
149     * @return A normalized form of the text.
150     */
151    private static String normalize(String text) {
152        text = Normalizer.normalize(text, Normalizer.Form.NFD);
153        text = accentsPattern.matcher(text).replaceAll("");
154        return text;
155    }
156
157    /**
158     * Generates the wordpiece tokens from the next token.
159     */
160    private void getWordpieceTokens() {
161        this.currentWordpieceTokens.clear();
162
163        String text = currentToken.text;
164        if(neverSplitTokens.contains(text)) {
165            currentWordpieceTokens.add(currentToken);
166            return;
167        }
168        
169        List<Token> basicTokens = this.basicTokenizer.tokenize(text);
170        for(Token basicToken : basicTokens) {
171            
172            text = basicToken.text;
173            
174            if (toLowerCase) {
175                text = text.toLowerCase();
176            }
177    
178            if (this.stripAccents) {
179                text = normalize(text);
180            }
181    
182            List<String> wordpieces = wordpiece.wordpiece(text);
183    
184            if (wordpieces.size() == 0) {
185                return;
186            } else if (wordpieces.size() == 1) {
187                String wp = wordpieces.get(0);
188                int start = basicToken.start + currentToken.start;
189                int end = basicToken.end + currentToken.start;
190                if (wp.equals(this.wordpiece.getUnknownToken())) {
191                    currentWordpieceTokens.add(new Token(wp, start, end, TokenType.UNKNOWN));
192                } else {
193                    currentWordpieceTokens.add(new Token(wp, start, end, TokenType.WORD));
194                }
195            } else {
196                int begin = currentToken.start + basicToken.start;
197                for (String wp : wordpieces) {
198                    TokenType type = TokenType.PREFIX;
199                    int end = begin + wp.length();
200                    if (wp.startsWith("##")) {
201                        end -= 2;
202                        type = TokenType.SUFFIX;
203                    }
204                    currentWordpieceTokens.add(new Token(wp, begin, end, type));
205                    begin = end;
206                }
207            }
208        }
209    }
210
211    @Override
212    public Token getToken() {
213        if (currentWordpieceIndex < currentWordpieceTokens.size()) {
214            return currentWordpieceTokens.get(currentWordpieceIndex);
215        } else {
216            throw new IllegalStateException("WordpieceTokenizer is not ready.");
217        }
218    }
219
220    @Override
221    public String getText() {
222        return getToken().text;
223    }
224
225    @Override
226    public int getStart() {
227        return getToken().start;
228    }
229
230    @Override
231    public int getEnd() {
232        return getToken().end;
233    }
234
235    @Override
236    public TokenType getType() {
237        return getToken().type;
238    }
239
240    @Override
241    public WordpieceTokenizer clone() {
242      try {
243          WordpieceTokenizer copy = (WordpieceTokenizer) super.clone();
244          copy.whitespaceTokenizer = whitespaceTokenizer.clone();
245          copy.basicTokenizer = basicTokenizer.clone();
246          copy.reset = false;
247          copy.currentToken = null;
248          copy.currentWordpieceTokens.clear();
249          copy.currentWordpieceIndex = -1;
250          return copy;
251      } catch (CloneNotSupportedException e) {
252          throw new AssertionError("WordpieceTokenizer is Cloneable, but clone call failed");
253      }
254    }
255}