/*
 * Decompiled with CFR 0.152.
 */
package com.opengamma.strata.pricer.capfloor;

import com.google.common.collect.ImmutableList;
import com.opengamma.strata.basics.ReferenceData;
import com.opengamma.strata.basics.currency.Currency;
import com.opengamma.strata.basics.index.IborIndex;
import com.opengamma.strata.collect.ArgChecker;
import com.opengamma.strata.collect.array.DoubleArray;
import com.opengamma.strata.collect.array.DoubleMatrix;
import com.opengamma.strata.data.MarketDataName;
import com.opengamma.strata.market.ValueType;
import com.opengamma.strata.market.curve.Curve;
import com.opengamma.strata.market.curve.CurveMetadata;
import com.opengamma.strata.market.curve.InterpolatedNodalCurve;
import com.opengamma.strata.market.curve.interpolator.CurveExtrapolator;
import com.opengamma.strata.market.curve.interpolator.CurveInterpolator;
import com.opengamma.strata.market.param.CurrencyParameterSensitivities;
import com.opengamma.strata.market.sensitivity.PointSensitivities;
import com.opengamma.strata.market.surface.Surface;
import com.opengamma.strata.market.surface.SurfaceMetadata;
import com.opengamma.strata.math.impl.linearalgebra.DecompositionFactory;
import com.opengamma.strata.math.impl.matrix.MatrixAlgebra;
import com.opengamma.strata.math.impl.matrix.MatrixAlgebraFactory;
import com.opengamma.strata.math.impl.minimization.DoubleRangeLimitTransform;
import com.opengamma.strata.math.impl.minimization.NonLinearParameterTransforms;
import com.opengamma.strata.math.impl.minimization.NonLinearTransformFunction;
import com.opengamma.strata.math.impl.minimization.ParameterLimitsTransform;
import com.opengamma.strata.math.impl.minimization.SingleRangeLimitTransform;
import com.opengamma.strata.math.impl.minimization.UncoupledParameterTransforms;
import com.opengamma.strata.math.impl.statistics.leastsquare.LeastSquareResults;
import com.opengamma.strata.math.impl.statistics.leastsquare.LeastSquareResultsWithTransform;
import com.opengamma.strata.math.impl.statistics.leastsquare.NonLinearLeastSquare;
import com.opengamma.strata.pricer.capfloor.IborCapletFloorletVolatilities;
import com.opengamma.strata.pricer.capfloor.IborCapletFloorletVolatilityCalibrationResult;
import com.opengamma.strata.pricer.capfloor.IborCapletFloorletVolatilityCalibrator;
import com.opengamma.strata.pricer.capfloor.IborCapletFloorletVolatilityDefinition;
import com.opengamma.strata.pricer.capfloor.SabrIborCapletFloorletPeriodPricer;
import com.opengamma.strata.pricer.capfloor.SabrIborCapletFloorletVolatilityBootstrapDefinition;
import com.opengamma.strata.pricer.capfloor.SabrParametersIborCapletFloorletVolatilities;
import com.opengamma.strata.pricer.capfloor.VolatilityIborCapFloorLegPricer;
import com.opengamma.strata.pricer.model.SabrParameters;
import com.opengamma.strata.pricer.option.RawOptionData;
import com.opengamma.strata.pricer.rate.RatesProvider;
import com.opengamma.strata.product.capfloor.IborCapletFloorletPeriod;
import com.opengamma.strata.product.capfloor.ResolvedIborCapFloorLeg;
import java.time.LocalDate;
import java.time.Period;
import java.time.ZonedDateTime;
import java.time.temporal.TemporalAmount;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;

public class SabrIborCapletFloorletVolatilityBootstrapper
extends IborCapletFloorletVolatilityCalibrator {
    public static final SabrIborCapletFloorletVolatilityBootstrapper DEFAULT = SabrIborCapletFloorletVolatilityBootstrapper.of(VolatilityIborCapFloorLegPricer.DEFAULT, SabrIborCapletFloorletPeriodPricer.DEFAULT, 1.0E-10, ReferenceData.standard());
    private static final ParameterLimitsTransform[] TRANSFORMS = new ParameterLimitsTransform[4];
    private static final double RHO_LIMIT = 0.999;
    private final NonLinearLeastSquare solver;
    private final SabrIborCapletFloorletPeriodPricer sabrPeriodPricer;

    public static SabrIborCapletFloorletVolatilityBootstrapper of(VolatilityIborCapFloorLegPricer pricer, SabrIborCapletFloorletPeriodPricer sabrPeriodPricer, double epsilon, ReferenceData referenceData) {
        NonLinearLeastSquare solver = new NonLinearLeastSquare(DecompositionFactory.SV_COMMONS, (MatrixAlgebra)MatrixAlgebraFactory.OG_ALGEBRA, epsilon);
        return new SabrIborCapletFloorletVolatilityBootstrapper(pricer, sabrPeriodPricer, solver, referenceData);
    }

    private SabrIborCapletFloorletVolatilityBootstrapper(VolatilityIborCapFloorLegPricer pricer, SabrIborCapletFloorletPeriodPricer sabrPeriodPricer, NonLinearLeastSquare solver, ReferenceData referenceData) {
        super(pricer, referenceData);
        this.sabrPeriodPricer = (SabrIborCapletFloorletPeriodPricer)ArgChecker.notNull((Object)sabrPeriodPricer, (String)"sabrPeriodPricer");
        this.solver = (NonLinearLeastSquare)ArgChecker.notNull((Object)solver, (String)"solver");
    }

    @Override
    public IborCapletFloorletVolatilityCalibrationResult calibrate(IborCapletFloorletVolatilityDefinition definition, ZonedDateTime calibrationDateTime, RawOptionData capFloorData, RatesProvider ratesProvider) {
        Curve rhoCurve;
        Curve betaCurve;
        ArgChecker.isTrue((boolean)ratesProvider.getValuationDate().equals(calibrationDateTime.toLocalDate()), (String)"valuationDate of ratesProvider should be coherent to calibrationDateTime");
        ArgChecker.isTrue((boolean)(definition instanceof SabrIborCapletFloorletVolatilityBootstrapDefinition), (String)"definition should be SabrIborCapletFloorletVolatilityBootstrapDefinition");
        SabrIborCapletFloorletVolatilityBootstrapDefinition bsDefinition = (SabrIborCapletFloorletVolatilityBootstrapDefinition)definition;
        IborIndex index = bsDefinition.getIndex();
        LocalDate calibrationDate = calibrationDateTime.toLocalDate();
        LocalDate baseDate = index.getEffectiveDateOffset().adjust(calibrationDate, this.getReferenceData());
        LocalDate startDate = baseDate.plus((TemporalAmount)index.getTenor());
        Function<Surface, IborCapletFloorletVolatilities> volatilitiesFunction = this.volatilitiesFunction(bsDefinition, calibrationDateTime, capFloorData);
        SurfaceMetadata metaData = bsDefinition.createMetadata(capFloorData);
        ImmutableList<Period> expiries = capFloorData.getExpiries();
        int nExpiries = expiries.size();
        DoubleArray strikes = capFloorData.getStrikes();
        DoubleMatrix errorsMatrix = capFloorData.getError().orElse(DoubleMatrix.filled((int)nExpiries, (int)strikes.size(), (double)1.0));
        ArrayList<Double> timeList = new ArrayList<Double>();
        ArrayList<Double> strikeList = new ArrayList<Double>();
        ArrayList<Double> volList = new ArrayList<Double>();
        ArrayList<ResolvedIborCapFloorLeg> capList = new ArrayList<ResolvedIborCapFloorLeg>();
        ArrayList<Double> priceList = new ArrayList<Double>();
        ArrayList<Double> errorList = new ArrayList<Double>();
        int[] startIndex = new int[nExpiries + 1];
        for (int i2 = 0; i2 < nExpiries; ++i2) {
            LocalDate endDate = baseDate.plus((TemporalAmount)expiries.get(i2));
            DoubleArray volatilityData = capFloorData.getData().row(i2);
            DoubleArray errors = errorsMatrix.row(i2);
            this.reduceRawData(bsDefinition, ratesProvider, strikes, volatilityData, errors, startDate, endDate, metaData, volatilitiesFunction, timeList, strikeList, volList, capList, priceList, errorList);
            startIndex[i2 + 1] = volList.size();
            ArgChecker.isTrue((startIndex[i2 + 1] > startIndex[i2] ? 1 : 0) != 0, (String)"no valid option data for {}", (Object[])new Object[]{expiries.get(i2)});
        }
        ImmutableList<CurveMetadata> metadataList = bsDefinition.createSabrParameterMetadata();
        DoubleArray timeToExpiries = DoubleArray.of((int)nExpiries, i -> (Double)timeList.get(startIndex[i]));
        BitSet fixed = new BitSet();
        boolean betaFix = false;
        if (bsDefinition.getBetaCurve().isPresent()) {
            betaFix = true;
            fixed.set(1);
            betaCurve = bsDefinition.getBetaCurve().get();
            rhoCurve = InterpolatedNodalCurve.of((CurveMetadata)((CurveMetadata)metadataList.get(2)), (DoubleArray)timeToExpiries, (DoubleArray)DoubleArray.filled((int)nExpiries), (CurveInterpolator)bsDefinition.getInterpolator(), (CurveExtrapolator)bsDefinition.getExtrapolatorLeft(), (CurveExtrapolator)bsDefinition.getExtrapolatorRight());
        } else {
            fixed.set(2);
            betaCurve = InterpolatedNodalCurve.of((CurveMetadata)((CurveMetadata)metadataList.get(1)), (DoubleArray)timeToExpiries, (DoubleArray)DoubleArray.filled((int)nExpiries), (CurveInterpolator)bsDefinition.getInterpolator(), (CurveExtrapolator)bsDefinition.getExtrapolatorLeft(), (CurveExtrapolator)bsDefinition.getExtrapolatorRight());
            rhoCurve = bsDefinition.getRhoCurve().get();
        }
        InterpolatedNodalCurve alphaCurve = InterpolatedNodalCurve.of((CurveMetadata)((CurveMetadata)metadataList.get(0)), (DoubleArray)timeToExpiries, (DoubleArray)DoubleArray.filled((int)nExpiries), (CurveInterpolator)bsDefinition.getInterpolator(), (CurveExtrapolator)bsDefinition.getExtrapolatorLeft(), (CurveExtrapolator)bsDefinition.getExtrapolatorRight());
        InterpolatedNodalCurve nuCurve = InterpolatedNodalCurve.of((CurveMetadata)((CurveMetadata)metadataList.get(3)), (DoubleArray)timeToExpiries, (DoubleArray)DoubleArray.filled((int)nExpiries), (CurveInterpolator)bsDefinition.getInterpolator(), (CurveExtrapolator)bsDefinition.getExtrapolatorLeft(), (CurveExtrapolator)bsDefinition.getExtrapolatorRight());
        Curve shiftCurve = bsDefinition.getShiftCurve();
        SabrParameters sabrParams = SabrParameters.of((Curve)alphaCurve, betaCurve, rhoCurve, (Curve)nuCurve, shiftCurve, bsDefinition.getSabrVolatilityFormula());
        SabrParametersIborCapletFloorletVolatilities vols = SabrParametersIborCapletFloorletVolatilities.of(bsDefinition.getName(), index, calibrationDateTime, sabrParams);
        double totalChiSq = 0.0;
        ZonedDateTime prevExpiry = calibrationDateTime.minusDays(1L);
        for (int i3 = 0; i3 < nExpiries; ++i3) {
            DoubleArray start = this.computeInitialValues(ratesProvider, betaCurve, shiftCurve, timeList, volList, capList, startIndex, i3, betaFix, capFloorData.getDataType());
            UncoupledParameterTransforms transform = new UncoupledParameterTransforms(start, TRANSFORMS, fixed);
            int nCaplets = startIndex[i3 + 1] - startIndex[i3];
            int currentStart = startIndex[i3];
            Function<DoubleArray, DoubleArray> valueFunction = this.createPriceFunction(ratesProvider, vols, prevExpiry, capList, priceList, startIndex, nExpiries, i3, nCaplets, betaFix);
            Function<DoubleArray, DoubleMatrix> jacobianFunction = this.createJacobianFunction(ratesProvider, vols, prevExpiry, capList, priceList, index.getCurrency(), startIndex, nExpiries, i3, nCaplets, betaFix);
            NonLinearTransformFunction transFunc = new NonLinearTransformFunction(valueFunction, jacobianFunction, (NonLinearParameterTransforms)transform);
            DoubleArray adjustedPrices = this.adjustedPrices(ratesProvider, vols, prevExpiry, capList, priceList, startIndex, i3, nCaplets);
            DoubleArray errors = DoubleArray.of((int)nCaplets, n -> (Double)errorList.get(currentStart + n));
            LeastSquareResults res = this.solver.solve(adjustedPrices, errors, transFunc.getFittingFunction(), transFunc.getFittingJacobian(), transform.transform(start));
            LeastSquareResultsWithTransform resTransform = new LeastSquareResultsWithTransform(res, (NonLinearParameterTransforms)transform);
            vols = this.updateParameters(vols, nExpiries, i3, betaFix, resTransform.getModelParameters());
            totalChiSq += res.getChiSq();
            prevExpiry = ((ResolvedIborCapFloorLeg)capList.get(startIndex[i3 + 1] - 1)).getFinalFixingDateTime();
        }
        return IborCapletFloorletVolatilityCalibrationResult.ofLeastSquare(vols, totalChiSq);
    }

    private DoubleArray computeInitialValues(RatesProvider ratesProvider, Curve betaCurve, Curve shiftCurve, List<Double> timeList, List<Double> volList, List<ResolvedIborCapFloorLeg> capList, int[] startIndex, int postion, boolean betaFixed, ValueType valueType) {
        double nuFirst;
        List<Double> vols = volList.subList(startIndex[postion], startIndex[postion + 1]);
        ResolvedIborCapFloorLeg cap = capList.get(startIndex[postion]);
        double fwd = ratesProvider.iborIndexRates(cap.getIndex()).rate(cap.getFinalPeriod().getIborRate().getObservation());
        double shift = shiftCurve.yValue(timeList.get(startIndex[postion]).doubleValue());
        double factor = valueType.equals((Object)ValueType.BLACK_VOLATILITY) ? 1.0 : 1.0 / (fwd + shift);
        List volsEquiv = vols.stream().map(v -> v * factor).collect(Collectors.toList());
        double betaInitial = betaFixed ? betaCurve.yValue(timeList.get(startIndex[postion]).doubleValue()) : 0.5;
        double alphaInitial = DoubleArray.copyOf(volsEquiv).min() * Math.pow(fwd, 1.0 - betaInitial);
        if (alphaInitial == (Double)volsEquiv.get(0) || alphaInitial == (Double)volsEquiv.get(volsEquiv.size() - 1)) {
            nuFirst = 0.1;
            alphaInitial *= 0.95;
        } else {
            nuFirst = 1.0;
        }
        return DoubleArray.of((double)alphaInitial, (double)betaInitial, (double)(-0.5 * betaInitial + 0.5 * (1.0 - betaInitial)), (double)nuFirst);
    }

    private Function<DoubleArray, DoubleArray> createPriceFunction(final RatesProvider ratesProvider, final SabrParametersIborCapletFloorletVolatilities volatilities, final ZonedDateTime prevExpiry, final List<ResolvedIborCapFloorLeg> capList, final List<Double> priceList, int[] startIndex, final int nExpiries, final int timeIndex, final int nCaplets, final boolean betaFixed) {
        final int currentStart = startIndex[timeIndex];
        Function<DoubleArray, DoubleArray> priceFunction = new Function<DoubleArray, DoubleArray>(){

            @Override
            public DoubleArray apply(DoubleArray x) {
                SabrParametersIborCapletFloorletVolatilities volsNew = SabrIborCapletFloorletVolatilityBootstrapper.this.updateParameters(volatilities, nExpiries, timeIndex, betaFixed, x);
                return DoubleArray.of((int)nCaplets, n -> ((ResolvedIborCapFloorLeg)capList.get(currentStart + n)).getCapletFloorletPeriods().stream().filter(p -> p.getFixingDateTime().isAfter(prevExpiry)).mapToDouble(p -> SabrIborCapletFloorletVolatilityBootstrapper.this.sabrPeriodPricer.presentValue((IborCapletFloorletPeriod)p, ratesProvider, volsNew).getAmount()).sum() / (Double)priceList.get(currentStart + n));
            }
        };
        return priceFunction;
    }

    private Function<DoubleArray, DoubleMatrix> createJacobianFunction(final RatesProvider ratesProvider, final SabrParametersIborCapletFloorletVolatilities volatilities, final ZonedDateTime prevExpiry, final List<ResolvedIborCapFloorLeg> capList, final List<Double> priceList, final Currency currency, int[] startIndex, final int nExpiries, final int timeIndex, final int nCaplets, final boolean betaFixed) {
        final Curve alphaCurve = volatilities.getParameters().getAlphaCurve();
        final Curve betaCurve = volatilities.getParameters().getBetaCurve();
        final Curve rhoCurve = volatilities.getParameters().getRhoCurve();
        final Curve nuCurve = volatilities.getParameters().getNuCurve();
        final int currentStart = startIndex[timeIndex];
        Function<DoubleArray, DoubleMatrix> jacobianFunction = new Function<DoubleArray, DoubleMatrix>(){

            @Override
            public DoubleMatrix apply(DoubleArray x) {
                SabrParametersIborCapletFloorletVolatilities volsNew = SabrIborCapletFloorletVolatilityBootstrapper.this.updateParameters(volatilities, nExpiries, timeIndex, betaFixed, x);
                double[][] jacobian = new double[nCaplets][4];
                for (int i = 0; i < nCaplets; ++i) {
                    PointSensitivities point = ((ResolvedIborCapFloorLeg)capList.get(currentStart + i)).getCapletFloorletPeriods().stream().filter(p -> p.getFixingDateTime().isAfter(prevExpiry)).map(p -> SabrIborCapletFloorletVolatilityBootstrapper.this.sabrPeriodPricer.presentValueSensitivityModelParamsSabr((IborCapletFloorletPeriod)p, ratesProvider, volsNew)).reduce((c1, c2) -> c1.combinedWith(c2)).get().build();
                    double targetPrice = (Double)priceList.get(currentStart + i);
                    CurrencyParameterSensitivities sensi = volsNew.parameterSensitivity(point);
                    jacobian[i][0] = sensi.getSensitivity((MarketDataName)alphaCurve.getName(), currency).getSensitivity().get(timeIndex) / targetPrice;
                    if (betaFixed) {
                        jacobian[i][1] = 0.0;
                        jacobian[i][2] = sensi.getSensitivity((MarketDataName)rhoCurve.getName(), currency).getSensitivity().get(timeIndex) / targetPrice;
                    } else {
                        jacobian[i][1] = sensi.getSensitivity((MarketDataName)betaCurve.getName(), currency).getSensitivity().get(timeIndex) / targetPrice;
                        jacobian[i][2] = 0.0;
                    }
                    jacobian[i][3] = sensi.getSensitivity((MarketDataName)nuCurve.getName(), currency).getSensitivity().get(timeIndex) / targetPrice;
                }
                return DoubleMatrix.ofUnsafe((double[][])jacobian);
            }
        };
        return jacobianFunction;
    }

    private SabrParametersIborCapletFloorletVolatilities updateParameters(SabrParametersIborCapletFloorletVolatilities volatilities, int nExpiries, int timeIndex, boolean betaFixed, DoubleArray newParameters) {
        int nBetaParams = volatilities.getParameters().getBetaCurve().getParameterCount();
        int nRhoParams = volatilities.getParameters().getRhoCurve().getParameterCount();
        SabrParametersIborCapletFloorletVolatilities newVols = volatilities.withParameter(timeIndex, newParameters.get(0)).withParameter(timeIndex + nExpiries + nBetaParams + nRhoParams, newParameters.get(3));
        if (betaFixed) {
            newVols = newVols.withParameter(timeIndex + nExpiries + nBetaParams, newParameters.get(2));
            return newVols;
        }
        newVols = newVols.withParameter(timeIndex + nExpiries, newParameters.get(1));
        return newVols;
    }

    private DoubleArray adjustedPrices(RatesProvider ratesProvider, IborCapletFloorletVolatilities vols, ZonedDateTime prevExpiry, List<ResolvedIborCapFloorLeg> capList, List<Double> priceList, int[] startIndex, int timeIndex, int nCaplets) {
        if (timeIndex == 0) {
            return DoubleArray.filled((int)nCaplets, (double)1.0);
        }
        int currentStart = startIndex[timeIndex];
        return DoubleArray.of((int)nCaplets, n -> ((Double)priceList.get(currentStart + n) - ((ResolvedIborCapFloorLeg)capList.get(currentStart + n)).getCapletFloorletPeriods().stream().filter(p -> !p.getFixingDateTime().isAfter(prevExpiry)).mapToDouble(p -> this.sabrPeriodPricer.presentValue((IborCapletFloorletPeriod)p, ratesProvider, vols).getAmount()).sum()) / (Double)priceList.get(currentStart + n));
    }

    static {
        SabrIborCapletFloorletVolatilityBootstrapper.TRANSFORMS[0] = new SingleRangeLimitTransform(0.0, ParameterLimitsTransform.LimitType.GREATER_THAN);
        SabrIborCapletFloorletVolatilityBootstrapper.TRANSFORMS[1] = new DoubleRangeLimitTransform(0.0, 1.0);
        SabrIborCapletFloorletVolatilityBootstrapper.TRANSFORMS[2] = new DoubleRangeLimitTransform(-0.999, 0.999);
        SabrIborCapletFloorletVolatilityBootstrapper.TRANSFORMS[3] = new DoubleRangeLimitTransform(0.001, 2.5);
    }
}

