/*
 * Decompiled with CFR 0.152.
 */
package com.linkedin.feathr.offline.generation.aggregations;

import com.linkedin.feathr.common.Params;
import com.linkedin.feathr.common.exception.ErrorLabel;
import com.linkedin.feathr.common.exception.FeathrConfigException;
import com.typesafe.config.Config;
import java.io.Serializable;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructField$;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.StructType$;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Option;
import scala.Tuple2;
import scala.collection.GenIterable;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.immutable.;
import scala.collection.immutable.List;
import scala.collection.immutable.Nil$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.runtime.java8.JFunction0;

@ScalaSignature(bytes="\u0006\u0001\u00055a!B\b\u0011\u0001Qa\u0002\"B\u0019\u0001\t\u0003\u0019\u0004b\u0002\u001c\u0001\u0001\u0004%\ta\u000e\u0005\b}\u0001\u0001\r\u0011\"\u0001@\u0011\u0019)\u0005\u0001)Q\u0005q!)a\t\u0001C!\u000f\")!\u000b\u0001C!'\")!\f\u0001C!'\")1\f\u0001C!9\")\u0001\r\u0001C!C\")Q\r\u0001C!M\")A\u000e\u0001C![\")Q\u000f\u0001C!m\")1\u0010\u0001C!y\"9\u00111\u0001\u0001\u0005\n\u0005\u0015!AD'j]B{w\u000e\\5oOV#\u0015I\u0012\u0006\u0003#I\tA\"Y4he\u0016<\u0017\r^5p]NT!a\u0005\u000b\u0002\u0015\u001d,g.\u001a:bi&|gN\u0003\u0002\u0016-\u00059qN\u001a4mS:,'BA\f\u0019\u0003\u00191W-\u0019;ie*\u0011\u0011DG\u0001\tY&t7.\u001a3j]*\t1$A\u0002d_6\u001c2\u0001A\u000f,!\tq\u0012&D\u0001 \u0015\t\u0001\u0013%A\u0006fqB\u0014Xm]:j_:\u001c(B\u0001\u0012$\u0003\r\u0019\u0018\u000f\u001c\u0006\u0003I\u0015\nQa\u001d9be.T!AJ\u0014\u0002\r\u0005\u0004\u0018m\u00195f\u0015\u0005A\u0013aA8sO&\u0011!f\b\u0002\u001d+N,'\u000fR3gS:,G-Q4he\u0016<\u0017\r^3Gk:\u001cG/[8o!\tas&D\u0001.\u0015\tqc#\u0001\u0004d_6lwN\\\u0005\u0003a5\u0012a\u0001U1sC6\u001c\u0018A\u0002\u001fj]&$hh\u0001\u0001\u0015\u0003Q\u0002\"!\u000e\u0001\u000e\u0003A\tQ\"Z7cK\u0012$\u0017N\\4TSj,W#\u0001\u001d\u0011\u0005ebT\"\u0001\u001e\u000b\u0003m\nQa]2bY\u0006L!!\u0010\u001e\u0003\u0007%sG/A\tf[\n,G\rZ5oONK'0Z0%KF$\"\u0001Q\"\u0011\u0005e\n\u0015B\u0001\";\u0005\u0011)f.\u001b;\t\u000f\u0011\u001b\u0011\u0011!a\u0001q\u0005\u0019\u0001\u0010J\u0019\u0002\u001d\u0015l'-\u001a3eS:<7+\u001b>fA\u0005!\u0011N\\5u)\t\u0001\u0005\nC\u0003J\u000b\u0001\u0007!*\u0001\u0004qCJ\fWn\u001d\t\u0003\u0017Bk\u0011\u0001\u0014\u0006\u0003\u001b:\u000baaY8oM&<'BA(\u001b\u0003!!\u0018\u0010]3tC\u001a,\u0017BA)M\u0005\u0019\u0019uN\u001c4jO\u0006Y\u0011N\u001c9viN\u001b\u0007.Z7b+\u0005!\u0006CA+Y\u001b\u00051&BA,\"\u0003\u0015!\u0018\u0010]3t\u0013\tIfK\u0001\u0006TiJ,8\r\u001e+za\u0016\fABY;gM\u0016\u00148k\u00195f[\u0006\f\u0001\u0002Z1uCRK\b/Z\u000b\u0002;B\u0011QKX\u0005\u0003?Z\u0013\u0001\u0002R1uCRK\b/Z\u0001\u000eI\u0016$XM]7j]&\u001cH/[2\u0016\u0003\t\u0004\"!O2\n\u0005\u0011T$a\u0002\"p_2,\u0017M\\\u0001\u000bS:LG/[1mSj,GC\u0001!h\u0011\u0015A'\u00021\u0001j\u0003\u0019\u0011WO\u001a4feB\u0011aD[\u0005\u0003W~\u0011\u0001$T;uC\ndW-Q4he\u0016<\u0017\r^5p]\n+hMZ3s\u0003\u0019)\b\u000fZ1uKR\u0019\u0001I\\8\t\u000b!\\\u0001\u0019A5\t\u000bA\\\u0001\u0019A9\u0002\u000b%t\u0007/\u001e;\u0011\u0005I\u001cX\"A\u0011\n\u0005Q\f#a\u0001*po\u0006)Q.\u001a:hKR\u0019\u0001i^=\t\u000bad\u0001\u0019A5\u0002\u000f\t,hMZ3sc!)!\u0010\u0004a\u0001c\u00069!-\u001e4gKJ\u0014\u0014\u0001C3wC2,\u0018\r^3\u0015\u0007u\f\t\u0001\u0005\u0002:}&\u0011qP\u000f\u0002\u0004\u0003:L\b\"\u00025\u000e\u0001\u0004\t\u0018!C2bY\u000e,H.\u0019;f)\u0015\u0001\u0015qAA\u0005\u0011\u0015Ag\u00021\u0001j\u0011\u0019\tYA\u0004a\u0001c\u0006\u0019!o\\<")
public class MinPoolingUDAF
extends UserDefinedAggregateFunction
implements Params {
    private int embeddingSize;
    private Option<Config> _params;

    @Override
    public Option<Config> _params() {
        return this._params;
    }

    @Override
    public void _params_$eq(Option<Config> x$1) {
        this._params = x$1;
    }

    public int embeddingSize() {
        return this.embeddingSize;
    }

    public void embeddingSize_$eq(int x$1) {
        this.embeddingSize = x$1;
    }

    @Override
    public void init(Config params) {
        Params.init$(this, params);
        this.embeddingSize_$eq(((Config)this._params().get()).getInt("embeddingSize"));
    }

    public StructType inputSchema() {
        return StructType$.MODULE$.apply((Seq)new .colon.colon((Object)new StructField("value", (DataType)new ArrayType((DataType)DoubleType$.MODULE$, false), StructField$.MODULE$.apply$default$3(), StructField$.MODULE$.apply$default$4()), (List)Nil$.MODULE$));
    }

    public StructType bufferSchema() {
        return StructType$.MODULE$.apply((Seq)new .colon.colon((Object)new StructField("agg", (DataType)new ArrayType((DataType)DoubleType$.MODULE$, false), StructField$.MODULE$.apply$default$3(), StructField$.MODULE$.apply$default$4()), (List)Nil$.MODULE$));
    }

    public DataType dataType() {
        return new ArrayType((DataType)DoubleType$.MODULE$, false);
    }

    public boolean deterministic() {
        return true;
    }

    public void initialize(MutableAggregationBuffer buffer) {
        buffer.update(0, (Object)Seq$.MODULE$.fill(this.embeddingSize(), (Function0)(JFunction0.mcD.sp & Serializable & scala.Serializable)() -> Double.MAX_VALUE));
    }

    public void update(MutableAggregationBuffer buffer, Row input) {
        this.calculate(buffer, input);
    }

    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        this.calculate(buffer1, buffer2);
    }

    public Object evaluate(Row buffer) {
        return buffer.getAs(0);
    }

    private void calculate(MutableAggregationBuffer buffer, Row row) {
        block1: {
            Seq embedding = (Seq)row.getAs(0);
            Seq aggregate = (Seq)((TraversableLike)buffer.getAs(0)).map((Function1 & Serializable & scala.Serializable)x -> BoxesRunTime.boxToDouble((double)x.doubleValue()), Seq$.MODULE$.canBuildFrom());
            if (embedding == null) break block1;
            if (embedding.size() != this.embeddingSize()) {
                throw new FeathrConfigException(ErrorLabel.FEATHR_USER_ERROR, new StringBuilder(69).append("embedding vector size has a length of ").append(embedding.size()).append(", different from expected size ").append(this.embeddingSize()).toString());
            }
            Seq newAgg = (Seq)((TraversableLike)aggregate.zip((GenIterable)embedding, Seq$.MODULE$.canBuildFrom())).map((Function1 & Serializable & scala.Serializable)x0$1 -> BoxesRunTime.boxToDouble((double)MinPoolingUDAF.$anonfun$calculate$2(x0$1)), Seq$.MODULE$.canBuildFrom());
            buffer.update(0, (Object)newAgg);
        }
    }

    public static final /* synthetic */ double $anonfun$calculate$2(Tuple2 x0$1) {
        Tuple2 tuple2 = x0$1;
        if (tuple2 == null) {
            throw new MatchError((Object)tuple2);
        }
        double x = tuple2._1$mcD$sp();
        double y = tuple2._2$mcD$sp();
        double d = Math.min(x, y);
        return d;
    }

    public MinPoolingUDAF() {
        Params.$init$(this);
        this.embeddingSize = 0;
    }
}

