/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.classification;

import java.io.IOException;
import java.io.Reader;
import java.io.StringReader;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.Classifier;
import org.apache.lucene.index.AtomicReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.queries.mlt.MoreLikeThis;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.util.BytesRef;

public class KNearestNeighborClassifier
implements Classifier<BytesRef> {
    private MoreLikeThis mlt;
    private String[] textFieldNames;
    private String classFieldName;
    private IndexSearcher indexSearcher;
    private final int k;
    private Query query;
    private int minDocsFreq;
    private int minTermFreq;

    public KNearestNeighborClassifier(int k) {
        this.k = k;
    }

    public KNearestNeighborClassifier(int k, int minDocsFreq, int minTermFreq) {
        this.k = k;
        this.minDocsFreq = minDocsFreq;
        this.minTermFreq = minTermFreq;
    }

    @Override
    public ClassificationResult<BytesRef> assignClass(String text) throws IOException {
        if (this.mlt == null) {
            throw new IOException("You must first call Classifier#train");
        }
        BooleanQuery mltQuery = new BooleanQuery();
        for (String textFieldName : this.textFieldNames) {
            mltQuery.add(new BooleanClause(this.mlt.like(textFieldName, new Reader[]{new StringReader(text)}), BooleanClause.Occur.SHOULD));
        }
        WildcardQuery classFieldQuery = new WildcardQuery(new Term(this.classFieldName, "*"));
        mltQuery.add(new BooleanClause((Query)classFieldQuery, BooleanClause.Occur.MUST));
        if (this.query != null) {
            mltQuery.add(this.query, BooleanClause.Occur.MUST);
        }
        TopDocs topDocs = this.indexSearcher.search((Query)mltQuery, this.k);
        return this.selectClassFromNeighbors(topDocs);
    }

    private ClassificationResult<BytesRef> selectClassFromNeighbors(TopDocs topDocs) throws IOException {
        Integer count;
        HashMap<BytesRef, Integer> classCounts = new HashMap<BytesRef, Integer>();
        for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
            BytesRef cl = new BytesRef((CharSequence)this.indexSearcher.doc(scoreDoc.doc).getField(this.classFieldName).stringValue());
            count = (Integer)classCounts.get(cl);
            if (count != null) {
                classCounts.put(cl, count + 1);
                continue;
            }
            classCounts.put(cl, 1);
        }
        double max = 0.0;
        BytesRef assignedClass = new BytesRef();
        for (Map.Entry entry : classCounts.entrySet()) {
            count = (Integer)entry.getValue();
            if (!((double)count.intValue() > max)) continue;
            max = count.intValue();
            assignedClass = ((BytesRef)entry.getKey()).clone();
        }
        double score = max / (double)this.k;
        return new ClassificationResult<BytesRef>(assignedClass, score);
    }

    @Override
    public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
        this.train(atomicReader, textFieldName, classFieldName, analyzer, null);
    }

    @Override
    public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException {
        this.train(atomicReader, new String[]{textFieldName}, classFieldName, analyzer, query);
    }

    @Override
    public void train(AtomicReader atomicReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
        this.textFieldNames = textFieldNames;
        this.classFieldName = classFieldName;
        this.mlt = new MoreLikeThis((IndexReader)atomicReader);
        this.mlt.setAnalyzer(analyzer);
        this.mlt.setFieldNames(textFieldNames);
        this.indexSearcher = new IndexSearcher((IndexReader)atomicReader);
        if (this.minDocsFreq > 0) {
            this.mlt.setMinDocFreq(this.minDocsFreq);
        }
        if (this.minTermFreq > 0) {
            this.mlt.setMinTermFreq(this.minTermFreq);
        }
        this.query = query;
    }
}

