/*
 * Decompiled with CFR 0.152.
 */
package com.xxdb.data;

import com.xxdb.data.AbstractTensor;
import com.xxdb.data.BasicEntityFactory;
import com.xxdb.data.Entity;
import com.xxdb.data.Vector;
import com.xxdb.io.ExtendedDataInput;
import com.xxdb.io.ExtendedDataOutput;
import java.io.IOException;

public class BasicTensor
extends AbstractTensor {
    private Entity.DATA_TYPE dataType;
    private int tensorType;
    private int deviceType;
    private int tensorFlags;
    private int dimensions;
    private long[] shapes;
    private long[] strides;
    private long preserveValue;
    private long elemCount;
    private Vector data;

    protected BasicTensor(Entity.DATA_TYPE dataType, ExtendedDataInput in) throws IOException {
        this.deserialize(dataType, in);
    }

    protected void deserialize(Entity.DATA_TYPE dataType, ExtendedDataInput in) throws IOException {
        int d;
        this.dataType = dataType;
        this.tensorType = in.readByte();
        this.deviceType = in.readByte();
        this.tensorFlags = in.readInt();
        this.dimensions = in.readInt();
        this.shapes = new long[this.dimensions];
        this.strides = new long[this.dimensions];
        for (d = 0; d < this.dimensions; ++d) {
            this.shapes[d] = in.readLong();
        }
        for (d = 0; d < this.dimensions; ++d) {
            this.strides[d] = in.readLong();
        }
        this.preserveValue = in.readLong();
        this.elemCount = in.readLong();
        if (this.elemCount > Integer.MAX_VALUE) {
            throw new RuntimeException("tensor element count more than 2,147,483,647(Integer.MAX_VALUE).");
        }
        Vector subVector = BasicEntityFactory.instance().createVectorWithDefaultValue(dataType, (int)this.elemCount, -1);
        subVector.deserialize(0, (int)this.elemCount, in);
        this.data = subVector;
    }

    @Override
    public Entity.DATA_CATEGORY getDataCategory() {
        return this.getDataCategory(this.dataType);
    }

    @Override
    public Entity.DATA_TYPE getDataType() {
        return this.dataType;
    }

    @Override
    public int rows() {
        return this.data.rows();
    }

    @Override
    public void write(ExtendedDataOutput output) throws IOException {
        throw new RuntimeException("BasicTensor not support write method.");
    }

    public int getDimensions() {
        return this.dimensions;
    }

    public long[] getShapes() {
        return this.shapes;
    }

    public long[] getStrides() {
        return this.strides;
    }

    public long getElemCount() {
        return this.elemCount;
    }

    public Vector getData() {
        return this.data;
    }

    @Override
    public String getString() {
        StringBuilder sb = new StringBuilder();
        sb.append("tensor<").append(this.getDataTypeString());
        for (long shape : this.shapes) {
            sb.append("[").append(shape).append("]");
        }
        sb.append(">(");
        this.printTensor(sb, 0, 0, new int[this.dimensions]);
        sb.append(")");
        return sb.toString();
    }

    private void printTensor(StringBuilder sb, int depth, int index, int[] indices) {
        if (depth == this.dimensions) {
            int flatIndex = this.getFlatIndex(indices);
            sb.append(this.data.get(flatIndex));
            return;
        }
        sb.append("[");
        long size = this.shapes[depth];
        int i = 0;
        while ((long)i < size) {
            indices[depth] = i;
            if (depth == this.dimensions - 1 && size > 11L && i == 11) {
                sb.append("...");
                break;
            }
            if (i > 0) {
                sb.append(",");
            }
            this.printTensor(sb, depth + 1, index * (int)size + i, indices);
            ++i;
        }
        sb.append("]");
    }

    private String getDataTypeString() {
        switch (this.dataType) {
            case DT_BOOL: {
                return "bool";
            }
            case DT_BYTE: {
                return "char";
            }
            case DT_SHORT: {
                return "short";
            }
            case DT_INT: {
                return "int";
            }
            case DT_LONG: {
                return "long";
            }
            case DT_FLOAT: {
                return "float";
            }
            case DT_DOUBLE: {
                return "double";
            }
        }
        throw new IllegalArgumentException("Unsupported data type: " + (Object)((Object)this.dataType));
    }

    private int getFlatIndex(int[] indices) {
        int flatIndex = 0;
        int stride = 1;
        for (int i = this.dimensions - 1; i >= 0; --i) {
            flatIndex += indices[i] * stride;
            stride = (int)((long)stride * this.shapes[i]);
        }
        return flatIndex;
    }
}

