/*
 * Decompiled with CFR 0.152.
 */
package com.opengamma.strata.pricer.impl.volatility.smile;

import com.opengamma.strata.basics.value.ValueDerivatives;
import com.opengamma.strata.collect.ArgChecker;
import com.opengamma.strata.collect.array.DoubleArray;
import com.opengamma.strata.market.ValueType;
import com.opengamma.strata.math.MathUtils;
import com.opengamma.strata.pricer.model.SabrVolatilityFormula;

public final class SabrHaganNormalVolatilityFormula
implements SabrVolatilityFormula {
    public static final SabrHaganNormalVolatilityFormula DEFAULT = new SabrHaganNormalVolatilityFormula();
    private static final double SMALL_Z = 1.0E-6;
    private static final double RHO_EPS = 1.0E-5;

    private SabrHaganNormalVolatilityFormula() {
    }

    @Override
    public ValueType getVolatilityType() {
        return ValueType.NORMAL_VOLATILITY;
    }

    @Override
    public double volatility(double forward, double strike, double timeToExpiry, double alpha, double beta, double rho, double nu) {
        if (beta == 0.0) {
            return this.volatilityBeta0(forward, strike, timeToExpiry, alpha, rho, nu);
        }
        return this.volatilityBetaNonZero(forward, strike, timeToExpiry, alpha, beta, rho, nu);
    }

    @Override
    public ValueDerivatives volatilityAdjoint(double forward, double strike, double timeToExpiry, double alpha, double beta, double rho, double nu) {
        if (beta == 0.0) {
            return this.volatilityBeta0Adjoint(forward, strike, timeToExpiry, alpha, rho, nu);
        }
        return this.volatilityBetaNonZeroAdjoint(forward, strike, timeToExpiry, alpha, beta, rho, nu);
    }

    public double volatilityBetaNonZero(double forward, double strike, double timeToExpiry, double alpha, double beta, double rho, double nu) {
        ArgChecker.isTrue((forward > 0.0 ? 1 : 0) != 0, (String)"forward must be positive");
        ArgChecker.isTrue((strike > 0.0 ? 1 : 0) != 0, (String)"strike must be positive");
        ArgChecker.isTrue((rho < 0.99999 ? 1 : 0) != 0, (String)"rho must be below 1 and not too close to 1");
        ArgChecker.isTrue((rho > -0.99999 ? 1 : 0) != 0, (String)"rho must be above -1 and not too close to -1");
        double logfK = Math.log(forward / strike);
        double logfK2 = logfK * logfK;
        double logfK4 = logfK2 * logfK2;
        double fK = forward * strike;
        double oneminusbeta = 1.0 - beta;
        double fKoneminusbeta = Math.pow(fK, oneminusbeta);
        double oneminusbeta2 = oneminusbeta * oneminusbeta;
        double oneminusbeta4 = oneminusbeta2 * oneminusbeta2;
        double zeta = nu / alpha * Math.pow(fK, 0.5 * oneminusbeta) * logfK;
        double term1 = alpha * Math.pow(fK, 0.5 * beta);
        double term2 = (1.0 + logfK2 / 24.0 + logfK4 / 1920.0) / (1.0 + oneminusbeta2 * logfK2 / 24.0 + oneminusbeta4 * logfK4 / 1920.0);
        double term3 = this.zetaOverXhat(zeta, rho);
        double term4 = 1.0 + (-beta * (2.0 - beta) * alpha * alpha / (24.0 * fKoneminusbeta) + rho * alpha * nu * beta / (4.0 * Math.sqrt(fKoneminusbeta)) + (2.0 - 3.0 * rho * rho) / 24.0 * nu * nu) * timeToExpiry;
        return term1 * term2 * term3 * term4;
    }

    public ValueDerivatives volatilityBetaNonZeroAdjoint(double forward, double strike, double timeToExpiry, double alpha, double beta, double rho, double nu) {
        ArgChecker.isTrue((forward > 0.0 ? 1 : 0) != 0, (String)"forward must be positive");
        ArgChecker.isTrue((strike > 0.0 ? 1 : 0) != 0, (String)"strike must be positive");
        ArgChecker.isTrue((rho < 0.99999 ? 1 : 0) != 0, (String)"rho must be below 1 and not too close to 1");
        ArgChecker.isTrue((rho > -0.99999 ? 1 : 0) != 0, (String)"rho must be above -1 and not too close to -1");
        double logfK = Math.log(forward / strike);
        double logfK2 = logfK * logfK;
        double logfK4 = logfK2 * logfK2;
        double fK = forward * strike;
        double oneminusbeta = 1.0 - beta;
        double fKoneminusbeta = Math.pow(fK, oneminusbeta);
        double oneminusbeta2 = oneminusbeta * oneminusbeta;
        double oneminusbeta4 = oneminusbeta2 * oneminusbeta2;
        double zeta = nu / alpha * Math.pow(fK, 0.5 * oneminusbeta) * logfK;
        double term1 = alpha * Math.pow(fK, 0.5 * beta);
        double term2Num = 1.0 + logfK2 / 24.0 + logfK4 / 1920.0;
        double term2Den = 1.0 + oneminusbeta2 * logfK2 / 24.0 + oneminusbeta4 * logfK4 / 1920.0;
        double term2 = term2Num / term2Den;
        ValueDerivatives term3 = this.zetaOverXhatAdjoint(zeta, rho);
        double term4 = 1.0 + (-beta * (2.0 - beta) * alpha * alpha / (24.0 * fKoneminusbeta) + rho * alpha * nu * beta / (4.0 * Math.sqrt(fKoneminusbeta)) + (2.0 - 3.0 * rho * rho) / 24.0 * nu * nu) * timeToExpiry;
        double volatility = term1 * term2 * term3.getValue() * term4;
        double term1Bar = term2 * term3.getValue() * term4;
        double term2Bar = term1 * term3.getValue() * term4;
        double term3Bar = term1 * term2 * term4;
        double term4Bar = term1 * term2 * term3.getValue();
        double betaBar = ((-2.0 + 2.0 * beta) * alpha * alpha / (24.0 * fKoneminusbeta) + rho * alpha * nu / (4.0 * Math.sqrt(fKoneminusbeta))) * timeToExpiry * term4Bar;
        double alphaBar = -beta * (2.0 - beta) * alpha / (12.0 * fKoneminusbeta) * timeToExpiry * term4Bar;
        alphaBar += rho * nu * beta / (4.0 * Math.sqrt(fKoneminusbeta)) * timeToExpiry * term4Bar;
        double fKoneminusbetaBar = beta * (2.0 - beta) * alpha * alpha / (24.0 * fKoneminusbeta * fKoneminusbeta) * timeToExpiry * term4Bar;
        fKoneminusbetaBar += -0.5 * rho * alpha * nu * beta / (4.0 * Math.pow(fKoneminusbeta, 1.5)) * timeToExpiry * term4Bar;
        double rhoBar = alpha * nu * beta / (4.0 * Math.sqrt(fKoneminusbeta)) * timeToExpiry * term4Bar;
        rhoBar += -6.0 * rho / 24.0 * nu * nu * timeToExpiry * term4Bar;
        double nuBar = rho * alpha * beta / (4.0 * Math.sqrt(fKoneminusbeta)) * timeToExpiry * term4Bar;
        nuBar += (2.0 - 3.0 * rho * rho) / 12.0 * nu * timeToExpiry * term4Bar;
        double zetaBar = term3.getDerivative(0) * term3Bar;
        rhoBar += term3.getDerivative(1) * term3Bar;
        double term2NumBar = 1.0 / term2Den * term2Bar;
        double term2DenBar = -term2Num / (term2Den * term2Den) * term2Bar;
        double oneminusbeta2Bar = logfK2 / 24.0 * term2DenBar;
        double logfK2Bar = oneminusbeta2 / 24.0 * term2DenBar;
        double oneminusbeta4Bar = logfK4 / 1920.0 * term2DenBar;
        double logfK4Bar = oneminusbeta4 / 1920.0 * term2DenBar;
        logfK2Bar += 0.041666666666666664 * term2NumBar;
        logfK4Bar += 5.208333333333333E-4 * term2NumBar;
        alphaBar += Math.pow(fK, 0.5 * beta) * term1Bar;
        double fKBar = 0.5 * beta * alpha * Math.pow(fK, 0.5 * beta - 1.0) * term1Bar;
        betaBar += alpha * Math.pow(fK, 0.5 * beta) * 0.5 * Math.log(fK) * term1Bar;
        nuBar += 1.0 / alpha * Math.pow(fK, 0.5 * oneminusbeta) * logfK * zetaBar;
        alphaBar += -nu / (alpha * alpha) * Math.pow(fK, 0.5 * oneminusbeta) * logfK * zetaBar;
        fKBar += 0.5 * oneminusbeta * nu / alpha * Math.pow(fK, 0.5 * oneminusbeta - 1.0) * logfK * zetaBar;
        double oneminusbetaBar = nu / alpha * Math.pow(fK, 0.5 * oneminusbeta) * 0.5 * Math.log(fK) * logfK * zetaBar;
        double logfKBar = nu / alpha * Math.pow(fK, 0.5 * oneminusbeta) * zetaBar;
        oneminusbetaBar += 2.0 * oneminusbeta * (oneminusbeta2Bar += 2.0 * oneminusbeta2 * oneminusbeta4Bar);
        double strikeBar = forward * (fKBar += oneminusbeta * Math.pow(fK, oneminusbeta - 1.0) * fKoneminusbetaBar);
        double forwardBar = strike * fKBar;
        return ValueDerivatives.of((double)volatility, (DoubleArray)DoubleArray.of((double)(forwardBar += 1.0 / forward * (logfKBar += 2.0 * logfK * (logfK2Bar += 2.0 * logfK2 * logfK4Bar))), (double)(strikeBar += -1.0 / strike * logfKBar), (double)alphaBar, (double)(betaBar += -(oneminusbetaBar += Math.pow(fK, oneminusbeta) * Math.log(fK) * fKoneminusbetaBar)), (double)rhoBar, (double)nuBar));
    }

    public double volatilityBeta0(double forward, double strike, double timeToExpiry, double alpha, double rho, double nu) {
        ArgChecker.isTrue((rho < 0.99999 ? 1 : 0) != 0, (String)"rho must be below 1 and not too close to 1");
        ArgChecker.isTrue((rho > -0.99999 ? 1 : 0) != 0, (String)"rho must be above -1 and not too close to -1");
        double zeta = nu / alpha * (forward - strike);
        double term3 = this.zetaOverXhat(zeta, rho);
        double term4 = 1.0 + (2.0 - 3.0 * rho * rho) / 24.0 * nu * nu * timeToExpiry;
        double volatility = alpha * term3 * term4;
        return volatility;
    }

    public ValueDerivatives volatilityBeta0Adjoint(double forward, double strike, double timeToExpiry, double alpha, double rho, double nu) {
        ArgChecker.isTrue((rho < 0.99999 ? 1 : 0) != 0, (String)"rho must be below 1 and not too close to 1");
        ArgChecker.isTrue((rho > -0.99999 ? 1 : 0) != 0, (String)"rho must be above -1 and not too close to -1");
        double zeta = nu / alpha * (forward - strike);
        ValueDerivatives term3 = this.zetaOverXhatAdjoint(zeta, rho);
        double term4 = 1.0 + (2.0 - 3.0 * rho * rho) / 24.0 * nu * nu * timeToExpiry;
        double volatility = alpha * term3.getValue() * term4;
        double alphaBar = term3.getValue() * term4;
        double term3Bar = alpha * term4;
        double term4Bar = alpha * term3.getValue();
        double rhoBar = -0.25 * rho * nu * nu * timeToExpiry * term4Bar;
        double nuBar = (2.0 - 3.0 * rho * rho) / 12.0 * nu * timeToExpiry * term4Bar;
        double zetaBar = term3.getDerivative(0) * term3Bar;
        double forwardBar = nu / alpha * zetaBar;
        double strikeBar = -nu / alpha * zetaBar;
        return ValueDerivatives.of((double)volatility, (DoubleArray)DoubleArray.of((double)forwardBar, (double)strikeBar, (double)(alphaBar += -nu / (alpha * alpha) * (forward - strike) * zetaBar), (double)(rhoBar += term3.getDerivative(1) * term3Bar), (double)(nuBar += 1.0 / alpha * (forward - strike) * zetaBar)));
    }

    protected double zetaOverXhat(double zeta, double rho) {
        if (MathUtils.nearZero((double)zeta, (double)1.0E-6)) {
            return 1.0 - rho * zeta / 2.0;
        }
        double c0 = 1.0 - 2.0 * rho * zeta + zeta * zeta;
        double c1 = Math.sqrt(c0) - rho + zeta;
        double c2 = 1.0 - rho;
        double c3 = c1 / c2;
        double xhat = Math.log(c3);
        return zeta / xhat;
    }

    protected ValueDerivatives zetaOverXhatAdjoint(double zeta, double rho) {
        if (MathUtils.nearZero((double)zeta, (double)1.0E-6)) {
            double zetaOverXhat = 1.0 - 0.5 * rho * zeta;
            double rhoBar = -0.5 * zeta;
            double zetaBar = -0.5 * rho;
            return ValueDerivatives.of((double)zetaOverXhat, (DoubleArray)DoubleArray.of((double)zetaBar, (double)rhoBar));
        }
        double c0 = 1.0 - 2.0 * rho * zeta + zeta * zeta;
        double c1 = Math.sqrt(c0) - rho + zeta;
        double c2 = 1.0 - rho;
        double c3 = c1 / c2;
        double xhat = Math.log(c3);
        double zetaOverXhat = zeta / xhat;
        double zetaBar = 1.0 / xhat;
        double xhatBar = -zeta / (xhat * xhat);
        double c3Bar = 1.0 / c3 * xhatBar;
        double c1Bar = 1.0 / c2 * c3Bar;
        double c2Bar = -c1 / (c2 * c2) * c3Bar;
        double rhoBar = -c2Bar;
        rhoBar += -c1Bar;
        zetaBar += c1Bar;
        double c0Bar = 0.5 / Math.sqrt(c0) * c1Bar;
        return ValueDerivatives.of((double)zetaOverXhat, (DoubleArray)DoubleArray.of((double)(zetaBar += (-2.0 * rho + 2.0 * zeta) * c0Bar), (double)(rhoBar += -2.0 * zeta * c0Bar)));
    }
}

