/*
 * Decompiled with CFR 0.152.
 */
package com.johnsnowlabs.ml.ai;

import com.johnsnowlabs.ml.ai.OpenAICompletion;
import com.johnsnowlabs.ml.ai.model.EmbeddingData;
import com.johnsnowlabs.ml.ai.model.TextEmbeddingResponse;
import com.johnsnowlabs.nlp.Annotation;
import com.johnsnowlabs.nlp.AnnotatorModel;
import com.johnsnowlabs.nlp.AnnotatorType$;
import com.johnsnowlabs.nlp.HasSimpleAnnotate;
import com.johnsnowlabs.util.ConfigHelper$;
import com.johnsnowlabs.util.ConfigLoader$;
import com.johnsnowlabs.util.JsonBuilder$;
import com.johnsnowlabs.util.JsonParser$;
import java.io.Serializable;
import org.apache.http.HttpEntity;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.entity.ContentType;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.util.Identifiable;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.collection.Map;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.immutable.List;
import scala.collection.immutable.List$;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.StringOps;
import scala.reflect.ClassTag$;
import scala.reflect.ManifestFactory$;
import scala.reflect.ScalaSignature;

@ScalaSignature(bytes="\u0006\u0001\u0005}d\u0001B\f\u0019\u0001\u0005B\u0001b\f\u0001\u0003\u0006\u0004%\t\u0005\r\u0005\t}\u0001\u0011\t\u0011)A\u0005c!)q\b\u0001C\u0001\u0001\")q\b\u0001C\u0001\u0007\"9A\t\u0001b\u0001\n\u0003*\u0005B\u0002)\u0001A\u0003%a\tC\u0004R\u0001\t\u0007I\u0011\t*\t\rM\u0003\u0001\u0015!\u0003K\u0011\u001d!\u0006A1A\u0005\u0002UCaa\u0019\u0001!\u0002\u00131\u0006\"\u00023\u0001\t\u0003)\u0007b\u00025\u0001\u0005\u0004%\t!\u0016\u0005\u0007S\u0002\u0001\u000b\u0011\u0002,\t\u000b)\u0004A\u0011A6\t\u000f5\u0004\u0001\u0019!C\u0005]\"9\u0001\u0010\u0001a\u0001\n\u0013I\bBB@\u0001A\u0003&q\u000eC\u0004\u0002\u0002\u0001!\t!a\u0001\t\r\u0005e\u0001\u0001\"\u00011\u0011\u001d\tY\u0002\u0001C!\u0003;Aq!a\u0013\u0001\t\u0003\ni\u0005C\u0004\u0002l\u0001!I!!\u001c\u0003!=\u0003XM\\!J\u000b6\u0014W\r\u001a3j]\u001e\u001c(BA\r\u001b\u0003\t\t\u0017N\u0003\u0002\u001c9\u0005\u0011Q\u000e\u001c\u0006\u0003;y\tAB[8i]Ntwn\u001e7bENT\u0011aH\u0001\u0004G>l7\u0001A\n\u0004\u0001\tb\u0003cA\u0012'Q5\tAE\u0003\u0002&9\u0005\u0019a\u000e\u001c9\n\u0005\u001d\"#AD!o]>$\u0018\r^8s\u001b>$W\r\u001c\t\u0003S)j\u0011\u0001G\u0005\u0003Wa\u0011\u0001c\u00149f]\u0006K5i\\7qY\u0016$\u0018n\u001c8\u0011\u0007\rj\u0003&\u0003\u0002/I\t\t\u0002*Y:TS6\u0004H.Z!o]>$\u0018\r^3\u0002\u0007ULG-F\u00012!\t\u00114H\u0004\u00024sA\u0011AgN\u0007\u0002k)\u0011a\u0007I\u0001\u0007yI|w\u000e\u001e \u000b\u0003a\nQa]2bY\u0006L!AO\u001c\u0002\rA\u0013X\rZ3g\u0013\taTH\u0001\u0004TiJLgn\u001a\u0006\u0003u]\nA!^5eA\u00051A(\u001b8jiz\"\"!\u0011\"\u0011\u0005%\u0002\u0001\"B\u0018\u0004\u0001\u0004\tD#A!\u0002'%t\u0007/\u001e;B]:|G/\u0019;peRK\b/Z:\u0016\u0003\u0019\u00032a\u0012%K\u001b\u00059\u0014BA%8\u0005\u0015\t%O]1z!\tYE*D\u0001\u0001\u0013\tieJA\u0007B]:|G/\u0019;peRK\b/Z\u0005\u0003\u001f\u0012\u0012a\u0003S1t\u001fV$\b/\u001e;B]:|G/\u0019;peRK\b/Z\u0001\u0015S:\u0004X\u000f^!o]>$\u0018\r^8s)f\u0004Xm\u001d\u0011\u0002'=,H\u000f];u\u0003:tw\u000e^1u_J$\u0016\u0010]3\u0016\u0003)\u000bAc\\;uaV$\u0018I\u001c8pi\u0006$xN\u001d+za\u0016\u0004\u0013!B7pI\u0016dW#\u0001,\u0011\u0007]\u000b\u0017'D\u0001Y\u0015\tI&,A\u0003qCJ\fWN\u0003\u0002\u001c7*\u0011A,X\u0001\u0006gB\f'o\u001b\u0006\u0003=~\u000ba!\u00199bG\",'\"\u00011\u0002\u0007=\u0014x-\u0003\u0002c1\n)\u0001+\u0019:b[\u00061Qn\u001c3fY\u0002\n\u0001b]3u\u001b>$W\r\u001c\u000b\u0003\u0017\u001aDQaZ\u0006A\u0002E\nQA^1mk\u0016\fA!^:fe\u0006)Qo]3sA\u000591/\u001a;Vg\u0016\u0014HCA&m\u0011\u00159g\u00021\u00012\u0003-\u0011W-\u0019:feR{7.\u001a8\u0016\u0003=\u00042a\u00129s\u0013\t\txG\u0001\u0004PaRLwN\u001c\t\u0004gZ\fT\"\u0001;\u000b\u0005U\\\u0016!\u00032s_\u0006$7-Y:u\u0013\t9HOA\u0005Ce>\fGmY1ti\u0006y!-Z1sKJ$vn[3o?\u0012*\u0017\u000f\u0006\u0002{{B\u0011qi_\u0005\u0003y^\u0012A!\u00168ji\"9a\u0010EA\u0001\u0002\u0004y\u0017a\u0001=%c\u0005a!-Z1sKJ$vn[3oA\u000512/\u001a;CK\u0006\u0014XM\u001d+pW\u0016t\u0017J\u001a(piN+G\u000fF\u0003L\u0003\u000b\t\u0019\u0002\u0003\u0004]%\u0001\u0007\u0011q\u0001\t\u0005\u0003\u0013\ty!\u0004\u0002\u0002\f)\u0019\u0011QB.\u0002\u0007M\fH.\u0003\u0003\u0002\u0012\u0005-!\u0001D*qCJ\\7+Z:tS>t\u0007bBA\u000b%\u0001\u0007\u0011qC\u0001\n_B,g.Q%LKf\u00042a\u001292\u000399W\r\u001e\"fCJ,'\u000fV8lK:\faBY3g_J,\u0017I\u001c8pi\u0006$X\r\u0006\u0003\u0002 \u0005u\u0002\u0007BA\u0011\u0003W\u0001b!!\u0003\u0002$\u0005\u001d\u0012\u0002BA\u0013\u0003\u0017\u0011q\u0001R1uCN,G\u000f\u0005\u0003\u0002*\u0005-B\u0002\u0001\u0003\f\u0003[!\u0012\u0011!A\u0001\u0006\u0003\tyCA\u0002`II\nB!!\r\u00028A\u0019q)a\r\n\u0007\u0005UrGA\u0004O_RD\u0017N\\4\u0011\u0007\u001d\u000bI$C\u0002\u0002<]\u00121!\u00118z\u0011\u001d\ty\u0004\u0006a\u0001\u0003\u0003\nq\u0001Z1uCN,G\u000f\r\u0003\u0002D\u0005\u001d\u0003CBA\u0005\u0003G\t)\u0005\u0005\u0003\u0002*\u0005\u001dC\u0001DA%\u0003{\t\t\u0011!A\u0003\u0002\u0005=\"aA0%c\u0005A\u0011M\u001c8pi\u0006$X\r\u0006\u0003\u0002P\u0005\u001d\u0004CBA)\u00037\n\tG\u0004\u0003\u0002T\u0005]cb\u0001\u001b\u0002V%\t\u0001(C\u0002\u0002Z]\nq\u0001]1dW\u0006<W-\u0003\u0003\u0002^\u0005}#aA*fc*\u0019\u0011\u0011L\u001c\u0011\u0007\r\n\u0019'C\u0002\u0002f\u0011\u0012!\"\u00118o_R\fG/[8o\u0011\u001d\tI'\u0006a\u0001\u0003\u001f\n1\"\u00198o_R\fG/[8og\u0006!\u0001o\\:u)\u0019\ty'a\u001e\u0002|A!q\tSA9!\r9\u00151O\u0005\u0004\u0003k:$!\u0002$m_\u0006$\bBBA=-\u0001\u0007\u0011'A\u0002ve2Da!! \u0017\u0001\u0004\t\u0014\u0001\u00036t_:\u0014u\u000eZ=")
public class OpenAIEmbeddings
extends AnnotatorModel<OpenAICompletion>
implements HasSimpleAnnotate<OpenAICompletion> {
    private final String uid;
    private final String[] inputAnnotatorTypes;
    private final String outputAnnotatorType;
    private final Param<String> model;
    private final Param<String> user;
    private Option<Broadcast<String>> bearerToken;

    @Override
    public UserDefinedFunction dfAnnotate() {
        return HasSimpleAnnotate.dfAnnotate$(this);
    }

    public String uid() {
        return this.uid;
    }

    @Override
    public String[] inputAnnotatorTypes() {
        return this.inputAnnotatorTypes;
    }

    @Override
    public String outputAnnotatorType() {
        return this.outputAnnotatorType;
    }

    public Param<String> model() {
        return this.model;
    }

    public OpenAIEmbeddings setModel(String value) {
        return (OpenAIEmbeddings)this.set(this.model(), value);
    }

    public Param<String> user() {
        return this.user;
    }

    public OpenAIEmbeddings setUser(String value) {
        return (OpenAIEmbeddings)this.set(this.user(), value);
    }

    private Option<Broadcast<String>> bearerToken() {
        return this.bearerToken;
    }

    private void bearerToken_$eq(Option<Broadcast<String>> x$1) {
        this.bearerToken = x$1;
    }

    public OpenAIEmbeddings setBearerTokenIfNotSet(SparkSession spark, Option<String> openAIKey) {
        block0: {
            if (!this.bearerToken().isEmpty() || !openAIKey.isDefined()) break block0;
            this.bearerToken_$eq((Option<Broadcast<String>>)new Some((Object)spark.sparkContext().broadcast(openAIKey.get(), ClassTag$.MODULE$.apply(String.class))));
        }
        return this;
    }

    public String getBearerToken() {
        return this.bearerToken().isDefined() ? (String)((Broadcast)this.bearerToken().get()).value() : "";
    }

    @Override
    public Dataset<?> beforeAnnotate(Dataset<?> dataset) {
        this.setBearerTokenIfNotSet(dataset.sparkSession(), (Option<String>)new Some((Object)ConfigLoader$.MODULE$.getConfigStringValue(ConfigHelper$.MODULE$.openAIAPIKey())));
        return dataset;
    }

    @Override
    public Seq<Annotation> annotate(Seq<Annotation> annotations) {
        Seq inputs = (Seq)annotations.map((Function1 & Serializable & scala.Serializable)annotation -> annotation.result(), Seq$.MODULE$.canBuildFrom());
        String userJson = JsonBuilder$.MODULE$.formatOptionalField("user", (Option<Object>)this.get(this.user()));
        String jsonTemplate = new StringOps(Predef$.MODULE$.augmentString("\n        |{\n        |    \"model\": \"%s\",\n        |    \"input\": \"%s\"\n        |    %s\n        |}\n        |")).stripMargin();
        Seq jsons = (Seq)inputs.map((Function1 & Serializable & scala.Serializable)input -> new Tuple2(input, (Object)new StringOps(Predef$.MODULE$.augmentString(jsonTemplate)).format((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{this.$(this.model()), input, userJson}))), Seq$.MODULE$.canBuildFrom());
        String openAIUrlEmbeddings = "https://api.openai.com/v1/embeddings";
        Seq annotationsEmbeddings = (Seq)jsons.map((Function1 & Serializable & scala.Serializable)x0$1 -> {
            Tuple2 tuple2 = x0$1;
            if (tuple2 == null) {
                throw new MatchError((Object)tuple2);
            }
            String input = (String)tuple2._1();
            String json = (String)tuple2._2();
            float[] response = this.post(openAIUrlEmbeddings, json);
            Annotation annotation = new Annotation(AnnotatorType$.MODULE$.DOCUMENT(), 0, input.length(), input, (Map<String, String>)((Map)Predef$.MODULE$.Map().apply((Seq)Nil$.MODULE$)), response);
            return annotation;
        }, Seq$.MODULE$.canBuildFrom());
        return annotationsEmbeddings;
    }

    private float[] post(String url, String jsonBody) {
        HttpPost httpPost = new HttpPost(url);
        httpPost.setEntity((HttpEntity)new StringEntity(jsonBody, ContentType.APPLICATION_JSON));
        String bearerToken = this.getBearerToken();
        Predef$.MODULE$.require(new StringOps(Predef$.MODULE$.augmentString(bearerToken)).nonEmpty(), (Function0 & Serializable & scala.Serializable)() -> "OpenAI API Key required");
        httpPost.setHeader("Authorization", new StringBuilder(7).append("Bearer ").append(bearerToken).toString());
        List<Object> embeddings = List$.MODULE$.empty();
        String responseBody = "";
        try (CloseableHttpClient httpclient = HttpClients.createDefault();){
            try {
                CloseableHttpResponse response = httpclient.execute((HttpUriRequest)httpPost);
                responseBody = EntityUtils.toString((HttpEntity)response.getEntity());
                TextEmbeddingResponse textEmbeddingResponse = (TextEmbeddingResponse)JsonParser$.MODULE$.parseObject(responseBody, ManifestFactory$.MODULE$.classType(TextEmbeddingResponse.class));
                embeddings = ((EmbeddingData)textEmbeddingResponse.data().head()).embedding();
            }
            catch (Exception ex) {
                if (responseBody.contains("error")) {
                    throw new Exception(responseBody);
                }
                ex.printStackTrace();
            }
        }
        return (float[])embeddings.toArray(ClassTag$.MODULE$.Float());
    }

    public OpenAIEmbeddings(String uid) {
        this.uid = uid;
        HasSimpleAnnotate.$init$(this);
        this.inputAnnotatorTypes = (String[])((Object[])new String[]{AnnotatorType$.MODULE$.DOCUMENT()});
        this.outputAnnotatorType = AnnotatorType$.MODULE$.DOCUMENT();
        this.model = new Param((Identifiable)this, "model", "ID of the OpenAI model to use");
        this.user = new Param((Identifiable)this, "user", "A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.");
        this.bearerToken = None$.MODULE$;
    }

    public OpenAIEmbeddings() {
        this(Identifiable$.MODULE$.randomUID("OPENAI_EMBEDDINGS"));
    }
}

