/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.neuralnetwork;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.distributions.empirical.kernelfunc.EpanechnikovKF;
import jsat.distributions.empirical.kernelfunc.KernelFunction;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.vectorcollection.DefaultVectorCollectionFactory;
import jsat.linear.vectorcollection.VectorCollection;
import jsat.linear.vectorcollection.VectorCollectionFactory;
import jsat.math.decayrates.DecayRate;
import jsat.math.decayrates.ExponetialDecay;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.ArrayUtils;
import jsat.utils.PairedReturn;
import jsat.utils.SystemInfo;
import jsat.utils.random.RandomUtil;

public class SOM
implements Classifier,
Parameterized {
    private static final long serialVersionUID = -6444988770441043797L;
    public static final int DEFAULT_MAX_ITERS = 500;
    public static final KernelFunction DEFAULT_KF = EpanechnikovKF.getInstance();
    public static final double DEFAULT_LEARNING_RATE = 0.1;
    public static final DecayRate DEFAULT_LEARNING_DECAY = new ExponetialDecay();
    public static final DecayRate DEFAULT_NEIGHBOR_DECAY = new ExponetialDecay();
    private int somWidth;
    private int somHeight;
    private int maxIters;
    private KernelFunction kf;
    private double initialLearningRate;
    private DecayRate learningDecay;
    private DecayRate neighborDecay;
    private DistanceMetric dm;
    private VectorCollectionFactory<VecPaired<Vec, Integer>> vcFactory;
    private Vec[][] weights;
    private CategoricalResults[] crWeightPairs;
    private VectorCollection<VecPaired<Vec, Integer>> vcCollection;
    private List<List<List<DataPoint>>> weightUpdates;

    public SOM(int somHeight, int somWeight) {
        this(new EuclideanDistance(), somHeight, somWeight);
    }

    public SOM(DistanceMetric dm, int somHeight, int somWeight) {
        this(dm, somHeight, somWeight, new DefaultVectorCollectionFactory<VecPaired<Vec, Integer>>());
    }

    public SOM(DistanceMetric dm, int somHeight, int somWeight, VectorCollectionFactory<VecPaired<Vec, Integer>> vcFactory) {
        this(500, DEFAULT_KF, 0.1, DEFAULT_LEARNING_DECAY, DEFAULT_NEIGHBOR_DECAY, dm, somHeight, somWeight, vcFactory);
    }

    private SOM(int maxIters, KernelFunction kf, double initialLearningRate, DecayRate learningDecay, DecayRate neighborDecay, DistanceMetric dm, int somHeight, int somWeight, VectorCollectionFactory<VecPaired<Vec, Integer>> vcFactory) {
        this.somHeight = somHeight;
        this.somWidth = somWeight;
        this.maxIters = maxIters;
        this.kf = kf;
        this.initialLearningRate = initialLearningRate;
        this.learningDecay = learningDecay;
        this.neighborDecay = neighborDecay;
        this.dm = dm;
        this.vcFactory = vcFactory;
    }

    public void setMaxIterations(int maxIters) {
        if (maxIters < 1) {
            throw new ArithmeticException("At least one iteration must be performed");
        }
        this.maxIters = maxIters;
    }

    public int getMaxIterations() {
        return this.maxIters;
    }

    public void setSomWidth(int somWidth) {
        if (somWidth < 1) {
            throw new ArithmeticException("Lattice width must be positive, not " + somWidth);
        }
        this.somWidth = somWidth;
    }

    public void setSomHeight(int somHeight) {
        if (somHeight < 1) {
            throw new ArithmeticException("ALttice height must be positive, not " + somHeight);
        }
        this.somHeight = somHeight;
    }

    public int getSomHeight() {
        return this.somHeight;
    }

    public int getSomWidth() {
        return this.somWidth;
    }

    public void setInitialLearningRate(double initialLearningRate) {
        if (Double.isInfinite(initialLearningRate) || Double.isNaN(initialLearningRate) || initialLearningRate <= 0.0) {
            throw new ArithmeticException("Learning rate must be a positive constant, not " + initialLearningRate);
        }
        this.initialLearningRate = initialLearningRate;
    }

    public double getInitialLearningRate() {
        return this.initialLearningRate;
    }

    public void setLearningDecay(DecayRate learningDecay) {
        if (learningDecay == null) {
            throw new NullPointerException("Can not set a decay rate to null");
        }
        this.learningDecay = learningDecay;
    }

    public DecayRate getLearningDecay() {
        return this.learningDecay;
    }

    public void setNeighborDecay(DecayRate neighborDecay) {
        if (neighborDecay == null) {
            throw new NullPointerException("Can not set a decay rate to null");
        }
        this.neighborDecay = neighborDecay;
    }

    public DecayRate getNeighborDecay() {
        return this.neighborDecay;
    }

    private double intitalizeWeights(int D) {
        for (int i = 0; i < this.somHeight; ++i) {
            for (int j = 0; j < this.somWidth; ++j) {
                this.weights[i][j] = DenseVector.random(D);
            }
        }
        return Math.max(this.somWidth, this.somHeight);
    }

    private void iterationStep(ExecutorService execServ, int i, DataSet dataSet, double nbrRange, double nbrRangeSqrd, Vec scratch, double learnRate) {
        Vec input_i = dataSet.getDataPoint(i).getNumericalValues();
        PairedReturn<Integer, Integer> closestBMUPR = this.getBMU(input_i);
        int xBest = closestBMUPR.getFirstItem();
        int yBest = closestBMUPR.getSecondItem();
        int xStart = Math.max((int)((double)xBest - nbrRange) - 1, 0);
        int yStart = Math.max((int)((double)yBest - nbrRange) - 1, 0);
        int xEnd = Math.min((int)((double)xBest + nbrRange) + 1, this.somWidth);
        int yEnd = Math.min((int)((double)yBest + nbrRange) + 1, this.somHeight);
        for (int x = xStart; x < xEnd; ++x) {
            Vec[] weights_x = this.weights[x];
            for (int y = yStart; y < yEnd; ++y) {
                int xLength = xBest - x;
                int yLength = yBest - y;
                int pointDistSqrd = xLength * xLength + yLength * yLength;
                if (!((double)pointDistSqrd < nbrRangeSqrd)) continue;
                double distWeight = this.kf.k(Math.sqrt(pointDistSqrd) / nbrRange);
                Vec weights_xy = weights_x[y];
                if (execServ == null) {
                    this.updateWeight(input_i, scratch, weights_xy, distWeight * learnRate);
                    continue;
                }
                this.weightUpdates.get(x).get(y).add(dataSet.getDataPoint(i));
            }
        }
    }

    private List<VecPaired<Vec, Integer>> setUpVectorCollection(ExecutorService threadPool) {
        ArrayList<VecPaired<Vec, Integer>> vecList = new ArrayList<VecPaired<Vec, Integer>>(this.somWidth * this.somHeight);
        for (int i = 0; i < this.weights.length; ++i) {
            for (int j = 0; j < this.weights[i].length; ++j) {
                vecList.add(new VecPaired<Vec, Integer>(this.weights[i][j], vecList.size()));
            }
        }
        this.vcCollection = threadPool == null ? this.vcFactory.getVectorCollection(vecList, this.dm) : this.vcFactory.getVectorCollection(vecList, this.dm, threadPool);
        return vecList;
    }

    private void updateWeight(Vec input_i, Vec scratch, Vec weightVec, double scale) {
        input_i.copyTo(scratch);
        scratch.mutableSubtract(weightVec);
        weightVec.mutableAdd(scale, scratch);
    }

    private PairedReturn<Integer, Integer> getBMU(Vec numericalValues) {
        double bestDist = Double.MAX_VALUE;
        int x = -1;
        int y = -1;
        for (int i = 0; i < this.weights.length; ++i) {
            Vec[] weights_i = this.weights[i];
            for (int j = 0; j < this.weights[i].length; ++j) {
                double dist = this.dm.dist(weights_i[j], numericalValues);
                if (!(dist < bestDist)) continue;
                bestDist = dist;
                x = i;
                y = j;
            }
        }
        return new PairedReturn<Integer, Integer>(x, y);
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }

    private void trainSOM(final DataSet dataSet, final ExecutorService execServ) throws InterruptedException {
        ThreadLocal<Vec> localScratch2;
        ThreadLocal<Vec> localScratch1;
        final int D = dataSet.getNumNumericalVars();
        this.weights = new Vec[this.somHeight][this.somWidth];
        double neighborRadius = this.intitalizeWeights(D);
        Random rand = RandomUtil.getRandom();
        DenseVector scratch = new DenseVector(D);
        int[] pointAccessOrder = new int[dataSet.getSampleSize()];
        for (int i = 0; i < pointAccessOrder.length; ++i) {
            pointAccessOrder[i] = i;
        }
        if (execServ != null) {
            this.weightUpdates = new ArrayList<List<List<DataPoint>>>(this.somHeight);
            for (int i = 0; i < this.somHeight; ++i) {
                ArrayList subList = new ArrayList(this.somWidth);
                this.weightUpdates.add(subList);
                for (int j = 0; j < this.somWidth; ++j) {
                    subList.add(Collections.synchronizedList(new ArrayList()));
                }
            }
            localScratch1 = new ThreadLocal<Vec>(){

                @Override
                protected Vec initialValue() {
                    return new DenseVector(D);
                }
            };
            localScratch2 = new ThreadLocal<Vec>(){

                @Override
                protected Vec initialValue() {
                    return new DenseVector(D);
                }
            };
        } else {
            localScratch1 = null;
            localScratch2 = null;
        }
        for (int iter = 0; iter < this.maxIters; ++iter) {
            final double nbrRange = this.neighborDecay.rate(iter, this.maxIters, neighborRadius);
            final double nbrRangeSqrd = nbrRange * nbrRange;
            final double learnRate = this.learningDecay.rate(iter, this.maxIters, this.initialLearningRate);
            if (execServ == null) {
                ArrayUtils.shuffle(pointAccessOrder, rand);
            } else {
                for (int i = 0; i < this.somHeight; ++i) {
                    for (int j = 0; j < this.somWidth; ++j) {
                        this.weightUpdates.get(i).get(j).clear();
                    }
                }
            }
            if (execServ == null) {
                for (int ir = 0; ir < pointAccessOrder.length; ++ir) {
                    this.iterationStep(execServ, pointAccessOrder[ir], dataSet, nbrRange, nbrRangeSqrd, scratch, learnRate);
                }
            } else {
                int pos = 0;
                int size = dataSet.getSampleSize() / SystemInfo.LogicalCores;
                int extra = dataSet.getSampleSize() % SystemInfo.LogicalCores;
                final CountDownLatch cdl = new CountDownLatch(SystemInfo.LogicalCores);
                while (pos < dataSet.getSampleSize()) {
                    final int to = (extra-- > 0 ? 1 : 0) + pos + size;
                    final int start = pos;
                    pos = to;
                    execServ.submit(new Runnable(){

                        @Override
                        public void run() {
                            for (int i = start; i < to; ++i) {
                                SOM.this.iterationStep(execServ, i, dataSet, nbrRange, nbrRangeSqrd, (Vec)localScratch1.get(), learnRate);
                            }
                            cdl.countDown();
                        }
                    });
                }
                cdl.await();
            }
            if (execServ == null) continue;
            final CountDownLatch cdl = new CountDownLatch(this.somHeight * this.somWidth);
            for (int i = 0; i < this.somHeight; ++i) {
                int j = 0;
                while (j < this.somWidth) {
                    final List<DataPoint> dataList = this.weightUpdates.get(i).get(j);
                    final int x = i;
                    final int y = j++;
                    execServ.submit(new Runnable(){

                        @Override
                        public void run() {
                            Vec mean = (Vec)localScratch1.get();
                            mean.zeroOut();
                            double denom = 0.0;
                            for (DataPoint dp : dataList) {
                                denom += dp.getWeight();
                                mean.mutableAdd(dp.getWeight(), dp.getNumericalValues());
                            }
                            if (denom > 0.0) {
                                mean.mutableDivide(denom);
                            }
                            SOM.this.updateWeight(mean, (Vec)localScratch2.get(), SOM.this.weights[x][y], learnRate);
                            cdl.countDown();
                        }
                    });
                }
            }
            cdl.await();
        }
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.crWeightPairs == null) {
            throw new UntrainedModelException();
        }
        return this.crWeightPairs[this.vcCollection.search(data.getNumericalValues(), 1).get(0).getVector().getPair()];
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        try {
            int i;
            this.trainSOM(dataSet, threadPool);
            List<VecPaired<Vec, Integer>> vecList = this.setUpVectorCollection(threadPool);
            this.crWeightPairs = new CategoricalResults[vecList.size()];
            for (i = 0; i < this.crWeightPairs.length; ++i) {
                this.crWeightPairs[i] = new CategoricalResults(dataSet.getClassSize());
            }
            for (i = 0; i < dataSet.getSampleSize(); ++i) {
                DataPoint dp = dataSet.getDataPoint(i);
                VecPaired<Vec, Integer> vpBMU = this.vcCollection.search(dp.getNumericalValues(), 1).get(0).getVector();
                int index = vpBMU.getPair();
                this.crWeightPairs[index].incProb(dataSet.getDataPointCategory(i), dp.getWeight());
            }
            for (i = 0; i < this.crWeightPairs.length; ++i) {
                this.crWeightPairs[i].normalize();
            }
        }
        catch (InterruptedException ex) {
            Logger.getLogger(SOM.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        this.trainC(dataSet, null);
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

    @Override
    public SOM clone() {
        int i;
        SOM clone = new SOM(this.maxIters, this.kf, this.initialLearningRate, this.learningDecay, this.neighborDecay, this.dm.clone(), this.somHeight, this.somHeight, this.vcFactory.clone());
        if (this.weights != null) {
            clone.weights = new Vec[this.weights.length][this.weights[0].length];
            for (i = 0; i < this.weights.length; ++i) {
                for (int j = 0; j < this.weights[i].length; ++j) {
                    clone.weights[i][j] = this.weights[i][j].clone();
                }
            }
        }
        if (this.vcCollection != null) {
            clone.vcCollection = this.vcCollection.clone();
        }
        if (this.crWeightPairs != null) {
            clone.crWeightPairs = new CategoricalResults[this.crWeightPairs.length];
            for (i = 0; i < this.crWeightPairs.length; ++i) {
                clone.crWeightPairs[i] = this.crWeightPairs[i].clone();
            }
        }
        return clone;
    }
}

