/*
 * Decompiled with CFR 0.152.
 */
package oracle.pgx.api.mllib;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import oracle.pgx.api.PgxEdge;
import oracle.pgx.api.PgxFuture;
import oracle.pgx.api.PgxGraph;
import oracle.pgx.api.PgxSession;
import oracle.pgx.api.frames.PgxFrame;
import oracle.pgx.api.internal.Core;
import oracle.pgx.api.internal.Graph;
import oracle.pgx.api.internal.mllib.EdgeWiseModelMetadata;
import oracle.pgx.api.mllib.Model;
import oracle.pgx.config.mllib.EdgeWiseModelConfig;
import oracle.pgx.config.mllib.GraphWiseConvLayerConfig;
import oracle.pgx.config.mllib.edgecombination.EdgeCombinationMethod;

public abstract class EdgeWiseModel<Config extends EdgeWiseModelConfig, Metadata extends EdgeWiseModelMetadata<Config>, ModelType extends EdgeWiseModel<Config, Metadata, ModelType>>
extends Model<ModelType> {
    Metadata modelMetadata;
    protected final BiFunction<PgxSession, Graph, PgxGraph> graphConstructor;

    public EdgeWiseModel(PgxSession session, Core core, Supplier<String> keystorePathSupplier, Supplier<char[]> keystorePasswordSupplier, Metadata modelMetadata, BiFunction<PgxSession, Graph, PgxGraph> graphConstructor) {
        super(session, core, keystorePathSupplier, keystorePasswordSupplier);
        this.modelMetadata = modelMetadata;
        this.graphConstructor = graphConstructor;
    }

    @Override
    String getModelName() {
        return ((EdgeWiseModelMetadata)((Object)this.modelMetadata)).getModelName();
    }

    @Override
    public PgxFuture<Void> destroyAsync() {
        return this.core.destroyMlModel(this.session.getSessionContext(), ((EdgeWiseModelMetadata)((Object)this.modelMetadata)).getModelName());
    }

    public void destroy() throws ExecutionException, InterruptedException {
        this.destroyAsync().get();
    }

    public int getNumEpochs() {
        return ((EdgeWiseModelMetadata)((Object)this.modelMetadata)).getConfig().getNumEpochs();
    }

    public double getLearningRate() {
        return ((EdgeWiseModelMetadata)((Object)this.modelMetadata)).getConfig().getLearningRate();
    }

    public int getBatchSize() {
        return this.getConfig().getBatchSize();
    }

    public int getEmbeddingDim() {
        return this.getConfig().getEmbeddingDim();
    }

    public int getSeed() {
        return this.getConfig().getSeed();
    }

    public GraphWiseConvLayerConfig[] getConvLayerConfigs() {
        return this.getConfig().getConvLayerConfigs();
    }

    public List<String> getVertexInputPropertyNames() {
        return this.getConfig().getVertexInputPropertyNames();
    }

    public List<String> getEdgeInputPropertyNames() {
        return this.getConfig().getEdgeInputPropertyNames();
    }

    public boolean isFitted() {
        return this.getConfig().isFitted();
    }

    public double getTrainingLoss() {
        return this.getConfig().getTrainingLoss();
    }

    public int getInputFeatureDim() {
        return this.getConfig().getInputFeatureDim();
    }

    public int getEdgeInputFeatureDim() {
        return this.getConfig().getEdgeInputFeatureDim();
    }

    public EdgeCombinationMethod getEdgeCombinationMethod() {
        return this.getConfig().getEdgeCombinationMethod();
    }

    public Config getConfig() {
        return (Config)((EdgeWiseModelMetadata)((Object)this.modelMetadata)).getConfig();
    }

    public abstract PgxFuture<Double> fitAsync(PgxGraph var1);

    public double fit(PgxGraph graph) throws ExecutionException, InterruptedException {
        return this.fitAsync(graph).get();
    }

    public abstract PgxFuture<PgxFrame> inferEmbeddingsAsync(PgxGraph var1, Iterable<PgxEdge> var2);

    public PgxFrame inferEmbeddings(PgxGraph graph, Iterable<PgxEdge> edges) {
        return this.inferEmbeddingsAsync(graph, edges).join();
    }

    protected List<Object> serializeEdges(Iterable<PgxEdge> edges) {
        ArrayList<Object> serializedEdges = new ArrayList<Object>();
        edges.forEach(e -> serializedEdges.add(e.serialize()));
        return serializedEdges;
    }
}

