package ai.xpl.android.client;

import android.content.Context;
import android.graphics.Bitmap;
import android.net.wifi.WifiInfo;
import android.net.wifi.WifiManager;
import android.util.Log;

import com.android.volley.Request;
import com.android.volley.RequestQueue;
import com.android.volley.Response;
import com.android.volley.VolleyError;

import org.pytorch.Device;
import org.pytorch.IValue;
import org.pytorch.LiteModuleLoader;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;

import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;


public class Task {
    static float[] NO_MEAN_RGB = new float[] {0.0f, 0.0f, 0.0f};
    static float[] NO_STD_RGB = new float[] {1.0f, 1.0f, 1.0f};

    public TaskData data;
    private String deviceFingerprints;
    private Context context;
    private RequestQueue requestQueue;
    private Map<String, Concept> concepts = new HashMap<>();

    public Map<String, Module> modules = new HashMap<>();

    public Task(TaskData taskData, Context context, RequestQueue requestQueue){
        this.data = taskData;
        for(Concept concept : data.concepts){
            concepts.put(concept.conceptId, concept);
        }

        this.context = context;
        this.requestQueue = requestQueue;

        this.deviceFingerprints = GetMacAddress();
    }

    public List<PredictedInstance> execute(Bitmap bitmap) {
        List<PredictedInstance> result = new ArrayList<>();
        StringBuilder performance_metrics = new StringBuilder();
        // 1.
        long start = System.currentTimeMillis();

        Bitmap resizedBitmap = Bitmap.createScaledBitmap(bitmap, 320,512, true);
        Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(resizedBitmap, NO_MEAN_RGB, NO_STD_RGB);

        Map<String, IValue> input = new HashMap<>();
        input.put("image", IValue.from(inputTensor));
        IValue inputs = IValue.dictStringKeyFrom(input);

//        Module tail = this.modules.get(getTailComponentName());
        Module head = this.modules.get(getHeadComponentName());
//
//        performance_metrics.append("1. Preparing input " + Long.toString(System.currentTimeMillis() - start) + "ms\n");
//
//        // 2.
//        start = System.currentTimeMillis();
//
//        IValue output = tail.forward(inputs);
//
//        performance_metrics.append("2. tail.forward " + Long.toString(System.currentTimeMillis() - start) + "ms\n");

        // 3.
        start = System.currentTimeMillis();

        IValue output = head.forward(inputs);
        IValue[] outputList = output.toList();

        for (IValue prediction: outputList) {
            Map<String,IValue> predictionDict = prediction.toDictStringKey();

            PredictedInstance instance = new PredictedInstance();
            instance.conceptId = predictionDict.get("concept_id").toStr();
            instance.text = predictionDict.get("text").toStr();
            instance.conceptDisplayName = this.concepts.get(instance.conceptId).displayName;

            instance.location = new Location();
            instance.location.centerX = predictionDict.get("center_x").toDouble();
            instance.location.centerY = predictionDict.get("center_y").toDouble();
            instance.location.halfWidth = predictionDict.get("hallf_width").toDouble();
            instance.location.halfHeight = predictionDict.get("half_height").toDouble();

            result.add(instance);
        }

        performance_metrics.append("3. head.forward " + Long.toString(System.currentTimeMillis() - start) + "ms\n");

        // 4.
        start = System.currentTimeMillis();

//        Map<String, IValue> dictOutput = output.toDictStringKey();
//        Tensor locationsTensor = dictOutput.get("pred_" + data.name+ "_locations").toTensor();
//        Tensor labelTensor = dictOutput.get("pred_" + data.name + "_label").toTensor();
//        Tensor objectivenessTensor = dictOutput.get("pred_" + data.name + "_objectiveness").toTensor();

//        IValue informativeness = dictOutput.get("pred_" + data.name + "_informativeness");
//        if (locationsTensor != null) { // test stub
////        if (informativeness != null) {
////            if (this.shouldSample(informativeness.toTensor())){
//            if (this.shouldSample(locationsTensor)){
//                this.sendSampleDataPoint(bitmap, locationsTensor, labelTensor, objectivenessTensor);
//            }
//        }

        performance_metrics.append("4.Taking sample " + Long.toString(System.currentTimeMillis() - start) + "ms\n");

        if (result.size() > 0){
            result.get(0).comment = performance_metrics.toString();
        }

        return result;
    }

    public List<PredictedInstance> execute(String text) {
        List<PredictedInstance> result = new ArrayList<>();
        return result;
    }

    private Boolean shouldSample(Tensor informativeness){
        return true;
    }

    private void sendSampleDataPoint(Bitmap imageSample, Tensor location, Tensor label, Tensor confidence) {
        DataPoint dataPoint = new DataPoint();
        dataPoint.dataPointId = UUID.randomUUID().toString().replace("-", "");;
        dataPoint.taskId = this.data.taskId;
        dataPoint.dataItems = new ArrayList<>();
        dataPoint.collectedByDeviceFingerprintId = this.deviceFingerprints;
        dataPoint.fileUris = new ArrayList<>();

        DataItem dataItem = new DataItem();
        dataItem.conceptId = "test_mobile_concept";
        dataItem.instanceId = UUID.randomUUID().toString().replace("-", "");;
        dataItem.predictorType = "model";
        dataItem.predictorId = this.data.model.modelId;
        dataItem.inputSet = "train";
        dataItem.logInformativeness = 0.5f;
        dataItem.logitConfidence = 0.5f;
        dataItem.location = new HashMap<>();
        dataItem.text = "Test text";
        dataItem.value = 0.12345f;

        dataItem.location.put("center_x", 0.5f);
        dataItem.location.put("center_y", 0.5f);
        dataItem.location.put("half_width", 0.5f);
        dataItem.location.put("half_height", 0.5f);

        dataPoint.dataItems.add(dataItem);

        dataPoint.bitmaps = new HashMap<>();
        dataPoint.bitmaps.put("input.png", imageSample);

        dataPoint.binaries = new HashMap<>();
        dataPoint.binaries.put("input.png", new byte[] {1});
//        dataPoint.binaries.put("input.png", makePng(imageSample));

        DataApiClient.postDataPoint(this.requestQueue, dataPoint, this.data.taskApiKey);
    }

    private String GetMacAddress() {
        WifiManager manager = (WifiManager)this.context.getSystemService(Context.WIFI_SERVICE);
        WifiInfo info = manager.getConnectionInfo();

        return info.getMacAddress();
    }

    private byte[] makePng(Bitmap bitmap){
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        bitmap.compress(Bitmap.CompressFormat.PNG, 100, byteArrayOutputStream);

        byte[] result = byteArrayOutputStream.toByteArray();

        return result;
    }

    public void loadComponents(Context context) throws Exception {
        for (ModelComponent component: this.data.model.components.values()){
            if (this.modules.get(component.name) == null) {
                File dir = context.getDir("xplai", Context.MODE_PRIVATE);
                String filename = getComponentFileName(component.name);

                if (!fileExists(context, filename)){
                    throw new Exception("Component file '" + filename + "' is not present on device.");
                }

                File componentFile = new File(dir, filename);
                Module module = LiteModuleLoader.load(componentFile.getAbsolutePath(), new HashMap<>(), Device.CPU);
                modules.put(component.name, module);
            }
        }
    }

    public void downloadModelComponent(Context context, RequestQueue requestQueue, String componentName) {
        String url = Objects.requireNonNull(this.data.model.components.get(componentName)).urlForDownload;
        InputStreamVolleyRequest request = new InputStreamVolleyRequest(Request.Method.GET, url,
                new Response.Listener<byte[]>() {
                    @Override
                    public void onResponse(byte[] response) {
                        // TODO handle the response
                        try {
                            if (response!=null) {
                                File dir = context.getDir(Config.ROOT_DIR, Context.MODE_PRIVATE);
                                String fileName = getComponentFileName(componentName);
                                File componentFile = new File(dir, fileName);
                                FileOutputStream outputStream = new FileOutputStream(componentFile);

                                outputStream.write(response);
                                outputStream.close();
                            }
                        } catch (Exception e) {
                            // TODO Auto-generated catch block
                            Log.d("KEY_ERROR", "UNABLE TO DOWNLOAD FILE");
                            e.printStackTrace();
                        }
                    }
                } ,new Response.ErrorListener() {

            @Override
            public void onErrorResponse(VolleyError error) {
                // TODO handle the error
                error.printStackTrace();
            }
        }, null);
        requestQueue.add(request);
    }

    public boolean componentFileExists(Context context, String componentName){
        String componentFileName = getComponentFileName(componentName);
        return fileExists(context, componentFileName);
    }

    private boolean fileExists(Context context, String filename){
        File dir = context.getDir("xplai", Context.MODE_PRIVATE);
        File componentFile = new File(dir, filename);

        return componentFile.exists();
    }

    public String getComponentFileName(String componentName){
        componentName = Objects.requireNonNull(this.data.model.components.get(componentName)).name;
        return this.data.taskId + "__" + "v" + this.data.model.version + "__" + componentName;
    }

    private String getTailComponentName() {
        return this.data.modality + "_rep.ptl";
    }

    private String getHeadComponentName() {
        return this.data.name + "_head.ptl";
    }

    private String getModelSingleComponentName() {
        return this.data.name + ".ptl";
    }
}
