/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.ml;

import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.VarcharType;
import com.facebook.presto.ml.EvaluateClassifierPredictionsState;
import com.facebook.presto.spi.function.AggregationFunction;
import com.facebook.presto.spi.function.AggregationState;
import com.facebook.presto.spi.function.CombineFunction;
import com.facebook.presto.spi.function.InputFunction;
import com.facebook.presto.spi.function.LiteralParameters;
import com.facebook.presto.spi.function.OutputFunction;
import com.facebook.presto.spi.function.SqlType;
import com.google.common.collect.Sets;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import java.util.Locale;
import java.util.Map;
import java.util.Set;

@AggregationFunction(value="evaluate_classifier_predictions")
public final class EvaluateClassifierPredictionsAggregation {
    private EvaluateClassifierPredictionsAggregation() {
    }

    @InputFunction
    public static void input(@AggregationState EvaluateClassifierPredictionsState state, @SqlType(value="bigint") long truth, @SqlType(value="bigint") long prediction) {
        EvaluateClassifierPredictionsAggregation.input(state, Slices.utf8Slice((String)String.valueOf(truth)), Slices.utf8Slice((String)String.valueOf(prediction)));
    }

    @InputFunction
    @LiteralParameters(value={"x", "y"})
    public static void input(@AggregationState EvaluateClassifierPredictionsState state, @SqlType(value="varchar(x)") Slice truth, @SqlType(value="varchar(y)") Slice prediction) {
        if (truth.equals((Object)prediction)) {
            String key = truth.toStringUtf8();
            if (!state.getTruePositives().containsKey(key)) {
                state.addMemoryUsage(truth.length() + 4);
            }
            state.getTruePositives().put(key, state.getTruePositives().getOrDefault(key, 0) + 1);
        } else {
            String truthKey = truth.toStringUtf8();
            String predictionKey = prediction.toStringUtf8();
            if (!state.getFalsePositives().containsKey(predictionKey)) {
                state.addMemoryUsage(prediction.length() + 4);
            }
            state.getFalsePositives().put(predictionKey, state.getFalsePositives().getOrDefault(predictionKey, 0) + 1);
            if (!state.getFalseNegatives().containsKey(truthKey)) {
                state.addMemoryUsage(truth.length() + 4);
            }
            state.getFalseNegatives().put(truthKey, state.getFalseNegatives().getOrDefault(truthKey, 0) + 1);
        }
    }

    @CombineFunction
    public static void combine(@AggregationState EvaluateClassifierPredictionsState state, @AggregationState EvaluateClassifierPredictionsState scratchState) {
        int size = 0;
        size += EvaluateClassifierPredictionsAggregation.mergeMaps(state.getTruePositives(), scratchState.getTruePositives());
        size += EvaluateClassifierPredictionsAggregation.mergeMaps(state.getFalsePositives(), scratchState.getFalsePositives());
        state.addMemoryUsage(size += EvaluateClassifierPredictionsAggregation.mergeMaps(state.getFalseNegatives(), scratchState.getFalseNegatives()));
    }

    private static int mergeMaps(Map<String, Integer> map, Map<String, Integer> other) {
        int deltaSize = 0;
        for (Map.Entry<String, Integer> entry : other.entrySet()) {
            if (!map.containsKey(entry.getKey())) {
                deltaSize += entry.getKey().getBytes().length + 4;
            }
            map.put(entry.getKey(), map.getOrDefault(entry.getKey(), 0) + other.getOrDefault(entry.getKey(), 0));
        }
        return deltaSize;
    }

    @OutputFunction(value="varchar")
    public static void output(@AggregationState EvaluateClassifierPredictionsState state, BlockBuilder out) {
        StringBuilder sb = new StringBuilder();
        long correct = state.getTruePositives().values().stream().reduce(0, Integer::sum).intValue();
        long total = correct + (long)state.getFalsePositives().values().stream().reduce(0, Integer::sum).intValue();
        sb.append(String.format(Locale.US, "Accuracy: %d/%d (%.2f%%)%n", correct, total, 100.0 * (double)correct / (double)total));
        Sets.SetView labels = Sets.union((Set)Sets.union(state.getTruePositives().keySet(), state.getFalsePositives().keySet()), state.getFalseNegatives().keySet());
        for (String label : labels) {
            int truePositives = state.getTruePositives().getOrDefault(label, 0);
            int falsePositives = state.getFalsePositives().getOrDefault(label, 0);
            int falseNegatives = state.getFalseNegatives().getOrDefault(label, 0);
            sb.append(String.format(Locale.US, "Class '%s'%n", label));
            sb.append(String.format(Locale.US, "Precision: %d/%d (%.2f%%)%n", truePositives, truePositives + falsePositives, 100.0 * (double)truePositives / (double)(truePositives + falsePositives)));
            sb.append(String.format(Locale.US, "Recall: %d/%d (%.2f%%)%n", truePositives, truePositives + falseNegatives, 100.0 * (double)truePositives / (double)(truePositives + falseNegatives)));
        }
        VarcharType.VARCHAR.writeString(out, sb.toString());
    }
}

