001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package hivemall.xgboost.utils; 020 021import hivemall.utils.collections.arrays.SparseFloatArray; 022import ml.dmlc.xgboost4j.java.DMatrix; 023import ml.dmlc.xgboost4j.java.XGBoostError; 024 025import java.util.ArrayList; 026import java.util.Arrays; 027import java.util.List; 028 029import javax.annotation.Nonnegative; 030import javax.annotation.Nonnull; 031 032public final class DenseDMatrixBuilder extends DMatrixBuilder { 033 034 @Nonnull 035 private final List<float[]> rows; 036 private int maxNumColumns; 037 038 @Nonnull 039 private final SparseFloatArray rowProbe; 040 041 public DenseDMatrixBuilder(@Nonnegative int initSize) { 042 super(); 043 this.rows = new ArrayList<float[]>(initSize); 044 this.maxNumColumns = 0; 045 this.rowProbe = new SparseFloatArray(32); 046 } 047 048 @Override 049 public DenseDMatrixBuilder nextColumn(@Nonnegative final int col, final float value) { 050 checkColIndex(col); 051 052 this.maxNumColumns = Math.max(col + 1, maxNumColumns); 053 if (value == 0.d) { 054 return this; 055 } 056 rowProbe.put(col, value); 057 return this; 058 } 059 060 @Override 061 public DenseDMatrixBuilder nextRow() { 062 float[] row = rowProbe.toArray(); 063 rowProbe.clear(); 064 rows.add(row); 065 return this; 066 } 067 068 @Override 069 public DMatrix buildMatrix(@Nonnull float[] labels) throws XGBoostError { 070 final int numRows = rows.size(); 071 if (labels.length != numRows) { 072 throw new XGBoostError( 073 String.format("labels.length does not match to nrows. labels.length=%d, nrows=%d", 074 labels.length, numRows)); 075 } 076 077 final float[] data = new float[numRows * maxNumColumns]; 078 Arrays.fill(data, Float.NaN); 079 for (int i = 0; i < numRows; i++) { 080 final float[] row = rows.get(i); 081 final int rowPtr = i * maxNumColumns; 082 for (int j = 0; j < row.length; j++) { 083 int ij = rowPtr + j; 084 data[ij] = row[j]; 085 } 086 } 087 088 DMatrix matrix = new DMatrix(data, numRows, maxNumColumns, Float.NaN); 089 matrix.setLabel(labels); 090 return matrix; 091 } 092 093}