/*
 * Decompiled with CFR 0.152.
 */
package datafu.pig.sampling;

import datafu.pig.sampling.ScoredTuple;
import java.io.IOException;
import java.util.Comparator;
import org.apache.commons.math.random.RandomDataImpl;
import org.apache.pig.AlgebraicEvalFunc;
import org.apache.pig.EvalFunc;
import org.apache.pig.data.BagFactory;
import org.apache.pig.data.DataBag;
import org.apache.pig.data.Tuple;
import org.apache.pig.data.TupleFactory;
import org.apache.pig.impl.logicalLayer.FrontendException;
import org.apache.pig.impl.logicalLayer.schema.Schema;

public class SimpleRandomSample
extends AlgebraicEvalFunc<DataBag> {
    public static final String OUTPUT_BAG_NAME_PREFIX = "SRS";
    private static final TupleFactory _TUPLE_FACTORY = TupleFactory.getInstance();
    private static final BagFactory _BAG_FACTORY = BagFactory.getInstance();

    public SimpleRandomSample() {
        super(new String[0]);
    }

    @Deprecated
    public SimpleRandomSample(String samplingProbability) {
        super(new String[0]);
        double p = Double.parseDouble(samplingProbability);
        SimpleRandomSample.verifySamplingProbability(p);
    }

    public String getInitial() {
        return Initial.class.getName();
    }

    public String getIntermed() {
        return Intermediate.class.getName();
    }

    public String getFinal() {
        return Final.class.getName();
    }

    public Schema outputSchema(Schema input) {
        try {
            Schema.FieldSchema inputFieldSchema = input.getField(0);
            if (inputFieldSchema.type != 120) {
                throw new RuntimeException("Expected a BAG as input");
            }
            return new Schema(new Schema.FieldSchema(super.getSchemaName(OUTPUT_BAG_NAME_PREFIX, input), inputFieldSchema.schema, 120));
        }
        catch (FrontendException e) {
            throw new RuntimeException(e);
        }
    }

    private static double getQ1(long n, double p) {
        double t1 = 20.0 / (3.0 * (double)n);
        double q1 = p + t1 - Math.sqrt(t1 * t1 + 3.0 * t1 * p);
        return q1;
    }

    private static double getQ2(long n, double p) {
        double t2 = 10.0 / (double)n;
        double q2 = p + t2 + Math.sqrt(t2 * t2 + 2.0 * t2 * p);
        return q2;
    }

    private static void verifySamplingProbability(double p) {
        if (p < 0.0 || p > 1.0) {
            throw new IllegalArgumentException("Sampling probabiilty must be inside [0, 1].");
        }
    }

    static class ScoredTupleComparator
    implements Comparator<Tuple> {
        private static final ScoredTupleComparator _instance = new ScoredTupleComparator();

        ScoredTupleComparator() {
        }

        public static final ScoredTupleComparator getInstance() {
            return _instance;
        }

        @Override
        public int compare(Tuple o1, Tuple o2) {
            try {
                ScoredTuple t1 = ScoredTuple.fromIntermediateTuple(o1);
                ScoredTuple t2 = ScoredTuple.fromIntermediateTuple(o2);
                return t1.getScore().compareTo(t2.getScore());
            }
            catch (Throwable e) {
                throw new RuntimeException("Cannot compare " + o1 + " and " + o2 + ".", e);
            }
        }
    }

    public static class Final
    extends EvalFunc<DataBag> {
        public Final() {
        }

        @Deprecated
        public Final(String samplingProbability) {
        }

        public DataBag exec(Tuple input) throws IOException {
            DataBag bag = (DataBag)input.get(0);
            boolean first = true;
            double p = 0.0;
            long n = 0L;
            DataBag selected = _BAG_FACTORY.newDefaultBag();
            DataBag waiting = _BAG_FACTORY.newSortedBag((Comparator)ScoredTupleComparator.getInstance());
            for (Tuple tuple : bag) {
                if (first) {
                    p = (Double)tuple.get(0);
                    first = false;
                }
                n += ((Long)tuple.get(1)).longValue();
                selected.addAll((DataBag)tuple.get(3));
                waiting.addAll((DataBag)tuple.get(4));
            }
            long numSelected = selected.size();
            long numWaiting = waiting.size();
            long s = (long)Math.ceil(p * (double)n);
            System.out.println("To sample " + s + " items from " + n + ", we pre-selected " + numSelected + ", and waitlisted " + waiting.size() + ".");
            long numNeeded = s - selected.size();
            if (numNeeded < 0L) {
                System.err.println("Pre-selected " + numSelected + " items, but only needed " + s + ".");
            }
            for (Tuple scored : waiting) {
                if (numNeeded <= 0L) break;
                selected.add(ScoredTuple.fromIntermediateTuple(scored).getTuple());
                --numNeeded;
            }
            if (numNeeded > 0L) {
                System.err.println("The waiting list only has " + numWaiting + " items, but needed " + numNeeded + " more.");
            }
            return selected;
        }
    }

    public static class Intermediate
    extends EvalFunc<Tuple> {
        public Intermediate() {
        }

        @Deprecated
        public Intermediate(String samplingProbability) {
        }

        public Tuple exec(Tuple input) throws IOException {
            DataBag bag = (DataBag)input.get(0);
            DataBag selected = _BAG_FACTORY.newDefaultBag();
            DataBag aggWaiting = _BAG_FACTORY.newDefaultBag();
            boolean first = true;
            double p = 0.0;
            long numItems = 0L;
            long n1 = 0L;
            for (Tuple tuple : bag) {
                if (first) {
                    p = (Double)tuple.get(0);
                    first = false;
                }
                n1 = Math.max((Long)tuple.get(2), numItems += ((Long)tuple.get(1)).longValue());
                selected.addAll((DataBag)tuple.get(3));
                aggWaiting.addAll((DataBag)tuple.get(4));
            }
            DataBag waiting = _BAG_FACTORY.newDefaultBag();
            if (n1 > 0L) {
                double q1 = SimpleRandomSample.getQ1(n1, p);
                double q2 = SimpleRandomSample.getQ2(n1, p);
                for (Tuple t : aggWaiting) {
                    ScoredTuple scored = ScoredTuple.fromIntermediateTuple(t);
                    if (scored.getScore() < q1) {
                        selected.add(scored.getTuple());
                        continue;
                    }
                    if (!(scored.getScore() < q2)) continue;
                    waiting.add(t);
                }
            }
            Tuple output = _TUPLE_FACTORY.newTuple();
            output.append((Object)p);
            output.append((Object)numItems);
            output.append((Object)n1);
            output.append((Object)selected);
            output.append((Object)waiting);
            return output;
        }
    }

    public static class Initial
    extends EvalFunc<Tuple> {
        private static RandomDataImpl _RNG = new RandomDataImpl();
        private boolean _first = true;
        private double _p = -1.0;
        private long _n1 = 0L;
        private long _localCount = 0L;

        private static synchronized double nextDouble() {
            return _RNG.nextUniform(0.0, 1.0);
        }

        public Initial() {
        }

        @Deprecated
        public Initial(String samplingProbability) {
            this._p = Double.parseDouble(samplingProbability);
        }

        public Tuple exec(Tuple input) throws IOException {
            double p;
            int numArgs = input.size();
            if (numArgs == 1) {
                if (this._p < 0.0) {
                    throw new IllegalArgumentException("Sampling probability is not given.");
                }
            } else if (numArgs < 2 || numArgs > 3) {
                throw new IllegalArgumentException("The input tuple should have either two or three fields: a bag of items, the sampling probability, and optionally a good lower bound of the size of the population or the exact number.");
            }
            DataBag items = (DataBag)input.get(0);
            long numItems = items.size();
            this._localCount += numItems;
            double d = p = numArgs == 1 ? this._p : ((Number)input.get(1)).doubleValue();
            if (this._first) {
                this._p = p;
                SimpleRandomSample.verifySamplingProbability(p);
            } else if (p != this._p) {
                throw new IllegalArgumentException("The sampling probability must be a scalar, but found two different values: " + this._p + " and " + p + ".");
            }
            long n1 = 0L;
            if (numArgs > 2) {
                n1 = ((Number)input.get(2)).longValue();
                if (this._first) {
                    this._n1 = n1;
                } else if (n1 != this._n1) {
                    throw new IllegalArgumentException("The lower bound of the population size must be a scalar, but found two different values: " + this._n1 + " and " + n1 + ".");
                }
            }
            this._first = false;
            n1 = Math.max(n1, this._localCount);
            DataBag selected = _BAG_FACTORY.newDefaultBag();
            DataBag waiting = _BAG_FACTORY.newDefaultBag();
            if (n1 > 0L) {
                double q1 = SimpleRandomSample.getQ1(n1, p);
                double q2 = SimpleRandomSample.getQ2(n1, p);
                for (Tuple t : items) {
                    double x = Initial.nextDouble();
                    if (x < q1) {
                        selected.add(t);
                        continue;
                    }
                    if (!(x < q2)) continue;
                    waiting.add(new ScoredTuple(x, t).getIntermediateTuple(_TUPLE_FACTORY));
                }
            }
            Tuple output = _TUPLE_FACTORY.newTuple();
            output.append((Object)p);
            output.append((Object)numItems);
            output.append((Object)n1);
            output.append((Object)selected);
            output.append((Object)waiting);
            return output;
        }
    }
}

