/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.serving.wlm;

import ai.djl.serving.util.ConfigManager;
import ai.djl.serving.wlm.BatchAggregator;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.WorkerState;
import ai.djl.serving.wlm.WorkerThread;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;

class WorkLoadManager {
    private ConfigManager configManager;
    private AtomicInteger gpuCounter;
    private ExecutorService threadPool;
    private ConcurrentHashMap<String, List<WorkerThread>> workers;

    public WorkLoadManager(ConfigManager configManager) {
        this.configManager = configManager;
        this.gpuCounter = new AtomicInteger(0);
        this.threadPool = Executors.newCachedThreadPool();
        this.workers = new ConcurrentHashMap();
    }

    public List<WorkerThread> getWorkers(String modelName) {
        List<WorkerThread> list = this.workers.get(modelName);
        if (list == null) {
            return Collections.emptyList();
        }
        return list;
    }

    public boolean hasWorker(String modelName) {
        List<WorkerThread> worker = this.workers.get(modelName);
        if (worker == null || worker.isEmpty()) {
            return false;
        }
        for (WorkerThread thread : worker) {
            if (!thread.isRunning()) continue;
            return true;
        }
        return false;
    }

    public int getNumRunningWorkers(String modelName) {
        int numWorking = 0;
        List<WorkerThread> threads = this.workers.get(modelName);
        if (threads != null) {
            for (WorkerThread thread : threads) {
                if (thread.getState() == WorkerState.WORKER_STOPPED || thread.getState() == WorkerState.WORKER_ERROR || thread.getState() == WorkerState.WORKER_SCALED_DOWN) continue;
                ++numWorking;
            }
        }
        return numWorking;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void modelChanged(ModelInfo modelInfo) {
        String string = modelInfo.getModelName();
        synchronized (string) {
            int currentWorkers;
            List threads;
            int minWorker = modelInfo.getMinWorkers();
            int maxWorker = modelInfo.getMaxWorkers();
            if (minWorker == 0) {
                threads = this.workers.remove(modelInfo.getModelName());
                if (threads == null) {
                    return;
                }
            } else {
                threads = this.workers.computeIfAbsent(modelInfo.getModelName(), k -> new ArrayList());
            }
            if ((currentWorkers = threads.size()) < minWorker) {
                this.addThreads(threads, modelInfo, minWorker - currentWorkers);
            } else {
                for (int i = currentWorkers - 1; i >= maxWorker; --i) {
                    WorkerThread thread = (WorkerThread)threads.remove(i);
                    thread.shutdown(WorkerState.WORKER_SCALED_DOWN);
                }
            }
        }
    }

    public void scheduleAsync(Runnable r) {
        this.threadPool.execute(r);
    }

    private void addThreads(List<WorkerThread> threads, ModelInfo model, int count) {
        int maxGpu = this.configManager.getNumberOfGpu();
        for (int i = 0; i < count; ++i) {
            int gpuId = -1;
            if (maxGpu > 0) {
                gpuId = this.gpuCounter.accumulateAndGet(maxGpu, (prev, maxGpuId) -> ++prev % maxGpuId);
            }
            BatchAggregator aggregator = new BatchAggregator(model);
            WorkerThread thread = new WorkerThread(gpuId, model, aggregator);
            threads.add(thread);
            this.threadPool.submit(thread);
        }
    }
}

