001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package hivemall.xgboost.utils;
020
021import ml.dmlc.xgboost4j.java.DMatrix;
022import ml.dmlc.xgboost4j.java.XGBoostError;
023
024import javax.annotation.Nonnegative;
025import javax.annotation.Nonnull;
026
027public abstract class DMatrixBuilder {
028
029    public DMatrixBuilder() {}
030
031    protected static final void checkColIndex(final int col) {
032        if (col < 0) {
033            throw new IllegalArgumentException("Found negative column index: " + col);
034        }
035    }
036
037    public void nextRow(@Nonnull final float[] row) {
038        for (int col = 0; col < row.length; col++) {
039            nextColumn(col, row[col]);
040        }
041        nextRow();
042    }
043
044    public void nextRow(@Nonnull final String[] row) {
045        for (String col : row) {
046            if (col == null) {
047                continue;
048            }
049            nextColumn(col);
050        }
051        nextRow();
052    }
053
054    public void nextRow(@Nonnull final String[] row, final int start, final int endEx) {
055        for (int i = start, last = Math.min(endEx, row.length); i < last; i++) {
056            String col = row[i];
057            if (col == null) {
058                continue;
059            }
060            nextColumn(col);
061        }
062        nextRow();
063    }
064
065    @Nonnull
066    public abstract DMatrixBuilder nextRow();
067
068    @Nonnull
069    public abstract DMatrixBuilder nextColumn(@Nonnegative int col, float value);
070
071    /**
072     * @throws IllegalArgumentException
073     * @throws NumberFormatException
074     */
075    @Nonnull
076    public DMatrixBuilder nextColumn(@Nonnull final String col) {
077        final int pos = col.indexOf(':');
078        if (pos == 0) {
079            throw new IllegalArgumentException("Invalid feature value representation: " + col);
080        }
081
082        final String feature;
083        final float value;
084        if (pos > 0) {
085            feature = col.substring(0, pos);
086            String s2 = col.substring(pos + 1);
087            value = Float.parseFloat(s2);
088        } else {
089            feature = col;
090            value = 1.f;
091        }
092
093        if (feature.indexOf(':') != -1) {
094            throw new IllegalArgumentException("Invalid feature format `<index>:<value>`: " + col);
095        }
096
097        int colIndex = Integer.parseInt(feature);
098        if (colIndex < 0) {
099            throw new IllegalArgumentException(
100                "Col index MUST be greater than or equals to 0: " + colIndex);
101        }
102
103        return nextColumn(colIndex, value);
104    }
105
106    @Nonnull
107    public abstract DMatrix buildMatrix(@Nonnull float[] labels) throws XGBoostError;
108
109}