/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.classification.utils;

import java.io.IOException;
import java.util.HashMap;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.Terms;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.grouping.GroupDocs;
import org.apache.lucene.search.grouping.GroupingSearch;
import org.apache.lucene.search.grouping.TopGroups;
import org.apache.lucene.store.Directory;
import org.apache.lucene.uninverting.UninvertingReader;

public class DatasetSplitter {
    private final double crossValidationRatio;
    private final double testRatio;

    public DatasetSplitter(double testRatio, double crossValidationRatio) {
        this.crossValidationRatio = crossValidationRatio;
        this.testRatio = testRatio;
    }

    public void split(LeafReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex, Analyzer analyzer, boolean termVectors, String classFieldName, String ... fieldNames) throws IOException {
        IndexWriter testWriter = new IndexWriter(testIndex, new IndexWriterConfig(analyzer));
        IndexWriter cvWriter = new IndexWriter(crossValidationIndex, new IndexWriterConfig(analyzer));
        IndexWriter trainingWriter = new IndexWriter(trainingIndex, new IndexWriterConfig(analyzer));
        Terms terms = originalIndex.terms(classFieldName);
        long noOfClasses = -1L;
        if (terms != null) {
            noOfClasses = terms.size();
        }
        if (noOfClasses == -1L) {
            noOfClasses = 10000L;
        }
        HashMap<String, UninvertingReader.Type> mapping = new HashMap<String, UninvertingReader.Type>();
        mapping.put(classFieldName, UninvertingReader.Type.SORTED);
        UninvertingReader uninvertingReader = new UninvertingReader(originalIndex, mapping);
        try {
            IndexSearcher indexSearcher = new IndexSearcher((IndexReader)uninvertingReader);
            GroupingSearch gs = new GroupingSearch(classFieldName);
            gs.setGroupSort(Sort.INDEXORDER);
            gs.setSortWithinGroup(Sort.INDEXORDER);
            gs.setAllGroups(true);
            gs.setGroupDocsLimit(originalIndex.maxDoc());
            TopGroups topGroups = gs.search(indexSearcher, (Query)new MatchAllDocsQuery(), 0, (int)noOfClasses);
            FieldType ft = new FieldType(TextField.TYPE_STORED);
            if (termVectors) {
                ft.setStoreTermVectors(true);
                ft.setStoreTermVectorOffsets(true);
                ft.setStoreTermVectorPositions(true);
            }
            int b = 0;
            for (GroupDocs group : topGroups.groups) {
                int totalHits = group.totalHits;
                double testSize = (double)totalHits * this.testRatio;
                int tc = 0;
                double cvSize = (double)totalHits * this.crossValidationRatio;
                int cvc = 0;
                for (ScoreDoc scoreDoc : group.scoreDocs) {
                    Document doc = this.createNewDoc(originalIndex, ft, scoreDoc, fieldNames);
                    if (b % 2 == 0 && (double)tc < testSize) {
                        testWriter.addDocument((Iterable)doc);
                        ++tc;
                    } else if ((double)cvc < cvSize) {
                        cvWriter.addDocument((Iterable)doc);
                        ++cvc;
                    } else {
                        trainingWriter.addDocument((Iterable)doc);
                    }
                    ++b;
                }
            }
            testWriter.commit();
            cvWriter.commit();
            trainingWriter.commit();
            testWriter.forceMerge(3);
            cvWriter.forceMerge(3);
            trainingWriter.forceMerge(3);
        }
        catch (Exception e) {
            throw new IOException(e);
        }
        finally {
            testWriter.close();
            cvWriter.close();
            trainingWriter.close();
            uninvertingReader.close();
        }
    }

    private Document createNewDoc(LeafReader originalIndex, FieldType ft, ScoreDoc scoreDoc, String[] fieldNames) throws IOException {
        Document doc = new Document();
        Document document = originalIndex.document(scoreDoc.doc);
        if (fieldNames != null && fieldNames.length > 0) {
            for (String fieldName : fieldNames) {
                IndexableField field = document.getField(fieldName);
                if (field == null) continue;
                doc.add((IndexableField)new Field(fieldName, field.stringValue(), ft));
            }
        } else {
            for (IndexableField field : document.getFields()) {
                if (field.readerValue() != null) {
                    doc.add((IndexableField)new Field(field.name(), field.readerValue(), ft));
                    continue;
                }
                if (field.binaryValue() != null) {
                    doc.add((IndexableField)new Field(field.name(), field.binaryValue(), ft));
                    continue;
                }
                if (field.stringValue() != null) {
                    doc.add((IndexableField)new Field(field.name(), field.stringValue(), ft));
                    continue;
                }
                if (field.numericValue() == null) continue;
                doc.add((IndexableField)new Field(field.name(), field.numericValue().toString(), ft));
            }
        }
        return doc;
    }
}

