/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.interop.tensorflow;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.BiFunction;
import org.tensorflow.Operand;
import org.tensorflow.Tensor;
import org.tensorflow.framework.losses.MeanSquaredError;
import org.tensorflow.framework.losses.Reduction;
import org.tensorflow.ndarray.FloatNdArray;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.index.Index;
import org.tensorflow.ndarray.index.Indices;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.family.TNumber;
import org.tribuo.Example;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.interop.tensorflow.OutputConverter;
import org.tribuo.interop.tensorflow.protos.OutputConverterProto;
import org.tribuo.protos.ProtoSerializable;
import org.tribuo.protos.ProtoSerializableClass;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.regression.ImmutableRegressionInfo;
import org.tribuo.regression.Regressor;

@ProtoSerializableClass(version=0)
public class RegressorConverter
implements OutputConverter<Regressor> {
    private static final long serialVersionUID = 1L;
    public static final int CURRENT_VERSION = 0;

    public static RegressorConverter deserializeFromProto(int version, String className, Any message) {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        if (message.getValue() != ByteString.EMPTY) {
            throw new IllegalArgumentException("Invalid proto");
        }
        return new RegressorConverter();
    }

    public OutputConverterProto serialize() {
        return (OutputConverterProto)ProtoUtil.serialize((ProtoSerializable)this);
    }

    @Override
    public BiFunction<Ops, Pair<Placeholder<? extends TNumber>, Operand<TNumber>>, Operand<TNumber>> loss() {
        return (ops, pair) -> new MeanSquaredError("tribuo-mse", Reduction.SUM_OVER_BATCH_SIZE).call(ops, (Operand)pair.getA(), (Operand)pair.getB());
    }

    @Override
    public <U extends TNumber> BiFunction<Ops, Operand<U>, Op> outputTransformFunction() {
        return Ops::identity;
    }

    @Override
    public Prediction<Regressor> convertToPrediction(Tensor tensor, ImmutableOutputInfo<Regressor> outputIDInfo, int numValidFeatures, Example<Regressor> example) {
        Regressor r = this.convertToOutput(tensor, outputIDInfo);
        return new Prediction((Output)r, numValidFeatures, example);
    }

    @Override
    public Regressor convertToOutput(Tensor tensor, ImmutableOutputInfo<Regressor> outputIDInfo) {
        FloatNdArray predictions = this.getBatchPredictions(tensor, outputIDInfo.size());
        long[] shape = predictions.shape().asArray();
        if (shape[0] != 1L) {
            throw new IllegalArgumentException("Supplied tensor has too many results, found " + shape[0]);
        }
        if (shape[1] != (long)outputIDInfo.size()) {
            throw new IllegalArgumentException("Supplied tensor has an incorrect number of dimensions, shape[1] = " + shape[1] + ", expected " + outputIDInfo.size());
        }
        String[] names = new String[outputIDInfo.size()];
        double[] values = new double[outputIDInfo.size()];
        for (Pair p : outputIDInfo) {
            int id = (Integer)p.getA();
            names[id] = ((Regressor)p.getB()).getNames()[0];
            values[id] = predictions.getFloat(new long[]{0L, id});
        }
        return new Regressor(names, values);
    }

    private FloatNdArray getBatchPredictions(Tensor tensor, int outputDims) {
        if (tensor instanceof TFloat32) {
            long[] shape = tensor.shape().asArray();
            if (shape.length != 2 && shape.length != 1) {
                throw new IllegalArgumentException("Supplied tensor has the wrong number of dimensions, shape = " + Arrays.toString(shape));
            }
            if (shape.length == 1) {
                TFloat32 floatTensor = (TFloat32)tensor;
                return floatTensor.slice(new Index[]{Indices.all(), Indices.newAxis()});
            }
            if (shape[1] != (long)outputDims) {
                throw new IllegalArgumentException("Supplied tensor has incorrect number of elements, tensor value dimension: " + Arrays.toString(shape) + ", output dimension: " + outputDims);
            }
            return (TFloat32)tensor;
        }
        throw new IllegalArgumentException("Tensor is not a 32-bit float. Found type " + tensor.getClass().getName());
    }

    @Override
    public List<Prediction<Regressor>> convertToBatchPrediction(Tensor tensor, ImmutableOutputInfo<Regressor> outputIDInfo, int[] numValidFeatures, List<Example<Regressor>> examples) {
        List<Regressor> regressors = this.convertToBatchOutput(tensor, outputIDInfo);
        ArrayList<Prediction<Regressor>> output = new ArrayList<Prediction<Regressor>>();
        if (regressors.size() != examples.size() || regressors.size() != numValidFeatures.length) {
            throw new IllegalArgumentException("Invalid number of predictions received from Tensorflow, expected " + numValidFeatures.length + ", received " + regressors.size());
        }
        for (int i = 0; i < regressors.size(); ++i) {
            output.add((Prediction<Regressor>)new Prediction((Output)regressors.get(i), numValidFeatures[i], examples.get(i)));
        }
        return output;
    }

    @Override
    public List<Regressor> convertToBatchOutput(Tensor tensor, ImmutableOutputInfo<Regressor> outputIDInfo) {
        FloatNdArray predictions = this.getBatchPredictions(tensor, outputIDInfo.size());
        ArrayList<Regressor> output = new ArrayList<Regressor>();
        int batchSize = (int)predictions.shape().asArray()[0];
        String[] names = new String[outputIDInfo.size()];
        for (Pair p : outputIDInfo) {
            int id = (Integer)p.getA();
            names[id] = ((Regressor)p.getB()).getNames()[0];
        }
        for (int i = 0; i < batchSize; ++i) {
            double[] values = new double[names.length];
            for (int j = 0; j < names.length; ++j) {
                values[j] = predictions.getFloat(new long[]{i, j});
            }
            output.add(new Regressor(names, values));
        }
        return output;
    }

    @Override
    public Tensor convertToTensor(Regressor example, ImmutableOutputInfo<Regressor> outputIDInfo) {
        TFloat32 output = TFloat32.tensorOf((Shape)Shape.of((long[])new long[]{1L, outputIDInfo.size()}));
        int[] ids = ((ImmutableRegressionInfo)outputIDInfo).getIDtoNaturalOrderMapping();
        double[] values = example.getValues();
        for (Pair p : outputIDInfo) {
            int id = (Integer)p.getA();
            output.setFloat((float)values[ids[id]], new long[]{0L, id});
        }
        return output;
    }

    @Override
    public Tensor convertToTensor(List<Example<Regressor>> examples, ImmutableOutputInfo<Regressor> outputIDInfo) {
        TFloat32 output = TFloat32.tensorOf((Shape)Shape.of((long[])new long[]{examples.size(), outputIDInfo.size()}));
        int[] ids = ((ImmutableRegressionInfo)outputIDInfo).getIDtoNaturalOrderMapping();
        int i = 0;
        for (Example<Regressor> e : examples) {
            double[] values = ((Regressor)e.getOutput()).getValues();
            for (int j = 0; j < outputIDInfo.size(); ++j) {
                output.setFloat((float)values[ids[j]], new long[]{i, j});
            }
            ++i;
        }
        return output;
    }

    @Override
    public boolean generatesProbabilities() {
        return false;
    }

    public String toString() {
        return "RegressorConverter()";
    }

    public ConfiguredObjectProvenance getProvenance() {
        return new ConfiguredObjectProvenanceImpl((Configurable)this, "OutputConverter");
    }

    @Override
    public Class<Regressor> getTypeWitness() {
        return Regressor.class;
    }
}

