/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.interop.tensorflow;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.logging.Logger;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
import org.tensorflow.types.TFloat32;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.interop.tensorflow.FeatureConverter;
import org.tribuo.interop.tensorflow.TensorMap;
import org.tribuo.interop.tensorflow.protos.DenseFeatureConverterProto;
import org.tribuo.interop.tensorflow.protos.FeatureConverterProto;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.protos.ProtoSerializable;
import org.tribuo.protos.ProtoSerializableClass;
import org.tribuo.protos.ProtoSerializableField;
import org.tribuo.protos.ProtoUtil;

@ProtoSerializableClass(serializedDataClass=DenseFeatureConverterProto.class, version=0)
public class DenseFeatureConverter
implements FeatureConverter {
    private static final long serialVersionUID = 1L;
    private static final Logger logger = Logger.getLogger(DenseFeatureConverter.class.getName());
    public static final int CURRENT_VERSION = 0;
    public static final int THRESHOLD = 1000000;
    public static final int WARNING_THRESHOLD = 10;
    private int warningCount = 0;
    @Config(mandatory=true, description="TensorFlow Placeholder Input name.")
    @ProtoSerializableField
    private String inputName;

    private DenseFeatureConverter() {
    }

    public DenseFeatureConverter(String inputName) {
        this.inputName = inputName;
    }

    public static DenseFeatureConverter deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        DenseFeatureConverterProto proto = (DenseFeatureConverterProto)message.unpack(DenseFeatureConverterProto.class);
        return new DenseFeatureConverter(proto.getInputName());
    }

    public FeatureConverterProto serialize() {
        return (FeatureConverterProto)ProtoUtil.serialize((ProtoSerializable)this);
    }

    float[] innerTransform(Example<?> example, ImmutableFeatureMap featureIDMap) {
        if (this.warningCount < 10 && featureIDMap.size() > 1000000) {
            logger.warning("Large dense example requested, featureIDMap.size() = " + featureIDMap.size() + ", example.size() = " + example.size());
            ++this.warningCount;
        }
        float[] output = new float[featureIDMap.size()];
        for (Feature f : example) {
            int id = featureIDMap.getID(f.getName());
            if (id <= -1) continue;
            output[id] = (float)f.getValue();
        }
        return output;
    }

    private float[] innerTransform(SGDVector vector) {
        if (this.warningCount < 10 && vector.size() > 1000000) {
            logger.warning("Large dense example requested, dimension = " + vector.size() + ", numActiveElements = " + vector.numActiveElements());
            ++this.warningCount;
        }
        float[] output = new float[vector.size()];
        if (vector instanceof DenseVector) {
            DenseVector denseVec = (DenseVector)vector;
            for (int i = 0; i < output.length; ++i) {
                output[i] = (float)denseVec.get(i);
            }
        } else {
            for (VectorTuple f : vector) {
                output[f.index] = (float)f.value;
            }
        }
        return output;
    }

    @Override
    public TensorMap convert(Example<?> example, ImmutableFeatureMap featureIDMap) {
        float[] output = this.innerTransform(example, featureIDMap);
        return new TensorMap(this.inputName, (Tensor)TFloat32.tensorOf((Shape)Shape.of((long[])new long[]{1L, output.length}), (FloatDataBuffer)DataBuffers.of((float[])output)));
    }

    @Override
    public TensorMap convert(List<? extends Example<?>> examples, ImmutableFeatureMap featureIDMap) {
        TFloat32 output = TFloat32.tensorOf((Shape)Shape.of((long[])new long[]{examples.size(), featureIDMap.size()}));
        int i = 0;
        for (Example<?> example : examples) {
            float[] features = this.innerTransform(example, featureIDMap);
            output.set((NdArray)NdArrays.vectorOf((float[])features), new long[]{i});
            ++i;
        }
        return new TensorMap(this.inputName, (Tensor)output);
    }

    @Override
    public TensorMap convert(SGDVector vector) {
        float[] output = this.innerTransform(vector);
        return new TensorMap(this.inputName, (Tensor)TFloat32.tensorOf((Shape)Shape.of((long[])new long[]{1L, output.length}), (FloatDataBuffer)DataBuffers.of((float[])output)));
    }

    @Override
    public TensorMap convert(List<? extends SGDVector> vectors) {
        TFloat32 output = TFloat32.tensorOf((Shape)Shape.of((long[])new long[]{vectors.size(), vectors.get(0).size()}));
        int i = 0;
        for (SGDVector sGDVector : vectors) {
            float[] features = this.innerTransform(sGDVector);
            output.set((NdArray)NdArrays.vectorOf((float[])features), new long[]{i});
            ++i;
        }
        return new TensorMap(this.inputName, (Tensor)output);
    }

    @Override
    public Set<String> inputNamesSet() {
        return Collections.singleton(this.inputName);
    }

    public String toString() {
        return "DenseFeatureConverter(inputName='" + this.inputName + "')";
    }

    public ConfiguredObjectProvenance getProvenance() {
        return new ConfiguredObjectProvenanceImpl((Configurable)this, "FeatureConverter");
    }
}

