/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.serving.wlm;

import ai.djl.inference.Predictor;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.serving.wlm.BatchAggregator;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.WorkerState;
import ai.djl.translate.TranslateException;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class WorkerThread
implements Runnable {
    private static final Logger logger = LoggerFactory.getLogger(WorkerThread.class);
    private static final AtomicInteger WORKER_COUNTER = new AtomicInteger(1);
    private ModelInfo model;
    private Predictor<Input, Output> predictor;
    private AtomicBoolean running = new AtomicBoolean(true);
    private BatchAggregator aggregator;
    private int gpuId;
    private AtomicReference<Thread> currentThread = new AtomicReference();
    private WorkerState state;
    private int workerId;
    private long startTime;

    public WorkerThread(int gpuId, ModelInfo model, BatchAggregator aggregator) {
        this.model = model;
        this.aggregator = aggregator;
        this.gpuId = gpuId;
        this.workerId = WORKER_COUNTER.getAndIncrement();
        this.startTime = System.currentTimeMillis();
        this.predictor = model.getModel().newPredictor();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void run() {
        Thread thread = Thread.currentThread();
        thread.setName(this.getWorkerName());
        this.currentThread.set(thread);
        List<Input> req = null;
        try {
            while (this.isRunning()) {
                req = this.aggregator.getRequest();
                try {
                    List reply = this.predictor.batchPredict(req);
                    this.aggregator.sendResponse(reply);
                }
                catch (TranslateException e) {
                    logger.warn("Failed to predict", (Throwable)e);
                    this.aggregator.sendError();
                }
                req = null;
            }
        }
        catch (InterruptedException e) {
            logger.debug("Shutting down the thread .. Scaling down.");
        }
        catch (Throwable t) {
            logger.error("Server error", t);
        }
        finally {
            this.currentThread.set(null);
            this.shutdown(WorkerState.WORKER_STOPPED);
            if (req != null) {
                this.aggregator.sendError();
            }
        }
    }

    public int getWorkerId() {
        return this.workerId;
    }

    public boolean isRunning() {
        return this.running.get();
    }

    public int getGpuId() {
        return this.gpuId;
    }

    public long getStartTime() {
        return this.startTime;
    }

    public WorkerState getState() {
        return this.state;
    }

    public void shutdown(WorkerState state) {
        this.running.set(false);
        this.setState(state);
        Thread thread = this.currentThread.getAndSet(null);
        if (thread != null) {
            thread.interrupt();
            this.aggregator.sendError();
        }
        this.predictor.close();
    }

    private String getWorkerName() {
        String modelName = this.model.getModelName();
        if (modelName.length() > 25) {
            modelName = modelName.substring(0, 25);
        }
        return "W-" + modelName + '-' + this.workerId;
    }

    void setState(WorkerState newState) {
        logger.debug("{} State change {} -> {}", new Object[]{this.getWorkerName(), this.state, newState});
        if (this.state != WorkerState.WORKER_SCALED_DOWN) {
            this.state = newState;
        }
    }
}

