/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import org.neo4j.gds.AlgoBaseProc;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.AlgorithmFactory;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.config.GraphProjectConfig;
import org.neo4j.gds.config.ToMapConvertible;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.executor.AlgorithmSpec;
import org.neo4j.gds.executor.ComputationResult;
import org.neo4j.gds.executor.ComputationResultConsumer;
import org.neo4j.gds.executor.validation.BeforeLoadValidation;
import org.neo4j.gds.executor.validation.ValidationConfiguration;
import org.neo4j.gds.model.ModelConfig;

public abstract class TrainProc<ALGO extends Algorithm<ALGO_RESULT>, ALGO_RESULT, TRAIN_CONFIG extends AlgoBaseConfig & ModelConfig, PROC_RESULT>
extends AlgoBaseProc<ALGO, ALGO_RESULT, TRAIN_CONFIG, PROC_RESULT> {
    protected abstract String modelType();

    protected abstract PROC_RESULT constructProcResult(ComputationResult<ALGO, ALGO_RESULT, TRAIN_CONFIG> var1);

    protected abstract Model<?, ?, ?> extractModel(ALGO_RESULT var1);

    public ComputationResultConsumer<ALGO, ALGO_RESULT, TRAIN_CONFIG, Stream<PROC_RESULT>> computationResultConsumer() {
        return (computationResult, executionContext) -> {
            this.modelCatalog().set(this.extractModel(computationResult.result()));
            return Stream.of(this.constructProcResult(computationResult));
        };
    }

    protected Stream<PROC_RESULT> trainAndStoreModelWithResult(ComputationResult<ALGO, ALGO_RESULT, TRAIN_CONFIG> computationResult) {
        return (Stream)this.computationResultConsumer().consume(computationResult, this.executionContext());
    }

    @Override
    public ValidationConfiguration<TRAIN_CONFIG> validationConfig() {
        return new ValidationConfiguration<TRAIN_CONFIG>(){

            public List<BeforeLoadValidation<TRAIN_CONFIG>> beforeLoadValidations() {
                return List.of(new TrainingConfigValidation(TrainProc.this.modelCatalog(), TrainProc.this.username(), TrainProc.this.modelType()));
            }
        };
    }

    public AlgorithmSpec<ALGO, ALGO_RESULT, TRAIN_CONFIG, Stream<PROC_RESULT>, AlgorithmFactory<?, ALGO, TRAIN_CONFIG>> withModelCatalog(ModelCatalog modelCatalog) {
        this.setModelCatalog(modelCatalog);
        return this;
    }

    public static class TrainResult {
        public final Map<String, Object> modelInfo;
        public final Map<String, Object> configuration;
        public final long trainMillis;

        public <TRAIN_RESULT, TRAIN_CONFIG extends ModelConfig & AlgoBaseConfig, TRAIN_INFO extends ToMapConvertible> TrainResult(Model<TRAIN_RESULT, TRAIN_CONFIG, TRAIN_INFO> trainedModel, long trainMillis, long nodeCount, long relationshipCount) {
            ModelConfig trainConfig = trainedModel.trainConfig();
            this.modelInfo = new HashMap<String, Object>();
            this.modelInfo.put("modelName", trainedModel.name());
            this.modelInfo.put("modelType", trainedModel.algoType());
            this.modelInfo.putAll(trainedModel.customInfo().toMap());
            this.configuration = trainConfig.toMap();
            this.trainMillis = trainMillis;
        }
    }

    public static class TrainingConfigValidation<TRAIN_CONFIG extends ModelConfig & AlgoBaseConfig>
    implements BeforeLoadValidation<TRAIN_CONFIG> {
        private final ModelCatalog modelCatalog;
        private final String username;
        private final String modelType;

        public TrainingConfigValidation(ModelCatalog modelCatalog, String username, String modelType) {
            this.modelCatalog = modelCatalog;
            this.username = username;
            this.modelType = modelType;
        }

        public void validateConfigsBeforeLoad(GraphProjectConfig graphProjectConfig, TRAIN_CONFIG config) {
            this.modelCatalog.verifyModelCanBeStored(this.username, config.modelName(), this.modelType);
        }
    }
}

