/*
 * Decompiled with CFR 0.152.
 */
package com.opengamma.strata.pricer.impl.tree;

import com.opengamma.strata.collect.ArgChecker;
import com.opengamma.strata.collect.array.DoubleArray;
import com.opengamma.strata.collect.array.DoubleMatrix;
import com.opengamma.strata.pricer.impl.tree.OptionFunction;
import com.opengamma.strata.product.option.BarrierType;
import java.util.Arrays;

abstract class SingleBarrierKnockoutFunction
implements OptionFunction {
    SingleBarrierKnockoutFunction() {
    }

    public abstract double getStrike();

    public abstract double getBarrierLevel(int var1);

    public abstract double getSign();

    public abstract BarrierType getBarrierType();

    public abstract double getRebate(int var1);

    @Override
    public DoubleArray getPayoffAtExpiryTrinomial(DoubleArray stateValue) {
        int nNodes = stateValue.size();
        double[] values = new double[nNodes];
        double rebate = this.getRebate(this.getNumberOfSteps());
        double barrierLevel = this.getBarrierLevel(this.getNumberOfSteps());
        boolean isDown = this.getBarrierType().isDown();
        Arrays.fill(values, rebate);
        int index = this.getLowerBoundIndex(stateValue, barrierLevel);
        ArgChecker.isTrue((index > -1 && index < nNodes - 1 ? 1 : 0) != 0, (String)"barrier is covered by tree");
        int iMin = isDown ? index + 1 : 0;
        int iMmax = !isDown ? index + 1 : nNodes;
        for (int i = iMin; i < iMmax; ++i) {
            values[i] = Math.max(this.getSign() * (stateValue.get(i) - this.getStrike()), 0.0);
        }
        double bd = barrierLevel - stateValue.get(index);
        double ub = stateValue.get(index + 1) - barrierLevel;
        double ud = stateValue.get(index + 1) - stateValue.get(index);
        if (isDown) {
            values[index + 1] = 0.5 * values[index + 1] + 0.5 * (bd * rebate + ub * values[index + 1]) / ud;
        } else {
            values[index] = barrierLevel == stateValue.get(index) ? rebate : 0.5 * values[index] + 0.5 * (ub * rebate + bd * values[index]) / ud;
        }
        return DoubleArray.ofUnsafe((double[])values);
    }

    @Override
    public DoubleArray getNextOptionValues(double discountFactor, DoubleMatrix transitionProbability, DoubleArray stateValue, DoubleArray values, int i) {
        int nNodes = 2 * i + 1;
        double[] res = new double[nNodes];
        double barrierLevel = this.getBarrierLevel(i);
        double rebate = this.getRebate(i);
        boolean isDown = this.getBarrierType().isDown();
        for (int j = 0; j < nNodes; ++j) {
            if (isDown && stateValue.get(j) <= barrierLevel || !isDown && stateValue.get(j) >= barrierLevel) {
                res[j] = rebate;
                continue;
            }
            double upProb = transitionProbability.get(j, 2);
            double middleProb = transitionProbability.get(j, 1);
            double downProb = transitionProbability.get(j, 0);
            res[j] = discountFactor * (upProb * values.get(j + 2) + middleProb * values.get(j + 1) + downProb * values.get(j));
        }
        int index = this.getLowerBoundIndex(stateValue, barrierLevel);
        if (index > -1 && index < nNodes - 1) {
            double bd = barrierLevel - stateValue.get(index);
            double ub = stateValue.get(index + 1) - barrierLevel;
            double ud = stateValue.get(index + 1) - stateValue.get(index);
            if (isDown) {
                res[index + 1] = 0.5 * res[index + 1] + 0.5 * (bd * rebate + ub * res[index + 1]) / ud;
            } else {
                res[index] = 0.5 * res[index] + 0.5 * (ub * rebate + bd * res[index]) / ud;
            }
        }
        return DoubleArray.ofUnsafe((double[])res);
    }

    private int getLowerBoundIndex(DoubleArray set, double value) {
        int n = set.size();
        if (value < set.get(0)) {
            return -1;
        }
        if (value > set.get(n - 1)) {
            return n - 1;
        }
        int index = Arrays.binarySearch(set.toArrayUnsafe(), value);
        if (index >= 0) {
            return index;
        }
        index = -(index + 1);
        if (value == -0.0 && --index < n - 1 && set.get(index + 1) == 0.0) {
            ++index;
        }
        return index;
    }
}

