001/*
002 * Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016package org.tribuo.util.tokens.impl.wordpiece;
017
018import org.tribuo.util.tokens.impl.SplitFunctionTokenizer;
019
020import com.oracle.labs.mlrg.olcut.config.Config;
021import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
022import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
023
024/**
025 * This is a tokenizer that is used "upstream" of {@link WordpieceTokenizer} and
026 * implements much of the functionality of the '<a href=
027 * "https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/tokenization_bert.py#L355">BasicTokenizer</a>'
028 * implementation in huggingface. One minor difference in this implementation is
029 * that there is no set of "never_split" tokens used here. Those are handled by
030 * {@link WordpieceTokenizer}.
031 */
032public class WordpieceBasicTokenizer extends SplitFunctionTokenizer {
033
034    /**
035     * Creates a {@link SplitFunction} that is used by the super class
036     * {@link SplitFunctionTokenizer} to determine how and where the tokenizer
037     * splits the input.
038     * 
039     * @param tokenizeChineseChars split Chinese characters into separate tokens?
040     * @return The splitting function.
041     */
042    public static SplitFunction createSplitFunction(boolean tokenizeChineseChars) {
043
044        return (codepoint, index, cs) -> {
045            if (Character.isWhitespace(codepoint)) {
046                return SplitResult.SPLIT_AT;
047            }
048            if (codepoint == 160) { // \u00a0 (NO-BREAK SPACE)
049                return SplitResult.SPLIT_AT;
050            }
051            if (isPunctuation(codepoint)) {
052                return SplitResult.SPLIT_BEFORE_AND_AFTER_PUNCTUATION;
053            }
054            if (tokenizeChineseChars && isChinese(codepoint)) {
055                return SplitResult.SPLIT_BEFORE_AND_AFTER_WORD;
056            }
057            if (codepoint == 0 || codepoint == 0xFFFD || isControl(codepoint)) {
058                return SplitResult.SPLIT_AT;
059            }
060
061            return SplitResult.NO_SPLIT_WORD;
062        };
063
064    }
065
066    /**
067     * Determines if the input code point should be considered a character that is punctuation.
068     * This will return true for all ascii characters that are not letters or digits and for any
069     * character whose Character type is defined as punctuation.  See {@link Character#getType(int)}.
070     * @param codepoint The codepoint to check.
071     * @return True if the codepoint is punctuation, false otherwise.
072     */
073    public static boolean isPunctuation(int codepoint) {
074        if (codepoint >= 33 && codepoint <= 47) {
075            return true;
076        }
077        if (codepoint >= 58 && codepoint <= 64) {
078            return true;
079        }
080        if (codepoint >= 91 && codepoint <= 96) {
081            return true;
082        }
083        if (codepoint >= 123 && codepoint <= 126) {
084            return true;
085        }
086
087        int charType = Character.getType(codepoint);
088        if (charType == Character.DASH_PUNCTUATION || charType == Character.START_PUNCTUATION
089                || charType == Character.END_PUNCTUATION || charType == Character.CONNECTOR_PUNCTUATION
090                || charType == Character.OTHER_PUNCTUATION || charType == Character.INITIAL_QUOTE_PUNCTUATION
091                || charType == Character.FINAL_QUOTE_PUNCTUATION) {
092            return true;
093        }
094
095        return false;
096    }
097
098    /**
099     * Determines if the provided codepoint is a Chinese character or not.
100     * @param codepoint a codepoint
101     * @return True if the codepoint is a Chinese character, false otherwise.
102     */
103    public static boolean isChinese(int codepoint) {
104        if ((codepoint >= 0x4E00 && codepoint <= 0x9FFF) || (codepoint >= 0x3400 && codepoint <= 0x4DBF)
105                || (codepoint >= 0x20000 && codepoint <= 0x2A6DF) || (codepoint >= 0x2A700 && codepoint <= 0x2B73F)
106                || (codepoint >= 0x2B740 && codepoint <= 0x2B81F) || (codepoint >= 0x2B820 && codepoint <= 0x2CEAF)
107                || (codepoint >= 0xF900 && codepoint <= 0xFAFF) || (codepoint >= 0x2F800 && codepoint <= 0x2FA1F)) {
108            return true;
109        }
110        return false;
111    }
112
113    /**
114     * Determines if the provided codepoint is a control character or not.
115     * @param codepoint The codepoint to check.
116     * @return True if it's a control character, false otherwise.
117     */
118    public static boolean isControl(int codepoint) {
119        char c = Character.toChars(codepoint)[0];
120        if (c == '\t' || c == '\n' || c == '\r') {
121            return false;
122        }
123        int charType = Character.getType(codepoint);
124        if (charType == Character.CONTROL || charType == Character.FORMAT || charType == Character.PRIVATE_USE
125                || charType == Character.SURROGATE) {
126            return true;
127        }
128        return false;
129    }
130
131    @Config(description = "split on Chinese tokens?")
132    private boolean tokenizeChineseChars = true;
133
134    /**
135     * Constructs a default tokenizer which tokenizes Chinese characters.
136     */
137    public WordpieceBasicTokenizer() {
138        this.postConfig();
139    }
140
141    /**
142     * Constructs a tokenizer.
143     * @param tokenizeChineseChars Should the Chinese characters be split into individual tokens.
144     */
145    public WordpieceBasicTokenizer(boolean tokenizeChineseChars) {
146        this.tokenizeChineseChars = tokenizeChineseChars;
147        this.postConfig();
148    }
149
150    /**
151     * Used by the OLCUT configuration system, and should not be called by external code.
152     */
153    @Override
154    public void postConfig() {
155        this.splitFunction = createSplitFunction(this.tokenizeChineseChars);
156    }
157
158    @Override
159    public ConfiguredObjectProvenance getProvenance() {
160        return new ConfiguredObjectProvenanceImpl(this, "Tokenizer");
161    }
162
163    @Override
164    public WordpieceBasicTokenizer clone() {
165        return new WordpieceBasicTokenizer(this.tokenizeChineseChars);
166    }
167}