/*
 * Decompiled with CFR 0.152.
 */
package com.linkedin.feathr.common;

import com.google.common.annotations.VisibleForTesting;
import com.linkedin.feathr.common.FeatureTypeConfig;
import com.linkedin.feathr.common.FeatureTypes;
import com.linkedin.feathr.common.tensor.DimensionType;
import com.linkedin.feathr.common.tensor.Primitive;
import com.linkedin.feathr.common.tensor.PrimitiveDimensionType;
import com.linkedin.feathr.common.tensor.TensorCategory;
import com.linkedin.feathr.common.tensor.TensorType;
import com.linkedin.feathr.common.types.PrimitiveType;
import com.linkedin.feathr.common.types.ValueType;
import com.linkedin.feathr.compute.Dimension;
import com.linkedin.feathr.compute.FeatureVersion;
import com.linkedin.feathr.compute.TensorFeatureFormat;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

public class PegasusFeatureTypeResolver {
    private static final PegasusFeatureTypeResolver INSTANCE = new PegasusFeatureTypeResolver();

    public static PegasusFeatureTypeResolver getInstance() {
        return INSTANCE;
    }

    private PegasusFeatureTypeResolver() {
    }

    public FeatureTypeConfig resolveFeatureType(FeatureVersion featureVersion) {
        FeatureTypes featureType = FeatureTypes.valueOf(featureVersion.getType().name());
        TensorType tensorType = null;
        if (featureType == FeatureTypes.TENSOR && featureVersion.hasFormat()) {
            tensorType = this.fromFeatureFormat(featureVersion.getFormat());
            featureType = FeatureTypes.TENSOR;
        }
        return tensorType != null ? new FeatureTypeConfig(featureType, tensorType, "No documentation") : new FeatureTypeConfig(featureType);
    }

    @Deprecated
    public Optional<Integer> resolveEmbeddingSize(FeatureVersion featureVersion) {
        FeatureTypes featureType = FeatureTypes.valueOf(featureVersion.getType().name());
        if (featureType != FeatureTypes.UNSPECIFIED && featureType != FeatureTypes.DENSE_VECTOR && featureType != FeatureTypes.TENSOR) {
            return Optional.empty();
        }
        if (!featureVersion.hasFormat()) {
            return Optional.empty();
        }
        TensorType tensorType = this.fromFeatureFormat(featureVersion.getFormat());
        int[] shape = tensorType.getShape();
        if (shape.length != 1) {
            return Optional.empty();
        }
        return Optional.of(shape[0]);
    }

    private TensorType fromFeatureFormat(TensorFeatureFormat featureFormat) {
        ValueType valType = this.fromValueTypeEnum(featureFormat.getValueType());
        TensorCategory tensorCategory = TensorCategory.valueOf(featureFormat.getTensorCategory().name());
        List<DimensionType> dimensionTypes = featureFormat.getDimensions().stream().map(this::fromDimension).collect(Collectors.toList());
        return new TensorType(tensorCategory, valType, dimensionTypes, null);
    }

    @VisibleForTesting
    DimensionType fromDimension(Dimension pegasusDimension) {
        Integer shape = pegasusDimension.getShape();
        switch (pegasusDimension.getType()) {
            case LONG: {
                return shape != null ? new PrimitiveDimensionType(Primitive.LONG, shape) : PrimitiveDimensionType.LONG;
            }
            case INT: {
                return shape != null ? new PrimitiveDimensionType(Primitive.INT, shape) : PrimitiveDimensionType.INT;
            }
            case STRING: {
                return shape != null ? new PrimitiveDimensionType(Primitive.STRING, shape) : PrimitiveDimensionType.STRING;
            }
        }
        throw new IllegalArgumentException("Unsupported dimension types from pegasus model: " + pegasusDimension.getType());
    }

    @VisibleForTesting
    ValueType fromValueTypeEnum(com.linkedin.feathr.compute.ValueType pegasusValType) {
        switch (pegasusValType) {
            case INT: {
                return PrimitiveType.INT;
            }
            case LONG: {
                return PrimitiveType.LONG;
            }
            case FLOAT: {
                return PrimitiveType.FLOAT;
            }
            case DOUBLE: {
                return PrimitiveType.DOUBLE;
            }
            case STRING: {
                return PrimitiveType.STRING;
            }
            case BOOLEAN: {
                return PrimitiveType.BOOLEAN;
            }
        }
        throw new IllegalArgumentException("Unsupported value type from the pegasus model: " + pegasusValType);
    }
}

