001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package hivemall.xgboost.utils;
020
021import hivemall.utils.collections.lists.FloatArrayList;
022import hivemall.utils.collections.lists.LongArrayList;
023import matrix4j.utils.collections.lists.IntArrayList;
024import ml.dmlc.xgboost4j.java.DMatrix;
025import ml.dmlc.xgboost4j.java.XGBoostError;
026
027import java.util.ArrayList;
028import java.util.Collections;
029import java.util.List;
030
031import javax.annotation.Nonnegative;
032import javax.annotation.Nonnull;
033
034public final class SparseDMatrixBuilder extends DMatrixBuilder {
035
036    private final boolean sortRequired;
037
038    @Nonnull
039    private final LongArrayList rowPointers;
040    @Nonnull
041    private final IntArrayList columnIndices;
042    @Nonnull
043    private final FloatArrayList values;
044
045    @Nonnull
046    private final List<ColValue> colCache;
047
048    private int maxNumColumns;
049
050    public SparseDMatrixBuilder(@Nonnegative int initSize) {
051        this(initSize, true);
052    }
053
054    public SparseDMatrixBuilder(@Nonnegative int initSize, boolean sortRequired) {
055        super();
056        this.sortRequired = sortRequired;
057        this.rowPointers = new LongArrayList(initSize + 1);
058        rowPointers.add(0);
059        this.columnIndices = new IntArrayList(initSize);
060        this.values = new FloatArrayList(initSize);
061        this.colCache = new ArrayList<>(32);
062        this.maxNumColumns = 0;
063    }
064
065    @Nonnull
066    public SparseDMatrixBuilder nextRow() {
067        if (sortRequired) {
068            Collections.sort(colCache);
069        }
070        for (ColValue e : colCache) {
071            columnIndices.add(e.col);
072            values.add(e.value);
073        }
074        colCache.clear();
075
076        int ptr = values.size();
077        rowPointers.add(ptr);
078        return this;
079    }
080
081    @Nonnull
082    public SparseDMatrixBuilder nextColumn(@Nonnegative int col, float value) {
083        checkColIndex(col);
084
085        this.maxNumColumns = Math.max(col + 1, maxNumColumns);
086        if (value == 0.d) {
087            return this;
088        }
089
090        colCache.add(new ColValue(col, value));
091        return this;
092    }
093
094    @Nonnull
095    public DMatrix buildMatrix(@Nonnull float[] labels) throws XGBoostError {
096        DMatrix matrix = new DMatrix(rowPointers.toArray(true), columnIndices.toArray(true),
097            values.toArray(true), DMatrix.SparseType.CSR, maxNumColumns);
098        matrix.setLabel(labels);
099        return matrix;
100    }
101
102    private static final class ColValue implements Comparable<ColValue> {
103        final int col;
104        final float value;
105
106        ColValue(int col, float value) {
107            this.col = col;
108            this.value = value;
109        }
110
111        @Override
112        public int compareTo(ColValue o) {
113            return Integer.compare(col, o.col);
114        }
115
116        @Override
117        public String toString() {
118            return "[column=" + col + ", value=" + value + ']';
119        }
120
121    }
122
123}