/*
 * Decompiled with CFR 0.152.
 */
package io.trino.plugin.ml;

import com.google.common.base.Preconditions;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.hash.HashCode;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.plugin.ml.Classifier;
import io.trino.plugin.ml.FeatureVector;
import io.trino.plugin.ml.Model;
import io.trino.plugin.ml.ModelUtils;
import io.trino.plugin.ml.Regressor;
import io.trino.plugin.ml.type.ClassifierType;
import io.trino.plugin.ml.type.RegressorType;
import io.trino.spi.block.Block;
import io.trino.spi.function.ScalarFunction;
import io.trino.spi.function.SqlType;

public final class MLFunctions {
    private static final Cache<HashCode, Model> MODEL_CACHE = CacheBuilder.newBuilder().maximumSize(5L).build();
    private static final String MAP_BIGINT_DOUBLE = "map(bigint,double)";

    private MLFunctions() {
    }

    @ScalarFunction(value="classify")
    @SqlType(value="varchar")
    public static Slice varcharClassify(@SqlType(value="map(bigint,double)") Block featuresMap, @SqlType(value="Classifier(varchar)") Slice modelSlice) {
        FeatureVector features = ModelUtils.toFeatures(featuresMap);
        Model model = MLFunctions.getOrLoadModel(modelSlice);
        Preconditions.checkArgument((boolean)model.getType().equals((Object)ClassifierType.VARCHAR_CLASSIFIER), (Object)"model is not a Classifier(varchar)");
        Classifier varcharClassifier = (Classifier)model;
        return Slices.utf8Slice((String)((String)varcharClassifier.classify(features)));
    }

    @ScalarFunction
    @SqlType(value="bigint")
    public static long classify(@SqlType(value="map(bigint,double)") Block featuresMap, @SqlType(value="Classifier(bigint)") Slice modelSlice) {
        FeatureVector features = ModelUtils.toFeatures(featuresMap);
        Model model = MLFunctions.getOrLoadModel(modelSlice);
        Preconditions.checkArgument((boolean)model.getType().equals((Object)ClassifierType.BIGINT_CLASSIFIER), (Object)"model is not a Classifier(bigint)");
        Classifier classifier = (Classifier)model;
        return ((Integer)classifier.classify(features)).intValue();
    }

    @ScalarFunction
    @SqlType(value="double")
    public static double regress(@SqlType(value="map(bigint,double)") Block featuresMap, @SqlType(value="Regressor") Slice modelSlice) {
        FeatureVector features = ModelUtils.toFeatures(featuresMap);
        Model model = MLFunctions.getOrLoadModel(modelSlice);
        Preconditions.checkArgument((boolean)model.getType().equals((Object)RegressorType.REGRESSOR), (Object)"model is not a regressor");
        Regressor regressor = (Regressor)model;
        return regressor.regress(features);
    }

    private static Model getOrLoadModel(Slice slice) {
        HashCode modelHash = ModelUtils.modelHash(slice);
        Model model = (Model)MODEL_CACHE.getIfPresent((Object)modelHash);
        if (model == null) {
            model = ModelUtils.deserialize(slice);
            MODEL_CACHE.put((Object)modelHash, (Object)model);
        }
        return model;
    }
}

