/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.angel.ml.psf.columns;

import com.tencent.angel.exception.AngelException;
import com.tencent.angel.ml.math2.VFactory;
import com.tencent.angel.ml.math2.vector.CompIntDoubleVector;
import com.tencent.angel.ml.math2.vector.CompIntFloatVector;
import com.tencent.angel.ml.math2.vector.IntDoubleVector;
import com.tencent.angel.ml.math2.vector.IntFloatVector;
import com.tencent.angel.ml.math2.vector.Vector;
import com.tencent.angel.ml.matrix.psf.get.base.GetFunc;
import com.tencent.angel.ml.matrix.psf.get.base.GetParam;
import com.tencent.angel.ml.matrix.psf.get.base.GetResult;
import com.tencent.angel.ml.matrix.psf.get.base.PartitionGetParam;
import com.tencent.angel.ml.matrix.psf.get.base.PartitionGetResult;
import com.tencent.angel.ml.psf.columns.GetColsParam;
import com.tencent.angel.ml.psf.columns.GetColsResult;
import com.tencent.angel.ml.psf.columns.PartitionGetColsParam;
import com.tencent.angel.ml.psf.columns.PartitionGetColsResult;
import com.tencent.angel.ps.server.data.request.InitFunc;
import com.tencent.angel.ps.storage.partition.RowBasedPartition;
import com.tencent.angel.ps.storage.vector.ServerIntDoubleRow;
import com.tencent.angel.ps.storage.vector.ServerIntFloatRow;
import com.tencent.angel.ps.storage.vector.ServerLongDoubleRow;
import com.tencent.angel.ps.storage.vector.ServerLongFloatRow;
import com.tencent.angel.ps.storage.vector.ServerRow;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;

public class GetColsFunc
extends GetFunc {
    public GetColsFunc(GetColsParam param) {
        super((GetParam)param);
    }

    public GetColsFunc() {
        super(null);
    }

    public PartitionGetResult partitionGet(PartitionGetParam partParam) {
        PartitionGetColsParam param = (PartitionGetColsParam)partParam;
        int[] rows = param.rows;
        long[] cols = param.cols;
        int matId = param.getMatrixId();
        int partitionId = param.getPartKey().getPartitionId();
        Arrays.sort(rows);
        RowBasedPartition partition = (RowBasedPartition)this.psContext.getMatrixStorageManager().getPart(matId, partitionId);
        ServerRow[] splits = new ServerRow[rows.length];
        for (int i = 0; i < rows.length; ++i) {
            splits[i] = partition.getRow(rows[i]);
        }
        Vector result = this.doGet(splits, cols, param.func);
        return new PartitionGetColsResult(rows, cols, result);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private Vector doGet(ServerRow[] rows, long[] cols, InitFunc func) {
        if (func != null) {
            rows[0].startWrite();
            try {
                Vector vector = this.doGetLockFree(rows, cols, func);
                return vector;
            }
            finally {
                rows[0].endWrite();
            }
        }
        rows[0].startRead();
        try {
            Vector vector = this.doGetLockFree(rows, cols, func);
            return vector;
        }
        finally {
            rows[0].endRead();
        }
    }

    private Vector doGetLockFree(ServerRow[] rows, long[] cols, InitFunc func) {
        if (rows[0] instanceof ServerIntDoubleRow) {
            IntDoubleVector[] vectors = new IntDoubleVector[cols.length];
            if (func != null) {
                for (int i = 0; i < cols.length; ++i) {
                    vectors[i] = VFactory.denseDoubleVector((int)rows.length);
                    for (int j = 0; j < rows.length; ++j) {
                        vectors[i].set(j, ((ServerIntDoubleRow)rows[j]).initAndGet((int)cols[i], func));
                    }
                }
            } else {
                for (int i = 0; i < cols.length; ++i) {
                    vectors[i] = VFactory.denseDoubleVector((int)rows.length);
                    for (int j = 0; j < rows.length; ++j) {
                        vectors[i].set(j, ((ServerIntDoubleRow)rows[j]).get((int)cols[i]));
                    }
                }
            }
            return VFactory.compIntDoubleVector((int)cols.length, (IntDoubleVector[])vectors, (int)rows.length);
        }
        if (rows[0] instanceof ServerLongDoubleRow) {
            IntDoubleVector[] vectors = new IntDoubleVector[cols.length];
            if (func != null) {
                for (int i = 0; i < cols.length; ++i) {
                    vectors[i] = VFactory.denseDoubleVector((int)rows.length);
                    for (int j = 0; j < rows.length; ++j) {
                        vectors[i].set(j, ((ServerLongDoubleRow)rows[j]).initAndGet(cols[i], func));
                    }
                }
            } else {
                for (int i = 0; i < cols.length; ++i) {
                    vectors[i] = VFactory.denseDoubleVector((int)rows.length);
                    for (int j = 0; j < rows.length; ++j) {
                        vectors[i].set(j, ((ServerLongDoubleRow)rows[j]).get(cols[i]));
                    }
                }
            }
            return VFactory.compIntDoubleVector((int)cols.length, (IntDoubleVector[])vectors, (int)rows.length);
        }
        if (rows[0] instanceof ServerIntFloatRow) {
            IntFloatVector[] vectors = new IntFloatVector[cols.length];
            if (func != null) {
                for (int i = 0; i < cols.length; ++i) {
                    vectors[i] = VFactory.denseFloatVector((int)rows.length);
                    for (int j = 0; j < rows.length; ++j) {
                        vectors[i].set(j, ((ServerIntFloatRow)rows[j]).initAndGet((int)cols[i], func));
                    }
                }
            } else {
                for (int i = 0; i < cols.length; ++i) {
                    vectors[i] = VFactory.denseFloatVector((int)rows.length);
                    for (int j = 0; j < rows.length; ++j) {
                        vectors[i].set(j, ((ServerIntFloatRow)rows[j]).get((int)cols[i]));
                    }
                }
            }
            return VFactory.compIntFloatVector((int)cols.length, (IntFloatVector[])vectors, (int)rows.length);
        }
        if (rows[0] instanceof ServerLongFloatRow) {
            IntFloatVector[] vectors = new IntFloatVector[cols.length];
            if (func != null) {
                for (int i = 0; i < cols.length; ++i) {
                    vectors[i] = VFactory.denseFloatVector((int)rows.length);
                    for (int j = 0; j < rows.length; ++j) {
                        vectors[i].set(j, ((ServerLongFloatRow)rows[j]).initAndGet(cols[i], func));
                    }
                }
            } else {
                for (int i = 0; i < cols.length; ++i) {
                    vectors[i] = VFactory.denseFloatVector((int)rows.length);
                    for (int j = 0; j < rows.length; ++j) {
                        vectors[i].set(j, ((ServerLongFloatRow)rows[j]).get(cols[i]));
                    }
                }
            }
            return VFactory.compIntFloatVector((int)cols.length, (IntFloatVector[])vectors, (int)rows.length);
        }
        throw new AngelException("The rowType " + rows[0].getRowType() + " is not support!");
    }

    public GetResult merge(List<PartitionGetResult> partResults) {
        PartitionGetColsResult rr = (PartitionGetColsResult)partResults.get(0);
        if (rr.vector instanceof CompIntDoubleVector) {
            HashMap<Long, Vector> maps = new HashMap<Long, Vector>();
            for (PartitionGetResult r : partResults) {
                PartitionGetColsResult rrr = (PartitionGetColsResult)r;
                long[] cols = rrr.cols;
                CompIntDoubleVector vector = (CompIntDoubleVector)rrr.vector;
                for (int i = 0; i < cols.length; ++i) {
                    maps.put(cols[i], (Vector)vector.getPartitions()[i]);
                }
            }
            return new GetColsResult(maps);
        }
        if (rr.vector instanceof CompIntFloatVector) {
            HashMap<Long, Vector> maps = new HashMap<Long, Vector>();
            for (PartitionGetResult r : partResults) {
                PartitionGetColsResult rrr = (PartitionGetColsResult)r;
                long[] cols = rrr.cols;
                CompIntFloatVector vector = (CompIntFloatVector)rrr.vector;
                for (int i = 0; i < cols.length; ++i) {
                    maps.put(cols[i], (Vector)vector.getPartitions()[i]);
                }
            }
            return new GetColsResult(maps);
        }
        throw new AngelException("Data type should be double or float!");
    }
}

