/*
 * Decompiled with CFR 0.152.
 */
package deepboof.misc;

import deepboof.Tensor;
import deepboof.misc.TensorFactory_F32;
import deepboof.misc.TensorFactory_F64;
import deepboof.tensors.Tensor_F32;
import deepboof.tensors.Tensor_F64;
import java.util.Random;

public class TensorFactory<T extends Tensor<T>> {
    Class tensorType;

    public TensorFactory(Class tensorType) {
        this.tensorType = tensorType;
    }

    public T create(int ... shape) {
        if (this.tensorType == Tensor_F64.class) {
            return (T)new Tensor_F64(shape);
        }
        if (this.tensorType == Tensor_F32.class) {
            return (T)new Tensor_F32(shape);
        }
        throw new IllegalArgumentException("Unknown/unsupported tensor type " + this.tensorType.getSimpleName());
    }

    public T randomM(Random rand, boolean subTensor, int minibatch, int[] shape) {
        int[] modshape = new int[shape.length + 1];
        modshape[0] = minibatch;
        System.arraycopy(shape, 0, modshape, 1, shape.length);
        if (this.tensorType == Tensor_F64.class) {
            return (T)TensorFactory_F64.random(rand, subTensor, modshape);
        }
        if (this.tensorType == Tensor_F32.class) {
            return (T)TensorFactory_F32.random(rand, subTensor, modshape);
        }
        throw new IllegalArgumentException("Unknown/unsupported tensor type " + this.tensorType.getSimpleName());
    }

    public T random(Random rand, boolean subTensor, int ... shape) {
        if (this.tensorType == Tensor_F64.class) {
            return (T)TensorFactory_F64.random(rand, subTensor, shape);
        }
        if (this.tensorType == Tensor_F32.class) {
            return (T)TensorFactory_F32.random(rand, subTensor, shape);
        }
        throw new IllegalArgumentException("Unknown/unsupported tensor type " + this.tensorType.getSimpleName());
    }

    public T random(Random rand, boolean subTensor, double min, double max, int ... shape) {
        if (this.tensorType == Tensor_F64.class) {
            return (T)TensorFactory_F64.randomMM(rand, subTensor, min, max, shape);
        }
        if (this.tensorType == Tensor_F32.class) {
            return (T)TensorFactory_F32.randomMM(rand, subTensor, (float)min, (float)max, shape);
        }
        throw new IllegalArgumentException("Unknown/unsupported tensor type " + this.tensorType.getSimpleName());
    }

    public Class<T> getTensorType() {
        return this.tensorType;
    }
}

