/*
 * Decompiled with CFR 0.152.
 */
package com.linkedin.feathr.common.tensor;

import com.google.common.collect.Lists;
import com.linkedin.feathr.common.tensor.DenseTensor;
import com.linkedin.feathr.common.tensor.DimensionType;
import com.linkedin.feathr.common.tensor.Primitive;
import com.linkedin.feathr.common.tensor.PrimitiveDimensionType;
import com.linkedin.feathr.common.tensor.Representable;
import com.linkedin.feathr.common.tensor.TensorCategory;
import com.linkedin.feathr.common.tensor.TensorData;
import com.linkedin.feathr.common.tensor.TensorType;
import com.linkedin.feathr.common.types.PrimitiveType;
import java.util.ArrayList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public final class TensorTypes {
    private static final Pattern CAPTURING_DIMENSION_TYPE = Pattern.compile("\\[(?<base>INT|LONG|STRING)(?:\\((?<shape>\\d+)\\))?]");
    private static final Pattern TENSOR_TYPE = Pattern.compile("TENSOR<(?<category>SPARSE|DENSE|RAGGED)>(?<dimensions>(?:\\[(?:INT|LONG|STRING|BYTES)(?:\\(\\d+\\))?])*):(?<value>INT|LONG|FLOAT|DOUBLE|STRING|BOOLEAN|BYTES)");

    private TensorTypes() {
    }

    public static TensorType parseTensorType(String syntax) {
        Matcher matcher = TENSOR_TYPE.matcher(syntax);
        boolean matches = matcher.matches();
        if (matches) {
            TensorCategory category = TensorCategory.valueOf(matcher.group("category"));
            String dimensions = matcher.group("dimensions");
            String value = matcher.group("value");
            return new TensorType(category, new PrimitiveType(Primitive.valueOf(value)), TensorTypes.parseDimensions(dimensions), null);
        }
        throw new IllegalArgumentException("Not a valid tensor type: " + syntax);
    }

    public static TensorType fromRepresentables(boolean sparse, Representable[] representables) {
        PrimitiveType valueType = new PrimitiveType(representables[representables.length - 1].getRepresentation());
        ArrayList dimensionTypes = Lists.newArrayListWithCapacity((int)(representables.length - 1));
        for (int i = 0; i < representables.length - 1; ++i) {
            dimensionTypes.add(new PrimitiveDimensionType(representables[i].getRepresentation()));
        }
        return new TensorType(sparse ? TensorCategory.SPARSE : TensorCategory.DENSE, valueType, dimensionTypes);
    }

    public static TensorType fromTensorData(TensorData tensorData) {
        return TensorTypes.fromRepresentables(!(tensorData instanceof DenseTensor), tensorData.getTypes());
    }

    private static List<DimensionType> parseDimensions(String syntax) {
        ArrayList<DimensionType> dimensions = new ArrayList<DimensionType>();
        if (syntax != null) {
            Matcher matcher = CAPTURING_DIMENSION_TYPE.matcher(syntax);
            while (matcher.find()) {
                PrimitiveDimensionType dimensionType = new PrimitiveDimensionType(Primitive.valueOf(matcher.group("base")));
                String shape = matcher.group("shape");
                PrimitiveDimensionType dimension = shape == null ? dimensionType : dimensionType.withShape(Integer.parseInt(shape));
                dimensions.add(dimension);
            }
        }
        return dimensions;
    }
}

