/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.operator.aggregation;

import com.facebook.presto.block.Block;
import com.facebook.presto.block.BlockBuilder;
import com.facebook.presto.block.BlockCursor;
import com.facebook.presto.operator.GroupByIdBlock;
import com.facebook.presto.operator.aggregation.Accumulator;
import com.facebook.presto.operator.aggregation.GroupedAccumulator;
import com.facebook.presto.operator.aggregation.SimpleAggregationFunction;
import com.facebook.presto.tuple.TupleInfo;
import com.facebook.presto.util.array.DoubleBigArray;
import com.facebook.presto.util.array.LongBigArray;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;

public class AverageAggregation
extends SimpleAggregationFunction {
    private final boolean inputIsLong;

    public AverageAggregation(TupleInfo.Type parameterType) {
        super(TupleInfo.SINGLE_DOUBLE, TupleInfo.SINGLE_VARBINARY, parameterType);
        if (parameterType == TupleInfo.Type.FIXED_INT_64) {
            this.inputIsLong = true;
        } else if (parameterType == TupleInfo.Type.DOUBLE) {
            this.inputIsLong = false;
        } else {
            throw new IllegalArgumentException("Expected parameter type to be FIXED_INT_64 or DOUBLE, but was " + parameterType);
        }
    }

    @Override
    protected GroupedAccumulator createGroupedAccumulator(Optional<Integer> maskChannel, Optional<Integer> sampleWeightChannel, double confidence, int valueChannel) {
        Preconditions.checkArgument((confidence == 1.0 ? 1 : 0) != 0, (Object)"avg does not support approximate queries");
        return new AverageGroupedAccumulator(valueChannel, this.inputIsLong, maskChannel, sampleWeightChannel);
    }

    @Override
    protected Accumulator createAccumulator(Optional<Integer> maskChannel, Optional<Integer> sampleWeightChannel, double confidence, int valueChannel) {
        Preconditions.checkArgument((confidence == 1.0 ? 1 : 0) != 0, (Object)"avg does not support approximate queries");
        return new AverageAccumulator(valueChannel, this.inputIsLong, maskChannel, sampleWeightChannel);
    }

    public static class AverageAccumulator
    extends SimpleAggregationFunction.SimpleAccumulator {
        private final boolean inputIsLong;
        private long count;
        private double sum;

        public AverageAccumulator(int valueChannel, boolean inputIsLong, Optional<Integer> maskChannel, Optional<Integer> sampleWeightChannel) {
            super(valueChannel, TupleInfo.SINGLE_DOUBLE, TupleInfo.SINGLE_VARBINARY, maskChannel, sampleWeightChannel);
            this.inputIsLong = inputIsLong;
        }

        @Override
        protected void processInput(Block block, Optional<Block> maskBlock, Optional<Block> sampleWeightBlock) {
            BlockCursor values = block.cursor();
            BlockCursor masks = null;
            if (maskBlock.isPresent()) {
                masks = ((Block)maskBlock.get()).cursor();
            }
            BlockCursor sampleWeights = null;
            if (sampleWeightBlock.isPresent()) {
                sampleWeights = ((Block)sampleWeightBlock.get()).cursor();
            }
            for (int position = 0; position < block.getPositionCount(); ++position) {
                Preconditions.checkState((boolean)values.advanceNextPosition());
                Preconditions.checkState((masks == null || masks.advanceNextPosition() ? 1 : 0) != 0);
                Preconditions.checkState((sampleWeights == null || sampleWeights.advanceNextPosition() ? 1 : 0) != 0);
                long sampleWeight = SimpleAggregationFunction.computeSampleWeight(masks, sampleWeights);
                if (values.isNull() || sampleWeight <= 0L) continue;
                this.count += sampleWeight;
                if (this.inputIsLong) {
                    this.sum += (double)(sampleWeight * values.getLong());
                    continue;
                }
                this.sum += (double)sampleWeight * values.getDouble();
            }
        }

        @Override
        protected void processIntermediate(Block block) {
            BlockCursor intermediates = block.cursor();
            for (int position = 0; position < block.getPositionCount(); ++position) {
                Preconditions.checkState((boolean)intermediates.advanceNextPosition());
                Slice value = intermediates.getSlice();
                this.count += value.getLong(0);
                this.sum += value.getDouble(8);
            }
        }

        @Override
        public void evaluateIntermediate(BlockBuilder out) {
            Slice value = Slices.allocate((int)16);
            value.setLong(0, this.count);
            value.setDouble(8, this.sum);
            out.append(value);
        }

        @Override
        public void evaluateFinal(BlockBuilder out) {
            if (this.count != 0L) {
                out.append(this.sum / (double)this.count);
            } else {
                out.appendNull();
            }
        }
    }

    public static class AverageGroupedAccumulator
    extends SimpleAggregationFunction.SimpleGroupedAccumulator {
        private final boolean inputIsLong;
        private final LongBigArray counts;
        private final DoubleBigArray sums;

        public AverageGroupedAccumulator(int valueChannel, boolean inputIsLong, Optional<Integer> maskChannel, Optional<Integer> sampleWeightChannel) {
            super(valueChannel, TupleInfo.SINGLE_DOUBLE, TupleInfo.SINGLE_VARBINARY, maskChannel, sampleWeightChannel);
            this.inputIsLong = inputIsLong;
            this.counts = new LongBigArray();
            this.sums = new DoubleBigArray();
        }

        @Override
        public long getEstimatedSize() {
            return this.counts.sizeOf() + this.sums.sizeOf();
        }

        @Override
        public void processInput(GroupByIdBlock groupIdsBlock, Block valuesBlock, Optional<Block> maskBlock, Optional<Block> sampleWeightBlock) {
            this.counts.ensureCapacity(groupIdsBlock.getGroupCount());
            this.sums.ensureCapacity(groupIdsBlock.getGroupCount());
            BlockCursor values = valuesBlock.cursor();
            BlockCursor masks = null;
            if (maskBlock.isPresent()) {
                masks = ((Block)maskBlock.get()).cursor();
            }
            BlockCursor sampleWeights = null;
            if (sampleWeightBlock.isPresent()) {
                sampleWeights = ((Block)sampleWeightBlock.get()).cursor();
            }
            for (int position = 0; position < groupIdsBlock.getPositionCount(); ++position) {
                Preconditions.checkState((boolean)values.advanceNextPosition());
                Preconditions.checkState((masks == null || masks.advanceNextPosition() ? 1 : 0) != 0);
                Preconditions.checkState((sampleWeights == null || sampleWeights.advanceNextPosition() ? 1 : 0) != 0);
                long groupId = groupIdsBlock.getGroupId(position);
                long sampleWeight = SimpleAggregationFunction.computeSampleWeight(masks, sampleWeights);
                if (values.isNull() || sampleWeight <= 0L) continue;
                this.counts.add(groupId, sampleWeight);
                double value = this.inputIsLong ? (double)values.getLong() : values.getDouble();
                this.sums.add(groupId, (double)sampleWeight * value);
            }
            Preconditions.checkState((!values.advanceNextPosition() ? 1 : 0) != 0);
        }

        @Override
        public void processIntermediate(GroupByIdBlock groupIdsBlock, Block block) {
            this.counts.ensureCapacity(groupIdsBlock.getGroupCount());
            this.sums.ensureCapacity(groupIdsBlock.getGroupCount());
            BlockCursor intermediateValues = block.cursor();
            for (int position = 0; position < groupIdsBlock.getPositionCount(); ++position) {
                Preconditions.checkState((boolean)intermediateValues.advanceNextPosition());
                long groupId = groupIdsBlock.getGroupId(position);
                Slice value = intermediateValues.getSlice();
                long count = value.getLong(0);
                this.counts.add(groupId, count);
                double sum = value.getDouble(8);
                this.sums.add(groupId, sum);
            }
            Preconditions.checkState((!intermediateValues.advanceNextPosition() ? 1 : 0) != 0);
        }

        @Override
        public void evaluateIntermediate(int groupId, BlockBuilder output) {
            long count = this.counts.get(groupId);
            double sum = this.sums.get(groupId);
            Slice value = Slices.allocate((int)16);
            value.setLong(0, count);
            value.setDouble(8, sum);
            output.append(value);
        }

        @Override
        public void evaluateFinal(int groupId, BlockBuilder output) {
            long count = this.counts.get(groupId);
            if (count != 0L) {
                double value = this.sums.get(groupId);
                output.append(value / (double)count);
            } else {
                output.appendNull();
            }
        }
    }
}

