/*
 * Decompiled with CFR 0.152.
 */
package io.cdap.plugin.gcp.speech;

import com.google.cloud.speech.v1.RecognitionAudio;
import com.google.cloud.speech.v1.RecognitionConfig;
import com.google.cloud.speech.v1.RecognizeResponse;
import com.google.cloud.speech.v1.SpeechClient;
import com.google.cloud.speech.v1.SpeechRecognitionAlternative;
import com.google.cloud.speech.v1.SpeechRecognitionResult;
import com.google.cloud.speech.v1.SpeechSettings;
import com.google.common.base.Strings;
import com.google.protobuf.ByteString;
import io.cdap.cdap.api.annotation.Description;
import io.cdap.cdap.api.annotation.Macro;
import io.cdap.cdap.api.annotation.Name;
import io.cdap.cdap.api.annotation.Plugin;
import io.cdap.cdap.api.data.format.StructuredRecord;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.cdap.etl.api.Emitter;
import io.cdap.cdap.etl.api.FailureCollector;
import io.cdap.cdap.etl.api.PipelineConfigurer;
import io.cdap.cdap.etl.api.StageSubmitterContext;
import io.cdap.cdap.etl.api.Transform;
import io.cdap.cdap.etl.api.TransformContext;
import io.cdap.plugin.gcp.common.GCPConfig;
import io.cdap.plugin.gcp.common.GCPUtils;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;

@Plugin(type="transform")
@Name(value="SpeechToText")
@Description(value="Converts audio files to text by applying powerful neural network models.")
public class SpeechToTextTransform
extends Transform<StructuredRecord, StructuredRecord> {
    public static final String NAME = "SpeechToText";
    public static final String DESCRIPTION = "Converts audio files to text by applying powerful neural network models.";
    private SpeechTransformConfig config;
    private Schema outputSchema = null;
    private static final Schema SPEECH = Schema.recordOf((String)"speech", (Schema.Field[])new Schema.Field[]{Schema.Field.of((String)"transcript", (Schema)Schema.nullableOf((Schema)Schema.of((Schema.Type)Schema.Type.STRING))), Schema.Field.of((String)"confidence", (Schema)Schema.nullableOf((Schema)Schema.of((Schema.Type)Schema.Type.FLOAT)))});
    private RecognitionConfig recognitionConfig;
    private SpeechClient speech;

    public void configurePipeline(PipelineConfigurer configurer) throws IllegalArgumentException {
        super.configurePipeline(configurer);
        Schema inputSchema = configurer.getStageConfigurer().getInputSchema();
        FailureCollector collector = configurer.getStageConfigurer().getFailureCollector();
        this.config.validate(inputSchema, collector);
        configurer.getStageConfigurer().setOutputSchema(this.getSchema(inputSchema));
    }

    public void prepareRun(StageSubmitterContext context) throws Exception {
        super.prepareRun(context);
        this.config.validate(context.getInputSchema(), context.getFailureCollector());
    }

    public void initialize(TransformContext context) throws Exception {
        super.initialize(context);
        this.outputSchema = context.getOutputSchema();
        this.setRecognitionConfig();
        if (this.config.isServiceAccountFilePath() == null) {
            context.getFailureCollector().addFailure("Service account type is undefined.", "Must be `filePath` or `JSON`");
            context.getFailureCollector().getOrThrowException();
        }
        this.speech = SpeechClient.create((SpeechSettings)this.getSettings());
    }

    public void transform(StructuredRecord input, Emitter<StructuredRecord> emitter) {
        ByteString audioBytes = ByteString.copyFrom((byte[])((byte[])input.get(this.config.audioField)));
        RecognitionAudio audio = RecognitionAudio.newBuilder().setContent(audioBytes).build();
        RecognizeResponse response = this.speech.recognize(this.recognitionConfig, audio);
        List results = response.getResultsList();
        Schema currentSchema = this.outputSchema != null ? this.outputSchema : this.getSchema(input.getSchema());
        StructuredRecord.Builder outputBuilder = StructuredRecord.builder((Schema)currentSchema);
        ArrayList<StructuredRecord> transcriptParts = new ArrayList<StructuredRecord>();
        StringBuilder completeTranscript = new StringBuilder();
        for (SpeechRecognitionResult result : results) {
            if (this.config.getPartsField() != null) {
                this.addTranscriptWithConfidence(transcriptParts, result.getAlternativesList());
            }
            if (this.config.getTextField() == null) continue;
            completeTranscript.append(((SpeechRecognitionAlternative)result.getAlternativesList().get(0)).getTranscript());
        }
        if (this.config.getPartsField() != null) {
            outputBuilder.set(this.config.getPartsField(), transcriptParts);
        }
        if (this.config.getTextField() != null) {
            outputBuilder.set(this.config.getTextField(), (Object)completeTranscript.toString());
        }
        this.copyFields(input, outputBuilder);
        emitter.emit((Object)outputBuilder.build());
    }

    public void destroy() {
        super.destroy();
        try {
            this.speech.close();
        }
        catch (Exception exception) {
            // empty catch block
        }
    }

    @Nullable
    private Schema getSchema(@Nullable Schema inputSchema) {
        if (inputSchema == null) {
            return null;
        }
        ArrayList<Schema.Field> fields = new ArrayList<Schema.Field>();
        if (inputSchema.getFields() != null) {
            fields.addAll(inputSchema.getFields());
        }
        boolean hasTranscriptField = false;
        if (this.config.transcriptionPartsField != null) {
            fields.add(Schema.Field.of((String)this.config.transcriptionPartsField, (Schema)Schema.arrayOf((Schema)SPEECH)));
            hasTranscriptField = true;
        }
        if (this.config.transcriptionTextField != null) {
            fields.add(Schema.Field.of((String)this.config.transcriptionTextField, (Schema)Schema.nullableOf((Schema)Schema.of((Schema.Type)Schema.Type.STRING))));
            hasTranscriptField = true;
        }
        if (!(hasTranscriptField || this.config.containsMacro("transcriptionPartsField") || this.config.containsMacro("transcriptionTextField"))) {
            throw new IllegalArgumentException("Either 'Transcript Parts Field' or 'Transcript Text Field' or both must be specified.");
        }
        return Schema.recordOf((String)"record", fields);
    }

    private SpeechSettings getSettings() throws IOException {
        SpeechSettings.Builder builder = SpeechSettings.newBuilder();
        if (!Strings.isNullOrEmpty((String)this.config.getServiceAccount())) {
            builder.setCredentialsProvider(() -> GCPUtils.loadServiceAccountCredentials(this.config.getServiceAccount(), this.config.isServiceAccountFilePath()));
        }
        return builder.build();
    }

    private void setRecognitionConfig() {
        this.recognitionConfig = RecognitionConfig.newBuilder().setEncoding(this.getAudioEncoding()).setSampleRateHertz(this.config.getSampleRate().intValue()).setProfanityFilter(this.config.shouldFilterProfanity()).setLanguageCode(this.config.getLanguage()).build();
    }

    private RecognitionConfig.AudioEncoding getAudioEncoding() {
        RecognitionConfig.AudioEncoding encoding = RecognitionConfig.AudioEncoding.LINEAR16;
        if (this.config.encoding.equalsIgnoreCase("amr")) {
            encoding = RecognitionConfig.AudioEncoding.AMR;
        } else if (this.config.encoding.equalsIgnoreCase("amr_wb")) {
            encoding = RecognitionConfig.AudioEncoding.AMR_WB;
        } else if (this.config.encoding.equalsIgnoreCase("flac")) {
            encoding = RecognitionConfig.AudioEncoding.FLAC;
        } else if (this.config.encoding.equalsIgnoreCase("mulaw")) {
            encoding = RecognitionConfig.AudioEncoding.MULAW;
        } else if (this.config.encoding.equalsIgnoreCase("ogg_opus")) {
            encoding = RecognitionConfig.AudioEncoding.OGG_OPUS;
        }
        return encoding;
    }

    private void copyFields(StructuredRecord input, StructuredRecord.Builder outputBuilder) {
        List fields = input.getSchema().getFields();
        if (fields != null) {
            for (Schema.Field field : fields) {
                outputBuilder.set(field.getName(), input.get(field.getName()));
            }
        }
    }

    private void addTranscriptWithConfidence(List<StructuredRecord> speechArray, List<SpeechRecognitionAlternative> alternatives) {
        for (SpeechRecognitionAlternative alternative : alternatives) {
            float confidence = alternative.getConfidence();
            String transcript = alternative.getTranscript();
            StructuredRecord.Builder speech = StructuredRecord.builder((Schema)SPEECH);
            speech.set("confidence", (Object)Float.valueOf(confidence));
            speech.set("transcript", (Object)transcript);
            speechArray.add(speech.build());
        }
    }

    private static class SpeechTransformConfig
    extends GCPConfig {
        private static final String NAME_AUDIOFIELD = "audiofield";
        private static final String NAME_TRANS_PART = "transcriptionPartsField";
        private static final String NAME_TRANS_TEXT = "transcriptionTextField";
        private static final String NAME_RATE = "samplerate";
        @Macro
        @Name(value="audiofield")
        @Description(value="Name of field containing binary audio file data")
        private String audioField;
        @Macro
        @Description(value="Audio encoding of the data sent in the audio message. All encodings support\nonly 1 channel (mono) audio. Only `FLAC` and `WAV` include a header that\ndescribes the bytes of audio that follow the header. The other encodings\nare raw audio bytes with no header.")
        private String encoding;
        @Macro
        @Name(value="samplerate")
        @Description(value="Sample rate in Hertz of the audio data sent in all `RecognitionAudio` messages. Valid values are: 8000-48000. 16000 is optimal. For best results, set the sampling rate of the audio source to 16000 Hz. If that's not possible, use the native sample rate of the audio source (instead of re-sampling).")
        private String sampleRate;
        @Macro
        @Description(value="Whether to mask profanity, replacing all but the initial character in each masked word with asterisks. For example, 'f***'.")
        private String profanity;
        @Macro
        @Description(value="The language of the supplied audio as a [BCP-47](https://www.rfc-editor.org/rfc/bcp/bcp47.txt) language tag. Example: \"en-US\". See [Language Support](https://cloud.google.com/speech/docs/languages) for a list of the currently supported language codes.")
        private String language;
        @Macro
        @Nullable
        @Description(value="The name of the field to store all the different chunks of transcription with all the different possibility and their confidence score. Defaults to 'parts'")
        private String transcriptionPartsField;
        @Macro
        @Nullable
        @Description(value="If a field name is specified then the transcription with highest confidence score will be stored as text.")
        private String transcriptionTextField;

        private SpeechTransformConfig() {
        }

        @Nullable
        public String getPartsField() {
            if (this.containsMacro(NAME_TRANS_PART) || this.transcriptionPartsField == null || this.transcriptionPartsField.isEmpty()) {
                return null;
            }
            return this.transcriptionPartsField;
        }

        @Nullable
        public String getTextField() {
            if (this.containsMacro(NAME_TRANS_TEXT) || this.transcriptionTextField == null || this.transcriptionTextField.isEmpty()) {
                return null;
            }
            return this.transcriptionTextField;
        }

        @Nullable
        public String getAudioField() {
            if (this.containsMacro(NAME_AUDIOFIELD)) {
                return null;
            }
            return this.audioField;
        }

        public boolean shouldFilterProfanity() {
            return this.profanity.equalsIgnoreCase("true");
        }

        public String getLanguage() {
            return this.language == null ? "en-US" : this.language;
        }

        @Nullable
        public Integer getSampleRate() {
            if (this.containsMacro(NAME_RATE)) {
                return null;
            }
            try {
                Integer.parseInt(this.sampleRate);
            }
            catch (NumberFormatException e) {
                throw new IllegalArgumentException("Sample rate should be a valid number");
            }
            return Integer.parseInt(this.sampleRate);
        }

        private void validate(@Nullable Schema inputSchema, FailureCollector collector) {
            if (inputSchema != null) {
                String audioFieldName = this.getAudioField();
                if (audioFieldName != null) {
                    Schema.Field field = inputSchema.getField(audioFieldName);
                    if (field == null) {
                        collector.addFailure(String.format("Field '%s' does not exist in the input schema.", audioFieldName), "Change audio field to be one of the schema fields.").withConfigProperty(NAME_AUDIOFIELD);
                    } else {
                        Schema fieldSchema = field.getSchema();
                        Schema schema = fieldSchema = fieldSchema.isNullable() ? fieldSchema.getNonNullable() : fieldSchema;
                        if (fieldSchema.getLogicalType() != null || fieldSchema.getType() != Schema.Type.BYTES) {
                            collector.addFailure(String.format("Field '%s' is of unsupported type '%s'.", audioFieldName, fieldSchema.getDisplayName()), "Ensure it is of type 'bytes'.").withConfigProperty(NAME_AUDIOFIELD).withInputSchemaField(audioFieldName);
                        }
                    }
                }
                if (!this.containsMacro(NAME_TRANS_TEXT) && this.getTextField() == null && !this.containsMacro(NAME_TRANS_PART) && this.getPartsField() == null) {
                    collector.addFailure("'Transcript Parts Field' or 'Transcript Text Field' are not provided.", "Provide atleast one of them.").withConfigProperty(NAME_TRANS_PART).withConfigProperty(NAME_TRANS_TEXT);
                }
                Set fields = inputSchema.getFields().stream().map(Schema.Field::getName).collect(Collectors.toSet());
                if (this.getTextField() != null && fields.contains(this.getTextField())) {
                    collector.addFailure(String.format("Transcript text field '%s' already exists in the input schema.", this.getTextField()), "Change the field name.").withConfigProperty(NAME_TRANS_TEXT).withInputSchemaField(this.getTextField());
                }
                if (this.getPartsField() != null && fields.contains(this.getPartsField())) {
                    collector.addFailure(String.format("Transcript parts field '%s' already exists in the input schema.", this.getPartsField()), "Change the field name.").withConfigProperty(NAME_TRANS_PART).withInputSchemaField(this.getPartsField());
                }
            }
            try {
                Integer sampleRate = this.getSampleRate();
                if (sampleRate != null && (sampleRate < 8000 || sampleRate > 48000)) {
                    collector.addFailure("Invalid sample rate.", "Ensure the value is between 8000 and 48000.").withConfigProperty(NAME_RATE);
                }
            }
            catch (IllegalArgumentException e) {
                collector.addFailure("Invalid sample rate.", "Ensure the value is between 8000 and 48000.").withConfigProperty(NAME_RATE);
            }
            collector.getOrThrowException();
        }
    }
}

