/*
 * Decompiled with CFR 0.152.
 */
package com.linkedin.feathr.common.tensor;

import com.linkedin.feathr.common.tensor.DimensionType;
import com.linkedin.feathr.common.tensor.Representable;
import com.linkedin.feathr.common.tensor.TensorCategory;
import com.linkedin.feathr.common.tensor.WriteableTuple;
import com.linkedin.feathr.common.types.PrimitiveType;
import com.linkedin.feathr.common.types.ValueType;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

public final class TensorType
implements Serializable {
    public static final TensorType EMPTY = new TensorType(PrimitiveType.FLOAT, Collections.emptyList(), Collections.emptyList());
    private final TensorCategory _tensorCategory;
    private final ValueType _valueType;
    private final List<DimensionType> _dimensionTypes;
    private final List<String> _dimensionNames;
    private volatile Representable[] _columnTypes = null;

    public TensorType(ValueType valueType, List<DimensionType> dimensionTypes) {
        this(valueType, dimensionTypes, null);
    }

    public TensorType(ValueType valueType, List<DimensionType> dimensionTypes, List<String> dimensionNames) {
        this(TensorCategory.SPARSE, valueType, dimensionTypes, dimensionNames);
    }

    public TensorType(TensorCategory tensorCategory, ValueType valueType, List<DimensionType> dimensionTypes, List<String> dimensionNames) {
        this._tensorCategory = tensorCategory;
        List<String> dimNames = dimensionNames;
        if (dimNames == null) {
            dimNames = new ArrayList<String>(dimensionTypes.size());
            for (DimensionType dt : dimensionTypes) {
                dimNames.add(dt.getName());
            }
        } else if (dimensionTypes.size() != dimNames.size()) {
            throw new IllegalArgumentException("The numbers of dimension types " + dimensionTypes + " and names " + dimNames + " have to be equal.");
        }
        this._valueType = valueType;
        this._dimensionTypes = dimensionTypes;
        this._dimensionNames = dimNames;
    }

    public TensorType(TensorCategory tensorCategory, ValueType valueType, List<DimensionType> dimensionTypes) {
        this(tensorCategory, valueType, dimensionTypes, null);
    }

    public TensorCategory getTensorCategory() {
        return this._tensorCategory;
    }

    public ValueType getValueType() {
        return this._valueType;
    }

    public List<DimensionType> getDimensionTypes() {
        return this._dimensionTypes;
    }

    public List<String> getDimensionNames() {
        return this._dimensionNames;
    }

    public Representable[] getColumnTypes() {
        if (this._columnTypes == null) {
            Representable[] representables = new Representable[this._dimensionTypes.size() + 1];
            int i = 0;
            for (DimensionType dimensionType : this._dimensionTypes) {
                representables[i] = dimensionType.getRepresentation();
                ++i;
            }
            representables[i] = this._valueType.getRepresentation();
            this._columnTypes = representables;
        }
        return this._columnTypes;
    }

    public void setDimensions(WriteableTuple target, Object[] dimensions) {
        Objects.requireNonNull(target);
        Objects.requireNonNull(dimensions);
        if (dimensions.length != this._dimensionTypes.size()) {
            throw new IllegalArgumentException("Wrong number of dimensions. Got " + dimensions.length + ", expected " + this._dimensionTypes.size());
        }
        for (int i = 0; i < dimensions.length; ++i) {
            DimensionType dimensionType = this._dimensionTypes.get(i);
            dimensionType.setDimensionValue(target, i, dimensions[i]);
        }
    }

    public int[] getShape() {
        int dimensionTypesSize = this._dimensionTypes.size();
        int[] shape = new int[dimensionTypesSize];
        for (int i = 0; i < dimensionTypesSize; ++i) {
            DimensionType dimensionType = this._dimensionTypes.get(i);
            shape[i] = dimensionType.getShape();
        }
        return shape;
    }

    public int getDenseSize() {
        int[] shape = this.getShape();
        int size = 1;
        for (int value : shape) {
            if (value == -1) {
                return -1;
            }
            size *= value;
        }
        return size;
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        TensorType that = (TensorType)o;
        return Objects.equals((Object)this._tensorCategory, (Object)that._tensorCategory) && Objects.equals(this._valueType, that._valueType) && Objects.equals(this._dimensionNames, that._dimensionNames) && Objects.equals(this._dimensionTypes, that._dimensionTypes);
    }

    public int hashCode() {
        return Objects.hash(new Object[]{this._tensorCategory, this._valueType, this._dimensionNames, this._dimensionTypes});
    }

    public String toString() {
        return "TENSOR<" + (Object)((Object)this.getTensorCategory()) + ">" + this.getDimensionTypes().stream().map(dimensionType -> "[" + dimensionType.toString() + "]").collect(Collectors.joining()) + ":" + this.getValueType();
    }
}

