/*
 * Decompiled with CFR 0.152.
 */
package com.linkedin.feathr.common.tensor;

import com.linkedin.feathr.common.tensor.TensorData;
import com.linkedin.feathr.common.tensor.TensorType;
import com.linkedin.feathr.common.tensor.scalar.ScalarTensor;
import com.linkedin.feathr.common.tensorbuilder.BulkTensorBuilder;
import com.linkedin.feathr.common.tensorbuilder.DenseTensorBuilderFactory;
import com.linkedin.feathr.common.tensorbuilder.TensorBuilder;
import com.linkedin.feathr.common.tensorbuilder.UniversalTensorBuilderFactory;
import java.util.List;
import java.util.Map;
import java.util.Set;

public final class Tensors {
    private Tensors() {
    }

    private static BulkTensorBuilder getBulkBuilder(TensorType type, int size) {
        BulkTensorBuilder builder = DenseTensorBuilderFactory.INSTANCE.getBulkTensorBuilder(type);
        if (!builder.hasVariableCardinality() && builder.getStaticCardinality() != size) {
            throw new IllegalArgumentException("The number of values " + size + " is not equal to the size of the type " + builder.getStaticCardinality() + ".");
        }
        if (size % builder.getStaticCardinality() != 0) {
            throw new IllegalArgumentException("The number of values " + size + " is not a multiple of the static size of the type " + builder.getStaticCardinality() + ".");
        }
        return builder;
    }

    public static TensorData asScalarTensor(TensorType type, Object scalar) {
        if (type.getDimensionTypes().size() > 0) {
            throw new IllegalArgumentException("Scalar tensors cannot have dimensions.");
        }
        return ScalarTensor.wrap(scalar, type.getValueType().getRepresentation());
    }

    public static TensorData asDenseTensor(TensorType type, float[] floats) {
        BulkTensorBuilder builder = Tensors.getBulkBuilder(type, floats.length);
        return builder.build(floats);
    }

    public static TensorData asDenseTensor(TensorType type, int[] ints) {
        BulkTensorBuilder builder = Tensors.getBulkBuilder(type, ints.length);
        return builder.build(ints);
    }

    public static TensorData asDenseTensor(TensorType type, long[] longs) {
        BulkTensorBuilder builder = Tensors.getBulkBuilder(type, longs.length);
        return builder.build(longs);
    }

    public static TensorData asDenseTensor(TensorType type, double[] doubles) {
        BulkTensorBuilder builder = Tensors.getBulkBuilder(type, doubles.length);
        return builder.build(doubles);
    }

    public static TensorData asDenseTensor(TensorType type, List<?> values) {
        BulkTensorBuilder builder = Tensors.getBulkBuilder(type, values.size());
        return builder.build(values);
    }

    public static TensorData asSparseTensor(TensorType type, Set<?> dimensionValues) {
        if (type.getDimensionTypes().size() != 1) {
            throw new IllegalArgumentException("Only one-dimensional tensors can represent sets.");
        }
        TensorBuilder<?> builder = UniversalTensorBuilderFactory.INSTANCE.getTensorBuilder(type);
        builder.start(dimensionValues.size());
        for (Object value : dimensionValues) {
            builder.setValue(0, value);
            builder.setValue(1, 1);
            builder.append();
        }
        return builder.build();
    }

    public static TensorData asSparseTensor(TensorType type, Map<?, ?> values) {
        if (type.getDimensionTypes().size() != 1) {
            throw new IllegalArgumentException("Only one-dimensional tensors can represent maps.");
        }
        TensorBuilder<?> builder = UniversalTensorBuilderFactory.INSTANCE.getTensorBuilder(type);
        builder.start(values.size());
        for (Map.Entry<?, ?> entry : values.entrySet()) {
            builder.setValue(0, entry.getKey());
            builder.setValue(1, entry.getValue());
            builder.append();
        }
        return builder.build();
    }
}

