/*
 * Decompiled with CFR 0.152.
 */
package com.spotify.zoltar.tf;

import com.google.auto.value.AutoValue;
import com.spotify.zoltar.Model;
import com.spotify.zoltar.fs.FileSystemExtras;
import com.spotify.zoltar.tf.AutoValue_TensorFlowModel;
import com.spotify.zoltar.tf.AutoValue_TensorFlowModel_Options;
import java.io.IOException;
import java.io.Serializable;
import java.net.URI;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.proto.framework.MetaGraphDef;
import org.tensorflow.proto.framework.SignatureDef;
import org.tensorflow.proto.framework.TensorInfo;

@AutoValue
public abstract class TensorFlowModel
implements Model<SavedModelBundle> {
    private static final Model.Id DEFAULT_ID = Model.Id.create((String)"tensorflow");
    private static final Options DEFAULT_OPTIONS = Options.builder().tags(Collections.singletonList("serve")).build();
    private static final String DEFAULT_SIGNATURE_DEF = "serving_default";

    public static TensorFlowModel create(URI modelResource) throws IOException {
        return TensorFlowModel.create(DEFAULT_ID, modelResource, DEFAULT_OPTIONS);
    }

    public static TensorFlowModel create(Model.Id id, URI modelResource) throws IOException {
        return TensorFlowModel.create(id, modelResource, DEFAULT_OPTIONS);
    }

    public static TensorFlowModel create(URI modelResource, Options options) throws IOException {
        return TensorFlowModel.create(DEFAULT_ID, modelResource, options);
    }

    public static TensorFlowModel create(Model.Id id, URI modelResource, Options options) throws IOException {
        return TensorFlowModel.create(id, modelResource, options, DEFAULT_SIGNATURE_DEF);
    }

    public static TensorFlowModel create(Model.Id id, URI modelResource, Options options, String signatureDefinition) throws IOException {
        URI normalizedUri = !"gs".equalsIgnoreCase(modelResource.getScheme()) || modelResource.toString().endsWith("/") ? modelResource : URI.create(modelResource.toString() + "/");
        URI localDir = FileSystemExtras.downloadIfNonLocal((URI)normalizedUri);
        SavedModelBundle model = SavedModelBundle.load((String)localDir.toString(), (String[])options.tags().toArray(new String[0]));
        SignatureDef signatureDef = model.metaGraphDef().getSignatureDefOrThrow(signatureDefinition);
        return new AutoValue_TensorFlowModel(id, model, options, model.metaGraphDef(), signatureDef, TensorFlowModel.toNameMap(signatureDef.getInputsMap()), TensorFlowModel.toNameMap(signatureDef.getOutputsMap()));
    }

    public void close() {
        if (this.instance() != null) {
            this.instance().close();
        }
    }

    public abstract SavedModelBundle instance();

    public abstract Options options();

    public abstract MetaGraphDef metaGraphDefinition();

    public abstract SignatureDef signatureDefinition();

    public abstract Map<String, String> inputsNameMap();

    public abstract Map<String, String> outputsNameMap();

    private static Map<String, String> toNameMap(Map<String, TensorInfo> infoMap) {
        return infoMap.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, p -> ((TensorInfo)p.getValue()).getName()));
    }

    @AutoValue
    public static abstract class Options
    implements Serializable {
        public abstract List<String> tags();

        public static Builder builder() {
            return new AutoValue_TensorFlowModel_Options.Builder();
        }

        @AutoValue.Builder
        public static abstract class Builder {
            public abstract Builder tags(List<String> var1);

            public abstract Options build();
        }
    }
}

