/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.interop.tensorflow;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
import org.tensorflow.types.TFloat32;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.interop.tensorflow.FeatureConverter;
import org.tribuo.interop.tensorflow.TensorMap;
import org.tribuo.interop.tensorflow.protos.FeatureConverterProto;
import org.tribuo.interop.tensorflow.protos.ImageConverterProto;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.protos.ProtoSerializable;
import org.tribuo.protos.ProtoSerializableClass;
import org.tribuo.protos.ProtoSerializableField;
import org.tribuo.protos.ProtoUtil;

@ProtoSerializableClass(serializedDataClass=ImageConverterProto.class, version=0)
public class ImageConverter
implements FeatureConverter {
    private static final long serialVersionUID = 1L;
    public static final int CURRENT_VERSION = 0;
    @Config(mandatory=true, description="TensorFlow Placeholder Input name.")
    @ProtoSerializableField
    private String inputName;
    @Config(mandatory=true, description="Image width.")
    @ProtoSerializableField
    private int width;
    @Config(mandatory=true, description="Image height.")
    @ProtoSerializableField
    private int height;
    @Config(mandatory=true, description="Number of channels.")
    @ProtoSerializableField
    private int channels;
    private int totalPixels;

    private ImageConverter() {
    }

    public ImageConverter(String inputName, int width, int height, int channels) {
        if (width < 1 || height < 1 || channels < 1) {
            throw new IllegalArgumentException("Inputs must be positive integers, found [" + width + "," + height + "," + channels + "]");
        }
        if (inputName == null || inputName.isEmpty()) {
            throw new IllegalArgumentException("The input name must be a valid String");
        }
        long values = (long)width * (long)height * (long)channels;
        if (values > Integer.MAX_VALUE) {
            throw new IllegalArgumentException("Image size must be less than 2^31, found " + values);
        }
        this.inputName = inputName;
        this.totalPixels = (int)values;
        this.width = width;
        this.height = height;
        this.channels = channels;
    }

    public void postConfig() {
        if (this.width < 1 || this.height < 1 || this.channels < 1) {
            throw new PropertyException("", "Inputs must be positive integers, found [" + this.width + "," + this.height + "," + this.channels + "]");
        }
        long values = (long)this.width * (long)this.height * (long)this.channels;
        if (values > Integer.MAX_VALUE) {
            throw new PropertyException("", "Image size must be less than 2^31, found " + values);
        }
        this.totalPixels = (int)values;
    }

    public static ImageConverter deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        ImageConverterProto proto = (ImageConverterProto)message.unpack(ImageConverterProto.class);
        return new ImageConverter(proto.getInputName(), proto.getWidth(), proto.getHeight(), proto.getChannels());
    }

    public FeatureConverterProto serialize() {
        return (FeatureConverterProto)ProtoUtil.serialize((ProtoSerializable)this);
    }

    @Override
    public Set<String> inputNamesSet() {
        return Collections.singleton(this.inputName);
    }

    @Override
    public TensorMap convert(Example<?> example, ImmutableFeatureMap featureIDMap) {
        float[] image = this.innerTransform(example, featureIDMap);
        return new TensorMap(this.inputName, (Tensor)TFloat32.tensorOf((Shape)Shape.of((long[])new long[]{1L, this.width, this.height, this.channels}), (FloatDataBuffer)DataBuffers.of((float[])image)));
    }

    float[] innerTransform(Example<?> example, ImmutableFeatureMap featureIDMap) {
        if (featureIDMap.size() > this.totalPixels) {
            throw new IllegalArgumentException("Found more values than expected, expected " + this.totalPixels + ", found " + featureIDMap.size());
        }
        float[] output = new float[this.totalPixels];
        for (Feature f : example) {
            int id = featureIDMap.getID(f.getName());
            output[id] = (float)f.getValue();
        }
        return output;
    }

    float[] innerTransform(SGDVector vector) {
        if (vector.size() > this.totalPixels) {
            throw new IllegalArgumentException("Found more values than expected, expected " + this.totalPixels + ", found " + vector.size());
        }
        float[] output = new float[this.totalPixels];
        for (VectorTuple f : vector) {
            output[f.index] = (float)f.value;
        }
        return output;
    }

    @Override
    public TensorMap convert(List<? extends Example<?>> examples, ImmutableFeatureMap featureIDMap) {
        TFloat32 output = TFloat32.tensorOf((Shape)Shape.of((long[])new long[]{examples.size(), this.width, this.height, this.channels}));
        int i = 0;
        for (Example<?> example : examples) {
            float[] features = this.innerTransform(example, featureIDMap);
            output.set((NdArray)NdArrays.wrap((Shape)Shape.of((long[])new long[]{this.width, this.height, this.channels}), (FloatDataBuffer)DataBuffers.of((float[])features)), new long[]{i});
            ++i;
        }
        return new TensorMap(this.inputName, (Tensor)output);
    }

    @Override
    public TensorMap convert(SGDVector vector) {
        float[] image = this.innerTransform(vector);
        return new TensorMap(this.inputName, (Tensor)TFloat32.tensorOf((Shape)Shape.of((long[])new long[]{1L, this.width, this.height, this.channels}), (FloatDataBuffer)DataBuffers.of((float[])image)));
    }

    @Override
    public TensorMap convert(List<? extends SGDVector> vectors) {
        TFloat32 output = TFloat32.tensorOf((Shape)Shape.of((long[])new long[]{vectors.size(), this.width, this.height, this.channels}));
        int i = 0;
        for (SGDVector sGDVector : vectors) {
            float[] features = this.innerTransform(sGDVector);
            output.set((NdArray)NdArrays.wrap((Shape)Shape.of((long[])new long[]{this.width, this.height, this.channels}), (FloatDataBuffer)DataBuffers.of((float[])features)), new long[]{i});
            ++i;
        }
        return new TensorMap(this.inputName, (Tensor)output);
    }

    public String toString() {
        return "ImageConverter(inputName='" + this.inputName + "',width=" + this.width + ",height=" + this.height + ",channels=" + this.channels + ")";
    }

    public ConfiguredObjectProvenance getProvenance() {
        return new ConfiguredObjectProvenanceImpl((Configurable)this, "FeatureConverter");
    }
}

