/*
 * Decompiled with CFR 0.152.
 */
package com.opengamma.strata.pricer.impl.volatility.local;

import com.opengamma.strata.basics.value.ValueDerivatives;
import com.opengamma.strata.collect.array.DoubleArray;
import com.opengamma.strata.collect.array.DoubleMatrix;
import com.opengamma.strata.collect.tuple.DoublesPair;
import com.opengamma.strata.market.ValueType;
import com.opengamma.strata.market.surface.DefaultSurfaceMetadata;
import com.opengamma.strata.market.surface.DeformedSurface;
import com.opengamma.strata.market.surface.Surface;
import com.opengamma.strata.market.surface.SurfaceMetadata;
import com.opengamma.strata.market.surface.SurfaceName;
import com.opengamma.strata.math.MathUtils;
import com.opengamma.strata.math.impl.differentiation.ScalarFirstOrderDifferentiator;
import com.opengamma.strata.math.impl.differentiation.ScalarSecondOrderDifferentiator;
import com.opengamma.strata.math.impl.differentiation.VectorFieldFirstOrderDifferentiator;
import com.opengamma.strata.math.impl.differentiation.VectorFieldSecondOrderDifferentiator;
import com.opengamma.strata.pricer.impl.volatility.local.LocalVolatilityCalculator;
import java.util.function.Function;

public class DupireLocalVolatilityCalculator
implements LocalVolatilityCalculator {
    private static final double SMALL = 1.0E-10;
    private static final ScalarFirstOrderDifferentiator FIRST_DERIV = new ScalarFirstOrderDifferentiator();
    private static final ScalarSecondOrderDifferentiator SECOND_DERIV = new ScalarSecondOrderDifferentiator();
    private static final VectorFieldFirstOrderDifferentiator FIRST_DERIV_SENSI = new VectorFieldFirstOrderDifferentiator();
    private static final VectorFieldSecondOrderDifferentiator SECOND_DERIV_SENSI = new VectorFieldSecondOrderDifferentiator();

    public DeformedSurface localVolatilityFromImpliedVolatility(final Surface impliedVolatilitySurface, final double spot, final Function<Double, Double> interestRate, final Function<Double, Double> dividendRate) {
        Function<DoublesPair, ValueDerivatives> func = new Function<DoublesPair, ValueDerivatives>(){

            @Override
            public ValueDerivatives apply(DoublesPair x) {
                double localVol;
                double t = x.getFirst();
                double k = x.getSecond();
                double r = (Double)interestRate.apply(t);
                double q = (Double)dividendRate.apply(t);
                double vol = impliedVolatilitySurface.zValue(t, k);
                DoubleArray volSensi = impliedVolatilitySurface.zValueParameterSensitivity(t, k).getSensitivity();
                double divT = (Double)FIRST_DERIV.differentiate(u -> impliedVolatilitySurface.zValue(u.doubleValue(), k)).apply(t);
                DoubleArray divTSensi = ((DoubleMatrix)FIRST_DERIV_SENSI.differentiate(u -> impliedVolatilitySurface.zValueParameterSensitivity(u.get(0), k).getSensitivity()).apply(DoubleArray.of((double)t))).column(0);
                DoubleArray localVolSensi = DoubleArray.of();
                if (k < 1.0E-10) {
                    localVol = Math.sqrt(vol * vol + 2.0 * vol * t * divT);
                    localVolSensi = volSensi.multipliedBy((vol + t * divT) / localVol).plus(divTSensi.multipliedBy(vol * t / localVol));
                } else {
                    double h2;
                    double divK = (Double)FIRST_DERIV.differentiate(l -> impliedVolatilitySurface.zValue(t, l.doubleValue())).apply(k);
                    DoubleArray divKSensi = ((DoubleMatrix)FIRST_DERIV_SENSI.differentiate(l -> impliedVolatilitySurface.zValueParameterSensitivity(t, l.get(0)).getSensitivity()).apply(DoubleArray.of((double)k))).column(0);
                    double divK2 = (Double)SECOND_DERIV.differentiate(l -> impliedVolatilitySurface.zValue(t, l.doubleValue())).apply(k);
                    DoubleArray divK2Sensi = ((DoubleMatrix)SECOND_DERIV_SENSI.differentiateNoCross(l -> impliedVolatilitySurface.zValueParameterSensitivity(t, l.get(0)).getSensitivity()).apply(DoubleArray.of((double)k))).column(0);
                    double rq = r - q;
                    double h1 = (Math.log(spot / k) + (rq + 0.5 * vol * vol) * t) / vol;
                    double den = 1.0 + 2.0 * h1 * k * divK + k * k * (h1 * (h2 = h1 - vol * t) * divK * divK + t * vol * divK2);
                    double var = (vol * vol + 2.0 * vol * t * (divT + k * rq * divK)) / den;
                    if (var < 0.0) {
                        throw new IllegalArgumentException("Negative variance");
                    }
                    localVol = Math.sqrt(var);
                    localVolSensi = volSensi.multipliedBy(localVol * k * h2 * divK * (1.0 + 0.5 * k * h2 * divK) / vol / den + 0.5 * localVol * MathUtils.pow2((double)(k * h1 * divK)) / vol / den + (vol + divT * t + rq * t * k * divK) / (localVol * den) - 0.5 * divK2 * localVol * k * k * t / den).plus(divKSensi.multipliedBy((vol * t * rq * k / localVol - localVol * k * h1 * (1.0 + k * h2 * divK)) / den)).plus(divTSensi.multipliedBy(vol * t / (localVol * den))).plus(divK2Sensi.multipliedBy(-0.5 * vol * localVol * k * k * t / den));
                }
                return ValueDerivatives.of((double)localVol, (DoubleArray)localVolSensi);
            }
        };
        DefaultSurfaceMetadata metadata = DefaultSurfaceMetadata.builder().xValueType(ValueType.YEAR_FRACTION).yValueType(ValueType.STRIKE).zValueType(ValueType.LOCAL_VOLATILITY).surfaceName(SurfaceName.of((String)("localVol_" + impliedVolatilitySurface.getName()))).build();
        return DeformedSurface.of((SurfaceMetadata)metadata, (Surface)impliedVolatilitySurface, (Function)func);
    }

    public DeformedSurface localVolatilityFromPrice(final Surface callPriceSurface, double spot, final Function<Double, Double> interestRate, final Function<Double, Double> dividendRate) {
        Function<DoublesPair, ValueDerivatives> func = new Function<DoublesPair, ValueDerivatives>(){

            @Override
            public ValueDerivatives apply(DoublesPair x) {
                double t = x.getFirst();
                double k = x.getSecond();
                double r = (Double)interestRate.apply(t);
                double q = (Double)dividendRate.apply(t);
                double price = callPriceSurface.zValue(t, k);
                DoubleArray priceSensi = callPriceSurface.zValueParameterSensitivity(t, k).getSensitivity();
                double divT = (Double)FIRST_DERIV.differentiate(u -> callPriceSurface.zValue(u.doubleValue(), k)).apply(t);
                DoubleArray divTSensi = ((DoubleMatrix)FIRST_DERIV_SENSI.differentiate(u -> callPriceSurface.zValueParameterSensitivity(u.get(0), k).getSensitivity()).apply(DoubleArray.of((double)t))).column(0);
                double divK = (Double)FIRST_DERIV.differentiate(l -> callPriceSurface.zValue(t, l.doubleValue())).apply(k);
                DoubleArray divKSensi = ((DoubleMatrix)FIRST_DERIV_SENSI.differentiate(l -> callPriceSurface.zValueParameterSensitivity(t, l.get(0)).getSensitivity()).apply(DoubleArray.of((double)k))).column(0);
                double divK2 = (Double)SECOND_DERIV.differentiate(l -> callPriceSurface.zValue(t, l.doubleValue())).apply(k);
                DoubleArray divK2Sensi = ((DoubleMatrix)SECOND_DERIV_SENSI.differentiateNoCross(l -> callPriceSurface.zValueParameterSensitivity(t, l.get(0)).getSensitivity()).apply(DoubleArray.of((double)k))).column(0);
                double var = 2.0 * (divT + q * price + (r - q) * k * divK) / (k * k * divK2);
                if (var < 0.0) {
                    throw new IllegalArgumentException("Negative variance");
                }
                double localVol = Math.sqrt(var);
                double factor = 1.0 / (localVol * k * k * divK2);
                DoubleArray localVolSensi = divTSensi.multipliedBy(factor).plus(divKSensi.multipliedBy((r - q) * k * factor)).plus(priceSensi.multipliedBy(q * factor)).plus(divK2Sensi.multipliedBy(-0.5 * localVol / divK2));
                return ValueDerivatives.of((double)localVol, (DoubleArray)localVolSensi);
            }
        };
        DefaultSurfaceMetadata metadata = DefaultSurfaceMetadata.builder().xValueType(ValueType.YEAR_FRACTION).yValueType(ValueType.STRIKE).zValueType(ValueType.LOCAL_VOLATILITY).surfaceName(SurfaceName.of((String)("localVol_" + callPriceSurface.getName()))).build();
        return DeformedSurface.of((SurfaceMetadata)metadata, (Surface)callPriceSurface, (Function)func);
    }
}

