001/*
002 * Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.util.tokens.impl.wordpiece;
018
019import java.io.IOException;
020import java.io.UncheckedIOException;
021import java.util.ArrayList;
022import java.util.Collections;
023import java.util.HashSet;
024import java.util.List;
025import java.util.Set;
026
027import com.oracle.labs.mlrg.olcut.config.Config;
028import com.oracle.labs.mlrg.olcut.config.Configurable;
029import com.oracle.labs.mlrg.olcut.util.IOUtil;
030
031/**
032 * This is vanilla implementation of the Wordpiece algorithm as found here:
033 * 
034 * <a href=
035 * "https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/tokenization_bert.py">
036 * https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/tokenization_bert.py</a>
037 * 
038 * <p>
039 * Please refer to the class definition for <code>WordpieceTokenizer</code>. It
040 * does not include any of the tokenization work that is typically performed
041 * before wordpiece is called as is done in the above-referenced implementation.
042 * That functionality is provided by {@link WordpieceTokenizer} and
043 * {@link WordpieceBasicTokenizer}.
044 * 
045 */
046public class Wordpiece implements Configurable {
047
048    /**
049     * The default unknown token string.
050     */
051    public static final String DEFAULT_UNKNOWN_TOKEN = "[UNK]";
052
053    @Config(mandatory=true, description="path to a vocabulary data file.")
054    private String vocabPath;
055    @Config(mandatory=false, description="the value to use for 'UNKNOWN' tokens. Defaults to '[UNK]' which is a common default in BERT-based solutions.")
056    private String unknownToken = DEFAULT_UNKNOWN_TOKEN;
057    @Config(mandatory=false, description="the maximum number of characters per word to consider. This helps eliminate doing extra work on pathological cases.")
058    private int maxInputCharactersPerWord = 100;
059
060    private Set<String> vocab;
061
062    /**
063     * For OLCUT.
064     */
065    @SuppressWarnings("unused")
066    private Wordpiece() { }
067
068    /**
069     * Constructs a Wordpiece using the supplied vocab.
070     * <p>
071     * Sets the unknown token to {@link #DEFAULT_UNKNOWN_TOKEN}.
072     * @param vocab The wordpiece vocabulary.
073     */
074    public Wordpiece(Set<String> vocab) {
075        this(vocab, DEFAULT_UNKNOWN_TOKEN);
076    }
077
078    /**
079     * Constructs a Wordpiece using the supplied vocabulary and unknown token.
080     * @param vocab The wordpiece vocabulary.
081     * @param unknownToken The unknown token.
082     */
083    public Wordpiece(Set<String> vocab, String unknownToken) {
084        this(vocab, unknownToken, 100);
085    }
086
087    /**
088     * Initializes an instance of Wordpiece with the given vocabulary, unknown
089     * token, and max word length.
090     * 
091     * @param vocab                     the pre-trained wordpiece vocabulary. See
092     *                                  the contents of e.g.,
093     *                                  <a href="https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt">
094     *                                  https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt</a>
095     * @param unknownToken              a string used to indicate a token was not
096     *                                  found in the vocabulary - typically "[UNK]"
097     * @param maxInputCharactersPerWord a maximum to shield against looping over
098     *                                  character-by-character pathologically long
099     *                                  "tokens"
100     */
101    public Wordpiece(Set<String> vocab, String unknownToken, int maxInputCharactersPerWord) {
102        this.vocab = Collections.unmodifiableSet(vocab);
103        this.unknownToken = unknownToken;
104        this.maxInputCharactersPerWord = maxInputCharactersPerWord;
105    }
106
107    /**
108     * Constructs a wordpiece by reading the vocabulary from the supplied path.
109     * @param vocabPath The path to the wordpiece vocabulary.
110     */
111    public Wordpiece(String vocabPath) {
112        this.vocabPath = vocabPath;
113        try {
114            this.postConfig();
115        } catch (IOException e) {
116            throw new UncheckedIOException(e);
117        }
118    }
119
120    /**
121     * Initializes an instance of Wordpiece with the given vocabulary, unknown
122     * token, and max word length.
123     * 
124     * @param vocabPath                 Path to the pre-trained wordpiece vocabulary. See
125     *                                  the contents of e.g.
126     *                                  https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt
127     * @param unknownToken              a string used to indicate a token was not
128     *                                  found in the vocabulary - typically "[UNK]"
129     * @param maxInputCharactersPerWord a maximum to shield against looping over
130     *                                  character-by-character pathologically long
131     *                                  "tokens"
132     */
133    public Wordpiece(String vocabPath, String unknownToken, int maxInputCharactersPerWord) {
134        this.vocabPath = vocabPath;
135        this.unknownToken = unknownToken;
136        this.maxInputCharactersPerWord = maxInputCharactersPerWord;
137        try {
138            this.postConfig();
139        } catch (IOException e) {
140            throw new UncheckedIOException(e);
141        }
142    }
143
144    /**
145     * Used by the OLCUT configuration system, and should not be called by external code.
146     */
147    @Override
148    public void postConfig() throws IOException {
149        this.vocab = Collections.unmodifiableSet(new HashSet<>(IOUtil.getLines(this.vocabPath)));
150    }
151
152    /**
153     * Executes Wordpiece tokenization on the provided token. Note that tokens
154     * corresponding to word suffixes as indicated in the provided vocabulary with
155     * the sequence "##" prepended to the entry may be returned by this method. This
156     * method does not perform whitespace tokenization or any other preprocessing.
157     * This method does not lowercase the input token or otherwise modify it in any
158     * way.
159     * 
160     * @param token the token to apply Wordpiece tokenization to.
161     * @return tokens corresponding to Wordpiece tokenization applied to the input
162     *         text. Some tokens may have a prefix "##" as described above. Some
163     *         tokens may correspond to an unknown token as specified during
164     *         initialization (default "[UNK]")
165     */
166    public List<String> wordpiece(String token) {
167        if (token.length() > this.maxInputCharactersPerWord) {
168            return Collections.singletonList(this.unknownToken);
169        }
170
171        List<String> subTokens = new ArrayList<>();
172
173        boolean isBad = false;
174        int start = 0;
175        while (start < token.length()) {
176            int end = token.length();
177            String currentSubstring = null;
178            while (start < end) {
179                String substring = token.substring(start, end);
180                if (start > 0) {
181                    substring = "##" + substring;
182                }
183                if (this.vocab.contains(substring)) {
184                    currentSubstring = substring;
185                    break;
186                }
187                end--;
188            }
189            if (currentSubstring == null) {
190                isBad = true;
191                break;
192            }
193            subTokens.add(currentSubstring);
194            start = end;
195        }
196        if (isBad) {
197            return Collections.singletonList(this.unknownToken);
198        } else {
199            return subTokens;
200        }
201    }
202
203    /**
204     * a getter for the "unknown" token specified during initialization.
205     * 
206     * @return the "unknown" token name - defaults to "[UNK]"
207     */
208    public String getUnknownToken() {
209        return unknownToken;
210    }
211
212    /**
213     * a getter for the maximum character count for a token to consider when
214     * {@link #wordpiece(String)} is applied to a token. This value is set at
215     * initialization and defaults to 100. Token values passed to that method that
216     * are not tokenized and the result of {@link #getUnknownToken()} is returned
217     * instead.
218     * 
219     * @return the maximum length of a token that will be analyzed by
220     *         {@link #wordpiece(String)}.
221     */
222    public int getMaxInputCharactersPerWord() {
223        return maxInputCharactersPerWord;
224    }
225
226}