/*
 * Decompiled with CFR 0.152.
 */
package boofcv.alg.segmentation.ms;

import boofcv.alg.interpolate.InterpolatePixelMB;
import boofcv.alg.misc.ImageMiscOps;
import boofcv.alg.segmentation.ms.SegmentMeanShiftSearch;
import boofcv.struct.feature.ColorQueue_F32;
import boofcv.struct.image.GrayS32;
import boofcv.struct.image.ImageMultiBand;
import boofcv.struct.image.ImageType;
import georegression.struct.point.Point2D_F32;
import georegression.struct.point.Point2D_I32;
import java.util.Arrays;
import org.ddogleg.struct.FastQueue;

public class SegmentMeanShiftSearchColor<T extends ImageMultiBand<T>>
extends SegmentMeanShiftSearch<T> {
    protected InterpolatePixelMB<T> interpolate;
    protected float[] pixelColor;
    protected float[] meanColor;
    protected float[] sumColor;
    protected FastQueue<Point2D_F32> history = new FastQueue(Point2D_F32.class, true);
    ImageType<T> imageType;

    public SegmentMeanShiftSearchColor(int maxIterations, float convergenceTol, InterpolatePixelMB<T> interpolate, int radiusX, int radiusY, float maxColorDistance, boolean fast, ImageType<T> imageType) {
        super(maxIterations, convergenceTol, radiusX, radiusY, maxColorDistance, fast);
        this.interpolate = interpolate;
        this.pixelColor = new float[imageType.getNumBands()];
        this.meanColor = new float[imageType.getNumBands()];
        this.sumColor = new float[imageType.getNumBands()];
        this.imageType = imageType;
        int numBands = imageType.getNumBands();
        this.modeColor = new ColorQueue_F32(numBands);
    }

    @Override
    public void process(T image) {
        this.image = image;
        this.stopRequested = false;
        this.modeLocation.reset();
        this.modeColor.reset();
        this.modeMemberCount.reset();
        this.interpolate.setImage(image);
        this.pixelToMode.reshape(((ImageMultiBand)image).width, ((ImageMultiBand)image).height);
        this.quickMode.reshape(((ImageMultiBand)image).width, ((ImageMultiBand)image).height);
        ImageMiscOps.fill((GrayS32)this.pixelToMode, (int)-1);
        ImageMiscOps.fill((GrayS32)this.quickMode, (int)-1);
        int indexImg = 0;
        for (int y = 0; y < ((ImageMultiBand)image).height && !this.stopRequested; ++y) {
            int x = 0;
            while (x < ((ImageMultiBand)image).width) {
                if (this.pixelToMode.data[indexImg] != -1) {
                    int peakIndex;
                    int n = peakIndex = this.pixelToMode.data[indexImg];
                    this.modeMemberCount.data[n] = this.modeMemberCount.data[n] + 1;
                } else {
                    this.interpolate.get((float)x, (float)y, this.meanColor);
                    this.findPeak(x, y, this.meanColor);
                    int modeX = (int)(this.modeX + 0.5f);
                    int modeY = (int)(this.modeY + 0.5f);
                    int modePixelIndex = modeY * ((ImageMultiBand)image).width + modeX;
                    int modeIndex = this.quickMode.data[modePixelIndex];
                    if (modeIndex < 0) {
                        modeIndex = this.modeLocation.size();
                        ((Point2D_I32)this.modeLocation.grow()).set(modeX, modeY);
                        this.savePeakColor(this.meanColor);
                        this.quickMode.data[modePixelIndex] = modeIndex;
                        this.modeMemberCount.add(0);
                    }
                    int n = modeIndex;
                    this.modeMemberCount.data[n] = this.modeMemberCount.data[n] + 1;
                    for (int i = 0; i < this.history.size; ++i) {
                        Point2D_F32 p = (Point2D_F32)this.history.get(i);
                        int px = (int)(p.x + 0.5f);
                        int py = (int)(p.y + 0.5f);
                        int index = this.pixelToMode.getIndex(px, py);
                        if (this.pixelToMode.data[index] != -1) continue;
                        this.pixelToMode.data[index] = modeIndex;
                    }
                }
                ++x;
                ++indexImg;
            }
        }
    }

    @Override
    public ImageType<T> getImageType() {
        return this.imageType;
    }

    protected void findPeak(float cx, float cy, float[] meanColor) {
        this.history.reset();
        ((Point2D_F32)this.history.grow()).set(cx, cy);
        for (int i = 0; i < this.maxIterations; ++i) {
            int yy;
            float total = 0.0f;
            float sumX = 0.0f;
            float sumY = 0.0f;
            Arrays.fill(this.sumColor, 0.0f);
            int kernelIndex = 0;
            float x0 = cx - (float)this.radiusX;
            float y0 = cy - (float)this.radiusY;
            if (this.interpolate.isInFastBounds(x0, y0) && this.interpolate.isInFastBounds(x0 + (float)this.widthX - 1.0f, y0 + (float)this.widthY - 1.0f)) {
                for (yy = 0; yy < this.widthY; ++yy) {
                    for (int xx = 0; xx < this.widthX; ++xx) {
                        float ds = this.spacialTable[kernelIndex++];
                        this.interpolate.get(x0 + (float)xx, y0 + (float)yy, this.pixelColor);
                        float dc = SegmentMeanShiftSearchColor.distanceSq(this.pixelColor, meanColor) / this.maxColorDistanceSq;
                        float weight = dc > 1.0f ? 0.0f : this.weight((ds + dc) / 2.0f);
                        total += weight;
                        sumX += weight * ((float)xx + x0);
                        sumY += weight * ((float)yy + y0);
                        SegmentMeanShiftSearchColor.sumColor(this.sumColor, this.pixelColor, weight);
                    }
                }
            } else {
                for (yy = 0; yy < this.widthY; ++yy) {
                    float sampleY = y0 + (float)yy;
                    if (sampleY < 0.0f) {
                        kernelIndex += this.widthX;
                        continue;
                    }
                    if (!(sampleY > (float)(((ImageMultiBand)this.image).height - 1))) {
                        int xx = 0;
                        while (xx < this.widthX) {
                            float sampleX = x0 + (float)xx;
                            if (!(sampleX < 0.0f) && !(sampleX > (float)(((ImageMultiBand)this.image).width - 1))) {
                                float ds = this.spacialTable[kernelIndex];
                                this.interpolate.get(x0 + (float)xx, y0 + (float)yy, this.pixelColor);
                                float dc = SegmentMeanShiftSearchColor.distanceSq(this.pixelColor, meanColor) / this.maxColorDistanceSq;
                                float weight = dc > 1.0f ? 0.0f : this.weight((ds + dc) / 2.0f);
                                total += weight;
                                sumX += weight * ((float)xx + x0);
                                sumY += weight * ((float)yy + y0);
                                SegmentMeanShiftSearchColor.sumColor(this.sumColor, this.pixelColor, weight);
                            }
                            ++xx;
                            ++kernelIndex;
                        }
                        continue;
                    }
                    break;
                }
            }
            if (total == 0.0f) break;
            float peakX = sumX / total;
            float peakY = sumY / total;
            if (this.fast) {
                ((Point2D_F32)this.history.grow()).set(peakX, peakY);
                int px = (int)(peakX + 0.5f);
                int py = (int)(peakY + 0.5f);
                int index = this.pixelToMode.getIndex(px, py);
                int modeIndex = this.pixelToMode.data[index];
                if (modeIndex != -1) {
                    Point2D_I32 modeP = (Point2D_I32)this.modeLocation.get(modeIndex);
                    this.modeX = modeP.x;
                    this.modeY = modeP.y;
                    return;
                }
            }
            float dx = peakX - cx;
            float dy = peakY - cy;
            cx = peakX;
            cy = peakY;
            SegmentMeanShiftSearchColor.meanColor(this.sumColor, meanColor, total);
            if (Math.abs(dx) < this.convergenceTol && Math.abs(dy) < this.convergenceTol) break;
        }
        this.modeX = cx;
        this.modeY = cy;
    }

    protected static void meanColor(float[] sum, float[] mean, float total) {
        for (int i = 0; i < sum.length; ++i) {
            mean[i] = sum[i] / total;
        }
    }

    protected static void sumColor(float[] sum, float[] pixel, float weight) {
        for (int i = 0; i < sum.length; ++i) {
            int n = i;
            sum[n] = sum[n] + pixel[i] * weight;
        }
    }

    protected void savePeakColor(float[] a) {
        float[] b = (float[])this.modeColor.grow();
        for (int i = 0; i < a.length; ++i) {
            b[i] = a[i];
        }
    }
}

