/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.metrics;

import ai.djl.modality.cv.MultiBoxTarget;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.training.metrics.TrainingMetric;

public class BoundingBoxError
extends TrainingMetric {
    private float ssdBoxPredictionError;
    private float numInstances;
    private MultiBoxTarget multiBoxTarget = new MultiBoxTarget.Builder().build();

    public BoundingBoxError(String name) {
        super(name);
    }

    @Override
    public void update(NDList labels, NDList predictions) {
        NDArray anchors = (NDArray)predictions.get(0);
        NDArray classPredictions = (NDArray)predictions.get(1);
        NDArray boundingBoxPredictions = (NDArray)predictions.get(2);
        NDList targets = this.multiBoxTarget.target(new NDList(anchors, labels.head(), classPredictions.transpose(0, 2, 1)));
        NDArray boundingBoxLabels = (NDArray)targets.get(0);
        NDArray boundingBoxMasks = (NDArray)targets.get(1);
        NDArray boundingBoxError = boundingBoxLabels.sub(boundingBoxPredictions).mul(boundingBoxMasks).abs().sum();
        this.ssdBoxPredictionError += boundingBoxError.getFloat(new long[0]);
        this.numInstances += (float)boundingBoxLabels.size();
    }

    @Override
    public TrainingMetric duplicate() {
        return new BoundingBoxError(this.getName());
    }

    @Override
    public void reset() {
        this.ssdBoxPredictionError = 0.0f;
        this.numInstances = 0.0f;
    }

    @Override
    public float getValue() {
        return this.ssdBoxPredictionError / this.numInstances;
    }
}

