/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.java;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;

public class ExternalCheckpointManager {
    private Log logger = LogFactory.getLog((String)"ExternalCheckpointManager");
    private String modelSuffix = ".ubj";
    private Path checkpointPath;
    private FileSystem fs;

    public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError {
        if (checkpointPath == null || checkpointPath.isEmpty()) {
            throw new XGBoostError("cannot create ExternalCheckpointManager with null or empty checkpoint path");
        }
        this.checkpointPath = new Path(checkpointPath);
        this.fs = fs;
    }

    private String getPath(int version) {
        return this.checkpointPath.toUri().getPath() + "/" + version + this.modelSuffix;
    }

    private List<Integer> getExistingVersions() throws IOException {
        if (!this.fs.exists(this.checkpointPath)) {
            return new ArrayList<Integer>();
        }
        return Arrays.stream(this.fs.listStatus(this.checkpointPath)).map(path -> path.getPath().getName()).filter(fileName -> fileName.endsWith(this.modelSuffix)).map(fileName -> Integer.valueOf(fileName.substring(0, fileName.length() - this.modelSuffix.length()))).collect(Collectors.toList());
    }

    private Integer latest(List<Integer> versions) {
        return versions.stream().max(Comparator.comparing(Integer::valueOf)).get();
    }

    public void cleanPath() throws IOException {
        this.fs.delete(this.checkpointPath, true);
    }

    public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
        List<Integer> versions = this.getExistingVersions();
        if (versions.size() > 0) {
            int latestVersion = this.latest(versions);
            String checkpointPath = this.getPath(latestVersion);
            FSDataInputStream in = this.fs.open(new Path(checkpointPath));
            this.logger.info((Object)("loaded checkpoint from " + checkpointPath));
            Booster booster = XGBoost.loadModel((InputStream)in);
            return booster;
        }
        return null;
    }

    public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError {
        List prevModelPaths = this.getExistingVersions().stream().map(this::getPath).collect(Collectors.toList());
        Integer iter = boosterToCheckpoint.getNumBoostedRound() - 1;
        String eventualPath = this.getPath(iter);
        String tempPath = eventualPath + "-" + UUID.randomUUID();
        try (FSDataOutputStream out = this.fs.create(new Path(tempPath), true);){
            boosterToCheckpoint.saveModel((OutputStream)out);
            this.fs.rename(new Path(tempPath), new Path(eventualPath));
            this.logger.info((Object)("saving checkpoint with version " + iter));
            prevModelPaths.stream().forEach(path -> {
                try {
                    this.fs.delete(new Path(path), true);
                }
                catch (IOException e) {
                    this.logger.error((Object)("failed to delete outdated checkpoint at " + path), (Throwable)e);
                }
            });
        }
    }

    public void cleanUpHigherVersions(int currentRound) throws IOException {
        this.getExistingVersions().stream().filter(v -> v > currentRound).forEach(v -> {
            try {
                this.fs.delete(new Path(this.getPath((int)v)), true);
            }
            catch (IOException e) {
                this.logger.error((Object)"failed to clean checkpoint from other training instance", (Throwable)e);
            }
        });
    }

    public List<Integer> getCheckpointRounds(int firstRound, int checkpointInterval, int numOfRounds) throws IOException {
        int end = firstRound + numOfRounds;
        int lastRound = end - 1;
        if (end - 1 < 0) {
            throw new IllegalArgumentException("Inavlid `numOfRounds`.");
        }
        ArrayList<Integer> arr = new ArrayList<Integer>();
        if (checkpointInterval > 0) {
            for (int i = firstRound; i < end; i += checkpointInterval) {
                arr.add(i);
            }
        }
        if (!arr.contains(lastRound)) {
            arr.add(lastRound);
        }
        return arr;
    }
}

