/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.xgboost.task;

import ai.h2o.xgboost4j.java.Booster;
import ai.h2o.xgboost4j.java.BoosterWrapper;
import ai.h2o.xgboost4j.java.DMatrix;
import ai.h2o.xgboost4j.java.Rabit;
import ai.h2o.xgboost4j.java.XGBoostError;
import hex.tree.xgboost.BoosterParms;
import hex.tree.xgboost.EvalMetric;
import java.util.Map;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.TimeUnit;
import org.apache.log4j.Logger;
import water.H2O;
import water.Key;
import water.nbhm.NonBlockingHashMap;
import water.util.Log;

public class XGBoostUpdater
extends Thread {
    private static final Logger LOG = Logger.getLogger(XGBoostUpdater.class);
    private static final long WORK_START_TIMEOUT_SECS = 300L;
    private static final long INACTIVE_CHECK_INTERVAL_SECS = 60L;
    private static final NonBlockingHashMap<Key, XGBoostUpdater> updaters = new NonBlockingHashMap();
    private final Key _modelKey;
    private final DMatrix _trainMat;
    private final DMatrix _validMat;
    private final BoosterParms _boosterParms;
    private final String _evalMetricSpec;
    private final byte[] _checkpointBoosterBytes;
    private final Map<String, String> _rabitEnv;
    private volatile SynchronousQueue<BoosterCallable<?>> _in;
    private volatile SynchronousQueue<Object> _out;
    private BoosterWrapper _booster;
    private volatile EvalMetric _evalMetric;

    private XGBoostUpdater(Key modelKey, DMatrix trainMat, DMatrix validMat, BoosterParms boosterParms, byte[] checkpointBoosterBytes, Map<String, String> rabitEnv) {
        super("XGBoostUpdater-" + modelKey);
        this._modelKey = modelKey;
        this._trainMat = trainMat;
        this._validMat = validMat;
        this._boosterParms = boosterParms;
        this._checkpointBoosterBytes = checkpointBoosterBytes;
        this._rabitEnv = rabitEnv;
        this._evalMetricSpec = (String)this._boosterParms.get().get("eval_metric");
        this._in = new SynchronousQueue();
        this._out = new SynchronousQueue();
    }

    @Override
    public void run() {
        try {
            Rabit.init(this._rabitEnv);
            while (!XGBoostUpdater.interrupted()) {
                BoosterCallable<?> task = this._in.take();
                Object result = task.call();
                this._out.put(result);
            }
        }
        catch (InterruptedException e) {
            XGBoostUpdater self = (XGBoostUpdater)updaters.get((Object)this._modelKey);
            if (self != null) {
                LOG.error((Object)("Updater thread was interrupted while it was still registered, name=" + self.getName()), (Throwable)e);
            } else {
                LOG.debug((Object)"Updater thread interrupted.", (Throwable)e);
            }
            Thread.currentThread().interrupt();
        }
        catch (XGBoostError e) {
            LOG.error((Object)"XGBoost training iteration failed", (Throwable)e);
        }
        finally {
            this._in = null;
            this._out = null;
            updaters.remove((Object)this._modelKey);
            try {
                this._trainMat.dispose();
                if (this._validMat != null) {
                    this._validMat.dispose();
                }
                if (this._booster != null) {
                    this._booster.dispose();
                }
            }
            catch (Exception e) {
                LOG.warn((Object)"Failed to dispose of training matrix/booster", (Throwable)e);
            }
            try {
                Rabit.shutdown();
            }
            catch (Exception xgBoostError) {
                LOG.warn((Object)"Rabit shutdown during update failed", (Throwable)xgBoostError);
            }
        }
    }

    private <T> T invoke(BoosterCallable<T> callable) throws InterruptedException {
        SynchronousQueue<Object> outQ;
        SynchronousQueue<BoosterCallable<?>> inQ = this._in;
        if (inQ == null) {
            throw new IllegalStateException("Updater is inactive on node " + H2O.SELF);
        }
        if (!inQ.offer(callable, 300L, TimeUnit.SECONDS)) {
            throw new IllegalStateException("XGBoostUpdater couldn't start work on task " + callable + " in " + 300L + "s.");
        }
        int i = 0;
        while ((outQ = this._out) != null) {
            ++i;
            Object result = outQ.poll(60L, TimeUnit.SECONDS);
            if (result != null) {
                return (T)result;
            }
            if (i <= 5) continue;
            LOG.warn((Object)String.format("XGBoost task of type '%s' is taking unexpectedly long, it didn't finish in %d seconds.", callable, 60L * (long)i));
        }
        throw new IllegalStateException("Cannot perform booster operation: updater is inactive on node " + H2O.SELF);
    }

    private EvalMetric parseEvalMetric(String evalMetricVal) {
        return XGBoostUpdater.parseEvalMetric(this._evalMetricSpec, this._validMat != null, evalMetricVal);
    }

    static EvalMetric parseEvalMetric(String evalMetricSpec, boolean hasValid, String evalMetricVal) {
        int expectedParts;
        String[] parts = evalMetricVal.split("\t");
        int n = expectedParts = hasValid ? 3 : 2;
        if (parts.length != expectedParts) {
            Log.err((Object[])new Object[]{"Evaluation metric cannot be parsed, unexpected number of elements. Value: '" + evalMetricSpec + "'."});
            return EvalMetric.empty(evalMetricSpec);
        }
        double validVal = Double.NaN;
        double trainVal = XGBoostUpdater.parseEvalMetricPart(parts[1]);
        if (hasValid) {
            validVal = XGBoostUpdater.parseEvalMetricPart(parts[2]);
        }
        return new EvalMetric(evalMetricSpec, trainVal, validVal);
    }

    static double parseEvalMetricPart(String evalMetricVal) {
        int sepPos = evalMetricVal.lastIndexOf(":");
        if (sepPos >= 0) {
            String valStr = evalMetricVal.substring(sepPos + 1).trim();
            try {
                return Double.parseDouble(valStr);
            }
            catch (Exception e) {
                Log.err((Object[])new Object[]{"Failed to parse value of evaluation metric: '" + evalMetricVal + "'.", e});
            }
        }
        return Double.NaN;
    }

    byte[] getBoosterBytes() {
        try {
            return this.invoke(new SerializeBooster());
        }
        catch (InterruptedException e) {
            throw new IllegalStateException("Failed to serialize Booster - operation was interrupted", e);
        }
    }

    EvalMetric getEvalMetric() {
        return this._evalMetric;
    }

    Booster doUpdate(int tid) {
        try {
            return this.invoke(new UpdateBooster(tid));
        }
        catch (InterruptedException e) {
            throw new IllegalStateException("Boosting iteration failed - operation was interrupted", e);
        }
    }

    static XGBoostUpdater make(Key modelKey, DMatrix trainMat, DMatrix validMat, BoosterParms boosterParms, byte[] checkpoint, Map<String, String> rabitEnv) {
        XGBoostUpdater updater = new XGBoostUpdater(modelKey, trainMat, validMat, boosterParms, checkpoint, rabitEnv);
        updater.setUncaughtExceptionHandler(LoggingExceptionHandler.INSTANCE);
        if (updaters.putIfAbsent((Object)modelKey, (Object)updater) != null) {
            throw new IllegalStateException("XGBoostUpdater for modelKey=" + modelKey + " already exists!");
        }
        return updater;
    }

    static void terminate(Key modelKey) {
        XGBoostUpdater updater = (XGBoostUpdater)updaters.remove((Object)modelKey);
        if (updater == null) {
            LOG.debug((Object)("XGBoostUpdater for modelKey=" + modelKey + " was already clean-up on node " + H2O.SELF));
        } else {
            updater.interrupt();
        }
    }

    static XGBoostUpdater getUpdater(Key modelKey) {
        XGBoostUpdater updater = (XGBoostUpdater)updaters.get((Object)modelKey);
        if (updater == null) {
            throw new IllegalStateException("XGBoostUpdater for modelKey=" + modelKey + " was not found!");
        }
        return updater;
    }

    private static class LoggingExceptionHandler
    implements Thread.UncaughtExceptionHandler {
        private static LoggingExceptionHandler INSTANCE = new LoggingExceptionHandler();

        private LoggingExceptionHandler() {
        }

        @Override
        public void uncaughtException(Thread t, Throwable e) {
            LOG.error((Object)("Uncaught exception in " + t.getName()), e);
        }
    }

    private static interface BoosterCallable<E> {
        public E call() throws XGBoostError;
    }

    private class SerializeBooster
    implements BoosterCallable<byte[]> {
        private SerializeBooster() {
        }

        @Override
        public byte[] call() throws XGBoostError {
            return XGBoostUpdater.this._booster.toByteArray();
        }

        public String toString() {
            return "SerializeBooster";
        }
    }

    private class UpdateBooster
    implements BoosterCallable<Booster> {
        private final int _tid;

        private UpdateBooster(int tid) {
            this._tid = tid;
        }

        @Override
        public Booster call() throws XGBoostError {
            if (XGBoostUpdater.this._booster == null && this._tid == 0) {
                XGBoostUpdater.this._booster = new BoosterWrapper(XGBoostUpdater.this._checkpointBoosterBytes, XGBoostUpdater.this._boosterParms.get(), XGBoostUpdater.this._trainMat, XGBoostUpdater.this._validMat);
                XGBoostUpdater.this._evalMetric = this.computeEvalMetric();
                byte[] boosterBytes = XGBoostUpdater.this._booster.toByteArray();
                LOG.info((Object)("Initial Booster created, size=" + boosterBytes.length));
            } else {
                assert (XGBoostUpdater.this._booster != null);
                XGBoostUpdater.this._booster.update(XGBoostUpdater.this._trainMat, this._tid);
                XGBoostUpdater.this._evalMetric = this.computeEvalMetric();
                XGBoostUpdater.this._booster.saveRabitCheckpoint();
            }
            return XGBoostUpdater.this._booster.getBooster();
        }

        private EvalMetric computeEvalMetric() throws XGBoostError {
            if (XGBoostUpdater.this._evalMetricSpec == null) {
                return null;
            }
            String evalMetricVal = XGBoostUpdater.this._booster.evalSet(XGBoostUpdater.this._trainMat, XGBoostUpdater.this._validMat, this._tid);
            return XGBoostUpdater.this.parseEvalMetric(evalMetricVal);
        }

        public String toString() {
            return "Boosting Iteration (tid=" + this._tid + ")";
        }
    }
}

