/*
 * Decompiled with CFR 0.152.
 */
package com.opengamma.strata.pricer.capfloor;

import com.google.common.collect.ImmutableList;
import com.opengamma.strata.basics.ReferenceData;
import com.opengamma.strata.basics.date.DayCount;
import com.opengamma.strata.basics.index.IborIndex;
import com.opengamma.strata.collect.ArgChecker;
import com.opengamma.strata.collect.array.DoubleArray;
import com.opengamma.strata.collect.array.DoubleMatrix;
import com.opengamma.strata.collect.tuple.Triple;
import com.opengamma.strata.market.ValueType;
import com.opengamma.strata.market.curve.Curve;
import com.opengamma.strata.market.curve.interpolator.CurveInterpolator;
import com.opengamma.strata.market.curve.interpolator.CurveInterpolators;
import com.opengamma.strata.market.param.CurrencyParameterSensitivity;
import com.opengamma.strata.market.surface.InterpolatedNodalSurface;
import com.opengamma.strata.market.surface.Surface;
import com.opengamma.strata.market.surface.SurfaceMetadata;
import com.opengamma.strata.market.surface.Surfaces;
import com.opengamma.strata.market.surface.interpolator.GridSurfaceInterpolator;
import com.opengamma.strata.market.surface.interpolator.SurfaceInterpolator;
import com.opengamma.strata.math.impl.linearalgebra.CholeskyDecompositionCommons;
import com.opengamma.strata.math.impl.minimization.PositiveOrZero;
import com.opengamma.strata.math.impl.statistics.leastsquare.LeastSquareWithPenaltyResults;
import com.opengamma.strata.math.impl.statistics.leastsquare.NonLinearLeastSquareWithPenalty;
import com.opengamma.strata.math.linearalgebra.Decomposition;
import com.opengamma.strata.pricer.capfloor.DirectIborCapletFloorletVolatilityDefinition;
import com.opengamma.strata.pricer.capfloor.IborCapletFloorletVolatilities;
import com.opengamma.strata.pricer.capfloor.IborCapletFloorletVolatilityCalibrationResult;
import com.opengamma.strata.pricer.capfloor.IborCapletFloorletVolatilityCalibrator;
import com.opengamma.strata.pricer.capfloor.IborCapletFloorletVolatilityDefinition;
import com.opengamma.strata.pricer.capfloor.ShiftedBlackIborCapletFloorletExpiryStrikeVolatilities;
import com.opengamma.strata.pricer.capfloor.VolatilityIborCapFloorLegPricer;
import com.opengamma.strata.pricer.option.RawOptionData;
import com.opengamma.strata.pricer.rate.RatesProvider;
import com.opengamma.strata.product.capfloor.IborCapletFloorletPeriod;
import com.opengamma.strata.product.capfloor.ResolvedIborCapFloorLeg;
import java.time.LocalDate;
import java.time.Period;
import java.time.ZonedDateTime;
import java.time.temporal.TemporalAmount;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;

public class DirectIborCapletFloorletVolatilityCalibrator
extends IborCapletFloorletVolatilityCalibrator {
    private static final DirectIborCapletFloorletVolatilityCalibrator STANDARD = DirectIborCapletFloorletVolatilityCalibrator.of(VolatilityIborCapFloorLegPricer.DEFAULT, 1.0E-8, ReferenceData.standard());
    private static final Function<DoubleArray, Boolean> POSITIVE = new PositiveOrZero();
    private static final GridSurfaceInterpolator INTERPOLATOR = GridSurfaceInterpolator.of((CurveInterpolator)CurveInterpolators.LINEAR, (CurveInterpolator)CurveInterpolators.LINEAR);
    private final NonLinearLeastSquareWithPenalty solver;

    public static DirectIborCapletFloorletVolatilityCalibrator standard() {
        return STANDARD;
    }

    public static DirectIborCapletFloorletVolatilityCalibrator of(VolatilityIborCapFloorLegPricer pricer, double epsilon, ReferenceData referenceData) {
        return new DirectIborCapletFloorletVolatilityCalibrator(pricer, epsilon, referenceData);
    }

    private DirectIborCapletFloorletVolatilityCalibrator(VolatilityIborCapFloorLegPricer pricer, double epsilon, ReferenceData referenceData) {
        super(pricer, referenceData);
        this.solver = new NonLinearLeastSquareWithPenalty((Decomposition)new CholeskyDecompositionCommons(), epsilon);
    }

    @Override
    public IborCapletFloorletVolatilityCalibrationResult calibrate(IborCapletFloorletVolatilityDefinition definition, ZonedDateTime calibrationDateTime, RawOptionData capFloorData, RatesProvider ratesProvider) {
        Triple<DoubleArray, DoubleArray, DoubleArray> capletNodes;
        ArgChecker.isTrue((boolean)ratesProvider.getValuationDate().equals(calibrationDateTime.toLocalDate()), (String)"valuationDate of ratesProvider should be coherent to calibrationDateTime");
        ArgChecker.isTrue((boolean)(definition instanceof DirectIborCapletFloorletVolatilityDefinition), (String)"definition should be DirectIborCapletFloorletVolatilityDefinition");
        DirectIborCapletFloorletVolatilityDefinition directDefinition = (DirectIborCapletFloorletVolatilityDefinition)definition;
        IborIndex index = directDefinition.getIndex();
        LocalDate calibrationDate = calibrationDateTime.toLocalDate();
        LocalDate baseDate = index.getEffectiveDateOffset().adjust(calibrationDate, this.getReferenceData());
        LocalDate startDate = baseDate.plus((TemporalAmount)index.getTenor());
        Function<Surface, IborCapletFloorletVolatilities> volatilitiesFunction = this.volatilitiesFunction(directDefinition, calibrationDateTime, capFloorData);
        SurfaceMetadata metadata = directDefinition.createMetadata(capFloorData);
        ImmutableList<Period> expiries = capFloorData.getExpiries();
        DoubleArray strikes = capFloorData.getStrikes();
        int nExpiries = expiries.size();
        ArrayList<Double> timeList = new ArrayList<Double>();
        ArrayList<Double> strikeList = new ArrayList<Double>();
        ArrayList<Double> volList = new ArrayList<Double>();
        ArrayList<ResolvedIborCapFloorLeg> capList = new ArrayList<ResolvedIborCapFloorLeg>();
        ArrayList<Double> priceList = new ArrayList<Double>();
        ArrayList<Double> errorList = new ArrayList<Double>();
        DoubleMatrix errorMatrix = capFloorData.getError().orElse(DoubleMatrix.filled((int)nExpiries, (int)strikes.size(), (double)1.0));
        int[] startIndex = new int[nExpiries + 1];
        for (int i = 0; i < nExpiries; ++i) {
            LocalDate endDate = baseDate.plus((TemporalAmount)expiries.get(i));
            DoubleArray volatilityForTime = capFloorData.getData().row(i);
            DoubleArray errorForTime = errorMatrix.row(i);
            this.reduceRawData(directDefinition, ratesProvider, capFloorData.getStrikes(), volatilityForTime, errorForTime, startDate, endDate, metadata, volatilitiesFunction, timeList, strikeList, volList, capList, priceList, errorList);
            startIndex[i + 1] = volList.size();
            ArgChecker.isTrue((startIndex[i + 1] > startIndex[i] ? 1 : 0) != 0, (String)"no valid option data for {}", (Object[])new Object[]{expiries.get(i)});
        }
        ResolvedIborCapFloorLeg cap = (ResolvedIborCapFloorLeg)capList.get(capList.size() - 1);
        int nCaplets = cap.getCapletFloorletPeriods().size();
        DoubleArray capletExpiries = DoubleArray.of((int)nCaplets, n -> directDefinition.getDayCount().relativeYearFraction(calibrationDate, ((IborCapletFloorletPeriod)cap.getCapletFloorletPeriods().get(n)).getFixingDateTime().toLocalDate()));
        DoubleArray initialVols = DoubleArray.copyOf(volList);
        if (directDefinition.getShiftCurve().isPresent()) {
            metadata = Surfaces.blackVolatilityByExpiryStrike((String)directDefinition.getName().getName(), (DayCount)directDefinition.getDayCount());
            Curve shiftCurve = directDefinition.getShiftCurve().get();
            if (capFloorData.getDataType().equals((Object)ValueType.NORMAL_VOLATILITY)) {
                initialVols = DoubleArray.of((int)capList.size(), n -> (Double)volList.get(n) / (ratesProvider.iborIndexRates(index).rate(((ResolvedIborCapFloorLeg)capList.get(n)).getFinalPeriod().getIborRate().getObservation()) + shiftCurve.yValue(((Double)timeList.get(n)).doubleValue())));
            }
            InterpolatedNodalSurface capVolSurface = InterpolatedNodalSurface.of((SurfaceMetadata)metadata, (DoubleArray)DoubleArray.copyOf(timeList), (DoubleArray)DoubleArray.copyOf(strikeList), (DoubleArray)initialVols, (SurfaceInterpolator)INTERPOLATOR);
            capletNodes = this.createCapletNodes(capVolSurface, capletExpiries, strikes, directDefinition.getShiftCurve().get());
            volatilitiesFunction = this.createShiftedBlackVolatilitiesFunction(index, calibrationDateTime, shiftCurve);
        } else {
            InterpolatedNodalSurface capVolSurface = InterpolatedNodalSurface.of((SurfaceMetadata)metadata, (DoubleArray)DoubleArray.copyOf(timeList), (DoubleArray)DoubleArray.copyOf(strikeList), (DoubleArray)initialVols, (SurfaceInterpolator)INTERPOLATOR);
            capletNodes = this.createCapletNodes(capVolSurface, capletExpiries, strikes);
        }
        InterpolatedNodalSurface baseSurface = InterpolatedNodalSurface.of((SurfaceMetadata)metadata, (DoubleArray)((DoubleArray)capletNodes.getFirst()), (DoubleArray)((DoubleArray)capletNodes.getSecond()), (DoubleArray)((DoubleArray)capletNodes.getThird()), (SurfaceInterpolator)INTERPOLATOR);
        DoubleMatrix penaltyMatrix = directDefinition.computePenaltyMatrix(strikes, capletExpiries);
        LeastSquareWithPenaltyResults res = this.solver.solve(DoubleArray.copyOf(priceList), DoubleArray.copyOf(errorList), this.getPriceFunction(capList, ratesProvider, volatilitiesFunction, baseSurface), this.getJacobianFunction(capList, ratesProvider, volatilitiesFunction, baseSurface), (DoubleArray)capletNodes.getThird(), penaltyMatrix, POSITIVE);
        InterpolatedNodalSurface resSurface = InterpolatedNodalSurface.of((SurfaceMetadata)metadata, (DoubleArray)((DoubleArray)capletNodes.getFirst()), (DoubleArray)((DoubleArray)capletNodes.getSecond()), (DoubleArray)res.getFitParameters(), (SurfaceInterpolator)directDefinition.getInterpolator());
        return IborCapletFloorletVolatilityCalibrationResult.ofLeastSquare(volatilitiesFunction.apply((Surface)resSurface), res.getChiSq());
    }

    private Function<Surface, IborCapletFloorletVolatilities> createShiftedBlackVolatilitiesFunction(final IborIndex index, final ZonedDateTime calibrationDateTime, final Curve shiftCurve) {
        Function<Surface, IborCapletFloorletVolatilities> func = new Function<Surface, IborCapletFloorletVolatilities>(){

            @Override
            public IborCapletFloorletVolatilities apply(Surface s) {
                return ShiftedBlackIborCapletFloorletExpiryStrikeVolatilities.of(index, calibrationDateTime, s, shiftCurve);
            }
        };
        return func;
    }

    private Triple<DoubleArray, DoubleArray, DoubleArray> createCapletNodes(InterpolatedNodalSurface capVolSurface, DoubleArray capletExpiries, DoubleArray strikes) {
        ArrayList timeCapletList = new ArrayList();
        ArrayList strikeCapletList = new ArrayList();
        ArrayList volCapletList = new ArrayList();
        int nTimes = capletExpiries.size();
        int nStrikes = strikes.size();
        for (int i = 0; i < nTimes; ++i) {
            double expiry = capletExpiries.get(i);
            timeCapletList.addAll(DoubleArray.filled((int)nStrikes, (double)expiry).toList());
            strikeCapletList.addAll(strikes.toList());
            volCapletList.addAll(DoubleArray.of((int)nStrikes, n -> capVolSurface.zValue(expiry, strikes.get(n))).toList());
        }
        return Triple.of((Object)DoubleArray.copyOf(timeCapletList), (Object)DoubleArray.copyOf(strikeCapletList), (Object)DoubleArray.copyOf(volCapletList));
    }

    private Triple<DoubleArray, DoubleArray, DoubleArray> createCapletNodes(InterpolatedNodalSurface capVolSurface, DoubleArray capletExpiries, DoubleArray strikes, Curve shiftCurve) {
        ArrayList timeCapletList = new ArrayList();
        ArrayList strikeCapletList = new ArrayList();
        ArrayList volCapletList = new ArrayList();
        int nTimes = capletExpiries.size();
        int nStrikes = strikes.size();
        for (int i = 0; i < nTimes; ++i) {
            double expiry = capletExpiries.get(i);
            double shift = shiftCurve.yValue(expiry);
            timeCapletList.addAll(DoubleArray.filled((int)nStrikes, (double)expiry).toList());
            strikeCapletList.addAll(strikes.plus(shift).toList());
            volCapletList.addAll(DoubleArray.of((int)nStrikes, n -> capVolSurface.zValue(expiry, strikes.get(n) + shift)).toList());
        }
        return Triple.of((Object)DoubleArray.copyOf(timeCapletList), (Object)DoubleArray.copyOf(strikeCapletList), (Object)DoubleArray.copyOf(volCapletList));
    }

    private Function<DoubleArray, DoubleArray> getPriceFunction(final List<ResolvedIborCapFloorLeg> capList, final RatesProvider ratesProvider, final Function<Surface, IborCapletFloorletVolatilities> volatilitiesFunction, final InterpolatedNodalSurface baseSurface) {
        final int nCaps = capList.size();
        Function<DoubleArray, DoubleArray> priceFunction = new Function<DoubleArray, DoubleArray>(){

            @Override
            public DoubleArray apply(DoubleArray capletVols) {
                IborCapletFloorletVolatilities newVols = (IborCapletFloorletVolatilities)volatilitiesFunction.apply(baseSurface.withZValues(capletVols));
                return DoubleArray.of((int)nCaps, n -> DirectIborCapletFloorletVolatilityCalibrator.this.getLegPricer().presentValue((ResolvedIborCapFloorLeg)capList.get(n), ratesProvider, newVols).getAmount());
            }
        };
        return priceFunction;
    }

    private Function<DoubleArray, DoubleMatrix> getJacobianFunction(final List<ResolvedIborCapFloorLeg> capList, final RatesProvider ratesProvider, final Function<Surface, IborCapletFloorletVolatilities> volatilitiesFunction, final InterpolatedNodalSurface baseSurface) {
        final int nCaps = capList.size();
        final int nNodes = baseSurface.getParameterCount();
        Function<DoubleArray, DoubleMatrix> jacobianFunction = new Function<DoubleArray, DoubleMatrix>(){

            @Override
            public DoubleMatrix apply(DoubleArray capletVols) {
                IborCapletFloorletVolatilities newVols = (IborCapletFloorletVolatilities)volatilitiesFunction.apply(baseSurface.withZValues(capletVols));
                return DoubleMatrix.ofArrayObjects((int)nCaps, (int)nNodes, n -> ((CurrencyParameterSensitivity)newVols.parameterSensitivity(DirectIborCapletFloorletVolatilityCalibrator.this.getLegPricer().presentValueSensitivityModelParamsVolatility((ResolvedIborCapFloorLeg)capList.get(n), ratesProvider, newVols).build()).getSensitivities().get(0)).getSensitivity());
            }
        };
        return jacobianFunction;
    }
}

