/*
 * Decompiled with CFR 0.152.
 */
package com.opengamma.strata.pricer.capfloor;

import com.google.common.collect.ImmutableList;
import com.opengamma.strata.basics.ReferenceData;
import com.opengamma.strata.basics.currency.Currency;
import com.opengamma.strata.basics.index.IborIndex;
import com.opengamma.strata.collect.ArgChecker;
import com.opengamma.strata.collect.array.DoubleArray;
import com.opengamma.strata.collect.array.DoubleMatrix;
import com.opengamma.strata.data.MarketDataName;
import com.opengamma.strata.market.curve.Curve;
import com.opengamma.strata.market.curve.CurveMetadata;
import com.opengamma.strata.market.curve.CurveName;
import com.opengamma.strata.market.param.CurrencyParameterSensitivities;
import com.opengamma.strata.market.sensitivity.PointSensitivities;
import com.opengamma.strata.market.surface.Surface;
import com.opengamma.strata.market.surface.SurfaceMetadata;
import com.opengamma.strata.math.impl.linearalgebra.DecompositionFactory;
import com.opengamma.strata.math.impl.matrix.MatrixAlgebra;
import com.opengamma.strata.math.impl.matrix.MatrixAlgebraFactory;
import com.opengamma.strata.math.impl.minimization.DoubleRangeLimitTransform;
import com.opengamma.strata.math.impl.minimization.NonLinearParameterTransforms;
import com.opengamma.strata.math.impl.minimization.NonLinearTransformFunction;
import com.opengamma.strata.math.impl.minimization.ParameterLimitsTransform;
import com.opengamma.strata.math.impl.minimization.SingleRangeLimitTransform;
import com.opengamma.strata.math.impl.minimization.UncoupledParameterTransforms;
import com.opengamma.strata.math.impl.statistics.leastsquare.LeastSquareResults;
import com.opengamma.strata.math.impl.statistics.leastsquare.LeastSquareResultsWithTransform;
import com.opengamma.strata.math.impl.statistics.leastsquare.NonLinearLeastSquare;
import com.opengamma.strata.pricer.capfloor.IborCapletFloorletVolatilities;
import com.opengamma.strata.pricer.capfloor.IborCapletFloorletVolatilityCalibrationResult;
import com.opengamma.strata.pricer.capfloor.IborCapletFloorletVolatilityCalibrator;
import com.opengamma.strata.pricer.capfloor.IborCapletFloorletVolatilityDefinition;
import com.opengamma.strata.pricer.capfloor.SabrIborCapFloorLegPricer;
import com.opengamma.strata.pricer.capfloor.SabrIborCapletFloorletVolatilityCalibrationDefinition;
import com.opengamma.strata.pricer.capfloor.SabrParametersIborCapletFloorletVolatilities;
import com.opengamma.strata.pricer.capfloor.VolatilityIborCapFloorLegPricer;
import com.opengamma.strata.pricer.model.SabrParameters;
import com.opengamma.strata.pricer.option.RawOptionData;
import com.opengamma.strata.pricer.rate.RatesProvider;
import com.opengamma.strata.product.capfloor.ResolvedIborCapFloorLeg;
import java.time.LocalDate;
import java.time.Period;
import java.time.ZonedDateTime;
import java.time.temporal.TemporalAmount;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;
import java.util.function.Function;

public class SabrIborCapletFloorletVolatilityCalibrator
extends IborCapletFloorletVolatilityCalibrator {
    public static final SabrIborCapletFloorletVolatilityCalibrator DEFAULT = SabrIborCapletFloorletVolatilityCalibrator.of(VolatilityIborCapFloorLegPricer.DEFAULT, SabrIborCapFloorLegPricer.DEFAULT, 1.0E-10, ReferenceData.standard());
    private static final ParameterLimitsTransform[] TRANSFORMS = new ParameterLimitsTransform[4];
    private static final double RHO_LIMIT = 0.999;
    private final NonLinearLeastSquare solver;
    private final SabrIborCapFloorLegPricer sabrPricer;

    public static SabrIborCapletFloorletVolatilityCalibrator of(VolatilityIborCapFloorLegPricer pricer, SabrIborCapFloorLegPricer sabrPricer, double epsilon, ReferenceData referenceData) {
        NonLinearLeastSquare solver = new NonLinearLeastSquare(DecompositionFactory.SV_COMMONS, (MatrixAlgebra)MatrixAlgebraFactory.OG_ALGEBRA, epsilon);
        return new SabrIborCapletFloorletVolatilityCalibrator(pricer, sabrPricer, solver, referenceData);
    }

    private SabrIborCapletFloorletVolatilityCalibrator(VolatilityIborCapFloorLegPricer pricer, SabrIborCapFloorLegPricer sabrPricer, NonLinearLeastSquare solver, ReferenceData referenceData) {
        super(pricer, referenceData);
        this.sabrPricer = (SabrIborCapFloorLegPricer)ArgChecker.notNull((Object)sabrPricer, (String)"sabrPricer");
        this.solver = (NonLinearLeastSquare)ArgChecker.notNull((Object)solver, (String)"solver");
    }

    @Override
    public IborCapletFloorletVolatilityCalibrationResult calibrate(IborCapletFloorletVolatilityDefinition definition, ZonedDateTime calibrationDateTime, RawOptionData capFloorData, RatesProvider ratesProvider) {
        ArgChecker.isTrue((boolean)ratesProvider.getValuationDate().equals(calibrationDateTime.toLocalDate()), (String)"valuationDate of ratesProvider should be coherent to calibrationDateTime");
        ArgChecker.isTrue((boolean)(definition instanceof SabrIborCapletFloorletVolatilityCalibrationDefinition), (String)"definition should be SabrIborCapletFloorletVolatilityCalibrationDefinition");
        SabrIborCapletFloorletVolatilityCalibrationDefinition sabrDefinition = (SabrIborCapletFloorletVolatilityCalibrationDefinition)definition;
        IborIndex index = sabrDefinition.getIndex();
        LocalDate calibrationDate = calibrationDateTime.toLocalDate();
        LocalDate baseDate = index.getEffectiveDateOffset().adjust(calibrationDate, this.getReferenceData());
        LocalDate startDate = baseDate.plus((TemporalAmount)index.getTenor());
        Function<Surface, IborCapletFloorletVolatilities> volatilitiesFunction = this.volatilitiesFunction(sabrDefinition, calibrationDateTime, capFloorData);
        SurfaceMetadata metadata = sabrDefinition.createMetadata(capFloorData);
        ImmutableList<Period> expiries = capFloorData.getExpiries();
        DoubleArray strikes = capFloorData.getStrikes();
        int nExpiries = expiries.size();
        ArrayList<Double> timeList = new ArrayList<Double>();
        ArrayList<Double> strikeList = new ArrayList<Double>();
        ArrayList<Double> volList = new ArrayList<Double>();
        ArrayList<ResolvedIborCapFloorLeg> capList = new ArrayList<ResolvedIborCapFloorLeg>();
        ArrayList<Double> priceList = new ArrayList<Double>();
        ArrayList<Double> errorList = new ArrayList<Double>();
        DoubleMatrix errorMatrix = capFloorData.getError().orElse(DoubleMatrix.filled((int)nExpiries, (int)strikes.size(), (double)1.0));
        int[] startIndex = new int[nExpiries + 1];
        for (int i = 0; i < nExpiries; ++i) {
            LocalDate endDate = baseDate.plus((TemporalAmount)expiries.get(i));
            DoubleArray volatilityForTime = capFloorData.getData().row(i);
            DoubleArray errorForTime = errorMatrix.row(i);
            this.reduceRawData(sabrDefinition, ratesProvider, capFloorData.getStrikes(), volatilityForTime, errorForTime, startDate, endDate, metadata, volatilitiesFunction, timeList, strikeList, volList, capList, priceList, errorList);
            startIndex[i + 1] = volList.size();
            ArgChecker.isTrue((startIndex[i + 1] > startIndex[i] ? 1 : 0) != 0, (String)"no valid option data for {}", (Object[])new Object[]{expiries.get(i)});
        }
        ImmutableList<CurveMetadata> metadataList = sabrDefinition.createSabrParameterMetadata();
        DoubleArray initialValues = sabrDefinition.createFullInitialValues();
        List<Curve> curveList = sabrDefinition.createSabrParameterCurve((List<CurveMetadata>)metadataList, initialValues);
        SabrParameters sabrParamsInitial = SabrParameters.of(curveList.get(0), curveList.get(1), curveList.get(2), curveList.get(3), sabrDefinition.getShiftCurve(), sabrDefinition.getSabrVolatilityFormula());
        SabrParametersIborCapletFloorletVolatilities vols = SabrParametersIborCapletFloorletVolatilities.of(sabrDefinition.getName(), index, calibrationDateTime, sabrParamsInitial);
        UncoupledParameterTransforms transform = new UncoupledParameterTransforms(initialValues, sabrDefinition.createFullTransform(TRANSFORMS), new BitSet());
        Function<DoubleArray, DoubleArray> valueFunction = this.createPriceFunction(sabrDefinition, ratesProvider, vols, capList, priceList);
        Function<DoubleArray, DoubleMatrix> jacobianFunction = this.createJacobianFunction(sabrDefinition, ratesProvider, vols, capList, priceList, index.getCurrency());
        NonLinearTransformFunction transFunc = new NonLinearTransformFunction(valueFunction, jacobianFunction, (NonLinearParameterTransforms)transform);
        LeastSquareResults res = this.solver.solve(DoubleArray.filled((int)priceList.size(), (double)1.0), DoubleArray.copyOf(errorList), transFunc.getFittingFunction(), transFunc.getFittingJacobian(), transform.transform(initialValues));
        LeastSquareResultsWithTransform resTransform = new LeastSquareResultsWithTransform(res, (NonLinearParameterTransforms)transform);
        vols = this.updateParameters(sabrDefinition, vols, resTransform.getModelParameters());
        return IborCapletFloorletVolatilityCalibrationResult.ofLeastSquare(vols, res.getChiSq());
    }

    private Function<DoubleArray, DoubleArray> createPriceFunction(final SabrIborCapletFloorletVolatilityCalibrationDefinition sabrDefinition, final RatesProvider ratesProvider, final SabrParametersIborCapletFloorletVolatilities volatilities, final List<ResolvedIborCapFloorLeg> capList, final List<Double> priceList) {
        Function<DoubleArray, DoubleArray> priceFunction = new Function<DoubleArray, DoubleArray>(){

            @Override
            public DoubleArray apply(DoubleArray x) {
                SabrParametersIborCapletFloorletVolatilities volsNew = SabrIborCapletFloorletVolatilityCalibrator.this.updateParameters(sabrDefinition, volatilities, x);
                return DoubleArray.of((int)capList.size(), n -> SabrIborCapletFloorletVolatilityCalibrator.this.sabrPricer.presentValue((ResolvedIborCapFloorLeg)capList.get(n), ratesProvider, volsNew).getAmount() / (Double)priceList.get(n));
            }
        };
        return priceFunction;
    }

    private Function<DoubleArray, DoubleMatrix> createJacobianFunction(final SabrIborCapletFloorletVolatilityCalibrationDefinition sabrDefinition, final RatesProvider ratesProvider, final SabrParametersIborCapletFloorletVolatilities volatilities, final List<ResolvedIborCapFloorLeg> capList, final List<Double> priceList, final Currency currency) {
        final int nCaps = capList.size();
        SabrParameters sabrParams = volatilities.getParameters();
        final CurveName alphaName = sabrParams.getAlphaCurve().getName();
        final CurveName betaName = sabrParams.getBetaCurve().getName();
        final CurveName rhoName = sabrParams.getRhoCurve().getName();
        final CurveName nuName = sabrParams.getNuCurve().getName();
        Function<DoubleArray, DoubleMatrix> jacobianFunction = new Function<DoubleArray, DoubleMatrix>(){

            @Override
            public DoubleMatrix apply(DoubleArray x) {
                SabrParametersIborCapletFloorletVolatilities volsNew = SabrIborCapletFloorletVolatilityCalibrator.this.updateParameters(sabrDefinition, volatilities, x);
                double[][] jacobian = new double[nCaps][];
                for (int i = 0; i < nCaps; ++i) {
                    PointSensitivities point = SabrIborCapletFloorletVolatilityCalibrator.this.sabrPricer.presentValueSensitivityModelParamsSabr((ResolvedIborCapFloorLeg)capList.get(i), ratesProvider, volsNew).build();
                    CurrencyParameterSensitivities sensi = volsNew.parameterSensitivity(point);
                    double targetPriceInv = 1.0 / (Double)priceList.get(i);
                    DoubleArray sensitivities = sensi.getSensitivity((MarketDataName)alphaName, currency).getSensitivity();
                    sensitivities = sabrDefinition.getBetaCurve().isPresent() ? sensitivities.concat(sensi.getSensitivity((MarketDataName)rhoName, currency).getSensitivity()) : sensitivities.concat(sensi.getSensitivity((MarketDataName)betaName, currency).getSensitivity());
                    jacobian[i] = sensitivities.concat(sensi.getSensitivity((MarketDataName)nuName, currency).getSensitivity()).multipliedBy(targetPriceInv).toArray();
                }
                return DoubleMatrix.ofUnsafe((double[][])jacobian);
            }
        };
        return jacobianFunction;
    }

    private SabrParametersIborCapletFloorletVolatilities updateParameters(SabrIborCapletFloorletVolatilityCalibrationDefinition sabrDefinition, SabrParametersIborCapletFloorletVolatilities volatilities, DoubleArray newValues) {
        SabrParameters sabrParams = volatilities.getParameters();
        CurveMetadata alphaMetadata = sabrParams.getAlphaCurve().getMetadata();
        CurveMetadata betaMetadata = sabrParams.getBetaCurve().getMetadata();
        CurveMetadata rhoMetadata = sabrParams.getRhoCurve().getMetadata();
        CurveMetadata nuMetadata = sabrParams.getNuCurve().getMetadata();
        List<Curve> newCurveList = sabrDefinition.createSabrParameterCurve((List<CurveMetadata>)ImmutableList.of((Object)alphaMetadata, (Object)betaMetadata, (Object)rhoMetadata, (Object)nuMetadata), newValues);
        SabrParameters newSabrParams = SabrParameters.of(newCurveList.get(0), newCurveList.get(1), newCurveList.get(2), newCurveList.get(3), sabrDefinition.getShiftCurve(), sabrDefinition.getSabrVolatilityFormula());
        SabrParametersIborCapletFloorletVolatilities newVols = SabrParametersIborCapletFloorletVolatilities.of(volatilities.getName(), volatilities.getIndex(), volatilities.getValuationDateTime(), newSabrParams);
        return newVols;
    }

    static {
        SabrIborCapletFloorletVolatilityCalibrator.TRANSFORMS[0] = new SingleRangeLimitTransform(0.0, ParameterLimitsTransform.LimitType.GREATER_THAN);
        SabrIborCapletFloorletVolatilityCalibrator.TRANSFORMS[1] = new DoubleRangeLimitTransform(0.0, 1.0);
        SabrIborCapletFloorletVolatilityCalibrator.TRANSFORMS[2] = new DoubleRangeLimitTransform(-0.999, 0.999);
        SabrIborCapletFloorletVolatilityCalibrator.TRANSFORMS[3] = new DoubleRangeLimitTransform(0.001, 2.5);
    }
}

