/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.vertexai;

import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.internal.Json;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.image.ImageModel;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.model.vertexai.spi.VertexAiImageModelBuilderFactory;
import dev.langchain4j.spi.ServiceHelper;
import java.io.IOException;
import java.net.URI;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;

public class VertexAiImageModel
implements ImageModel {
    private final Long seed;
    private final String endpoint;
    private final EndpointName endpointName;
    private final String language;
    private final Integer guidanceScale;
    private final String negativePrompt;
    private final ImageStyle sampleImageStyle;
    private final Integer sampleImageSize;
    private final int maxRetries;
    private final Boolean withPersisting;
    private Path tempDirectory;

    public VertexAiImageModel(String endpoint, String project, String location, String publisher, String modelName, Long seed, String language, Integer guidanceScale, String negativePrompt, ImageStyle sampleImageStyle, Integer sampleImageSize, Integer maxRetries, Boolean withPersisting, Path persistTo) {
        this.endpoint = ValidationUtils.ensureNotBlank((String)endpoint, (String)"endpoint");
        this.endpointName = EndpointName.ofProjectLocationPublisherModelName((String)ValidationUtils.ensureNotBlank((String)project, (String)"project"), (String)ValidationUtils.ensureNotBlank((String)location, (String)"location"), (String)ValidationUtils.ensureNotBlank((String)publisher, (String)"publisher"), (String)ValidationUtils.ensureNotBlank((String)modelName, (String)"modelName"));
        this.seed = seed == null ? null : Long.valueOf(ValidationUtils.ensureBetween((Long)seed, (long)0L, (long)0xFFFFFFFFL, (String)"seed"));
        this.language = language;
        this.guidanceScale = guidanceScale;
        this.negativePrompt = negativePrompt;
        this.sampleImageStyle = sampleImageStyle;
        this.sampleImageSize = sampleImageSize;
        this.maxRetries = maxRetries == null ? 3 : maxRetries;
        this.withPersisting = withPersisting;
        if (this.withPersisting != null && this.withPersisting.booleanValue()) {
            try {
                if (persistTo != null) {
                    if (!persistTo.toFile().exists() && !persistTo.toFile().mkdirs()) {
                        throw new IOException("Impossible to create persistTo temporary directory");
                    }
                    this.tempDirectory = persistTo;
                } else {
                    this.tempDirectory = Files.createTempDirectory("imagen-directory-", new FileAttribute[0]);
                }
            }
            catch (IOException e) {
                throw new RuntimeException("Impossible to create persistence temporary directory", e);
            }
        }
    }

    public Response<Image> generate(String prompt) {
        Response<List<Image>> generatedImageResponse = this.generate(prompt, 1);
        return Response.from((Object)((Image)((List)generatedImageResponse.content()).get(0)), (TokenUsage)generatedImageResponse.tokenUsage(), (FinishReason)generatedImageResponse.finishReason());
    }

    public Response<List<Image>> generate(String prompt, int n) {
        return this.generate(prompt, null, null, n);
    }

    private Response<List<Image>> generate(String prompt, Image image, Image mask, int n) {
        Response response;
        block8: {
            PredictionServiceSettings serviceSettings = ((PredictionServiceSettings.Builder)PredictionServiceSettings.newBuilder().setEndpoint(this.endpoint)).build();
            PredictionServiceClient client = PredictionServiceClient.create((PredictionServiceSettings)serviceSettings);
            try {
                List<Value> instances = this.prepareInstance(prompt, image, mask);
                Value parameters = this.prepareParameters(n);
                PredictResponse predictResponse = (PredictResponse)RetryUtils.withRetry(() -> client.predict(this.endpointName, instances, parameters), (int)this.maxRetries);
                List allImages = predictResponse.getPredictionsList().stream().map(v -> {
                    String bytesBase64Encoded = ((Value)v.getStructValue().getFieldsMap().get("bytesBase64Encoded")).getStringValue();
                    return Image.builder().base64Data(bytesBase64Encoded).url(this.persistAndGetURI(bytesBase64Encoded)).build();
                }).collect(Collectors.toList());
                response = Response.from(allImages);
                if (client == null) break block8;
            }
            catch (Throwable throwable) {
                try {
                    if (client != null) {
                        try {
                            client.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (IOException e) {
                    throw new RuntimeException(e);
                }
            }
            client.close();
        }
        return response;
    }

    private Value prepareParameters(int n) throws InvalidProtocolBufferException {
        HashMap<String, Object> paramsMap = new HashMap<String, Object>();
        paramsMap.put("sampleCount", n);
        if (this.seed != null) {
            paramsMap.put("seed", this.seed);
        }
        if (this.sampleImageStyle != null) {
            paramsMap.put("sampleImageStyle", this.sampleImageStyle.name());
        }
        if (this.sampleImageSize != null) {
            paramsMap.put("mode", "upscale");
            paramsMap.put("sampleImageSize", this.sampleImageSize.toString());
        }
        if (this.guidanceScale != null) {
            paramsMap.put("guidanceScale", this.guidanceScale);
        }
        if (this.negativePrompt != null) {
            paramsMap.put("negativePrompt", this.negativePrompt);
        }
        if (this.language != null) {
            paramsMap.put("language", this.language);
        }
        Value.Builder parametersBuilder = Value.newBuilder();
        JsonFormat.parser().merge(Json.toJson(paramsMap), (Message.Builder)parametersBuilder);
        return parametersBuilder.build();
    }

    private List<Value> prepareInstance(String prompt, Image image, Image mask) throws InvalidProtocolBufferException {
        HashMap<String, String> imageMap;
        HashMap<String, Object> promptMap = new HashMap<String, Object>();
        promptMap.put("prompt", prompt);
        if (image != null && image.base64Data() != null) {
            imageMap = new HashMap<String, String>();
            imageMap.put("bytesBase64Encoded", image.base64Data());
            promptMap.put("image", imageMap);
        }
        if (mask != null && mask.base64Data() != null) {
            imageMap = new HashMap();
            imageMap.put("bytesBase64Encoded", mask.base64Data());
            HashMap<String, HashMap<String, String>> maskMap = new HashMap<String, HashMap<String, String>>();
            maskMap.put("image", imageMap);
            promptMap.put("mask", maskMap);
        }
        Value.Builder instanceBuilder = Value.newBuilder();
        JsonFormat.parser().merge(Json.toJson(promptMap), (Message.Builder)instanceBuilder);
        return Collections.singletonList(instanceBuilder.build());
    }

    public Response<Image> edit(Image image, String prompt) {
        Response<Image> generatedImageResponse = this.edit(image, null, prompt);
        return Response.from((Object)((Image)generatedImageResponse.content()), (TokenUsage)generatedImageResponse.tokenUsage(), (FinishReason)generatedImageResponse.finishReason());
    }

    public Response<Image> edit(Image image, Image mask, String prompt) {
        Response<List<Image>> generatedImageResponse = this.generate(prompt, image, mask, 1);
        return Response.from((Object)((Image)((List)generatedImageResponse.content()).get(0)), (TokenUsage)generatedImageResponse.tokenUsage(), (FinishReason)generatedImageResponse.finishReason());
    }

    private URI persistAndGetURI(String bytesBase64Encoded) {
        if (this.withPersisting != null && this.withPersisting.booleanValue()) {
            try {
                Path tempFile = Files.createTempFile(this.tempDirectory, "imagen-image-", ".png", new FileAttribute[0]);
                Files.write(tempFile, Base64.getDecoder().decode(bytesBase64Encoded), new OpenOption[0]);
                return tempFile.toUri();
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        return null;
    }

    public static Builder builder() {
        Iterator iterator = ServiceHelper.loadFactories(VertexAiImageModelBuilderFactory.class).iterator();
        if (iterator.hasNext()) {
            VertexAiImageModelBuilderFactory factory = (VertexAiImageModelBuilderFactory)iterator.next();
            return (Builder)factory.get();
        }
        return new Builder();
    }

    public static enum ImageStyle {
        photograph,
        digital_art,
        landscape,
        sketch,
        watercolor,
        cyberpunk,
        pop_art;

    }

    public static class Builder {
        private String endpoint;
        private String project;
        private String location;
        private String publisher;
        private String modelName;
        private Long seed;
        private String language;
        private String negativePrompt;
        private ImageStyle sampleImageStyle;
        private Integer sampleImageSize;
        private Integer maxRetries;
        private Integer guidanceScale;
        private Boolean withPersisting;
        private Path persistTo;

        public Builder endpoint(String endpoint) {
            this.endpoint = endpoint;
            return this;
        }

        public Builder project(String project) {
            this.project = project;
            return this;
        }

        public Builder location(String location) {
            this.location = location;
            return this;
        }

        public Builder publisher(String publisher) {
            this.publisher = publisher;
            return this;
        }

        public Builder modelName(String modelName) {
            this.modelName = modelName;
            return this;
        }

        public Builder seed(Long seed) {
            this.seed = seed;
            return this;
        }

        public Builder language(String language) {
            this.language = language;
            return this;
        }

        public Builder guidanceScale(Integer guidanceScale) {
            this.guidanceScale = guidanceScale;
            return this;
        }

        public Builder negativePrompt(String negativePrompt) {
            this.negativePrompt = negativePrompt;
            return this;
        }

        public Builder sampleImageStyle(ImageStyle sampleImageStyle) {
            this.sampleImageStyle = sampleImageStyle;
            return this;
        }

        public Builder sampleImageSize(Integer sampleImageSize) {
            this.sampleImageSize = sampleImageSize;
            return this;
        }

        public Builder maxRetries(Integer maxRetries) {
            this.maxRetries = maxRetries;
            return this;
        }

        public Builder withPersisting() {
            this.withPersisting = Boolean.TRUE;
            return this;
        }

        public Builder persistTo(Path persistTo) {
            this.persistTo = persistTo;
            return this.withPersisting();
        }

        public VertexAiImageModel build() {
            return new VertexAiImageModel(this.endpoint, this.project, this.location, this.publisher, this.modelName, this.seed, this.language, this.guidanceScale, this.negativePrompt, this.sampleImageStyle, this.sampleImageSize, this.maxRetries, this.withPersisting, this.persistTo);
        }
    }
}

