/*
 * Decompiled with CFR 0.152.
 */
package datafu.pig.sampling;

import java.io.IOException;
import java.util.Iterator;
import java.util.Random;
import org.apache.pig.EvalFunc;
import org.apache.pig.builtin.Nondeterministic;
import org.apache.pig.data.BagFactory;
import org.apache.pig.data.DataBag;
import org.apache.pig.data.Tuple;
import org.apache.pig.impl.logicalLayer.FrontendException;
import org.apache.pig.impl.logicalLayer.schema.Schema;

@Nondeterministic
public class WeightedSample
extends EvalFunc<DataBag> {
    BagFactory bagFactory = BagFactory.getInstance();
    Long seed = null;

    public WeightedSample() {
    }

    public WeightedSample(String seed) {
        this.seed = Long.parseLong(seed);
    }

    public DataBag exec(Tuple input) throws IOException {
        DataBag output = this.bagFactory.newDefaultBag();
        DataBag samples = (DataBag)input.get(0);
        if (samples == null || samples.size() == 0L) {
            return output;
        }
        int numSamples = (int)samples.size();
        if (numSamples == 1) {
            return samples;
        }
        Tuple[] tuples = new Tuple[numSamples];
        int tupleIndex = 0;
        Iterator iterator = samples.iterator();
        while (iterator.hasNext()) {
            Tuple tuple;
            tuples[tupleIndex] = tuple = (Tuple)iterator.next();
            ++tupleIndex;
        }
        double[] scores = new double[numSamples];
        int scoreIndex = ((Number)input.get(1)).intValue();
        tupleIndex = 0;
        for (Tuple tuple : samples) {
            double score = ((Number)tuple.get(scoreIndex)).doubleValue();
            scores[tupleIndex] = score = Math.max(score, Double.MIN_NORMAL);
            ++tupleIndex;
        }
        int limitSamples = numSamples;
        if (input.size() == 3) {
            limitSamples = Math.min(((Number)input.get(2)).intValue(), numSamples);
        }
        Random rng = null;
        rng = this.seed == null ? new Random() : new Random(this.seed);
        for (int k = 0; k < limitSamples; ++k) {
            double val = rng.nextDouble();
            int idx = this.find_cumsum_interval(scores, val, k, numSamples);
            if (idx == numSamples) {
                idx = rng.nextInt(numSamples - k) + k;
            }
            output.add(tuples[idx]);
            scores[idx] = scores[k];
            tuples[idx] = tuples[k];
        }
        return output;
    }

    public int find_cumsum_interval(double[] scores, double val, int begin, int end) {
        int i;
        double sum = 0.0;
        double cumsum = 0.0;
        for (i = begin; i < end; ++i) {
            sum += scores[i];
        }
        for (i = begin; i < end; ++i) {
            if (!((cumsum += scores[i]) / sum > val)) continue;
            return i;
        }
        return end;
    }

    public Schema outputSchema(Schema input) {
        try {
            if (input.size() != 2 && input.size() != 3) {
                throw new RuntimeException("Expected input to have two or three fields");
            }
            Schema.FieldSchema inputFieldSchema = input.getField(0);
            if (inputFieldSchema.type != 120) {
                throw new RuntimeException("Expected a BAG as first input, got: " + inputFieldSchema.type);
            }
            if (input.getField((int)1).type != 10) {
                throw new RuntimeException("Expected an INT as second input, got: " + input.getField((int)1).type);
            }
            if (input.size() == 3 && input.getField((int)2).type != 10 && input.getField((int)2).type != 15) {
                throw new RuntimeException("Expected an INT or LONG as second input, got: " + input.getField((int)2).type);
            }
            return new Schema(new Schema.FieldSchema(this.getSchemaName(((Object)((Object)this)).getClass().getName().toLowerCase(), input), inputFieldSchema.schema, 120));
        }
        catch (FrontendException e) {
            e.printStackTrace();
            throw new RuntimeException(e);
        }
    }
}

