/*
 * Decompiled with CFR 0.152.
 */
package com.johnsnowlabs.ml.crf;

import com.johnsnowlabs.ml.crf.CrfDataset;
import com.johnsnowlabs.ml.crf.CrfParams;
import com.johnsnowlabs.ml.crf.DatasetMetadata;
import com.johnsnowlabs.ml.crf.FbCalculator;
import com.johnsnowlabs.ml.crf.Instance;
import com.johnsnowlabs.ml.crf.InstanceLabels;
import com.johnsnowlabs.ml.crf.L2DecayStrategy;
import com.johnsnowlabs.ml.crf.LinearChainCrfModel;
import com.johnsnowlabs.ml.crf.VectorMath$;
import com.johnsnowlabs.nlp.annotators.ner.Verbose$;
import java.io.Serializable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Array$;
import scala.Enumeration;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableOnce;
import scala.collection.mutable.ArrayOps;
import scala.math.Numeric;
import scala.math.Ordering;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.FloatRef;
import scala.runtime.IntRef;
import scala.runtime.RichInt$;
import scala.runtime.java8.JFunction1;
import scala.util.Random$;

@ScalaSignature(bytes="\u0006\u0001\u0005\u0005a\u0001\u0002\u0006\f\u0001QA\u0001b\u0007\u0001\u0003\u0006\u0004%\t\u0001\b\u0005\tC\u0001\u0011\t\u0011)A\u0005;!)!\u0005\u0001C\u0001G!9a\u0005\u0001b\u0001\n\u00139\u0003B\u0002\u0019\u0001A\u0003%\u0001\u0006C\u00032\u0001\u0011\u0005!\u0007C\u0003W\u0001\u0011\u0005q\u000bC\u0003a\u0001\u0011%\u0011\rC\u0003u\u0001\u0011\u0005QO\u0001\bMS:,\u0017M]\"iC&t7I\u001d4\u000b\u00051i\u0011aA2sM*\u0011abD\u0001\u0003[2T!\u0001E\t\u0002\u0019)|\u0007N\\:o_^d\u0017MY:\u000b\u0003I\t1aY8n\u0007\u0001\u0019\"\u0001A\u000b\u0011\u0005YIR\"A\f\u000b\u0003a\tQa]2bY\u0006L!AG\f\u0003\r\u0005s\u0017PU3g\u0003\u0019\u0001\u0018M]1ngV\tQ\u0004\u0005\u0002\u001f?5\t1\"\u0003\u0002!\u0017\tI1I\u001d4QCJ\fWn]\u0001\ba\u0006\u0014\u0018-\\:!\u0003\u0019a\u0014N\\5u}Q\u0011A%\n\t\u0003=\u0001AQaG\u0002A\u0002u\ta\u0001\\8hO\u0016\u0014X#\u0001\u0015\u0011\u0005%rS\"\u0001\u0016\u000b\u0005-b\u0013!B:mMRR'\"A\u0017\u0002\u0007=\u0014x-\u0003\u00020U\t1Aj\\4hKJ\fq\u0001\\8hO\u0016\u0014\b%A\u0002m_\u001e$2a\r\u001cG!\t1B'\u0003\u00026/\t!QK\\5u\u0011\u00199d\u0001\"a\u0001q\u0005)a/\u00197vKB\u0019a#O\u001e\n\u0005i:\"\u0001\u0003\u001fcs:\fW.\u001a \u0011\u0005q\u001aeBA\u001fB!\tqt#D\u0001@\u0015\t\u00015#\u0001\u0004=e>|GOP\u0005\u0003\u0005^\ta\u0001\u0015:fI\u00164\u0017B\u0001#F\u0005\u0019\u0019FO]5oO*\u0011!i\u0006\u0005\u0006\u000f\u001a\u0001\r\u0001S\u0001\t[&tG*\u001a<fYB\u0011\u0011j\u0015\b\u0003\u0015Fk\u0011a\u0013\u0006\u0003\u00196\u000b1A\\3s\u0015\tqu*\u0001\u0006b]:|G/\u0019;peNT!\u0001U\b\u0002\u00079d\u0007/\u0003\u0002S\u0017\u00069a+\u001a:c_N,\u0017B\u0001+V\u0005\u0015aUM^3m\u0015\t\u00116*\u0001\u0005ue\u0006LgnU$E)\tA6\f\u0005\u0002\u001f3&\u0011!l\u0003\u0002\u0014\u0019&tW-\u0019:DQ\u0006Lgn\u0011:g\u001b>$W\r\u001c\u0005\u00069\u001e\u0001\r!X\u0001\bI\u0006$\u0018m]3u!\tqb,\u0003\u0002`\u0017\tQ1I\u001d4ECR\f7/\u001a;\u0002\u000f\u001d,G\u000fT8tgR!!-\u001a6p!\t12-\u0003\u0002e/\t)a\t\\8bi\")a\r\u0003a\u0001O\u0006A1/\u001a8uK:\u001cW\r\u0005\u0002\u001fQ&\u0011\u0011n\u0003\u0002\t\u0013:\u001cH/\u00198dK\")1\u000e\u0003a\u0001Y\u00061A.\u00192fYN\u0004\"AH7\n\u00059\\!AD%ogR\fgnY3MC\n,Gn\u001d\u0005\u0006a\"\u0001\r!]\u0001\bG>tG/\u001a=u!\tq\"/\u0003\u0002t\u0017\taaIY\"bY\u000e,H.\u0019;pe\u0006IAm\\*hIN#X\r\u001d\u000b\u0007gY<\bP_@\t\u000b\u0019L\u0001\u0019A4\t\u000b-L\u0001\u0019\u00017\t\u000beL\u0001\u0019\u00012\u0002\u0003\u0005DQa_\u0005A\u0002q\fqa^3jO\"$8\u000fE\u0002\u0017{\nL!A`\f\u0003\u000b\u0005\u0013(/Y=\t\u000bAL\u0001\u0019A9")
public class LinearChainCrf {
    private final CrfParams params;
    private final Logger logger;

    public CrfParams params() {
        return this.params;
    }

    private Logger logger() {
        return this.logger;
    }

    public void log(Function0<String> value, Enumeration.Value minLevel) {
        block0: {
            if (!minLevel.$greater$eq((Object)this.params().verbose())) break block0;
            this.logger().info((String)value.apply());
        }
    }

    public LinearChainCrfModel trainSGD(CrfDataset dataset) {
        DatasetMetadata metadata = dataset.metadata();
        float[] weights = VectorMath$.MODULE$.Vector(dataset.metadata().attrFeatures().length + dataset.metadata().transitions().length, VectorMath$.MODULE$.Vector$default$2());
        int labels = dataset.metadata().labels().length;
        if (this.params().randomSeed().isDefined()) {
            Random$.MODULE$.setSeed((long)BoxesRunTime.unboxToInt((Object)this.params().randomSeed().get()));
        }
        int maxLength = BoxesRunTime.unboxToInt((Object)((TraversableOnce)dataset.instances().map((Function1 & Serializable & scala.Serializable)w -> BoxesRunTime.boxToInteger((int)LinearChainCrf.$anonfun$trainSGD$1(w)), Seq$.MODULE$.canBuildFrom())).max((Ordering)Ordering.Int$.MODULE$));
        this.log((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(8).append("labels: ").append(labels).toString(), Verbose$.MODULE$.TrainingStat());
        this.log((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(11).append("instances: ").append(dataset.instances().size()).toString(), Verbose$.MODULE$.TrainingStat());
        this.log((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(10).append("features: ").append(weights.length).toString(), Verbose$.MODULE$.TrainingStat());
        this.log((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(11).append("maxLength: ").append(maxLength).toString(), Verbose$.MODULE$.TrainingStat());
        FbCalculator context = new FbCalculator(maxLength, metadata);
        float[] bestW = VectorMath$.MODULE$.Vector(weights.length, VectorMath$.MODULE$.Vector$default$2());
        FloatRef bestLoss = FloatRef.create((float)Float.MAX_VALUE);
        FloatRef lastLoss = FloatRef.create((float)Float.MAX_VALUE);
        IntRef notImprovedEpochs = IntRef.create((int)0);
        L2DecayStrategy decayStrategy = new L2DecayStrategy(dataset.instances().size(), this.params().l2(), this.params().c0());
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), this.params().maxEpochs()).withFilter((Function1)(JFunction1.mcZI.sp & Serializable & scala.Serializable)epoch -> notImprovedEpochs$1.elem < 10 || epoch < this.params().minEpochs()).foreach((Function1)(JFunction1.mcVI.sp & Serializable & scala.Serializable)epoch -> {
            FloatRef loss = FloatRef.create((float)0.0f);
            this.log((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(15).append("\nEpoch: ").append(epoch).append(", eta: ").append(decayStrategy.eta()).toString(), Verbose$.MODULE$.Epochs());
            long started = System.nanoTime();
            Seq shuffled = (Seq)Random$.MODULE$.shuffle(dataset.instances(), Seq$.MODULE$.canBuildFrom());
            IntRef instancesCount = IntRef.create((int)0);
            shuffled.withFilter((Function1 & Serializable & scala.Serializable)check$ifrefutable$1 -> BoxesRunTime.boxToBoolean((boolean)LinearChainCrf.$anonfun$trainSGD$9(check$ifrefutable$1))).foreach((Function1 & Serializable & scala.Serializable)x$1 -> {
                LinearChainCrf.$anonfun$trainSGD$10(this, decayStrategy, context, weights, loss, instancesCount, x$1);
                return BoxedUnit.UNIT;
            });
            decayStrategy.reset(weights);
            float l2Loss = this.params().l2() * BoxesRunTime.unboxToFloat((Object)new ArrayOps.ofFloat(Predef$.MODULE$.floatArrayOps((float[])new ArrayOps.ofFloat(Predef$.MODULE$.floatArrayOps(weights)).map((Function1)(JFunction1.mcFF.sp & Serializable & scala.Serializable)w -> w * w, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Float())))).sum((Numeric)Numeric.FloatIsFractional$.MODULE$));
            float totalLoss = loss.elem + l2Loss;
            this.log((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(16).append("finished, time: ").append((double)(System.nanoTime() - started) / 1.0E9).toString(), Verbose$.MODULE$.Epochs());
            this.log((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(30).append("Loss = ").append(totalLoss).append(", logLoss = ").append(loss$1.elem).append(", l2Loss = ").append(l2Loss).toString(), Verbose$.MODULE$.Epochs());
            if (totalLoss < bestLoss$1.elem) {
                bestLoss$1.elem = totalLoss;
                VectorMath$.MODULE$.copy(weights, bestW);
                notImprovedEpochs$1.elem = (bestLoss$1.elem - totalLoss) / totalLoss < this.params().lossEps() ? 0 : ++notImprovedEpochs$1.elem;
            } else {
                ++notImprovedEpochs$1.elem;
            }
            lastLoss$1.elem = totalLoss;
        });
        return new LinearChainCrfModel(bestW, metadata);
    }

    private float getLoss(Instance sentence, InstanceLabels labels, FbCalculator context) {
        FloatRef result;
        block0: {
            int length = sentence.items().length();
            IntRef prevLabel = IntRef.create((int)0);
            result = FloatRef.create((float)0.0f);
            RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), length).foreach$mVc$sp((Function1)(JFunction1.mcVI.sp & Serializable & scala.Serializable)i -> {
                result$1.elem -= context.logPhi()[i][prevLabel$1.elem][BoxesRunTime.unboxToInt((Object)labels.labels().apply(i))];
                prevLabel$1.elem = BoxesRunTime.unboxToInt((Object)labels.labels().apply(i));
                result$1.elem += (float)Math.log(context.c()[i]);
            });
            if (!(result.elem >= 0.0f)) break block0;
            Predef$.MODULE$.assert(result.elem >= 0.0f);
        }
        return result.elem;
    }

    public void doSgdStep(Instance sentence, InstanceLabels labels, float a, float[] weights, FbCalculator context) {
        context.addObservedExpectations(weights, sentence, labels, a);
        context.addModelExpectations(weights, sentence, -a);
    }

    public static final /* synthetic */ int $anonfun$trainSGD$1(Tuple2 w) {
        return ((Instance)w._2()).items().size();
    }

    public static final /* synthetic */ boolean $anonfun$trainSGD$9(Tuple2 check$ifrefutable$1) {
        Tuple2 tuple2 = check$ifrefutable$1;
        boolean bl = tuple2 != null;
        return bl;
    }

    public static final /* synthetic */ void $anonfun$trainSGD$10(LinearChainCrf $this, L2DecayStrategy decayStrategy$1, FbCalculator context$1, float[] weights$1, FloatRef loss$1, IntRef instancesCount$1, Tuple2 x$1) {
        BoxedUnit boxedUnit;
        Tuple2 tuple2 = x$1;
        if (tuple2 != null) {
            InstanceLabels labels = (InstanceLabels)tuple2._1();
            Instance sentence = (Instance)tuple2._2();
            decayStrategy$1.nextStep();
            context$1.calculate(sentence, weights$1, decayStrategy$1.getScale());
            $this.doSgdStep(sentence, labels, decayStrategy$1.alpha(), weights$1, context$1);
            loss$1.elem += $this.getLoss(sentence, labels, context$1);
            ++instancesCount$1.elem;
            if (instancesCount$1.elem % 1000 == 0) {
                decayStrategy$1.reset(weights$1);
                boxedUnit = BoxedUnit.UNIT;
            } else {
                boxedUnit = BoxedUnit.UNIT;
            }
        } else {
            throw new MatchError((Object)tuple2);
        }
        BoxedUnit boxedUnit2 = boxedUnit;
    }

    public LinearChainCrf(CrfParams params) {
        this.params = params;
        this.logger = LoggerFactory.getLogger((String)"CRF");
    }
}

