/*
 * Decompiled with CFR 0.152.
 */
package boofcv.deepboof;

import boofcv.abst.scene.ImageClassifier;
import boofcv.deepboof.ClipAndReduce;
import boofcv.deepboof.DataManipulationOps;
import boofcv.struct.image.GrayF32;
import boofcv.struct.image.ImageType;
import boofcv.struct.image.Planar;
import deepboof.Function;
import deepboof.Tensor;
import deepboof.graph.FunctionSequence;
import deepboof.tensors.Tensor_F32;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import org.ddogleg.struct.DogArray;

public abstract class BaseImageClassifier
implements ImageClassifier<Planar<GrayF32>> {
    protected FunctionSequence<Tensor_F32, Function<Tensor_F32>> network;
    protected List<String> categories = new ArrayList<String>();
    protected ImageType<Planar<GrayF32>> imageType = ImageType.pl((int)3, GrayF32.class);
    protected ClipAndReduce<Planar<GrayF32>> massage = new ClipAndReduce<Planar<GrayF32>>(true, this.imageType);
    protected int imageSize;
    protected Planar<GrayF32> imageRgb;
    protected Tensor_F32 tensorInput;
    protected Tensor_F32 tensorOutput;
    protected DogArray<ImageClassifier.Score> categoryScores = new DogArray(ImageClassifier.Score::new);
    protected int categoryBest;
    Comparator<ImageClassifier.Score> comparator = (o1, o2) -> {
        if (o1.score < o2.score) {
            return 1;
        }
        if (o1.score > o2.score) {
            return -1;
        }
        return 0;
    };

    protected BaseImageClassifier(int imageSize) {
        this.imageSize = imageSize;
        this.imageRgb = new Planar(GrayF32.class, imageSize, imageSize, 3);
        this.tensorInput = new Tensor_F32(new int[]{1, 3, imageSize, imageSize});
    }

    @Override
    public ImageType<Planar<GrayF32>> getInputType() {
        return this.imageType;
    }

    @Override
    public void classify(Planar<GrayF32> image) {
        DataManipulationOps.imageToTensor(this.preprocess(image), this.tensorInput, 0);
        this.innerProcess(this.tensorInput);
    }

    protected Planar<GrayF32> preprocess(Planar<GrayF32> image) {
        if (image.width == this.imageSize && image.height == this.imageSize) {
            this.imageRgb.setTo(image);
        } else {
            if (image.width < this.imageSize || image.height < this.imageSize) {
                throw new IllegalArgumentException("Image width or height is too small");
            }
            this.massage.massage(image, this.imageRgb);
        }
        return this.imageRgb;
    }

    protected void innerProcess(Tensor_F32 tensorInput) {
        this.network.process((Tensor)tensorInput, (Tensor)this.tensorOutput);
        this.categoryScores.reset();
        double scoreBest = -1.7976931348623157E308;
        this.categoryBest = -1;
        for (int category = 0; category < this.tensorOutput.length(1); ++category) {
            double score = this.tensorOutput.get(new int[]{0, category});
            ((ImageClassifier.Score)this.categoryScores.grow()).set(score, category);
            if (!(score > scoreBest)) continue;
            scoreBest = score;
            this.categoryBest = category;
        }
        Collections.sort(this.categoryScores.toList(), this.comparator);
    }

    @Override
    public int getBestResult() {
        return this.categoryBest;
    }

    @Override
    public List<ImageClassifier.Score> getAllResults() {
        return this.categoryScores.toList();
    }

    @Override
    public List<String> getCategories() {
        return this.categories;
    }

    public Planar<GrayF32> getImageRgb() {
        return this.imageRgb;
    }
}

