001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package hivemall.xgboost.utils;
020
021import hivemall.utils.collections.arrays.SparseFloatArray;
022import ml.dmlc.xgboost4j.java.DMatrix;
023import ml.dmlc.xgboost4j.java.XGBoostError;
024
025import java.util.ArrayList;
026import java.util.Arrays;
027import java.util.List;
028
029import javax.annotation.Nonnegative;
030import javax.annotation.Nonnull;
031
032public final class DenseDMatrixBuilder extends DMatrixBuilder {
033
034    @Nonnull
035    private final List<float[]> rows;
036    private int maxNumColumns;
037
038    @Nonnull
039    private final SparseFloatArray rowProbe;
040
041    public DenseDMatrixBuilder(@Nonnegative int initSize) {
042        super();
043        this.rows = new ArrayList<float[]>(initSize);
044        this.maxNumColumns = 0;
045        this.rowProbe = new SparseFloatArray(32);
046    }
047
048    @Override
049    public DenseDMatrixBuilder nextColumn(@Nonnegative final int col, final float value) {
050        checkColIndex(col);
051
052        this.maxNumColumns = Math.max(col + 1, maxNumColumns);
053        if (value == 0.d) {
054            return this;
055        }
056        rowProbe.put(col, value);
057        return this;
058    }
059
060    @Override
061    public DenseDMatrixBuilder nextRow() {
062        float[] row = rowProbe.toArray();
063        rowProbe.clear();
064        rows.add(row);
065        return this;
066    }
067
068    @Override
069    public DMatrix buildMatrix(@Nonnull float[] labels) throws XGBoostError {
070        final int numRows = rows.size();
071        if (labels.length != numRows) {
072            throw new XGBoostError(
073                String.format("labels.length does not match to nrows. labels.length=%d, nrows=%d",
074                    labels.length, numRows));
075        }
076
077        final float[] data = new float[numRows * maxNumColumns];
078        Arrays.fill(data, Float.NaN);
079        for (int i = 0; i < numRows; i++) {
080            final float[] row = rows.get(i);
081            final int rowPtr = i * maxNumColumns;
082            for (int j = 0; j < row.length; j++) {
083                int ij = rowPtr + j;
084                data[ij] = row[j];
085            }
086        }
087
088        DMatrix matrix = new DMatrix(data, numRows, maxNumColumns, Float.NaN);
089        matrix.setLabel(labels);
090        return matrix;
091    }
092
093}