/*
 * Decompiled with CFR 0.152.
 */
package com.linkedin.feathr.common;

import com.google.common.collect.Maps;
import com.linkedin.feathr.common.tensor.DimensionType;
import com.linkedin.feathr.common.tensor.LOLTensorData;
import com.linkedin.feathr.common.tensor.Primitive;
import com.linkedin.feathr.common.tensor.ReadableTuple;
import com.linkedin.feathr.common.tensor.Representable;
import com.linkedin.feathr.common.tensor.SimpleWriteableTuple;
import com.linkedin.feathr.common.tensor.StandaloneReadableTuple;
import com.linkedin.feathr.common.tensor.TensorData;
import com.linkedin.feathr.common.tensor.TensorIterator;
import com.linkedin.feathr.common.tensor.TensorType;
import com.linkedin.feathr.common.tensor.TensorTypes;
import com.linkedin.feathr.common.tensor.WriteableTuple;
import com.linkedin.feathr.common.tensorbuilder.TensorBuilder;
import com.linkedin.feathr.common.tensorbuilder.TensorBuilderFactory;
import com.linkedin.feathr.common.tensorbuilder.UniversalTensorBuilderFactory;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

public final class TensorUtils {
    public static final int DEFAULT_MAX_STRING_LEN = 10240;
    private static final String SEPARATOR = ",";
    private static final String NEXT_LINE = "\n";
    private static final String VALUE_DIM_NAME = "Value";
    private static final String EXCEED_MAX_LIMIT = "...";

    private TensorUtils() {
    }

    public static String getDebugString(TensorType type, TensorData data, int maxStringLenLimit) {
        StringBuilder debugStringBuilder = new StringBuilder();
        List<String> dimensionNames = type.getDimensionNames();
        for (String dimensionName : dimensionNames) {
            debugStringBuilder.append(dimensionName);
            debugStringBuilder.append(SEPARATOR);
        }
        debugStringBuilder.append(VALUE_DIM_NAME);
        ArrayList<DimensionType> columnTypes = new ArrayList<DimensionType>(type.getDimensionTypes());
        columnTypes.add((DimensionType)((Object)type.getValueType()));
        TensorIterator dataIterator = data.iterator();
        dataIterator.start();
        while (dataIterator.isValid()) {
            debugStringBuilder.append(NEXT_LINE);
            if (debugStringBuilder.length() >= maxStringLenLimit) {
                debugStringBuilder.append(EXCEED_MAX_LIMIT);
                break;
            }
            int numCols = columnTypes.size();
            CharSequence[] strings = TensorUtils.convertToStrings(type, dataIterator, numCols);
            debugStringBuilder.append(String.join((CharSequence)SEPARATOR, strings));
            dataIterator.next();
        }
        return debugStringBuilder.toString();
    }

    public static TensorData convertNestedMapToTensor(Map<String, Object> map, TensorType tensorType) {
        return TensorUtils.convertNestedMapToTensor(map, tensorType, UniversalTensorBuilderFactory.INSTANCE);
    }

    public static TensorData convertNestedMapToTensor(Map<String, Object> map, TensorType tensorType, TensorBuilderFactory tensorBuilderFactory) {
        TensorBuilder<?> tensorBuilder = tensorBuilderFactory.getTensorBuilder(tensorType);
        tensorBuilder.start(map.size());
        TensorUtils.populateTensorBuilder(map, tensorType, tensorBuilder, 0, new SimpleWriteableTuple(tensorType.getColumnTypes()));
        return tensorBuilder.build();
    }

    public static Map<ReadableTuple, Float> convertTensorToMap(TensorData tensor) {
        if (tensor == null) {
            return null;
        }
        Map<ReadableTuple, Object> outputMap = TensorUtils.convertTensorToMapWithGenericValues(tensor);
        Representable[] columnTypes = tensor.getTypes();
        if (columnTypes[columnTypes.length - 1] == Primitive.FLOAT) {
            return outputMap;
        }
        HashMap termValues = Maps.newHashMapWithExpectedSize((int)outputMap.size());
        for (Map.Entry<ReadableTuple, Object> entry : outputMap.entrySet()) {
            if (entry.getValue() instanceof Number) {
                termValues.put(entry.getKey(), Float.valueOf(((Number)entry.getValue()).floatValue()));
                continue;
            }
            if (entry.getValue() instanceof Boolean) {
                termValues.put(entry.getKey(), Float.valueOf((Boolean)entry.getValue() != false ? 1.0f : 0.0f));
                continue;
            }
            if (entry.getValue() instanceof String) {
                try {
                    termValues.put(entry.getKey(), Float.valueOf(Float.parseFloat((String)entry.getValue())));
                    continue;
                }
                catch (NumberFormatException e) {
                    throw new IllegalArgumentException(String.format("String value %s can not be formatted to a float", entry.getValue()), e);
                }
            }
            throw new IllegalArgumentException("Expecting Primitive value but received " + entry.getValue().getClass());
        }
        return termValues;
    }

    public static Map<ReadableTuple, Object> convertTensorToMapWithGenericValues(TensorData tensor) {
        if (tensor == null) {
            return null;
        }
        HashMap outputMap = Maps.newHashMapWithExpectedSize((int)tensor.estimatedCardinality());
        TensorIterator iterator = tensor.iterator();
        Representable[] columnTypes = iterator.getTypes();
        int valueDim = columnTypes.length - 1;
        Primitive valueType = columnTypes[valueDim].getRepresentation();
        iterator.start();
        while (iterator.isValid()) {
            StandaloneReadableTuple row = new StandaloneReadableTuple(iterator, true);
            outputMap.put(row, valueType.toObject(iterator, valueDim));
            iterator.next();
        }
        return outputMap;
    }

    private static void populateTensorBuilder(Map<String, Object> map, TensorType tensorType, TensorBuilder tensorBuilder, int depth, SimpleWriteableTuple writeableTuple) {
        int numColumns = tensorType.getColumnTypes().length;
        map.forEach((k, v) -> {
            tensorType.getDimensionTypes().get(depth).setDimensionValue(writeableTuple, depth, k);
            if (v instanceof Map) {
                if (depth + 2 >= numColumns) {
                    throw new IllegalArgumentException(String.format("Expected only %d columns, but found more", numColumns));
                }
                TensorUtils.populateTensorBuilder((Map)v, tensorType, tensorBuilder, depth + 1, writeableTuple);
            } else {
                if (depth + 2 != numColumns) {
                    throw new IllegalArgumentException(String.format("Value %s is at depth %d but tensorType suggests it should be at %d", v.toString(), depth, numColumns));
                }
                tensorType.getValueType().getRepresentation().from(v, (WriteableTuple)writeableTuple, depth + 1);
                TensorUtils.setRow(tensorBuilder, writeableTuple);
            }
        });
    }

    private static void setRow(TensorBuilder tensorBuilder, SimpleWriteableTuple writeableTuple) {
        Representable[] types = tensorBuilder.getTypes();
        for (int i = 0; i < writeableTuple.getTypes().length; ++i) {
            types[i].getRepresentation().copy(writeableTuple, i, tensorBuilder, i);
        }
        tensorBuilder.append();
    }

    public static LOLTensorData convertToLOLTensor(TensorData ut) {
        Representable[] columnTypes = ut.getTypes();
        int numColumns = columnTypes.length;
        ArrayList dimensions = new ArrayList(numColumns - 1);
        ArrayList<Object> values = new ArrayList<Object>(ut.estimatedCardinality());
        int valueDim = columnTypes.length - 1;
        Primitive valueType = columnTypes[valueDim].getRepresentation();
        for (int i = 0; i < numColumns - 1; ++i) {
            dimensions.add(new ArrayList(ut.estimatedCardinality()));
        }
        TensorIterator iter = ut.iterator();
        while (iter.isValid()) {
            for (int i = 0; i < numColumns - 1; ++i) {
                ((List)dimensions.get(i)).add(iter.getValue(i));
            }
            values.add(valueType.toObject(iter, valueDim));
            iter.next();
        }
        return new LOLTensorData(ut.getTypes(), dimensions, values);
    }

    public static String[] convertToStrings(TensorType tensorType, ReadableTuple readableTuple, int numCols) {
        List<DimensionType> dims = tensorType.getDimensionTypes();
        int numDims = dims.size();
        if (numCols > numDims + 1) {
            throw new IllegalArgumentException("Number of columns in the output is greater than number of dims & value");
        }
        boolean convertValue = false;
        if (numCols > numDims) {
            convertValue = true;
        } else {
            numDims = numCols;
        }
        String[] ret = new String[numCols];
        for (int i = 0; i < numDims; ++i) {
            DimensionType dim = dims.get(i);
            ret[i] = dim.getDimensionValue(readableTuple, i).toString();
        }
        if (convertValue) {
            ret[numDims] = tensorType.getValueType().getRepresentation().toString(readableTuple, numDims);
        }
        return ret;
    }

    public static <K> Function<ReadableTuple, K> wrapKeyGen(TensorType inputTensor, Function<String[], K> keyGen) {
        int numDims = inputTensor.getDimensionTypes().size();
        return readableTuple -> keyGen.apply(TensorUtils.convertToStrings(inputTensor, readableTuple, numDims));
    }

    public static long[] getShape(TensorType tensorType) {
        List<DimensionType> dimensionTypes = tensorType.getDimensionTypes();
        int numDims = dimensionTypes.size();
        long[] shape = new long[numDims];
        for (int i = 0; i < numDims; ++i) {
            DimensionType dimensionType = dimensionTypes.get(i);
            shape[i] = dimensionType.getShape();
        }
        return shape;
    }

    public static TensorData populateTensor(Representable[] columnTypes, Object[][] data, TensorBuilder tensorBuilder) {
        for (int i = 0; i < data.length; ++i) {
            for (int j = 0; j < data[i].length; ++j) {
                if (data[i].length != columnTypes.length) {
                    throw new IllegalArgumentException(String.format("data[i] length should be equal to columnType lengthFound data[i].length = %s and columnType.length = %s", data[i].length, columnTypes.length));
                }
                columnTypes[j].getRepresentation().from(data[i][j], (WriteableTuple)tensorBuilder, j);
            }
            tensorBuilder.append();
        }
        return tensorBuilder.build();
    }

    @Deprecated
    public static TensorType parseTensorType(String syntax) {
        return TensorTypes.parseTensorType(syntax);
    }

    public static int safeRatio(int numerator, int denominator) {
        if (denominator == 0) {
            if (numerator == 0) {
                return 0;
            }
            throw new IllegalArgumentException("Dividing a non-zero " + numerator + " by zero.");
        }
        int ratio = numerator / denominator;
        if (ratio * denominator != numerator) {
            throw new IllegalArgumentException("Integer division has a non-zero remainder " + numerator + "/" + denominator + ".");
        }
        return ratio;
    }
}

