/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.xgboost4j.java;

import ai.h2o.xgboost4j.java.Booster;
import ai.h2o.xgboost4j.java.XGBoost;
import ai.h2o.xgboost4j.java.XGBoostError;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;

public class ExternalCheckpointManager {
    private Log logger = LogFactory.getLog((String)"ExternalCheckpointManager");
    private String modelSuffix = ".model";
    private Path checkpointPath;
    private FileSystem fs;

    public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError {
        if (checkpointPath == null || checkpointPath.isEmpty()) {
            throw new XGBoostError("cannot create ExternalCheckpointManager with null or empty checkpoint path");
        }
        this.checkpointPath = new Path(checkpointPath);
        this.fs = fs;
    }

    private String getPath(int version) {
        return this.checkpointPath.toUri().getPath() + "/" + version + this.modelSuffix;
    }

    private List<Integer> getExistingVersions() throws IOException {
        if (!this.fs.exists(this.checkpointPath)) {
            return new ArrayList<Integer>();
        }
        return Arrays.stream(this.fs.listStatus(this.checkpointPath)).map(path -> path.getPath().getName()).filter(fileName -> fileName.endsWith(this.modelSuffix)).map(fileName -> Integer.valueOf(fileName.substring(0, fileName.length() - this.modelSuffix.length()))).collect(Collectors.toList());
    }

    public void cleanPath() throws IOException {
        this.fs.delete(this.checkpointPath, true);
    }

    public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
        List<Integer> versions = this.getExistingVersions();
        if (versions.size() > 0) {
            int latestVersion = versions.stream().max(Comparator.comparing(Integer::valueOf)).get();
            String checkpointPath = this.getPath(latestVersion);
            FSDataInputStream in = this.fs.open(new Path(checkpointPath));
            this.logger.info((Object)("loaded checkpoint from " + checkpointPath));
            Booster booster = XGBoost.loadModel((InputStream)in);
            booster.setVersion(latestVersion);
            return booster;
        }
        return null;
    }

    public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError {
        List prevModelPaths = this.getExistingVersions().stream().map(this::getPath).collect(Collectors.toList());
        String eventualPath = this.getPath(boosterToCheckpoint.getVersion());
        String tempPath = eventualPath + "-" + UUID.randomUUID();
        try (FSDataOutputStream out = this.fs.create(new Path(tempPath), true);){
            boosterToCheckpoint.saveModel((OutputStream)out);
            this.fs.rename(new Path(tempPath), new Path(eventualPath));
            this.logger.info((Object)("saving checkpoint with version " + boosterToCheckpoint.getVersion()));
            prevModelPaths.stream().forEach(path -> {
                try {
                    this.fs.delete(new Path(path), true);
                }
                catch (IOException e) {
                    this.logger.error((Object)("failed to delete outdated checkpoint at " + path), (Throwable)e);
                }
            });
        }
    }

    public void cleanUpHigherVersions(int currentRound) throws IOException {
        this.getExistingVersions().stream().filter(v -> v / 2 >= currentRound).forEach(v -> {
            try {
                this.fs.delete(new Path(this.getPath((int)v)), true);
            }
            catch (IOException e) {
                this.logger.error((Object)"failed to clean checkpoint from other training instance", (Throwable)e);
            }
        });
    }

    public List<Integer> getCheckpointRounds(int checkpointInterval, int numOfRounds) throws IOException {
        if (checkpointInterval > 0) {
            List prevRounds = this.getExistingVersions().stream().map(v -> v / 2).collect(Collectors.toList());
            prevRounds.add(0);
            int firstCheckpointRound = prevRounds.stream().max(Comparator.comparing(Integer::valueOf)).get() + checkpointInterval;
            ArrayList<Integer> arr = new ArrayList<Integer>();
            for (int i = firstCheckpointRound; i <= numOfRounds; i += checkpointInterval) {
                arr.add(i);
            }
            arr.add(numOfRounds);
            return arr;
        }
        if (checkpointInterval <= 0) {
            ArrayList<Integer> l = new ArrayList<Integer>();
            l.add(numOfRounds);
            return l;
        }
        throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.");
    }
}

