/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.pytorch.engine;

import ai.djl.BaseModel;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.ndarray.types.DataType;
import ai.djl.pytorch.engine.PtNDManager;
import ai.djl.pytorch.jni.JniUtils;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.initializer.Initializer;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import ai.djl.util.Utils;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.FileVisitOption;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Collectors;

public class PtModel
extends BaseModel {
    PtModel(String name, Device device) {
        super(name);
        this.manager = PtNDManager.getSystemManager().newSubManager(device);
        this.manager.setName("ptModel");
        this.dataType = DataType.FLOAT32;
    }

    public void load(Path modelPath, String prefix, Map<String, ?> options) throws IOException, MalformedModelException {
        this.setModelDir(modelPath);
        this.wasLoaded = true;
        if (prefix == null) {
            prefix = this.modelName;
        }
        if (this.block == null) {
            Path modelFile = this.findModelFile(prefix, this.modelDir.toFile().getName(), "model.pt");
            if (modelFile == null) {
                String fileName = prefix.endsWith(".pt") ? prefix : prefix + ".pt";
                throw new FileNotFoundException(fileName + " file not found in: " + this.modelDir);
            }
            String[] extraFileKeys = Utils.EMPTY_ARRAY;
            String[] extraFileValues = Utils.EMPTY_ARRAY;
            boolean mapLocation = false;
            boolean trainParam = false;
            if (options != null) {
                if (options.containsKey("extraFiles")) {
                    extraFileKeys = ((String)options.get("extraFiles")).split(",");
                    extraFileValues = new String[extraFileKeys.length];
                }
                trainParam = Boolean.parseBoolean((String)options.get("trainParam"));
                mapLocation = Boolean.parseBoolean((String)options.get("mapLocation"));
            }
            this.block = JniUtils.loadModule((PtNDManager)this.manager, modelFile, mapLocation, extraFileKeys, extraFileValues, trainParam);
            for (int i = 0; i < extraFileKeys.length; ++i) {
                this.properties.put(extraFileKeys[i], extraFileValues[i]);
            }
            this.block.freezeParameters(!trainParam);
        } else {
            String paramOption;
            boolean hasParameter = true;
            if (options != null && (paramOption = (String)options.get("hasParameter")) != null) {
                hasParameter = Boolean.parseBoolean(paramOption);
            }
            if (hasParameter) {
                Path paramFile = this.paramPathResolver(prefix, options);
                if (paramFile == null) {
                    throw new IOException("Parameter file not found in: " + this.modelDir + ". If you only specified model path, make sure path name match your saved model file name.");
                }
                this.readParameters(paramFile, options);
            }
        }
    }

    public void load(InputStream modelStream, Map<String, ?> options) throws IOException {
        boolean mapLocation = false;
        if (options != null) {
            mapLocation = Boolean.parseBoolean((String)options.get("mapLocation"));
        }
        this.load(modelStream, mapLocation);
    }

    public void load(InputStream modelStream, boolean mapLocation) throws IOException {
        this.modelDir = Files.createTempDirectory("pt-model", new FileAttribute[0]);
        this.modelDir.toFile().deleteOnExit();
        this.block = JniUtils.loadModule((PtNDManager)this.manager, modelStream, mapLocation, false);
    }

    private Path findModelFile(String ... prefixes) {
        if (Files.isRegularFile(this.modelDir, new LinkOption[0])) {
            Path file = this.modelDir;
            this.modelDir = this.modelDir.getParent();
            String fileName = file.toFile().getName();
            this.modelName = fileName.endsWith(".pt") ? fileName.substring(0, fileName.length() - 3) : fileName;
            return file;
        }
        for (String prefix : prefixes) {
            Path modelFile = this.modelDir.resolve(prefix);
            if (Files.isRegularFile(modelFile, new LinkOption[0])) {
                return modelFile;
            }
            if (prefix.endsWith(".pt") || !Files.isRegularFile(modelFile = this.modelDir.resolve(prefix + ".pt"), new LinkOption[0])) continue;
            return modelFile;
        }
        return null;
    }

    public Trainer newTrainer(TrainingConfig trainingConfig) {
        PairList initializer = trainingConfig.getInitializers();
        if (this.block == null) {
            throw new IllegalStateException("You must set a block for the model before creating a new trainer");
        }
        if (this.wasLoaded) {
            this.block.freezeParameters(false);
        }
        for (Pair pair : initializer) {
            if (pair.getKey() == null || pair.getValue() == null) continue;
            this.block.setInitializer((Initializer)pair.getKey(), (Predicate)pair.getValue());
        }
        return new Trainer((Model)this, trainingConfig);
    }

    public String[] getArtifactNames() {
        try {
            List files = Files.walk(this.modelDir, new FileVisitOption[0]).filter(x$0 -> Files.isRegularFile(x$0, new LinkOption[0])).collect(Collectors.toList());
            ArrayList<String> ret = new ArrayList<String>(files.size());
            for (Path path : files) {
                String fileName = path.toFile().getName();
                if (fileName.endsWith(".pt")) continue;
                Path relative = this.modelDir.relativize(path);
                ret.add(relative.toString());
            }
            return ret.toArray(Utils.EMPTY_ARRAY);
        }
        catch (IOException e) {
            throw new AssertionError("Failed list files", e);
        }
    }
}

