/*
 *  Copyright (C) 2022 Cojen.org
 *
 *  This program is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU Affero General Public License as
 *  published by the Free Software Foundation, either version 3 of the
 *  License, or (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU Affero General Public License for more details.
 *
 *  You should have received a copy of the GNU Affero General Public License
 *  along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */

package org.cojen.tupl.rows;

import java.io.IOException;

import java.lang.invoke.MethodHandle;

import java.util.Arrays;
import java.util.Comparator;

import org.cojen.tupl.Entry;
import org.cojen.tupl.Scanner;
import org.cojen.tupl.Sorter;
import org.cojen.tupl.Transaction;

/**
 * 
 *
 * @author Brian S O'Neill
 */
final class RowSorter<R> extends ScanBatch<R> implements RowConsumer<R> {
    // FIXME: Make configurable and/or "smart".
    private static final int EXTERNAL_THRESHOLD = 1_000_000;

    private ScanBatch<R> mFirstBatch, mLastBatch;

    @SuppressWarnings("unchecked")
    static <R> Scanner<R> sort(SortedQueryLauncher<R> launcher, Transaction txn, Object... args)
        throws IOException
    {
        var sorter = new RowSorter<R>();

        // Pass sorter as if it's a row, but it's actually a RowConsumer.
        Scanner source = launcher.mSource.newScanner(txn, (R) sorter, args);

        int numRows = 0;
        for (Object c = source.row(); c != null; c = source.step(c)) {
            if (++numRows >= EXTERNAL_THRESHOLD) {
                return sorter.finishExternal(launcher, source);
            }
        }

        Comparator<R> comparator = launcher.mComparator;

        if (numRows == 0) {
            return new ARS<>(comparator);
        }

        var rows = (R[]) new Object[numRows];
        ScanBatch first = sorter.mFirstBatch;
        sorter.mFirstBatch = null;
        sorter.mLastBatch = null;
        first.decodeAllRows(rows, 0);

        Arrays.parallelSort(rows, comparator);

        return new ARS<>(launcher.mTable, rows, comparator);
    }

    @SuppressWarnings("unchecked")
    static <R> void sortWrite(SortedQueryLauncher<R> launcher, RowWriter writer,
                              Transaction txn, Object... args)
        throws IOException
    {
        var ext = new External<R>(launcher);

        Scanner<Entry> sorted;

        try {
            // Pass ext as if it's a row, but it's actually a RowConsumer.
            ext.transferAll(launcher.mSource.newScanner(txn, (R) ext, args));
            sorted = ext.finishScan();
        } catch (Throwable e) {
            throw ext.failed(e);
        }

        MethodHandle mh = launcher.mWriteRow;

        if (mh == null) {
            SecondaryInfo info = ext.mSortedInfo;
            RowGen rowGen = info.rowGen();
            // This is a bit ugly -- creating a projection specification only to immediately
            // crack it open.
            byte[] spec = DecodePartialMaker.makeFullSpec(rowGen, null, launcher.projection());
            launcher.mWriteRow = mh = WriteRowMaker.makeWriteRowHandle(info, spec);
        }

        try (sorted) {
            for (Entry e = sorted.row(); e != null; e = sorted.step(e)) {
                mh.invokeExact(writer, e.key(), e.value());
            }
        } catch (Throwable e) {
            throw RowUtils.rethrow(e);
        }
    }

    @Override
    public void beginBatch(Scanner scanner, RowEvaluator<R> evaluator) {
        ScanBatch<R> batch;
        if (mLastBatch == null) {
            mFirstBatch = batch = this;
        } else {
            batch = new ScanBatch<R>();
            mLastBatch.appendNext(batch);
        }
        mLastBatch = batch;
        batch.mEvaluator = evaluator;
    }

    @Override
    public void accept(byte[] key, byte[] value) throws IOException {
        mLastBatch.addEntry(key, value);
    }

    private Scanner<R> finishExternal(SortedQueryLauncher<R> launcher, Scanner source)
        throws IOException
    {
        var ext = new External<R>(launcher);

        try {
            RowDecoder<R> decoder = launcher.mDecoder;

            if (decoder == null) {
                launcher.mDecoder = decoder = SortDecoderMaker
                    .findDecoder(ext.mRowType, ext.mSortedInfo, launcher.projection());
            }

            // Transfer all the undecoded rows into the sorter.

            ScanBatch<R> batch = mFirstBatch;

            mFirstBatch = null;
            mLastBatch = null;

            byte[][] kvPairs = null;

            do {
                ext.assignTranscoder(batch.mEvaluator);
                kvPairs = batch.transcode(ext.mTranscoder, ext.mSorter, kvPairs);
            } while ((batch = batch.detachNext()) != null);

            // Transfer all the rest.
            ext.transferAll(source);

            return new SRS<>(ext.finishScan(), decoder, launcher.mComparator);
        } catch (Throwable e) {
            throw ext.failed(e);
        }
    }

    private static class ARS<R> extends ArrayScanner<R> {
        private final Comparator<R> mComparator;

        ARS(Comparator<R> comparator) {
            mComparator = comparator;
        }

        ARS(BaseTable<R> table, R[] rows, Comparator<R> comparator) {
            super(table, rows);
            mComparator = comparator;
        }

        @Override
        public Comparator<R> getComparator() {
            return mComparator;
        }
    }

    private static class SRS<R> extends ScannerScanner<R> {
        private final Comparator<R> mComparator;

        SRS(Scanner<Entry> scanner, RowDecoder<R> decoder, Comparator<R> comparator)
            throws IOException
        {
            super(scanner, decoder);
            mComparator = comparator;
        }

        @Override
        public int characteristics() {
            return NONNULL | ORDERED | IMMUTABLE | SORTED;
        }

        @Override
        public Comparator<R> getComparator() {
            return mComparator;
        }
    }

    /**
     * Performs an external merge sort.
     */
    private static final class External<R> implements RowConsumer<R> {
        final RowStore mRowStore;
        final Sorter mSorter;
        final Class<?> mRowType;
        final SecondaryInfo mSortedInfo;

        Transcoder mTranscoder;

        private byte[][] mBatch;
        private int mBatchSize;

        External(SortedQueryLauncher<R> launcher) throws IOException {
            mRowStore = launcher.mTable.rowStore();
            mSorter = mRowStore.mDatabase.newSorter();
            mRowType = launcher.mTable.rowType();

            SecondaryInfo sortedInfo = launcher.mSortedInfo;

            if (sortedInfo == null) {
                launcher.mSortedInfo = sortedInfo = SortDecoderMaker.findSortedInfo
                    (mRowType, launcher.mSpec, launcher.projection(), true);
            }

            mSortedInfo = sortedInfo;

            mBatch = new byte[100][];
        }

        void assignTranscoder(RowEvaluator<R> evaluator) {
            mTranscoder = mRowStore.findSortTranscoder(mRowType, evaluator, mSortedInfo);
        }

        /**
         * Transfers all remaining undecoded rows from the source into the sorter.
         */
        @SuppressWarnings("unchecked")
        void transferAll(Scanner source) throws IOException {
            // Pass `this` as if it's a row, but it's actually a RowConsumer.
            while (source.step(this) != null);
        }

        Scanner<Entry> finishScan() throws IOException {
            flush();
            mBatch = null;
            return mSorter.finishScan();
        }

        RuntimeException failed(Throwable e) {
            try {
                mSorter.reset();
            } catch (Throwable e2) {
                RowUtils.suppress(e, e2);
            }
            throw RowUtils.rethrow(e);
        }

        @Override
        public void beginBatch(Scanner scanner, RowEvaluator<R> evaluator) throws IOException {
            flush();
            assignTranscoder(evaluator);
        }

        @Override
        public void accept(byte[] key, byte[] value) throws IOException {
            byte[][] batch = mBatch;
            int size = mBatchSize;
            mTranscoder.transcode(key, value, batch, size);
            size += 2;
            if (size < batch.length) {
                mBatchSize = size;
            } else {
                mSorter.addBatch(batch, 0, size >> 1);
                mBatchSize = 0;
            }
        }

        private void flush() throws IOException {
            if (mBatchSize > 0) {
                mSorter.addBatch(mBatch, 0, mBatchSize >> 1);
                mBatchSize = 0;
            }
        }
    }
}
