/*
 * Decompiled with CFR 0.152.
 */
package com.microsoft.azure.synapse.ml.lightgbm;

import com.microsoft.azure.synapse.ml.lightgbm.BasePartitionTask;
import com.microsoft.azure.synapse.ml.lightgbm.LightGBMUtils$;
import com.microsoft.azure.synapse.ml.lightgbm.PartitionDataState;
import com.microsoft.azure.synapse.ml.lightgbm.PartitionTaskContext;
import com.microsoft.azure.synapse.ml.lightgbm.SharedDatasetState;
import com.microsoft.azure.synapse.ml.lightgbm.SharedState;
import com.microsoft.azure.synapse.ml.lightgbm.TrainingContext;
import com.microsoft.azure.synapse.ml.lightgbm.dataset.BaseAggregatedColumns;
import com.microsoft.azure.synapse.ml.lightgbm.dataset.BaseChunkedColumns;
import com.microsoft.azure.synapse.ml.lightgbm.dataset.DenseAggregatedColumns;
import com.microsoft.azure.synapse.ml.lightgbm.dataset.DenseChunkedColumns;
import com.microsoft.azure.synapse.ml.lightgbm.dataset.LightGBMDataset;
import com.microsoft.azure.synapse.ml.lightgbm.dataset.PeekingIterator;
import com.microsoft.azure.synapse.ml.lightgbm.dataset.SparseAggregatedColumns;
import com.microsoft.azure.synapse.ml.lightgbm.dataset.SparseChunkedColumns;
import java.io.Serializable;
import org.apache.spark.sql.Row;
import scala.Function1;
import scala.None$;
import scala.Option;
import scala.Option$;
import scala.Predef$;
import scala.collection.Iterator;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

@ScalaSignature(bytes="\u0006\u0001\u0005\ra\u0001B\u0005\u000b\u0001]AQ\u0001\b\u0001\u0005\u0002uAQa\b\u0001\u0005R\u0001BQa\r\u0001\u0005\u0012QBQA\u0016\u0001\u0005\u0012]CQ!\u0019\u0001\u0005\u0012\tDQa\u001a\u0001\u0005\u0012!DQa\u001d\u0001\u0005\nQDQA\u001f\u0001\u0005\nm\u0014\u0011CQ;mWB\u000b'\u000f^5uS>tG+Y:l\u0015\tYA\"\u0001\u0005mS\u001eDGo\u001a2n\u0015\tia\"\u0001\u0002nY*\u0011q\u0002E\u0001\bgft\u0017\r]:f\u0015\t\t\"#A\u0003buV\u0014XM\u0003\u0002\u0014)\u0005IQ.[2s_N|g\r\u001e\u0006\u0002+\u0005\u00191m\\7\u0004\u0001M\u0011\u0001\u0001\u0007\t\u00033ii\u0011AC\u0005\u00037)\u0011\u0011CQ1tKB\u000b'\u000f^5uS>tG+Y:l\u0003\u0019a\u0014N\\5u}Q\ta\u0004\u0005\u0002\u001a\u0001\u0005\u0011\u0012N\\5uS\u0006d\u0017N_3J]R,'O\\1m)\u0011\ts\u0005L\u0019\u0011\u0005\t*S\"A\u0012\u000b\u0003\u0011\nQa]2bY\u0006L!AJ\u0012\u0003\tUs\u0017\u000e\u001e\u0005\u0006Q\t\u0001\r!K\u0001\u0004GRD\bCA\r+\u0013\tY#BA\bUe\u0006Lg.\u001b8h\u0007>tG/\u001a=u\u0011\u0015i#\u00011\u0001/\u0003U\u0019\bn\\;mI\u0016CXmY;uKR\u0013\u0018-\u001b8j]\u001e\u0004\"AI\u0018\n\u0005A\u001a#a\u0002\"p_2,\u0017M\u001c\u0005\u0006e\t\u0001\rAL\u0001\u0011SN,U\u000e\u001d;z!\u0006\u0014H/\u001b;j_:\fA\u0004\u001d:fa\u0006\u0014X\rU1si&$\u0018n\u001c8ECR\f\u0017J\u001c;fe:\fG\u000eF\u00026qq\u0002\"!\u0007\u001c\n\u0005]R!A\u0005)beRLG/[8o\t\u0006$\u0018m\u0015;bi\u0016DQ\u0001K\u0002A\u0002e\u0002\"!\u0007\u001e\n\u0005mR!\u0001\u0006)beRLG/[8o)\u0006\u001c8nQ8oi\u0016DH\u000fC\u0003>\u0007\u0001\u0007a(A\u0005j]B,HOU8xgB\u0019qh\u0012&\u000f\u0005\u0001+eBA!E\u001b\u0005\u0011%BA\"\u0017\u0003\u0019a$o\\8u}%\tA%\u0003\u0002GG\u00059\u0001/Y2lC\u001e,\u0017B\u0001%J\u0005!IE/\u001a:bi>\u0014(B\u0001$$!\tYE+D\u0001M\u0015\tie*A\u0002tc2T!a\u0014)\u0002\u000bM\u0004\u0018M]6\u000b\u0005E\u0013\u0016AB1qC\u000eDWMC\u0001T\u0003\ry'oZ\u0005\u0003+2\u00131AU8x\u0003i9W\r\u001e+sC&t\u0017N\\4ECR\f7/\u001a;J]R,'O\\1m)\rAfl\u0018\t\u00033rk\u0011A\u0017\u0006\u00037*\tq\u0001Z1uCN,G/\u0003\u0002^5\nyA*[4ii\u001e\u0013U\nR1uCN,G\u000fC\u0003)\t\u0001\u0007\u0011\bC\u0003a\t\u0001\u0007Q'A\u0005eCR\f7\u000b^1uK\u0006ar-\u001a;WC2LG-\u0019;j_:$\u0015\r^1tKRLe\u000e^3s]\u0006dG\u0003\u0002-dI\u0016DQ\u0001K\u0003A\u0002eBQ\u0001Y\u0003A\u0002UBQAZ\u0003A\u0002a\u000b\u0001C]3gKJ,gnY3ECR\f7/\u001a;\u0002\u001f\u001d,g.\u001a:bi\u0016$\u0015\r^1tKR$B\u0001W5k_\")\u0001F\u0002a\u0001s!)1N\u0002a\u0001Y\u0006\u0011\u0011m\u0019\t\u000336L!A\u001c.\u0003+\t\u000b7/Z!hOJ,w-\u0019;fI\u000e{G.^7og\")aM\u0002a\u0001aB\u0019!%\u001d-\n\u0005I\u001c#AB(qi&|g.A\thKR\u001c\u0005.\u001e8lK\u0012\u001cu\u000e\\;n]N$2!\u001e=z!\tIf/\u0003\u0002x5\n\u0011\")Y:f\u0007\",hn[3e\u0007>dW/\u001c8t\u0011\u0015As\u00011\u0001:\u0011\u0015it\u00011\u0001?\u0003}iWM]4f\u0007\",hn[:J]R|\u0017iZ4sK\u001e\fG/\u001a3BeJ\f\u0017p\u001d\u000b\u0005Yrlx\u0010C\u0003)\u0011\u0001\u0007\u0011\bC\u0003\u007f\u0011\u0001\u0007Q/\u0001\u0002ug\"1\u0011\u0011\u0001\u0005A\u00029\nq\"[:G_J4\u0016\r\\5eCRLwN\u001c")
public class BulkPartitionTask
extends BasePartitionTask {
    @Override
    public void initializeInternal(TrainingContext ctx, boolean shouldExecuteTraining, boolean isEmptyPartition) {
        if (ctx.useSingleDatasetMode() && !isEmptyPartition) {
            ctx.sharedState().incrementArrayProcessedSignal(this.log());
            if (!shouldExecuteTraining) {
                ctx.sharedState().incrementDataPrepDoneSignal(this.log());
                return;
            }
            return;
        }
    }

    @Override
    public PartitionDataState preparePartitionDataInternal(PartitionTaskContext ctx, Iterator<Row> inputRows) {
        if (ctx.shouldExecuteTraining()) {
            if (ctx.trainingCtx().useSingleDatasetMode()) {
                ctx.sharedState().helperStartSignal().countDown();
                this.log().info(new StringBuilder(50).append("Initiated helper start signal on task ").append(ctx.taskId()).append(", partition ").append(ctx.partitionId()).toString());
            }
        } else {
            this.log().info(new StringBuilder(52).append("Waiting for helper start signal on task ").append(ctx.taskId()).append(", partition ").append(ctx.partitionId()).toString());
            ctx.sharedState().helperStartSignal().await();
        }
        this.log().info(new StringBuilder(33).append("Reading data on task ").append(ctx.taskId()).append(", partition ").append(ctx.partitionId()).toString());
        BaseChunkedColumns prepAggregatedColumns = this.getChunkedColumns(ctx, inputRows);
        this.log().info(new StringBuilder(33).append("Merging data on task ").append(ctx.taskId()).append(", partition ").append(ctx.partitionId()).toString());
        BaseAggregatedColumns aggregatedColumns = this.mergeChunksIntoAggregatedArrays(ctx, prepAggregatedColumns, false);
        Option aggregatedValidationColumns = ctx.trainingCtx().validationData().map((Function1 & Serializable & scala.Serializable)data -> {
            BaseChunkedColumns prepAggregatedColumns = this.getChunkedColumns(ctx, (Iterator<Row>)new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])data.value())).toIterator());
            return this.mergeChunksIntoAggregatedArrays(ctx, prepAggregatedColumns, true);
        });
        return new PartitionDataState((Option<BaseAggregatedColumns>)Option$.MODULE$.apply((Object)aggregatedColumns), (Option<BaseAggregatedColumns>)aggregatedValidationColumns);
    }

    @Override
    public LightGBMDataset getTrainingDatasetInternal(PartitionTaskContext ctx, PartitionDataState dataState) {
        return this.generateDataset(ctx, (BaseAggregatedColumns)dataState.aggregatedTrainingData().get(), (Option<LightGBMDataset>)None$.MODULE$);
    }

    @Override
    public LightGBMDataset getValidationDatasetInternal(PartitionTaskContext ctx, PartitionDataState dataState, LightGBMDataset referenceDataset) {
        return this.generateDataset(ctx, (BaseAggregatedColumns)dataState.aggregatedValidationData().get(), (Option<LightGBMDataset>)Option$.MODULE$.apply((Object)referenceDataset));
    }

    public LightGBMDataset generateDataset(PartitionTaskContext ctx, BaseAggregatedColumns ac, Option<LightGBMDataset> referenceDataset) {
        LightGBMDataset lightGBMDataset;
        try {
            LightGBMDataset datasetInner = ac.generateDataset(ctx, referenceDataset);
            ctx.trainingCtx().columnParams().groupColumn().foreach((Function1 & Serializable & scala.Serializable)x$1 -> {
                datasetInner.addGroupColumn(ac.getGroups());
                return BoxedUnit.UNIT;
            });
            lightGBMDataset = datasetInner.setFeatureNames(ctx.trainingCtx().featureNames(), ac.getNumCols());
        }
        finally {
            ac.cleanup();
        }
        return lightGBMDataset;
    }

    private BaseChunkedColumns getChunkedColumns(PartitionTaskContext ctx, Iterator<Row> inputRows) {
        TrainingContext trainingCtx = ctx.trainingCtx();
        PeekingIterator<Row> newIterator = this.determineMatrixType(ctx, inputRows);
        if (!BoxesRunTime.unboxToBoolean((Object)ctx.sharedState().isSparse().get())) {
            return new DenseChunkedColumns(newIterator, trainingCtx.columnParams(), trainingCtx.schema(), trainingCtx.trainingParams().executionParams().chunkSize());
        }
        return new SparseChunkedColumns(newIterator, trainingCtx.columnParams(), trainingCtx.schema(), trainingCtx.trainingParams().executionParams().chunkSize(), trainingCtx.useSingleDatasetMode());
    }

    private BaseAggregatedColumns mergeChunksIntoAggregatedArrays(PartitionTaskContext ctx, BaseChunkedColumns ts, boolean isForValidation) {
        boolean mergeRowsIntoDataset;
        BaseAggregatedColumns aggregatedColumns;
        SharedDatasetState sharedDatasetState;
        SharedState sharedState = ctx.sharedState();
        boolean useSingleDataset = ctx.trainingCtx().useSingleDatasetMode();
        boolean isSparse = BoxesRunTime.unboxToBoolean((Object)sharedState.isSparse().get());
        SharedDatasetState sharedDatasetState2 = sharedDatasetState = isForValidation ? sharedState.validationDatasetState() : sharedState.datasetState();
        BaseAggregatedColumns baseAggregatedColumns = !isSparse ? (useSingleDataset ? sharedDatasetState.denseAggregatedColumns() : new DenseAggregatedColumns(ctx.trainingParams().executionParams().chunkSize())) : (aggregatedColumns = useSingleDataset ? sharedDatasetState.sparseAggregatedColumns() : new SparseAggregatedColumns(ctx.trainingParams().executionParams().chunkSize()));
        boolean bl = !isForValidation ? true : (mergeRowsIntoDataset = !useSingleDataset || BoxesRunTime.unboxToLong((Object)sharedState.mainExecutorWorker().get()) == LightGBMUtils$.MODULE$.getTaskId());
        if (mergeRowsIntoDataset) {
            aggregatedColumns.incrementCount(ts, ctx.partitionId());
        }
        if (useSingleDataset) {
            sharedDatasetState.arrayProcessedSignal().countDown();
            sharedDatasetState.arrayProcessedSignal().await();
        }
        if (mergeRowsIntoDataset) {
            aggregatedColumns.addRows(ts);
        }
        ts.release();
        return aggregatedColumns;
    }
}

