/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.nlp.embedding;

import ai.djl.modality.nlp.SimpleVocabulary;
import ai.djl.modality.nlp.embedding.WordEmbedding;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.core.Embedding;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Optional;

public class TrainableWordEmbedding
extends Embedding<String>
implements WordEmbedding {
    private static final String DEFAULT_UNKNOWN_TOKEN = "<unk>";

    public TrainableWordEmbedding(Builder builder) {
        super(builder);
    }

    public TrainableWordEmbedding(SimpleVocabulary simpleVocabulary, int embeddingSize) {
        super(((Builder)((Builder)((Builder)((Builder)TrainableWordEmbedding.builder().setEmbeddingSize(embeddingSize)).setItems(simpleVocabulary.getAllTokens())).optSparseGrad(false)).optDefaultItem(simpleVocabulary.getUnknownToken())).optUseDefault(false));
    }

    public TrainableWordEmbedding(NDArray embedding, List<String> items) {
        super(embedding, items);
        this.fallthroughEmbedding = new Embedding.DefaultItem(DEFAULT_UNKNOWN_TOKEN);
    }

    public TrainableWordEmbedding(NDArray embedding, List<String> items, boolean sparseGrad) {
        super(embedding, items, sparseGrad);
        this.fallthroughEmbedding = new Embedding.DefaultItem(DEFAULT_UNKNOWN_TOKEN);
    }

    @Override
    public boolean vocabularyContains(String word) {
        return this.embedder.containsKey(word);
    }

    @Override
    public int preprocessWordToEmbed(String word) {
        return this.embed(word);
    }

    @Override
    public NDArray embedWord(NDManager manager, int index) {
        throw new UnsupportedOperationException("This operation is not supported by this class.");
    }

    @Override
    public String unembedWord(NDArray word) {
        if (!word.isScalar()) {
            throw new IllegalArgumentException("NDArray word must be scalar index");
        }
        int wordIndex = word.toIntArray()[0];
        Optional result = this.unembed(wordIndex);
        if (result.isPresent()) {
            return (String)result.get();
        }
        result = this.fallthroughEmbedding.unembed(wordIndex);
        if (result.isPresent()) {
            return (String)result.get();
        }
        throw new IllegalArgumentException("Failed to unembed word");
    }

    @Override
    public byte[] encode(String input) {
        byte[] encodedInput = input.getBytes(StandardCharsets.UTF_8);
        return encodedInput;
    }

    @Override
    public String decode(byte[] byteArray) {
        return new String(byteArray, StandardCharsets.UTF_8);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder
    extends Embedding.BaseBuilder<String, Builder> {
        Builder() {
            this.embeddingType = String.class;
            this.defaultItem = TrainableWordEmbedding.DEFAULT_UNKNOWN_TOKEN;
        }

        @Override
        protected Builder setType(Class<String> embeddingType) {
            return this.self();
        }

        @Override
        protected Builder self() {
            return this;
        }

        public Builder optUnknownToken(String unknownToken) {
            return (Builder)this.optDefaultItem(unknownToken);
        }

        public TrainableWordEmbedding build() {
            return new TrainableWordEmbedding(this);
        }
    }
}

