/*
 * Decompiled with CFR 0.152.
 */
package io.trino.plugin.ml;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.UnmodifiableIterator;
import io.airlift.slice.BasicSliceInput;
import io.airlift.slice.SliceOutput;
import io.airlift.slice.Slices;
import io.trino.plugin.ml.AbstractFeatureTransformation;
import io.trino.plugin.ml.Dataset;
import io.trino.plugin.ml.FeatureVector;
import io.trino.plugin.ml.type.ModelType;
import it.unimi.dsi.fastutil.ints.Int2DoubleMap;
import it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap;
import it.unimi.dsi.fastutil.ints.IntIterator;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;

public class FeatureUnitNormalizer
extends AbstractFeatureTransformation {
    private final Int2DoubleMap mins = new Int2DoubleOpenHashMap();
    private final Int2DoubleMap maxs = new Int2DoubleOpenHashMap();

    public FeatureUnitNormalizer() {
        this.mins.defaultReturnValue(Double.POSITIVE_INFINITY);
        this.maxs.defaultReturnValue(Double.NEGATIVE_INFINITY);
    }

    @Override
    public ModelType getType() {
        return ModelType.MODEL;
    }

    @Override
    public byte[] getSerializedData() {
        SliceOutput output = Slices.allocate((int)(20 * this.mins.size())).getOutput();
        IntIterator intIterator = this.mins.keySet().iterator();
        while (intIterator.hasNext()) {
            int key = (Integer)intIterator.next();
            output.appendInt(key);
            output.appendDouble(this.mins.get(key));
            output.appendDouble(this.maxs.get(key));
        }
        return output.slice().getBytes();
    }

    public static FeatureUnitNormalizer deserialize(byte[] modelData) {
        BasicSliceInput input = Slices.wrappedBuffer((byte[])modelData).getInput();
        FeatureUnitNormalizer model = new FeatureUnitNormalizer();
        while (input.isReadable()) {
            int key = input.readInt();
            model.mins.put(key, input.readDouble());
            model.maxs.put(key, input.readDouble());
        }
        return model;
    }

    @Override
    public void train(Dataset dataset) {
        for (FeatureVector vector : dataset.getDatapoints()) {
            for (Map.Entry<Integer, Double> feature : vector.getFeatures().entrySet()) {
                int key = feature.getKey();
                double value = feature.getValue();
                if (value < this.mins.get(key)) {
                    this.mins.put(key, value);
                }
                if (!(value > this.maxs.get(key))) continue;
                this.maxs.put(key, value);
            }
        }
        UnmodifiableIterator unmodifiableIterator = ImmutableSet.copyOf((Collection)this.mins.keySet()).iterator();
        while (unmodifiableIterator.hasNext()) {
            int key = (Integer)unmodifiableIterator.next();
            if (this.mins.get(key) != this.maxs.get(key)) continue;
            this.mins.remove(key);
            this.maxs.remove(key);
        }
    }

    @Override
    public FeatureVector transform(FeatureVector features) {
        HashMap<Integer, Double> transformed = new HashMap<Integer, Double>();
        for (Map.Entry<Integer, Double> entry : features.getFeatures().entrySet()) {
            int key = entry.getKey();
            double value = entry.getValue();
            if (this.mins.containsKey((Object)entry.getKey())) {
                double min = this.mins.get(key);
                value = (value - min) / (this.maxs.get(key) - min);
            } else {
                value = 0.0;
            }
            value = Math.clamp(value, 0.0, 1.0);
            transformed.put(entry.getKey(), value);
        }
        return new FeatureVector(transformed);
    }
}

