001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package hivemall.xgboost.utils; 020 021import hivemall.utils.collections.lists.FloatArrayList; 022import hivemall.utils.collections.lists.LongArrayList; 023import matrix4j.utils.collections.lists.IntArrayList; 024import ml.dmlc.xgboost4j.java.DMatrix; 025import ml.dmlc.xgboost4j.java.XGBoostError; 026 027import java.util.ArrayList; 028import java.util.Collections; 029import java.util.List; 030 031import javax.annotation.Nonnegative; 032import javax.annotation.Nonnull; 033 034public final class SparseDMatrixBuilder extends DMatrixBuilder { 035 036 private final boolean sortRequired; 037 038 @Nonnull 039 private final LongArrayList rowPointers; 040 @Nonnull 041 private final IntArrayList columnIndices; 042 @Nonnull 043 private final FloatArrayList values; 044 045 @Nonnull 046 private final List<ColValue> colCache; 047 048 private int maxNumColumns; 049 050 public SparseDMatrixBuilder(@Nonnegative int initSize) { 051 this(initSize, true); 052 } 053 054 public SparseDMatrixBuilder(@Nonnegative int initSize, boolean sortRequired) { 055 super(); 056 this.sortRequired = sortRequired; 057 this.rowPointers = new LongArrayList(initSize + 1); 058 rowPointers.add(0); 059 this.columnIndices = new IntArrayList(initSize); 060 this.values = new FloatArrayList(initSize); 061 this.colCache = new ArrayList<>(32); 062 this.maxNumColumns = 0; 063 } 064 065 @Nonnull 066 public SparseDMatrixBuilder nextRow() { 067 if (sortRequired) { 068 Collections.sort(colCache); 069 } 070 for (ColValue e : colCache) { 071 columnIndices.add(e.col); 072 values.add(e.value); 073 } 074 colCache.clear(); 075 076 int ptr = values.size(); 077 rowPointers.add(ptr); 078 return this; 079 } 080 081 @Nonnull 082 public SparseDMatrixBuilder nextColumn(@Nonnegative int col, float value) { 083 checkColIndex(col); 084 085 this.maxNumColumns = Math.max(col + 1, maxNumColumns); 086 if (value == 0.d) { 087 return this; 088 } 089 090 colCache.add(new ColValue(col, value)); 091 return this; 092 } 093 094 @Nonnull 095 public DMatrix buildMatrix(@Nonnull float[] labels) throws XGBoostError { 096 DMatrix matrix = new DMatrix(rowPointers.toArray(true), columnIndices.toArray(true), 097 values.toArray(true), DMatrix.SparseType.CSR, maxNumColumns); 098 matrix.setLabel(labels); 099 return matrix; 100 } 101 102 private static final class ColValue implements Comparable<ColValue> { 103 final int col; 104 final float value; 105 106 ColValue(int col, float value) { 107 this.col = col; 108 this.value = value; 109 } 110 111 @Override 112 public int compareTo(ColValue o) { 113 return Integer.compare(col, o.col); 114 } 115 116 @Override 117 public String toString() { 118 return "[column=" + col + ", value=" + value + ']'; 119 } 120 121 } 122 123}