/*
 * Decompiled with CFR 0.152.
 */
package com.johnsnowlabs.nlp.embeddings;

import com.johnsnowlabs.nlp.Annotation;
import com.johnsnowlabs.nlp.AnnotatorModel;
import com.johnsnowlabs.nlp.AnnotatorType$;
import com.johnsnowlabs.nlp.HasProtectedParams;
import com.johnsnowlabs.nlp.HasSimpleAnnotate;
import com.johnsnowlabs.nlp.annotators.common.Sentence;
import com.johnsnowlabs.nlp.annotators.common.SentenceSplit$;
import com.johnsnowlabs.nlp.annotators.common.WordpieceEmbeddingsSentence;
import com.johnsnowlabs.nlp.annotators.common.WordpieceEmbeddingsSentence$;
import com.johnsnowlabs.nlp.embeddings.HasEmbeddingsProperties;
import com.johnsnowlabs.nlp.embeddings.SentenceEmbeddings$;
import com.johnsnowlabs.storage.Database;
import com.johnsnowlabs.storage.HasStorageRef;
import com.johnsnowlabs.storage.HasStorageRef$;
import com.johnsnowlabs.storage.RocksDBConnection;
import java.io.Serializable;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.ml.util.Identifiable;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import scala.Array$;
import scala.Function1;
import scala.MatchError;
import scala.Option;
import scala.Predef;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.collection.Map;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableOnce;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;
import scala.runtime.java8.JFunction1;

@ScalaSignature(bytes="\u0006\u0001\u0005Mf\u0001B\f\u0019\u0001\u0005B\u0001\u0002\u000e\u0001\u0003\u0006\u0004%\t%\u000e\u0005\t\u0007\u0002\u0011\t\u0011)A\u0005m!)A\t\u0001C\u0001\u000b\"9q\t\u0001b\u0001\n\u0003B\u0005BB(\u0001A\u0003%\u0011\nC\u0004Q\u0001\t\u0007I\u0011I)\t\rY\u0003\u0001\u0015!\u0003S\u0011\u001d9\u0006A1A\u0005BaCa!\u0019\u0001!\u0002\u0013I\u0006\"\u00022\u0001\t\u0003\u001a\u0007b\u00023\u0001\u0005\u0004%\t!\u001a\u0005\u0007i\u0002\u0001\u000b\u0011\u00024\t\u000bU\u0004A\u0011\u0001<\t\u000b\u0011\u0003A\u0011A=\t\u000bi\u0004A\u0011B>\t\u000f\u0005\u001d\u0001\u0001\"\u0011\u0002\n!9\u0011q\u0005\u0001\u0005R\u0005%\u0002bBA/\u0001\u0011E\u0013qL\u0004\b\u0003\u007fB\u0002\u0012AAA\r\u00199\u0002\u0004#\u0001\u0002\u0004\"1A\t\u0006C\u0001\u0003;C\u0011\"a(\u0015\u0003\u0003%I!!)\u0003%M+g\u000e^3oG\u0016,UNY3eI&twm\u001d\u0006\u00033i\t!\"Z7cK\u0012$\u0017N\\4t\u0015\tYB$A\u0002oYBT!!\b\u0010\u0002\u0019)|\u0007N\\:o_^d\u0017MY:\u000b\u0003}\t1aY8n\u0007\u0001\u0019R\u0001\u0001\u0012)W9\u00022a\t\u0013'\u001b\u0005Q\u0012BA\u0013\u001b\u00059\teN\\8uCR|'/T8eK2\u0004\"a\n\u0001\u000e\u0003a\u00012aI\u0015'\u0013\tQ#DA\tICN\u001c\u0016.\u001c9mK\u0006sgn\u001c;bi\u0016\u0004\"a\n\u0017\n\u00055B\"a\u0006%bg\u0016k'-\u001a3eS:<7\u000f\u0015:pa\u0016\u0014H/[3t!\ty#'D\u00011\u0015\t\tD$A\u0004ti>\u0014\u0018mZ3\n\u0005M\u0002$!\u0004%bgN#xN]1hKJ+g-A\u0002vS\u0012,\u0012A\u000e\t\u0003o\u0001s!\u0001\u000f \u0011\u0005ebT\"\u0001\u001e\u000b\u0005m\u0002\u0013A\u0002\u001fs_>$hHC\u0001>\u0003\u0015\u00198-\u00197b\u0013\tyD(\u0001\u0004Qe\u0016$WMZ\u0005\u0003\u0003\n\u0013aa\u0015;sS:<'BA =\u0003\u0011)\u0018\u000e\u001a\u0011\u0002\rqJg.\u001b;?)\t1c\tC\u00035\u0007\u0001\u0007a'A\npkR\u0004X\u000f^!o]>$\u0018\r^8s)f\u0004X-F\u0001J!\tQ5*D\u0001\u0001\u0013\taUJA\u0007B]:|G/\u0019;peRK\b/Z\u0005\u0003\u001dj\u0011a\u0003S1t\u001fV$\b/\u001e;B]:|G/\u0019;peRK\b/Z\u0001\u0015_V$\b/\u001e;B]:|G/\u0019;peRK\b/\u001a\u0011\u0002'%t\u0007/\u001e;B]:|G/\u0019;peRK\b/Z:\u0016\u0003I\u00032a\u0015+J\u001b\u0005a\u0014BA+=\u0005\u0015\t%O]1z\u0003QIg\u000e];u\u0003:tw\u000e^1u_J$\u0016\u0010]3tA\u0005IA-[7f]NLwN\\\u000b\u00023B\u0019!J\u00170\n\u0005mc&A\u0004)s_R,7\r^3e!\u0006\u0014\u0018-\\\u0005\u0003;j\u0011!\u0003S1t!J|G/Z2uK\u0012\u0004\u0016M]1ngB\u00111kX\u0005\u0003Ar\u00121!\u00138u\u0003)!\u0017.\\3og&|g\u000eI\u0001\rO\u0016$H)[7f]NLwN\\\u000b\u0002=\u0006y\u0001o\\8mS:<7\u000b\u001e:bi\u0016<\u00170F\u0001g!\r9'ON\u0007\u0002Q*\u0011\u0011N[\u0001\u0006a\u0006\u0014\u0018-\u001c\u0006\u0003W2\f!!\u001c7\u000b\u00055t\u0017!B:qCJ\\'BA8q\u0003\u0019\t\u0007/Y2iK*\t\u0011/A\u0002pe\u001eL!a\u001d5\u0003\u000bA\u000b'/Y7\u0002!A|w\u000e\\5oON#(/\u0019;fOf\u0004\u0013AE:fiB{w\u000e\\5oON#(/\u0019;fOf$\"AS<\t\u000bal\u0001\u0019\u0001\u001c\u0002\u0011M$(/\u0019;fOf$\u0012AJ\u0001\u001cG\u0006d7-\u001e7bi\u0016\u001cVM\u001c;f]\u000e,W)\u001c2fI\u0012LgnZ:\u0015\u0007q\f\t\u0001E\u0002T)v\u0004\"a\u0015@\n\u0005}d$!\u0002$m_\u0006$\bbBA\u0002\u001f\u0001\u0007\u0011QA\u0001\u0007[\u0006$(/\u001b=\u0011\u0007M#F0\u0001\u0005b]:|G/\u0019;f)\u0011\tY!a\t\u0011\r\u00055\u0011qCA\u000f\u001d\u0011\ty!a\u0005\u000f\u0007e\n\t\"C\u0001>\u0013\r\t)\u0002P\u0001\ba\u0006\u001c7.Y4f\u0013\u0011\tI\"a\u0007\u0003\u0007M+\u0017OC\u0002\u0002\u0016q\u00022aIA\u0010\u0013\r\t\tC\u0007\u0002\u000b\u0003:tw\u000e^1uS>t\u0007bBA\u0013!\u0001\u0007\u00111B\u0001\fC:tw\u000e^1uS>t7/\u0001\bcK\u001a|'/Z!o]>$\u0018\r^3\u0015\t\u0005-\u0012q\n\u0019\u0005\u0003[\ti\u0004\u0005\u0004\u00020\u0005U\u0012\u0011H\u0007\u0003\u0003cQ1!a\rm\u0003\r\u0019\u0018\u000f\\\u0005\u0005\u0003o\t\tDA\u0004ECR\f7/\u001a;\u0011\t\u0005m\u0012Q\b\u0007\u0001\t-\ty$EA\u0001\u0002\u0003\u0015\t!!\u0011\u0003\u0007}##'\u0005\u0003\u0002D\u0005%\u0003cA*\u0002F%\u0019\u0011q\t\u001f\u0003\u000f9{G\u000f[5oOB\u00191+a\u0013\n\u0007\u00055CHA\u0002B]fDq!!\u0015\u0012\u0001\u0004\t\u0019&A\u0004eCR\f7/\u001a;1\t\u0005U\u0013\u0011\f\t\u0007\u0003_\t)$a\u0016\u0011\t\u0005m\u0012\u0011\f\u0003\r\u00037\ny%!A\u0001\u0002\u000b\u0005\u0011\u0011\t\u0002\u0004?\u0012\n\u0014!D1gi\u0016\u0014\u0018I\u001c8pi\u0006$X\r\u0006\u0003\u0002b\u0005u\u0004\u0003BA2\u0003orA!!\u001a\u0002v9!\u0011qMA:\u001d\u0011\tI'!\u001d\u000f\t\u0005-\u0014q\u000e\b\u0004s\u00055\u0014\"A9\n\u0005=\u0004\u0018BA7o\u0013\r\t\u0019\u0004\\\u0005\u0005\u0003+\t\t$\u0003\u0003\u0002z\u0005m$!\u0003#bi\u00064%/Y7f\u0015\u0011\t)\"!\r\t\u000f\u0005E#\u00031\u0001\u0002b\u0005\u00112+\u001a8uK:\u001cW-R7cK\u0012$\u0017N\\4t!\t9CcE\u0004\u0015\u0003\u000b\u000bY)a&\u0011\u0007M\u000b9)C\u0002\u0002\nr\u0012a!\u00118z%\u00164\u0007#BAG\u0003'3SBAAH\u0015\r\t\tJ[\u0001\u0005kRLG.\u0003\u0003\u0002\u0016\u0006=%!\u0006#fM\u0006,H\u000e\u001e)be\u0006l7OU3bI\u0006\u0014G.\u001a\t\u0004'\u0006e\u0015bAANy\ta1+\u001a:jC2L'0\u00192mKR\u0011\u0011\u0011Q\u0001\fe\u0016\fGMU3t_24X\r\u0006\u0002\u0002$B!\u0011QUAX\u001b\t\t9K\u0003\u0003\u0002*\u0006-\u0016\u0001\u00027b]\u001eT!!!,\u0002\t)\fg/Y\u0005\u0005\u0003c\u000b9K\u0001\u0004PE*,7\r\u001e")
public class SentenceEmbeddings
extends AnnotatorModel<SentenceEmbeddings>
implements HasSimpleAnnotate<SentenceEmbeddings>,
HasEmbeddingsProperties,
HasStorageRef {
    private final String uid;
    private final String outputAnnotatorType;
    private final String[] inputAnnotatorTypes;
    private final HasProtectedParams.ProtectedParam<Object> dimension;
    private final Param<String> poolingStrategy;
    private final Param<String> storageRef;

    public static MLReader<SentenceEmbeddings> read() {
        return SentenceEmbeddings$.MODULE$.read();
    }

    public static Object load(String string) {
        return SentenceEmbeddings$.MODULE$.load(string);
    }

    @Override
    public RocksDBConnection createDatabaseConnection(Database database) {
        return HasStorageRef.createDatabaseConnection$(this, database);
    }

    @Override
    public HasStorageRef setStorageRef(String value) {
        return HasStorageRef.setStorageRef$(this, value);
    }

    @Override
    public String getStorageRef() {
        return HasStorageRef.getStorageRef$(this);
    }

    @Override
    public void validateStorageRef(Dataset<?> dataset, String[] inputCols, String annotatorType) {
        HasStorageRef.validateStorageRef$(this, dataset, inputCols, annotatorType);
    }

    @Override
    public HasEmbeddingsProperties setDimension(int value) {
        return HasEmbeddingsProperties.setDimension$(this, value);
    }

    @Override
    public Column wrapEmbeddingsMetadata(Column col, int embeddingsDim, Option<String> embeddingsRef) {
        return HasEmbeddingsProperties.wrapEmbeddingsMetadata$(this, col, embeddingsDim, embeddingsRef);
    }

    @Override
    public Option<String> wrapEmbeddingsMetadata$default$3() {
        return HasEmbeddingsProperties.wrapEmbeddingsMetadata$default$3$(this);
    }

    @Override
    public Column wrapSentenceEmbeddingsMetadata(Column col, int embeddingsDim, Option<String> embeddingsRef) {
        return HasEmbeddingsProperties.wrapSentenceEmbeddingsMetadata$(this, col, embeddingsDim, embeddingsRef);
    }

    @Override
    public Option<String> wrapSentenceEmbeddingsMetadata$default$3() {
        return HasEmbeddingsProperties.wrapSentenceEmbeddingsMetadata$default$3$(this);
    }

    @Override
    public <T> HasProtectedParams.ProtectedParam<T> ProtectedParam(Param<T> baseParam) {
        return HasProtectedParams.ProtectedParam$(this, baseParam);
    }

    @Override
    public <T> HasProtectedParams set(HasProtectedParams.ProtectedParam<T> param, T value) {
        return HasProtectedParams.set$(this, param, value);
    }

    @Override
    public UserDefinedFunction dfAnnotate() {
        return HasSimpleAnnotate.dfAnnotate$(this);
    }

    @Override
    public Param<String> storageRef() {
        return this.storageRef;
    }

    @Override
    public void com$johnsnowlabs$storage$HasStorageRef$_setter_$storageRef_$eq(Param<String> x$1) {
        this.storageRef = x$1;
    }

    @Override
    public void com$johnsnowlabs$nlp$embeddings$HasEmbeddingsProperties$_setter_$dimension_$eq(HasProtectedParams.ProtectedParam<Object> x$1) {
    }

    public String uid() {
        return this.uid;
    }

    @Override
    public String outputAnnotatorType() {
        return this.outputAnnotatorType;
    }

    @Override
    public String[] inputAnnotatorTypes() {
        return this.inputAnnotatorTypes;
    }

    @Override
    public HasProtectedParams.ProtectedParam<Object> dimension() {
        return this.dimension;
    }

    @Override
    public int getDimension() {
        return BoxesRunTime.unboxToInt((Object)this.$(this.dimension()));
    }

    public Param<String> poolingStrategy() {
        return this.poolingStrategy;
    }

    public SentenceEmbeddings setPoolingStrategy(String strategy) {
        SentenceEmbeddings sentenceEmbeddings;
        String string = strategy.toLowerCase();
        if ("average".equals(string)) {
            sentenceEmbeddings = (SentenceEmbeddings)this.set(this.poolingStrategy(), "AVERAGE");
        } else if ("sum".equals(string)) {
            sentenceEmbeddings = (SentenceEmbeddings)this.set(this.poolingStrategy(), "SUM");
        } else {
            throw new MatchError((Object)"poolingStrategy must be either AVERAGE or SUM");
        }
        return sentenceEmbeddings;
    }

    /*
     * WARNING - void declaration
     */
    private float[] calculateSentenceEmbeddings(float[][] matrix) {
        void var2_2;
        float[] res = (float[])Array$.MODULE$.ofDim(matrix[0].length, ClassTag$.MODULE$.Float());
        this.setDimension(matrix[0].length);
        new ArrayOps.ofFloat(Predef$.MODULE$.floatArrayOps(matrix[0])).indices().foreach$mVc$sp((Function1)(JFunction1.mcVI.sp & Serializable & scala.Serializable)j -> {
            block0: {
                new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])matrix)).indices().foreach$mVc$sp((Function1)(JFunction1.mcVI.sp & Serializable & scala.Serializable)i -> {
                    res$1[j$1] = res[j] + matrix[i][j];
                });
                Object object = this.$(this.poolingStrategy());
                String string = "AVERAGE";
                if (object != null ? !object.equals(string) : string != null) break block0;
                res$1[j] = res[j] / (float)matrix.length;
            }
        });
        return var2_2;
    }

    @Override
    public Seq<Annotation> annotate(Seq<Annotation> annotations) {
        Seq<Sentence> sentences = SentenceSplit$.MODULE$.unpack(annotations);
        Seq<WordpieceEmbeddingsSentence> embeddingsSentences = WordpieceEmbeddingsSentence$.MODULE$.unpack(annotations);
        return (Seq)sentences.map((Function1 & Serializable & scala.Serializable)sentence -> {
            Seq embeddings2 = (Seq)embeddingsSentences.filter((Function1 & Serializable & scala.Serializable)embeddings -> BoxesRunTime.boxToBoolean((boolean)SentenceEmbeddings.$anonfun$annotate$2(sentence, embeddings)));
            float[] sentenceEmbeddings = (float[])((TraversableOnce)embeddings2.flatMap((Function1 & Serializable & scala.Serializable)tokenEmbedding -> new ArrayOps.ofFloat(SentenceEmbeddings.$anonfun$annotate$3(this, tokenEmbedding)), Seq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.Float());
            return new Annotation(this.outputAnnotatorType(), sentence.start(), sentence.end(), sentence.content(), (Map<String, String>)((Map)Predef$.MODULE$.Map().apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"sentence"), (Object)Integer.toString(sentence.index())), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"token"), (Object)sentence.content()), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"pieceId"), (Object)"-1"), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"isWordStart"), (Object)"true")}))), sentenceEmbeddings);
        }, Seq$.MODULE$.canBuildFrom());
    }

    @Override
    public Dataset<?> beforeAnnotate(Dataset<?> dataset) {
        String ref = HasStorageRef$.MODULE$.getStorageRefFromInput(dataset, (String[])this.$((Param)this.inputCols()), AnnotatorType$.MODULE$.WORD_EMBEDDINGS());
        Object object = this.get(this.storageRef()).isEmpty() ? this.setStorageRef(ref) : BoxedUnit.UNIT;
        return dataset;
    }

    @Override
    public Dataset<Row> afterAnnotate(Dataset<Row> dataset) {
        return dataset.withColumn(this.getOutputCol(), this.wrapSentenceEmbeddingsMetadata(dataset.col(this.getOutputCol()), BoxesRunTime.unboxToInt((Object)this.$(this.dimension())), (Option<String>)new Some(this.$(this.storageRef()))));
    }

    public static final /* synthetic */ boolean $anonfun$annotate$2(Sentence sentence$1, WordpieceEmbeddingsSentence embeddings) {
        return embeddings.sentenceId() == sentence$1.index();
    }

    public static final /* synthetic */ float[] $anonfun$annotate$3(SentenceEmbeddings $this, WordpieceEmbeddingsSentence tokenEmbedding) {
        float[][] allEmbeddings = (float[][])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])tokenEmbedding.tokens())).map((Function1 & Serializable & scala.Serializable)token -> token.embeddings(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE))));
        return Predef$.MODULE$.floatArrayOps($this.calculateSentenceEmbeddings(allEmbeddings));
    }

    public SentenceEmbeddings(String uid) {
        this.uid = uid;
        HasSimpleAnnotate.$init$(this);
        HasProtectedParams.$init$(this);
        HasEmbeddingsProperties.$init$(this);
        HasStorageRef.$init$(this);
        this.outputAnnotatorType = AnnotatorType$.MODULE$.SENTENCE_EMBEDDINGS();
        this.inputAnnotatorTypes = (String[])((Object[])new String[]{AnnotatorType$.MODULE$.DOCUMENT(), AnnotatorType$.MODULE$.WORD_EMBEDDINGS()});
        this.dimension = this.ProtectedParam((Param)new IntParam((Identifiable)this, "dimension", "Number of embedding dimensions"));
        this.poolingStrategy = new Param((Identifiable)this, "poolingStrategy", "Choose how you would like to aggregate Word Embeddings to Sentence Embeddings: AVERAGE or SUM");
        this.setDefault((Seq)Predef$.MODULE$.wrapRefArray((Object[])new ParamPair[]{this.inputCols().$minus$greater((Object)new String[]{AnnotatorType$.MODULE$.DOCUMENT(), AnnotatorType$.MODULE$.WORD_EMBEDDINGS()}), this.outputCol().$minus$greater((Object)"sentence_embeddings"), this.poolingStrategy().$minus$greater((Object)"AVERAGE"), this.dimension().$minus$greater(BoxesRunTime.boxToInteger((int)100))}));
    }

    public SentenceEmbeddings() {
        this(Identifiable$.MODULE$.randomUID("SENTENCE_EMBEDDINGS"));
    }
}

