#ifndef ML_TRANSFORMERS_AUDIO_NOISE_SUPPRESSION_ML_TRANSFORMERS_AUDIO_NOISE_SUPPRESSION_H_
#define ML_TRANSFORMERS_AUDIO_NOISE_SUPPRESSION_ML_TRANSFORMERS_AUDIO_NOISE_SUPPRESSION_H_

#include <complex>
#include <functional>
#include <memory>
#include <string>
#include <vector>

#if defined(_WIN32)
#ifdef VONAGE_LIBRARY_IMPL
#define VONAGE_EXPORT __declspec(dllexport)
#else
#define VONAGE_EXPORT __declspec(dllimport)
#endif
#else
#define VONAGE_EXPORT
#endif

namespace vonage {

struct NSMLHelper;
class ResampleInterface;

class VONAGE_EXPORT AudioMLTransformerInterface {
public:
    enum class ReturnCode {
        kSuccess = 0,
        kWarning = 1,
        kError = 2
    };

    enum class WarningCode {
        kOkWarning = 0,
        kAudioOptionsWarning = 1,
    };

    enum class ErrorCode {
        kOkError = 0,
        kInitializedError = 1,
        kProcessError = 2,
    };

    virtual ~AudioMLTransformerInterface() = default;
    virtual uint8_t GetRemixAndResampleConfig(size_t& num_channels,
                                              int& sample_rate_hz,
                                              bool& internal_resample_supported) const = 0;
    virtual uint8_t Transform(const int16_t* in_data,
                              size_t samples_per_channel,
                              size_t num_channels,
                              int sample_rate_hz,
                              int16_t* out_data,
                              size_t* out_data_len,
                              size_t out_data_max_len) = 0;
    virtual uint8_t CompleteTransform(size_t samples_per_channel,
                                      size_t num_channels,
                                      int sample_rate_hz,
                                      int16_t* out_data,
                                      size_t* out_data_len,
                                      size_t out_data_max_len) = 0;
    virtual uint8_t SetAudioOptions(bool echo_cancellation,
                                    bool auto_gain_control,
                                    bool noise_suppression,
                                    bool stereo_swapping,
                                    bool highpass_filter) = 0;
    virtual const std::pair<uint8_t, std::string>& GetLastWarning() const = 0;
    virtual const std::pair<uint8_t, std::string>& GetLastError() const = 0;
};

struct NSAudioMLTransformerModelFileContent {
    size_t size;
    uint8_t* content;
};

class VONAGE_EXPORT NSAudioMLTransformer final : public AudioMLTransformerInterface {
public:
    static void SetForceCPPResampler(bool val);

    using WriteWavSamplesCallback = std::function<void(int sample_rate_hz, size_t num_channels, const int16_t* samples, size_t num_samples)>;
    NSAudioMLTransformer(const std::vector<std::string>& ml_model_file_paths);
    NSAudioMLTransformer(const std::vector<NSAudioMLTransformerModelFileContent>& ml_model_file_contents);
    virtual ~NSAudioMLTransformer();
    
    uint8_t Init(int32_t xnnpack_delegate_options_num_threads);
    
    // AudioMLTransformerInterface
    
    uint8_t GetRemixAndResampleConfig(size_t& num_channels,
                                      int& sample_rate_hz,
                                      bool& internal_resample_supported) const override;
    uint8_t Transform(const int16_t* in_data,
                      size_t samples_per_channel,
                      size_t num_channels,
                      int sample_rate_hz,
                      int16_t* out_data,
                      size_t* out_data_len,
                      size_t out_data_max_len) override;
    uint8_t CompleteTransform(size_t samples_per_channel,
                              size_t num_channels,
                              int sample_rate_hz,
                              int16_t* out_data,
                              size_t* out_data_len,
                              size_t out_data_max_len) override;
    uint8_t SetAudioOptions(bool echo_cancellation,
                            bool auto_gain_control,
                            bool noise_suppression,
                            bool stereo_swapping,
                            bool highpass_filter) override;
    const std::pair<uint8_t, std::string>& GetLastWarning() const override;
    const std::pair<uint8_t, std::string>& GetLastError() const override;
    
    void SetWriteWavSamplesCallback(WriteWavSamplesCallback callback);
    float GetAverageLatencyNs();
    
private:
    NSAudioMLTransformer();
    
    std::unique_ptr<int16_t[]> AllocateData(size_t data_length);
    void GetPhase(std::vector<std::complex<double>> fft_res, float* in_mag, float* in_phase, int count);
    uint8_t RunMLInference();
    uint8_t ProcessAudio();
    uint8_t InitModels();
    uint8_t InitModelsFromPath();
    uint8_t InitModelsFromContent();
    uint8_t TransformInternal(const int16_t* in_data,
                              size_t samples_per_channel,
                              size_t num_channels,
                              int sample_rate_hz);
    
private:
    std::pair<uint8_t, std::string> last_warning_ = {static_cast<uint8_t>(AudioMLTransformerInterface::WarningCode::kOkWarning), {""}};
    std::pair<uint8_t, std::string> last_error_ = {static_cast<uint8_t>(AudioMLTransformerInterface::ErrorCode::kOkError), {""}};
    std::vector<std::string> ml_model_file_paths_;
    std::vector<NSAudioMLTransformerModelFileContent> ml_model_file_contents_;
    std::unique_ptr<NSMLHelper> ns_ml_helper_;
    bool initialized_ = false;
    std::vector<float> audio_input_container_;
    std::vector<float> audio_output_container_;
    std::vector<int16_t> clean_audio_buffer_;
    std::vector<int16_t> clean_and_resample_audio_buffer_;
    std::vector<float> frame_transform_latency_ns_;
    WriteWavSamplesCallback write_wav_samples_callback_;
    std::unique_ptr<ResampleInterface> in_resampler_;
    std::unique_ptr<ResampleInterface> out_resampler_;
    std::unique_ptr<int16_t[]> leftover_buffer_;
    size_t leftover_size_;
};

} // namespace vonage

#endif // ML_TRANSFORMERS_AUDIO_NOISE_SUPPRESSION_ML_TRANSFORMERS_AUDIO_NOISE_SUPPRESSION_H_
