/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.angel.graph.client.getNodeAttrs;

import com.tencent.angel.graph.client.getNodeAttrs.GetNodeAttrsParam;
import com.tencent.angel.graph.client.getNodeAttrs.GetNodeAttrsResult;
import com.tencent.angel.graph.client.getNodeAttrs.PartGetNodeAttrsParam;
import com.tencent.angel.graph.client.getNodeAttrs.PartGetNodeAttrsResult;
import com.tencent.angel.ml.matrix.psf.get.base.GetFunc;
import com.tencent.angel.ml.matrix.psf.get.base.GetParam;
import com.tencent.angel.ml.matrix.psf.get.base.GetResult;
import com.tencent.angel.ml.matrix.psf.get.base.PartitionGetParam;
import com.tencent.angel.ml.matrix.psf.get.base.PartitionGetResult;
import com.tencent.angel.ps.storage.matrix.ServerMatrix;
import com.tencent.angel.ps.storage.partition.RowBasedPartition;
import com.tencent.angel.ps.storage.partition.ServerPartition;
import com.tencent.angel.ps.storage.vector.ServerLongAnyRow;
import com.tencent.angel.ps.storage.vector.element.FloatArrayElement;
import it.unimi.dsi.fastutil.ints.Int2ObjectArrayMap;
import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap;
import java.util.List;
import java.util.Random;

public class GetNodeAttrs
extends GetFunc {
    public GetNodeAttrs(GetNodeAttrsParam param) {
        super((GetParam)param);
    }

    public GetNodeAttrs() {
        this(null);
    }

    public PartitionGetResult partitionGet(PartitionGetParam partParam) {
        PartGetNodeAttrsParam param = (PartGetNodeAttrsParam)partParam;
        ServerMatrix matrix = this.psContext.getMatrixStorageManager().getMatrix(partParam.getMatrixId());
        ServerPartition part = matrix.getPartition(partParam.getPartKey().getPartitionId());
        ServerLongAnyRow row = (ServerLongAnyRow)((RowBasedPartition)part).getRow(0);
        long[] nodeIds = param.getNodeIds();
        float[][] attrs = new float[nodeIds.length][];
        int count = param.getCount();
        Random r = new Random();
        for (int i = 0; i < nodeIds.length; ++i) {
            long nodeId = nodeIds[i];
            FloatArrayElement element = (FloatArrayElement)row.get(nodeId);
            if (element == null) {
                attrs[i] = null;
                continue;
            }
            float[] nodeAttrs = element.getData();
            if (nodeAttrs == null || nodeAttrs.length == 0) {
                attrs[i] = null;
                continue;
            }
            if (count <= 0 || nodeAttrs.length <= count) {
                attrs[i] = nodeAttrs;
                continue;
            }
            attrs[i] = new float[count];
            int startPos = Math.abs(r.nextInt()) % nodeAttrs.length;
            if (startPos + count <= nodeAttrs.length) {
                System.arraycopy(nodeAttrs, startPos, attrs[i], 0, count);
                continue;
            }
            System.arraycopy(nodeAttrs, startPos, attrs[i], 0, nodeAttrs.length - startPos);
            System.arraycopy(nodeAttrs, 0, attrs[i], nodeAttrs.length - startPos, count - (nodeAttrs.length - startPos));
        }
        return new PartGetNodeAttrsResult(part.getPartitionKey().getPartitionId(), attrs);
    }

    public GetResult merge(List<PartitionGetResult> partResults) {
        Int2ObjectArrayMap partIdToResultMap = new Int2ObjectArrayMap(partResults.size());
        for (PartitionGetResult result : partResults) {
            partIdToResultMap.put(((PartGetNodeAttrsResult)result).getPartId(), (Object)result);
        }
        GetNodeAttrsParam param = (GetNodeAttrsParam)this.getParam();
        long[] nodeIds = param.getNodeIds();
        List<PartitionGetParam> partParams = param.getPartParams();
        Long2ObjectOpenHashMap nodeIdToAttrs = new Long2ObjectOpenHashMap(nodeIds.length);
        for (PartitionGetParam partParam : partParams) {
            int start = ((PartGetNodeAttrsParam)partParam).getStartIndex();
            int end = ((PartGetNodeAttrsParam)partParam).getEndIndex();
            PartGetNodeAttrsResult partResult = (PartGetNodeAttrsResult)((Object)partIdToResultMap.get(partParam.getPartKey().getPartitionId()));
            float[][] results = partResult.getNodeIdToAttrs();
            for (int i = start; i < end; ++i) {
                nodeIdToAttrs.put(nodeIds[i], (Object)results[i - start]);
            }
        }
        return new GetNodeAttrsResult((Long2ObjectOpenHashMap<float[]>)nodeIdToAttrs);
    }
}

