/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.listeners.checkpoint;

import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.Serializable;
import java.nio.charset.StandardCharsets;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Date;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.TimeUnit;
import lombok.NonNull;
import org.apache.commons.io.IOUtils;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.ListenerResponse;
import org.nd4j.autodiff.listeners.Loss;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.listeners.checkpoint.Checkpoint;
import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.shade.guava.io.Files;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CheckpointListener
extends BaseListener
implements Serializable {
    private static final Logger log = LoggerFactory.getLogger(CheckpointListener.class);
    private File rootDir;
    private String fileNamePrefix;
    private KeepMode keepMode;
    private int keepLast;
    private int keepEvery;
    private boolean logSaving;
    private boolean deleteExisting;
    private boolean saveUpdaterState;
    private Integer saveEveryNEpochs;
    private Integer saveEveryNIterations;
    private boolean saveEveryNIterSinceLast;
    private Long saveEveryAmount;
    private TimeUnit saveEveryUnit;
    private Long saveEveryMs;
    private boolean saveEverySinceLast;
    private int lastCheckpointNum = -1;
    private File checkpointRecordFile;
    private Checkpoint lastCheckpoint;
    private long startTime = -1L;
    private int startIter = -1;
    private Long lastSaveEveryMsNoSinceLast;

    private CheckpointListener(Builder builder) {
        this.rootDir = builder.rootDir;
        this.fileNamePrefix = builder.fileNamePrefix;
        this.keepMode = builder.keepMode;
        this.keepLast = builder.keepLast;
        this.keepEvery = builder.keepEvery;
        this.logSaving = builder.logSaving;
        this.deleteExisting = builder.deleteExisting;
        this.saveUpdaterState = builder.saveUpdaterState;
        this.saveEveryNEpochs = builder.saveEveryNEpochs;
        this.saveEveryNIterations = builder.saveEveryNIterations;
        this.saveEveryNIterSinceLast = builder.saveEveryNIterSinceLast;
        this.saveEveryAmount = builder.saveEveryAmount;
        this.saveEveryUnit = builder.saveEveryUnit;
        this.saveEverySinceLast = builder.saveEverySinceLast;
        if (this.saveEveryAmount != null) {
            this.saveEveryMs = TimeUnit.MILLISECONDS.convert(this.saveEveryAmount, this.saveEveryUnit);
        }
        if (!this.rootDir.exists()) {
            this.rootDir.mkdir();
        }
        this.checkpointRecordFile = new File(this.rootDir, "checkpointInfo.txt");
        if (this.checkpointRecordFile.exists() && this.checkpointRecordFile.length() > 0L) {
            if (this.deleteExisting) {
                this.checkpointRecordFile.delete();
                File[] files = this.rootDir.listFiles();
                if (files != null && files.length > 0) {
                    for (File f : files) {
                        String name = f.getName();
                        if (!name.startsWith("checkpoint_") || !name.endsWith("MultiLayerNetwork.zip") && !name.endsWith("ComputationGraph.zip")) continue;
                        f.delete();
                    }
                }
            } else {
                throw new IllegalStateException("Detected existing checkpoint files at directory " + this.rootDir.getAbsolutePath() + ". Use deleteExisting(true) to delete existing checkpoint files when present.");
            }
        }
    }

    @Override
    public ListenerResponse epochEnd(SameDiff sameDiff, At at, LossCurve lossCurve, long epochTimeMillis) {
        if (this.saveEveryNEpochs != null && (at.epoch() + 1) % this.saveEveryNEpochs == 0) {
            this.saveCheckpoint(sameDiff, at);
        }
        return ListenerResponse.CONTINUE;
    }

    @Override
    public boolean isActive(Operation operation) {
        return operation == Operation.TRAINING;
    }

    @Override
    public void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss) {
        if (this.startTime < 0L) {
            this.startTime = System.currentTimeMillis();
            this.startIter = at.iteration();
            return;
        }
        if (this.saveEveryNIterations != null) {
            if (this.saveEveryNIterSinceLast) {
                long lastSaveIter = this.lastCheckpoint != null ? this.lastCheckpoint.getIteration() : this.startIter;
                if ((long)at.iteration() - lastSaveIter >= (long)this.saveEveryNIterations.intValue()) {
                    this.saveCheckpoint(sd, at);
                    return;
                }
            } else if ((at.iteration() + 1) % this.saveEveryNIterations == 0) {
                this.saveCheckpoint(sd, at);
                return;
            }
        }
        long time = System.currentTimeMillis();
        if (this.saveEveryUnit != null) {
            if (this.saveEverySinceLast) {
                long lastSaveTime;
                long l = lastSaveTime = this.lastCheckpoint != null ? this.lastCheckpoint.getTimestamp() : this.startTime;
                if (time - lastSaveTime >= this.saveEveryMs) {
                    this.saveCheckpoint(sd, at);
                    return;
                }
            } else {
                long lastSave;
                long l = lastSave = this.lastSaveEveryMsNoSinceLast != null ? this.lastSaveEveryMsNoSinceLast : this.startTime;
                if (time - lastSave > this.saveEveryMs) {
                    this.saveCheckpoint(sd, at);
                    this.lastSaveEveryMsNoSinceLast = time;
                    return;
                }
            }
        }
    }

    private void saveCheckpoint(SameDiff sd, At at) {
        try {
            this.saveCheckpointHelper(sd, at);
        }
        catch (Exception e) {
            throw new RuntimeException("Error saving checkpoint", e);
        }
    }

    private void saveCheckpointHelper(SameDiff model, At at) throws Exception {
        if (!this.checkpointRecordFile.exists()) {
            this.checkpointRecordFile.createNewFile();
            CheckpointListener.writeCheckpointInfo(Checkpoint.getFileHeader() + "\n", this.checkpointRecordFile);
        }
        Checkpoint c = new Checkpoint(++this.lastCheckpointNum, System.currentTimeMillis(), at.iteration(), at.epoch(), null);
        String filename = this.getFileName(this.lastCheckpointNum, at, c.getTimestamp());
        c.setFilename(filename);
        File saveFile = new File(this.rootDir, c.getFilename());
        model.save(saveFile, this.saveUpdaterState);
        String s = c.toFileString();
        CheckpointListener.writeCheckpointInfo(s + "\n", this.checkpointRecordFile);
        if (this.logSaving) {
            log.info("Model checkpoint saved: epoch {}, iteration {}, path: {}", c.getEpoch(), c.getIteration(), new File(this.rootDir, c.getFilename()).getPath());
        }
        this.lastCheckpoint = c;
        if (this.keepMode == null || this.keepMode == KeepMode.ALL) {
            return;
        }
        if (this.keepMode == KeepMode.LAST) {
            List<Checkpoint> checkpoints = this.availableCheckpoints();
            Iterator<Checkpoint> iter = checkpoints.iterator();
            while (checkpoints.size() > this.keepLast) {
                Checkpoint toRemove = iter.next();
                File f = this.getFileForCheckpoint(toRemove);
                f.delete();
                iter.remove();
            }
        } else {
            for (Checkpoint cp : this.availableCheckpoints()) {
                if (cp.getCheckpointNum() > 0 && (cp.getCheckpointNum() + 1) % this.keepEvery == 0 || cp.getCheckpointNum() > this.lastCheckpointNum - this.keepLast) continue;
                File f = this.getFileForCheckpoint(cp);
                f.delete();
            }
        }
    }

    private String getFileName(int checkpointNum, At at, long time) {
        StringBuilder sb = new StringBuilder();
        if (this.fileNamePrefix != null) {
            sb.append(this.fileNamePrefix);
            if (!this.fileNamePrefix.endsWith("_")) {
                sb.append("_");
            }
        }
        sb.append("checkpoint-").append(checkpointNum).append("_epoch-").append(at.epoch()).append("_iter-").append(at.iteration());
        SimpleDateFormat sdf = new SimpleDateFormat("YYYY-MM-dd_HH-mm-ss");
        String date = sdf.format(new Date(time));
        sb.append("_").append(date).append(".bin");
        return sb.toString();
    }

    private static String writeCheckpointInfo(String str, File f) {
        try {
            if (!f.exists()) {
                f.createNewFile();
            }
            Files.append(str, f, StandardCharsets.UTF_8);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        return str;
    }

    public List<Checkpoint> availableCheckpoints() {
        if (!this.checkpointRecordFile.exists()) {
            return Collections.emptyList();
        }
        return CheckpointListener.availableCheckpoints(this.rootDir);
    }

    public static List<Checkpoint> availableCheckpoints(File directory) {
        List<String> lines;
        File checkpointRecordFile = new File(directory, "checkpointInfo.txt");
        Preconditions.checkState(checkpointRecordFile.exists(), "Could not find checkpoint record file at expected path %s", (Object)checkpointRecordFile.getAbsolutePath());
        try (BufferedInputStream is = new BufferedInputStream(new FileInputStream(checkpointRecordFile));){
            lines = IOUtils.readLines(is);
        }
        catch (IOException e) {
            throw new RuntimeException("Error loading checkpoint data from file: " + checkpointRecordFile.getAbsolutePath(), e);
        }
        ArrayList<Checkpoint> out = new ArrayList<Checkpoint>(lines.size() - 1);
        for (int i = 1; i < lines.size(); ++i) {
            Checkpoint c = Checkpoint.fromFileString(lines.get(i));
            if (!new File(directory, c.getFilename()).exists()) continue;
            out.add(c);
        }
        return out;
    }

    public Checkpoint lastCheckpoint() {
        if (!this.checkpointRecordFile.exists()) {
            return null;
        }
        return CheckpointListener.lastCheckpoint(this.rootDir);
    }

    public static Checkpoint lastCheckpoint(File rootDir) {
        List<Checkpoint> all = CheckpointListener.availableCheckpoints(rootDir);
        if (all.isEmpty()) {
            return null;
        }
        return all.get(all.size() - 1);
    }

    public File getFileForCheckpoint(Checkpoint checkpoint) {
        return this.getFileForCheckpoint(checkpoint.getCheckpointNum());
    }

    public File getFileForCheckpoint(int checkpointNum) {
        return CheckpointListener.getFileForCheckpoint(this.rootDir, checkpointNum);
    }

    public static File getFileForCheckpoint(File rootDir, int checkpointNum) {
        if (checkpointNum < 0) {
            throw new IllegalArgumentException("Invalid checkpoint number: " + checkpointNum);
        }
        String contains = "_checkpoint-" + checkpointNum + "_epoch-";
        File[] allFiles = rootDir.listFiles();
        if (allFiles != null) {
            for (File f : allFiles) {
                if (!f.getAbsolutePath().contains(contains)) continue;
                return f;
            }
        }
        throw new IllegalStateException("Model file for checkpoint " + checkpointNum + " does not exist");
    }

    public SameDiff loadCheckpoint(int checkpointNum, boolean loadUpdaterState) {
        return CheckpointListener.loadCheckpoint(this.rootDir, checkpointNum, loadUpdaterState);
    }

    public static SameDiff loadCheckpoint(File rootDir, int checkpointNum, boolean loadUpdaterState) {
        File f = CheckpointListener.getFileForCheckpoint(rootDir, checkpointNum);
        return SameDiff.load(f, loadUpdaterState);
    }

    public static SameDiff loadLastCheckpoint(File rootDir, boolean loadUpdaterState) {
        Checkpoint last = CheckpointListener.lastCheckpoint(rootDir);
        return CheckpointListener.loadCheckpoint(rootDir, last.getCheckpointNum(), loadUpdaterState);
    }

    public static Builder builder(@NonNull File rootDir) {
        if (rootDir == null) {
            throw new NullPointerException("rootDir is marked non-null but is null");
        }
        return new Builder(rootDir);
    }

    public static class Builder {
        private File rootDir;
        private String fileNamePrefix = "SameDiff";
        private KeepMode keepMode;
        private int keepLast;
        private int keepEvery;
        private boolean saveUpdaterState = true;
        private boolean logSaving = true;
        private boolean deleteExisting = false;
        private Integer saveEveryNEpochs;
        private Integer saveEveryNIterations;
        private boolean saveEveryNIterSinceLast;
        private Long saveEveryAmount;
        private TimeUnit saveEveryUnit;
        private boolean saveEverySinceLast;

        public Builder(@NonNull String rootDir) {
            this(new File(rootDir));
            if (rootDir == null) {
                throw new NullPointerException("rootDir is marked non-null but is null");
            }
        }

        public Builder(@NonNull File rootDir) {
            if (rootDir == null) {
                throw new NullPointerException("rootDir is marked non-null but is null");
            }
            this.rootDir = rootDir;
        }

        public Builder fileNamePrefix(String fileNamePrefix) {
            this.fileNamePrefix = fileNamePrefix;
            return this;
        }

        public Builder saveEveryEpoch() {
            return this.saveEveryNEpochs(1);
        }

        public Builder saveEveryNEpochs(int n) {
            this.saveEveryNEpochs = n;
            return this;
        }

        public Builder saveEveryNIterations(int n) {
            return this.saveEveryNIterations(n, false);
        }

        public Builder saveEveryNIterations(int n, boolean sinceLast) {
            this.saveEveryNIterations = n;
            this.saveEveryNIterSinceLast = sinceLast;
            return this;
        }

        public Builder saveEvery(long amount, TimeUnit timeUnit) {
            return this.saveEvery(amount, timeUnit, false);
        }

        public Builder saveEvery(long amount, TimeUnit timeUnit, boolean sinceLast) {
            this.saveEveryAmount = amount;
            this.saveEveryUnit = timeUnit;
            this.saveEverySinceLast = sinceLast;
            return this;
        }

        public Builder keepAll() {
            this.keepMode = KeepMode.ALL;
            return this;
        }

        public Builder keepLast(int n) {
            if (n <= 0) {
                throw new IllegalArgumentException("Number of model files to keep should be > 0 (got: " + n + ")");
            }
            this.keepMode = KeepMode.LAST;
            this.keepLast = n;
            return this;
        }

        public Builder keepLastAndEvery(int nLast, int everyN) {
            if (nLast <= 0) {
                throw new IllegalArgumentException("Most recent number of model files to keep should be > 0 (got: " + nLast + ")");
            }
            if (everyN <= 0) {
                throw new IllegalArgumentException("Every n model files to keep should be > 0 (got: " + everyN + ")");
            }
            this.keepMode = KeepMode.LAST_AND_EVERY;
            this.keepLast = nLast;
            this.keepEvery = everyN;
            return this;
        }

        public Builder logSaving(boolean logSaving) {
            this.logSaving = logSaving;
            return this;
        }

        public Builder saveUpdaterState(boolean saveUpdaterState) {
            this.saveUpdaterState = saveUpdaterState;
            return this;
        }

        public Builder deleteExisting(boolean deleteExisting) {
            this.deleteExisting = deleteExisting;
            return this;
        }

        public CheckpointListener build() {
            if (this.saveEveryNEpochs == null && this.saveEveryAmount == null && this.saveEveryNIterations == null) {
                throw new IllegalStateException("Cannot construct listener: no models will be saved (must use at least one of: save every N epochs, every N iterations, or every T time periods)");
            }
            return new CheckpointListener(this);
        }
    }

    private static enum KeepMode {
        ALL,
        LAST,
        LAST_AND_EVERY;

    }
}

