/*
 * Decompiled with CFR 0.152.
 */
package io.trino.plugin.ml;

import io.trino.plugin.ml.Classifier;
import io.trino.plugin.ml.Dataset;
import io.trino.plugin.ml.FeatureTransformation;
import io.trino.plugin.ml.FeatureVector;
import io.trino.plugin.ml.Model;
import io.trino.plugin.ml.ModelUtils;
import io.trino.plugin.ml.type.ClassifierType;
import io.trino.plugin.ml.type.ModelType;
import java.util.List;
import java.util.Objects;

public class ClassifierFeatureTransformer
implements Classifier<Integer> {
    private final Classifier<Integer> classifier;
    private final FeatureTransformation transformation;

    public ClassifierFeatureTransformer(Classifier<Integer> classifier, FeatureTransformation transformation) {
        this.classifier = Objects.requireNonNull(classifier, "classifier is null");
        this.transformation = Objects.requireNonNull(transformation, "transformation is null");
    }

    @Override
    public ModelType getType() {
        return ClassifierType.BIGINT_CLASSIFIER;
    }

    @Override
    public byte[] getSerializedData() {
        return ModelUtils.serializeModels(this.classifier, this.transformation);
    }

    public static ClassifierFeatureTransformer deserialize(byte[] data) {
        List<Model> models = ModelUtils.deserializeModels(data);
        return new ClassifierFeatureTransformer((Classifier)models.get(0), (FeatureTransformation)models.get(1));
    }

    @Override
    public Integer classify(FeatureVector features) {
        return this.classifier.classify(this.transformation.transform(features));
    }

    @Override
    public void train(Dataset dataset) {
        this.transformation.train(dataset);
        this.classifier.train(this.transformation.transform(dataset));
    }
}

