/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.embeddings.reader.impl;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import lombok.NonNull;
import org.deeplearning4j.clustering.sptree.DataPoint;
import org.deeplearning4j.clustering.vptree.VPTree;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.util.SetUtils;

public class TreeModelUtils<T extends SequenceElement>
extends BasicModelUtils<T> {
    protected VPTree vpTree;

    @Override
    public void init(@NonNull WeightLookupTable<T> lookupTable) {
        if (lookupTable == null) {
            throw new NullPointerException("lookupTable is marked @NonNull but is null");
        }
        super.init(lookupTable);
        this.vpTree = null;
    }

    protected synchronized void checkTree() {
        if (this.vpTree == null) {
            ArrayList<DataPoint> points = new ArrayList<DataPoint>();
            for (String word : this.vocabCache.words()) {
                points.add(new DataPoint(this.vocabCache.indexOf(word), this.lookupTable.vector(word)));
            }
            this.vpTree = new VPTree(points);
        }
    }

    @Override
    public Collection<String> wordsNearest(String label, int n) {
        if (!this.vocabCache.hasToken(label)) {
            return new ArrayList<String>();
        }
        Collection<String> collection = this.wordsNearest(Arrays.asList(label), new ArrayList<String>(), n + 1);
        if (collection.contains(label)) {
            collection.remove(label);
        }
        return collection;
    }

    @Override
    public Collection<String> wordsNearest(Collection<String> positive, Collection<String> negative, int top) {
        for (String p : SetUtils.union(new HashSet<String>(positive), new HashSet<String>(negative))) {
            if (this.vocabCache.containsWord(p)) continue;
            return new ArrayList<String>();
        }
        INDArray words = Nd4j.create((int)(positive.size() + negative.size()), (int)this.lookupTable.layerSize());
        int row = 0;
        for (String s : positive) {
            words.putRow((long)row++, this.lookupTable.vector(s));
        }
        for (String s : negative) {
            words.putRow((long)row++, this.lookupTable.vector(s).mul((Number)-1));
        }
        INDArray mean = words.isMatrix() ? words.mean(new int[]{0}) : words;
        return this.wordsNearest(mean, top);
    }

    @Override
    public Collection<String> wordsNearest(INDArray words, int top) {
        this.checkTree();
        ArrayList add = new ArrayList();
        ArrayList distances = new ArrayList();
        this.vpTree.search(words, top, add, distances);
        ArrayList<String> ret = new ArrayList<String>();
        for (DataPoint e : add) {
            String word = this.vocabCache.wordAtIndex(e.getIndex());
            ret.add(word);
        }
        return super.wordsNearest(words, top);
    }
}

