/*
 * Decompiled with CFR 0.152.
 */
package com.opengamma.strata.pricer.impl.option;

import com.opengamma.strata.basics.value.ValueDerivatives;
import com.opengamma.strata.collect.ArgChecker;
import com.opengamma.strata.collect.array.DoubleArray;
import com.opengamma.strata.collect.array.DoubleMatrix;
import com.opengamma.strata.collect.tuple.Pair;
import com.opengamma.strata.math.impl.linearalgebra.SVDecompositionCommons;
import com.opengamma.strata.math.impl.linearalgebra.SVDecompositionResult;
import com.opengamma.strata.math.impl.rootfinding.BracketRoot;
import com.opengamma.strata.math.impl.rootfinding.RidderSingleRootFinder;
import com.opengamma.strata.pricer.impl.option.BlackFormulaRepository;
import com.opengamma.strata.pricer.impl.volatility.smile.SabrFormulaData;
import com.opengamma.strata.pricer.impl.volatility.smile.SabrHaganVolatilityFunctionProvider;
import com.opengamma.strata.pricer.impl.volatility.smile.VolatilityFunctionProvider;
import com.opengamma.strata.product.common.PutCall;
import java.util.Arrays;
import java.util.function.Function;

public final class SabrExtrapolationRightFunction {
    private static final SVDecompositionCommons SVD = new SVDecompositionCommons();
    private static final double SMALL_EXPIRY = 1.0E-6;
    private static final double SMALL_PARAMETER = -10000.0;
    private static final double SMALL_PRICE = 1.0E-15;
    private final VolatilityFunctionProvider<SabrFormulaData> sabrFunction;
    private final double forward;
    private final double timeToExpiry;
    private final SabrFormulaData sabrData;
    private final double cutOffStrike;
    private final double mu;
    private final double[] parameter;
    private volatile double[] parameterDerivativeForward;
    private volatile double[][] parameterDerivativeSabr;
    private volatile double volatilityK;
    private final double[] priceK = new double[3];

    public static SabrExtrapolationRightFunction of(double forward, double timeToExpiry, SabrFormulaData sabrData, double cutOffStrike, double mu) {
        return new SabrExtrapolationRightFunction(forward, sabrData, cutOffStrike, timeToExpiry, mu, SabrHaganVolatilityFunctionProvider.DEFAULT);
    }

    public static SabrExtrapolationRightFunction of(double forward, SabrFormulaData sabrData, double cutOffStrike, double timeToExpiry, double mu, VolatilityFunctionProvider<SabrFormulaData> volatilityFunction) {
        return new SabrExtrapolationRightFunction(forward, sabrData, cutOffStrike, timeToExpiry, mu, volatilityFunction);
    }

    private SabrExtrapolationRightFunction(double forward, SabrFormulaData sabrData, double cutOffStrike, double timeToExpiry, double mu, VolatilityFunctionProvider<SabrFormulaData> volatilityFunction) {
        ArgChecker.notNull((Object)sabrData, (String)"sabrData");
        ArgChecker.notNull(volatilityFunction, (String)"volatilityFunction");
        this.sabrFunction = volatilityFunction;
        this.forward = forward;
        this.sabrData = sabrData;
        this.cutOffStrike = cutOffStrike;
        this.timeToExpiry = timeToExpiry;
        this.mu = mu;
        if (timeToExpiry > 1.0E-6) {
            this.parameter = this.computesFittingParameters();
        } else {
            this.parameter = new double[]{-10000.0, 0.0, 0.0};
            this.parameterDerivativeForward = new double[3];
            this.parameterDerivativeSabr = new double[4][3];
        }
    }

    public double price(double strike, PutCall putCall) {
        if (strike <= this.cutOffStrike) {
            double volatility = this.sabrFunction.volatility(this.forward, strike, this.timeToExpiry, this.sabrData);
            return BlackFormulaRepository.price(this.forward, strike, this.timeToExpiry, volatility, putCall.isCall());
        }
        double price = this.extrapolation(strike);
        if (putCall.isPut()) {
            price -= this.forward - strike;
        }
        return price;
    }

    public double priceDerivativeStrike(double strike, PutCall putCall) {
        if (strike <= this.cutOffStrike) {
            ValueDerivatives volatilityAdjoint = this.sabrFunction.volatilityAdjoint(this.forward, strike, this.timeToExpiry, this.sabrData);
            ValueDerivatives bsAdjoint = BlackFormulaRepository.priceAdjoint(this.forward, strike, this.timeToExpiry, volatilityAdjoint.getValue(), putCall.equals((Object)PutCall.CALL));
            return bsAdjoint.getDerivative(1) + bsAdjoint.getDerivative(3) * volatilityAdjoint.getDerivative(1);
        }
        double pDK = this.extrapolationDerivative(strike);
        if (putCall.isPut()) {
            pDK += 1.0;
        }
        return pDK;
    }

    public double priceDerivativeForward(double strike, PutCall putCall) {
        double f;
        if (strike <= this.cutOffStrike) {
            ValueDerivatives volatilityA = this.sabrFunction.volatilityAdjoint(this.forward, strike, this.timeToExpiry, this.sabrData);
            ValueDerivatives pA = BlackFormulaRepository.priceAdjoint(this.forward, strike, this.timeToExpiry, volatilityA.getValue(), putCall == PutCall.CALL);
            return pA.getDerivative(0) + pA.getDerivative(3) * volatilityA.getDerivative(0);
        }
        if (this.parameterDerivativeForward == null) {
            this.parameterDerivativeForward = this.computesParametersDerivativeForward();
        }
        double fDa = f = this.extrapolation(strike);
        double fDb = f / strike;
        double fDc = fDb / strike;
        double priceDerivative = fDa * this.parameterDerivativeForward[0] + fDb * this.parameterDerivativeForward[1] + fDc * this.parameterDerivativeForward[2];
        if (putCall.isPut()) {
            priceDerivative -= 1.0;
        }
        return priceDerivative;
    }

    public ValueDerivatives priceAdjointSabr(double strike, PutCall putCall) {
        double price;
        double[] priceDerivativeSabr = new double[4];
        if (strike <= this.cutOffStrike) {
            ValueDerivatives volatilityA = this.sabrFunction.volatilityAdjoint(this.forward, strike, this.timeToExpiry, this.sabrData);
            ValueDerivatives pA = BlackFormulaRepository.priceAdjoint(this.forward, strike, this.timeToExpiry, volatilityA.getValue(), putCall == PutCall.CALL);
            price = pA.getValue();
            for (int loopparam = 0; loopparam < 4; ++loopparam) {
                priceDerivativeSabr[loopparam] = pA.getDerivative(3) * volatilityA.getDerivative(loopparam + 2);
            }
        } else {
            double f;
            if (this.parameterDerivativeSabr == null) {
                this.parameterDerivativeSabr = this.computesParametersDerivativeSabr();
            }
            double fDa = f = this.extrapolation(strike);
            double fDb = f / strike;
            double fDc = fDb / strike;
            price = putCall.isCall() ? f : f - this.forward + strike;
            for (int loopparam = 0; loopparam < 4; ++loopparam) {
                priceDerivativeSabr[loopparam] = fDa * this.parameterDerivativeSabr[loopparam][0] + fDb * this.parameterDerivativeSabr[loopparam][1] + fDc * this.parameterDerivativeSabr[loopparam][2];
            }
        }
        return ValueDerivatives.of((double)price, (DoubleArray)DoubleArray.ofUnsafe((double[])priceDerivativeSabr));
    }

    public SabrFormulaData getSabrData() {
        return this.sabrData;
    }

    public double getCutOffStrike() {
        return this.cutOffStrike;
    }

    public double getMu() {
        return this.mu;
    }

    public double getTimeToExpiry() {
        return this.timeToExpiry;
    }

    public double[] getParameter() {
        return this.parameter;
    }

    public double[] getParameterDerivativeForward() {
        if (this.parameterDerivativeForward == null) {
            this.parameterDerivativeForward = this.computesParametersDerivativeForward();
        }
        return this.parameterDerivativeForward;
    }

    public double[][] getParameterDerivativeSabr() {
        if (this.parameterDerivativeSabr == null) {
            this.parameterDerivativeSabr = this.computesParametersDerivativeSabr();
        }
        return this.parameterDerivativeSabr;
    }

    private double[] computesFittingParameters() {
        double[] param = new double[3];
        double[] vD = new double[6];
        double[][] vD2 = new double[2][2];
        this.volatilityK = this.sabrFunction.volatilityAdjoint2(this.forward, this.cutOffStrike, this.timeToExpiry, this.sabrData, vD, vD2);
        Pair<ValueDerivatives, double[][]> pa2 = BlackFormulaRepository.priceAdjoint2(this.forward, this.cutOffStrike, this.timeToExpiry, this.volatilityK, true);
        double[] bsD = ((ValueDerivatives)pa2.getFirst()).getDerivatives().toArrayUnsafe();
        double[][] bsD2 = (double[][])pa2.getSecond();
        this.priceK[0] = ((ValueDerivatives)pa2.getFirst()).getValue();
        this.priceK[1] = bsD[1] + bsD[3] * vD[1];
        this.priceK[2] = bsD2[1][1] + bsD2[1][2] * vD[1] + (bsD2[2][1] + bsD2[2][2] * vD[1]) * vD[1] + bsD[3] * vD2[1][1];
        if (Math.abs(this.priceK[0]) < 1.0E-15 && Math.abs(this.priceK[1]) < 1.0E-15 && Math.abs(this.priceK[2]) < 1.0E-15) {
            return new double[]{-100.0, 0.0, 0.0};
        }
        Function<Double, Double> toSolveC = this.getCFunction(this.priceK, this.cutOffStrike, this.mu);
        BracketRoot bracketer = new BracketRoot();
        double accuracy = 1.0E-5;
        RidderSingleRootFinder rootFinder = new RidderSingleRootFinder(accuracy);
        double[] range = bracketer.getBracketedPoints(toSolveC, -1.0, 1.0);
        param[2] = rootFinder.getRoot(toSolveC, Double.valueOf(range[0]), Double.valueOf(range[1]));
        param[1] = -2.0 * param[2] / this.cutOffStrike - (this.priceK[1] / this.priceK[0] * this.cutOffStrike + this.mu) * this.cutOffStrike;
        param[0] = Math.log(this.priceK[0] / Math.pow(this.cutOffStrike, -this.mu)) - param[1] / this.cutOffStrike - param[2] / (this.cutOffStrike * this.cutOffStrike);
        return param;
    }

    private double[] computesParametersDerivativeForward() {
        if (Math.abs(this.priceK[0]) < 1.0E-15 && Math.abs(this.priceK[1]) < 1.0E-15 && Math.abs(this.priceK[2]) < 1.0E-15) {
            return new double[]{0.0, 0.0, 0.0};
        }
        double[] pDF = new double[3];
        double shift = 1.0E-5;
        double[] vD = new double[6];
        double[][] vD2 = new double[2][2];
        this.sabrFunction.volatilityAdjoint2(this.forward, this.cutOffStrike, this.timeToExpiry, this.sabrData, vD, vD2);
        Pair<ValueDerivatives, double[][]> pa2 = BlackFormulaRepository.priceAdjoint2(this.forward, this.cutOffStrike, this.timeToExpiry, this.volatilityK, true);
        double[] bsD = ((ValueDerivatives)pa2.getFirst()).getDerivatives().toArrayUnsafe();
        double[][] bsD2 = (double[][])pa2.getSecond();
        pDF[0] = bsD[0] + bsD[3] * vD[0];
        pDF[1] = bsD2[0][1] + bsD2[2][0] * vD[1] + (bsD2[1][2] + bsD2[2][2] * vD[1]) * vD[0] + bsD[3] * vD2[1][0];
        Pair<ValueDerivatives, double[][]> pa2KP = BlackFormulaRepository.priceAdjoint2(this.forward, this.cutOffStrike * (1.0 + shift), this.timeToExpiry, this.volatilityK, true);
        double[][] bsD2KP = (double[][])pa2KP.getSecond();
        double bsD3FKK = (bsD2KP[1][0] - bsD2[1][0]) / (this.cutOffStrike * shift);
        Pair<ValueDerivatives, double[][]> pa2VP = BlackFormulaRepository.priceAdjoint2(this.forward, this.cutOffStrike, this.timeToExpiry, this.volatilityK * (1.0 + shift), true);
        double[][] bsD2VP = (double[][])pa2VP.getSecond();
        double bsD3sss = (bsD2VP[2][2] - bsD2[2][2]) / (this.volatilityK * shift);
        double bsD3sFK = (bsD2VP[0][1] - bsD2[0][1]) / (this.volatilityK * shift);
        double bsD3sFs = (bsD2VP[0][2] - bsD2[0][2]) / (this.volatilityK * shift);
        double bsD3sKK = (bsD2VP[1][1] - bsD2[1][1]) / (this.volatilityK * shift);
        double bsD3ssK = (bsD2VP[2][1] - bsD2[2][1]) / (this.volatilityK * shift);
        double[] vDKP = new double[6];
        double[][] vD2KP = new double[2][2];
        this.sabrFunction.volatilityAdjoint2(this.forward, this.cutOffStrike * (1.0 + shift), this.timeToExpiry, this.sabrData, vDKP, vD2KP);
        double vD3KKF = (vD2KP[1][0] - vD2[1][0]) / (this.cutOffStrike * shift);
        pDF[2] = bsD3FKK + bsD3sFK * vD[1] + (bsD3sFK + bsD3sFs * vD[1]) * vD[1] + bsD2[2][0] * vD2[1][1] + (bsD3sKK + bsD3ssK * vD[1] + (bsD3ssK + bsD3sss * vD[1]) * vD[1] + bsD2[2][2] * vD2[1][1]) * vD[0] + 2.0 * (bsD2[1][2] + bsD2[2][2] * vD[1]) * vD2[1][0] + bsD[3] * vD3KKF;
        double[][] fD = new double[3][3];
        double f = this.priceK[0];
        double fp = this.priceK[1];
        double fpp = this.priceK[2];
        fD[0][0] = f;
        fD[0][1] = f / this.cutOffStrike;
        fD[0][2] = fD[0][1] / this.cutOffStrike;
        fD[1][0] = fp;
        fD[1][1] = (fp - fD[0][1]) / this.cutOffStrike;
        fD[1][2] = (fp - 2.0 * fD[0][1]) / (this.cutOffStrike * this.cutOffStrike);
        fD[2][0] = fpp;
        fD[2][1] = (fpp + fD[0][2] * (2.0 * (this.mu + 1.0) + 2.0 * this.parameter[1] / this.cutOffStrike + 4.0 * this.parameter[2] / (this.cutOffStrike * this.cutOffStrike))) / this.cutOffStrike;
        fD[2][2] = (fpp + fD[0][2] * (2.0 * (2.0 * this.mu + 3.0) + 4.0 * this.parameter[1] / this.cutOffStrike + 8.0 * this.parameter[2] / (this.cutOffStrike * this.cutOffStrike))) / (this.cutOffStrike * this.cutOffStrike);
        SVDecompositionResult decmp = SVD.apply(DoubleMatrix.ofUnsafe((double[][])fD));
        return decmp.solve(pDF);
    }

    private double[][] computesParametersDerivativeSabr() {
        double[][] result = new double[4][3];
        if (Math.abs(this.priceK[0]) < 1.0E-15 && Math.abs(this.priceK[1]) < 1.0E-15 && Math.abs(this.priceK[2]) < 1.0E-15) {
            return result;
        }
        double[][] pdSabr = new double[4][3];
        double shift = 1.0E-5;
        double[] vD = new double[6];
        double[][] vD2 = new double[2][2];
        this.sabrFunction.volatilityAdjoint2(this.forward, this.cutOffStrike, this.timeToExpiry, this.sabrData, vD, vD2);
        for (int loopparam = 0; loopparam < 4; ++loopparam) {
            SabrFormulaData sabrDatapP;
            double paramShift;
            int paramIndex = 2 + loopparam;
            Pair<ValueDerivatives, double[][]> pa2 = BlackFormulaRepository.priceAdjoint2(this.forward, this.cutOffStrike, this.timeToExpiry, this.volatilityK, true);
            double[] bsD = ((ValueDerivatives)pa2.getFirst()).getDerivatives().toArrayUnsafe();
            double[][] bsD2 = (double[][])pa2.getSecond();
            pdSabr[loopparam][0] = bsD[3] * vD[paramIndex];
            double[] vDpP = new double[6];
            double[][] vD2pP = new double[2][2];
            switch (loopparam) {
                case 0: {
                    double param = this.sabrData.getAlpha();
                    paramShift = param * shift;
                    sabrDatapP = this.sabrData.withAlpha(param + paramShift);
                    break;
                }
                case 1: {
                    double param = this.sabrData.getBeta();
                    paramShift = shift;
                    sabrDatapP = this.sabrData.withBeta(param + paramShift);
                    break;
                }
                case 2: {
                    double param = this.sabrData.getRho();
                    paramShift = shift;
                    sabrDatapP = this.sabrData.withRho(param + paramShift);
                    break;
                }
                default: {
                    double param = this.sabrData.getNu();
                    paramShift = shift;
                    sabrDatapP = this.sabrData.withNu(param + paramShift);
                }
            }
            this.sabrFunction.volatilityAdjoint2(this.forward, this.cutOffStrike, this.timeToExpiry, sabrDatapP, vDpP, vD2pP);
            double vD2Kp = (vDpP[1] - vD[1]) / paramShift;
            double vD3KKa = (vD2pP[1][1] - vD2[1][1]) / paramShift;
            pdSabr[loopparam][1] = (bsD2[1][2] + bsD2[2][2] * vD[1]) * vD[paramIndex] + bsD[3] * vD2Kp;
            Pair<ValueDerivatives, double[][]> pa2VP = BlackFormulaRepository.priceAdjoint2(this.forward, this.cutOffStrike, this.timeToExpiry, this.volatilityK * (1.0 + shift), true);
            double[][] bsD2VP = (double[][])pa2VP.getSecond();
            double bsD3sss = (bsD2VP[2][2] - bsD2[2][2]) / (this.volatilityK * shift);
            double bsD3sKK = (bsD2VP[1][1] - bsD2[1][1]) / (this.volatilityK * shift);
            double bsD3ssK = (bsD2VP[2][1] - bsD2[2][1]) / (this.volatilityK * shift);
            pdSabr[loopparam][2] = (bsD3sKK + bsD3ssK * vD[1] + (bsD3ssK + bsD3sss * vD[1]) * vD[1] + bsD2[2][2] * vD2[1][1]) * vD[paramIndex] + 2.0 * (bsD2[2][1] + bsD2[2][2] * vD[1]) * vD2Kp + bsD[3] * vD3KKa;
        }
        double[][] fD = new double[3][3];
        double f = this.priceK[0];
        double fp = this.priceK[1];
        double fpp = this.priceK[2];
        fD[0][0] = f;
        fD[0][1] = f / this.cutOffStrike;
        fD[0][2] = fD[0][1] / this.cutOffStrike;
        fD[1][0] = fp;
        fD[1][1] = (fp - fD[0][1]) / this.cutOffStrike;
        fD[1][2] = (fp - 2.0 * fD[0][1]) / (this.cutOffStrike * this.cutOffStrike);
        fD[2][0] = fpp;
        fD[2][1] = (fpp + fD[0][2] * (2.0 * (this.mu + 1.0) + 2.0 * this.parameter[1] / this.cutOffStrike + 4.0 * this.parameter[2] / (this.cutOffStrike * this.cutOffStrike))) / this.cutOffStrike;
        fD[2][2] = (fpp + fD[0][2] * (2.0 * (2.0 * this.mu + 3.0) + 4.0 * this.parameter[1] / this.cutOffStrike + 8.0 * this.parameter[2] / (this.cutOffStrike * this.cutOffStrike))) / (this.cutOffStrike * this.cutOffStrike);
        DoubleMatrix fDmatrix = DoubleMatrix.ofUnsafe((double[][])fD);
        SVDecompositionResult decmp = SVD.apply(fDmatrix);
        for (int loopparam = 0; loopparam < 4; ++loopparam) {
            result[loopparam] = decmp.solve(pdSabr[loopparam]);
        }
        return result;
    }

    private double extrapolation(double strike) {
        return Math.pow(strike, -this.mu) * Math.exp(this.parameter[0] + this.parameter[1] / strike + this.parameter[2] / (strike * strike));
    }

    private double extrapolationDerivative(double strike) {
        return -this.extrapolation(strike) * (this.mu + (this.parameter[1] + 2.0 * this.parameter[2] / strike) / strike) / strike;
    }

    private Function<Double, Double> getCFunction(double[] price, final double cutOffStrike, final double mu) {
        final double[] cPrice = Arrays.copyOf(price, price.length);
        return new Function<Double, Double>(){

            @Override
            public Double apply(Double c) {
                double b = -2.0 * c / cutOffStrike - (cPrice[1] / cPrice[0] * cutOffStrike + mu) * cutOffStrike;
                double k2 = cutOffStrike * cutOffStrike;
                double res = -cPrice[2] / cPrice[0] * k2 + mu * (mu + 1.0) + 2.0 * b * (mu + 1.0) / cutOffStrike + (2.0 * c * (2.0 * mu + 3.0) + b * b) / k2 + 4.0 * b * c / (k2 * cutOffStrike) + 4.0 * c * c / (k2 * k2);
                return res;
            }
        };
    }
}

