/*
 * Tencent is pleased to support the open source community by making Angel available.
 *
 * Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 
 * compliance with the License. You may obtain a copy of the License at
 *
 * https://opensource.org/licenses/Apache-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations under
 * the License.
 *
 */


package com.tencent.angel.ml.GBDT.objective;

import com.tencent.angel.ml.GBDT.algo.RegTree.GradPair;
import com.tencent.angel.ml.GBDT.algo.RegTree.RegTDataStore;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import java.util.ArrayList;
import java.util.List;

/**
 * Description:
 */

public class RegLossObj implements ObjFunc {

  private static final Log LOG = LogFactory.getLog(RegLossObj.class);

  private LossHelper loss;
  private float scalePosWeight;

  public RegLossObj(LossHelper loss) {
    this.loss = loss;
    this.scalePosWeight = 1.0f;
  }

  /**
   * get gradient over each of predictions, given existing information. preds: prediction of current
   * round info information about labels, weights, groups in rank iteration current iteration
   * number. return:_gpair output of get gradient, saves gradient and second order gradient in
   *
   * @param preds:     predictive value
   * @param dataStore: data meta info
   * @param iteration: current interation
   */
  @Override public GradPair[] calGrad(float[] preds, RegTDataStore dataStore, int iteration) {
    assert preds.length > 0;
    assert preds.length == dataStore.labels.length;
    int ndata = preds.length; // number of data instances
    GradPair[] rec = new GradPair[ndata];
    // check if label in range
    boolean label_correct = true;
    for (int i = 0; i < ndata; i++) {
      float p = loss.transPred(preds[i]);
      float w = dataStore.getWeight(i);
      if (dataStore.labels[i] == 1.0f)
        w *= scalePosWeight;
      if (!loss.checkLabel(dataStore.labels[i]))
        label_correct = false;
      GradPair pair = new GradPair(loss.firOrderGrad(p, dataStore.labels[i]) * w,
        loss.secOrderGrad(p, dataStore.labels[i]) * w);
      rec[i] = pair;
    }
    if (!label_correct) {
      LOG.error(loss.labelErrorMsg());
    }
    return rec;
  }

  /**
   * transform prediction values, this is only called when Prediction is called preds: prediction
   * values, saves to this vector as well
   *
   * @param preds
   */
  @Override public void transPred(float[] preds) {
    int ndata = preds.length;
    for (int j = 0; j < ndata; ++j) {
      preds[j] = loss.transPred(preds[j]);
    }
  }

  /**
   * transform prediction values, this is only called when Eval is called usually it redirect to
   * transPred preds: prediction values, saves to this vector as well
   *
   * @param preds
   */
  @Override public void transEval(float[] preds) {
    this.transPred(preds);
  }

  /**
   * transform probability value back to margin this is used to transform user-set base_score back
   * to margin used by gradient boosting
   *
   * @param base_score
   */
  @Override public float prob2Margin(float base_score) {
    return loss.prob2Margin(base_score);
  }

  @Override public String defaultEvalMetric() {
    return loss.defaultEvalMetric();
  }
}