/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.java;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;

public class ExternalCheckpointManager {
    private Log logger = LogFactory.getLog((String)"ExternalCheckpointManager");
    private String modelSuffix = ".model";
    private Path checkpointPath;
    private FileSystem fs;

    public ExternalCheckpointManager(String string, FileSystem fileSystem) throws XGBoostError {
        if (string == null || string.isEmpty()) {
            throw new XGBoostError("cannot create ExternalCheckpointManager with null or empty checkpoint path");
        }
        this.checkpointPath = new Path(string);
        this.fs = fileSystem;
    }

    private String getPath(int n) {
        return this.checkpointPath.toUri().getPath() + "/" + n + this.modelSuffix;
    }

    private List<Integer> getExistingVersions() throws IOException {
        if (!this.fs.exists(this.checkpointPath)) {
            return new ArrayList<Integer>();
        }
        return Arrays.stream(this.fs.listStatus(this.checkpointPath)).map(fileStatus -> fileStatus.getPath().getName()).filter(string -> string.endsWith(this.modelSuffix)).map(string -> Integer.valueOf(string.substring(0, string.length() - this.modelSuffix.length()))).collect(Collectors.toList());
    }

    public void cleanPath() throws IOException {
        this.fs.delete(this.checkpointPath, true);
    }

    public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
        List<Integer> list = this.getExistingVersions();
        if (list.size() > 0) {
            int n = list.stream().max(Comparator.comparing(Integer::valueOf)).get();
            String string = this.getPath(n);
            FSDataInputStream fSDataInputStream = this.fs.open(new Path(string));
            this.logger.info((Object)("loaded checkpoint from " + string));
            Booster booster = XGBoost.loadModel((InputStream)fSDataInputStream);
            booster.setVersion(n);
            return booster;
        }
        return null;
    }

    public void updateCheckpoint(Booster booster) throws IOException, XGBoostError {
        List list = this.getExistingVersions().stream().map(this::getPath).collect(Collectors.toList());
        String string2 = this.getPath(booster.getVersion());
        String string3 = string2 + "-" + UUID.randomUUID();
        try (FSDataOutputStream fSDataOutputStream = this.fs.create(new Path(string3), true);){
            booster.saveModel((OutputStream)fSDataOutputStream);
            this.fs.rename(new Path(string3), new Path(string2));
            this.logger.info((Object)("saving checkpoint with version " + booster.getVersion()));
            list.stream().forEach(string -> {
                try {
                    this.fs.delete(new Path(string), true);
                }
                catch (IOException iOException) {
                    this.logger.error((Object)("failed to delete outdated checkpoint at " + string), (Throwable)iOException);
                }
            });
        }
    }

    public void cleanUpHigherVersions(int n3) throws IOException {
        this.getExistingVersions().stream().filter(n2 -> n2 / 2 >= n3).forEach(n -> {
            try {
                this.fs.delete(new Path(this.getPath((int)n)), true);
            }
            catch (IOException iOException) {
                this.logger.error((Object)"failed to clean checkpoint from other training instance", (Throwable)iOException);
            }
        });
    }

    public List<Integer> getCheckpointRounds(int n2, int n3) throws IOException {
        if (n2 > 0) {
            List list = this.getExistingVersions().stream().map(n -> n / 2).collect(Collectors.toList());
            list.add(0);
            int n4 = list.stream().max(Comparator.comparing(Integer::valueOf)).get() + n2;
            ArrayList<Integer> arrayList = new ArrayList<Integer>();
            for (int i = n4; i <= n3; i += n2) {
                arrayList.add(i);
            }
            arrayList.add(n3);
            return arrayList;
        }
        if (n2 <= 0) {
            ArrayList<Integer> arrayList = new ArrayList<Integer>();
            arrayList.add(n3);
            return arrayList;
        }
        throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.");
    }
}

