/*
 * Decompiled with CFR 0.152.
 */
package com.opengamma.strata.pricer.capfloor;

import com.google.common.collect.ImmutableList;
import com.opengamma.strata.basics.ReferenceData;
import com.opengamma.strata.basics.index.IborIndex;
import com.opengamma.strata.collect.ArgChecker;
import com.opengamma.strata.collect.array.DoubleArray;
import com.opengamma.strata.collect.array.DoubleMatrix;
import com.opengamma.strata.collect.tuple.Pair;
import com.opengamma.strata.market.ValueType;
import com.opengamma.strata.market.curve.ConstantCurve;
import com.opengamma.strata.market.curve.Curve;
import com.opengamma.strata.market.curve.CurveMetadata;
import com.opengamma.strata.market.curve.InterpolatedNodalCurve;
import com.opengamma.strata.market.curve.interpolator.CurveExtrapolator;
import com.opengamma.strata.market.curve.interpolator.CurveInterpolator;
import com.opengamma.strata.market.curve.interpolator.CurveInterpolators;
import com.opengamma.strata.market.param.CurrencyParameterSensitivity;
import com.opengamma.strata.math.impl.linearalgebra.CholeskyDecompositionCommons;
import com.opengamma.strata.math.impl.minimization.PositiveOrZero;
import com.opengamma.strata.math.impl.statistics.leastsquare.LeastSquareWithPenaltyResults;
import com.opengamma.strata.math.impl.statistics.leastsquare.NonLinearLeastSquareWithPenalty;
import com.opengamma.strata.math.linearalgebra.Decomposition;
import com.opengamma.strata.pricer.capfloor.BlackIborCapletFloorletExpiryFlatVolatilities;
import com.opengamma.strata.pricer.capfloor.DirectIborCapletFloorletFlatVolatilityDefinition;
import com.opengamma.strata.pricer.capfloor.IborCapletFloorletVolatilities;
import com.opengamma.strata.pricer.capfloor.IborCapletFloorletVolatilityCalibrationResult;
import com.opengamma.strata.pricer.capfloor.IborCapletFloorletVolatilityCalibrator;
import com.opengamma.strata.pricer.capfloor.IborCapletFloorletVolatilityDefinition;
import com.opengamma.strata.pricer.capfloor.NormalIborCapletFloorletExpiryFlatVolatilities;
import com.opengamma.strata.pricer.capfloor.VolatilityIborCapFloorLegPricer;
import com.opengamma.strata.pricer.option.RawOptionData;
import com.opengamma.strata.pricer.rate.RatesProvider;
import com.opengamma.strata.product.capfloor.IborCapletFloorletPeriod;
import com.opengamma.strata.product.capfloor.ResolvedIborCapFloorLeg;
import java.time.LocalDate;
import java.time.Period;
import java.time.ZonedDateTime;
import java.time.temporal.TemporalAmount;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;

public class DirectIborCapletFloorletFlatVolatilityCalibrator
extends IborCapletFloorletVolatilityCalibrator {
    private static final DirectIborCapletFloorletFlatVolatilityCalibrator STANDARD = DirectIborCapletFloorletFlatVolatilityCalibrator.of(VolatilityIborCapFloorLegPricer.DEFAULT, 1.0E-8, ReferenceData.standard());
    private static final Function<DoubleArray, Boolean> POSITIVE = new PositiveOrZero();
    private final NonLinearLeastSquareWithPenalty solver;

    public static DirectIborCapletFloorletFlatVolatilityCalibrator standard() {
        return STANDARD;
    }

    public static DirectIborCapletFloorletFlatVolatilityCalibrator of(VolatilityIborCapFloorLegPricer pricer, double epsilon, ReferenceData referenceData) {
        return new DirectIborCapletFloorletFlatVolatilityCalibrator(pricer, epsilon, referenceData);
    }

    private DirectIborCapletFloorletFlatVolatilityCalibrator(VolatilityIborCapFloorLegPricer pricer, double epsilon, ReferenceData referenceData) {
        super(pricer, referenceData);
        this.solver = new NonLinearLeastSquareWithPenalty((Decomposition)new CholeskyDecompositionCommons(), epsilon);
    }

    @Override
    public IborCapletFloorletVolatilityCalibrationResult calibrate(IborCapletFloorletVolatilityDefinition definition, ZonedDateTime calibrationDateTime, RawOptionData capFloorData, RatesProvider ratesProvider) {
        ArgChecker.isTrue((boolean)ratesProvider.getValuationDate().equals(calibrationDateTime.toLocalDate()), (String)"valuationDate of ratesProvider should be coherent to calibrationDateTime");
        ArgChecker.isTrue((boolean)(definition instanceof DirectIborCapletFloorletFlatVolatilityDefinition), (String)"definition should be DirectIborCapletFloorletFlatVolatilityDefinition");
        DirectIborCapletFloorletFlatVolatilityDefinition directDefinition = (DirectIborCapletFloorletFlatVolatilityDefinition)definition;
        DoubleArray strikes = capFloorData.getStrikes();
        ArgChecker.isTrue((strikes.size() == 1 ? 1 : 0) != 0, (String)"strike size should be 1");
        IborIndex index = directDefinition.getIndex();
        LocalDate calibrationDate = calibrationDateTime.toLocalDate();
        LocalDate baseDate = index.getEffectiveDateOffset().adjust(calibrationDate, this.getReferenceData());
        LocalDate startDate = baseDate.plus((TemporalAmount)index.getTenor());
        Function<Curve, IborCapletFloorletVolatilities> volatilitiesFunction = this.flatVolatilitiesFunction(directDefinition, calibrationDateTime, capFloorData);
        CurveMetadata metadata = directDefinition.createCurveMetadata(capFloorData);
        ImmutableList<Period> expiries = capFloorData.getExpiries();
        int nExpiries = expiries.size();
        ArrayList<Double> timeList = new ArrayList<Double>();
        ArrayList<Double> volList = new ArrayList<Double>();
        ArrayList<ResolvedIborCapFloorLeg> capList = new ArrayList<ResolvedIborCapFloorLeg>();
        ArrayList<Double> priceList = new ArrayList<Double>();
        ArrayList<Double> errorList = new ArrayList<Double>();
        DoubleMatrix errorMatrix = capFloorData.getError().orElse(DoubleMatrix.filled((int)nExpiries, (int)1, (double)1.0));
        int[] startIndex = new int[nExpiries + 1];
        for (int i = 0; i < nExpiries; ++i) {
            LocalDate endDate = baseDate.plus((TemporalAmount)expiries.get(i));
            double strike = strikes.get(0);
            double volatilityForTime = capFloorData.getData().row(i).get(0);
            double errorForTime = errorMatrix.row(i).get(0);
            this.reduceRawData(directDefinition, ratesProvider, strike, volatilityForTime, errorForTime, startDate, endDate, metadata, volatilitiesFunction, timeList, volList, capList, priceList, errorList);
            startIndex[i + 1] = volList.size();
            ArgChecker.isTrue((startIndex[i + 1] > startIndex[i] ? 1 : 0) != 0, (String)"no valid option data for {}", (Object[])new Object[]{expiries.get(i)});
        }
        ResolvedIborCapFloorLeg cap = (ResolvedIborCapFloorLeg)capList.get(capList.size() - 1);
        int nCaplets = cap.getCapletFloorletPeriods().size();
        DoubleArray capletExpiries = DoubleArray.of((int)nCaplets, n -> directDefinition.getDayCount().relativeYearFraction(calibrationDate, ((IborCapletFloorletPeriod)cap.getCapletFloorletPeriods().get(n)).getFixingDateTime().toLocalDate()));
        DoubleArray initialVols = DoubleArray.copyOf(volList);
        InterpolatedNodalCurve capVolCurve = InterpolatedNodalCurve.of((CurveMetadata)metadata, (DoubleArray)DoubleArray.copyOf(timeList), (DoubleArray)initialVols, (CurveInterpolator)CurveInterpolators.LINEAR);
        Pair<DoubleArray, DoubleArray> capletNodes = this.createCapletNodes(capVolCurve, capletExpiries);
        InterpolatedNodalCurve baseCurve = InterpolatedNodalCurve.of((CurveMetadata)metadata, (DoubleArray)((DoubleArray)capletNodes.getFirst()), (DoubleArray)((DoubleArray)capletNodes.getSecond()), (CurveInterpolator)CurveInterpolators.LINEAR);
        DoubleMatrix penaltyMatrix = directDefinition.computePenaltyMatrix(capletExpiries);
        LeastSquareWithPenaltyResults res = this.solver.solve(DoubleArray.copyOf(priceList), DoubleArray.copyOf(errorList), this.getPriceFunction(capList, ratesProvider, volatilitiesFunction, baseCurve), this.getJacobianFunction(capList, ratesProvider, volatilitiesFunction, baseCurve), (DoubleArray)capletNodes.getSecond(), penaltyMatrix, POSITIVE);
        InterpolatedNodalCurve resCurve = InterpolatedNodalCurve.of((CurveMetadata)metadata, (DoubleArray)((DoubleArray)capletNodes.getFirst()), (DoubleArray)res.getFitParameters(), (CurveInterpolator)directDefinition.getInterpolator(), (CurveExtrapolator)directDefinition.getExtrapolatorLeft(), (CurveExtrapolator)directDefinition.getExtrapolatorRight());
        IborCapletFloorletVolatilityCalibrationResult calibrationResult = IborCapletFloorletVolatilityCalibrationResult.ofLeastSquare(volatilitiesFunction.apply((Curve)resCurve), res.getChiSq());
        return calibrationResult;
    }

    private Pair<DoubleArray, DoubleArray> createCapletNodes(InterpolatedNodalCurve capVolCurve, DoubleArray capletExpiries) {
        ArrayList<Double> timeCapletList = new ArrayList<Double>();
        ArrayList<Double> volCapletList = new ArrayList<Double>();
        int nTimes = capletExpiries.size();
        for (int i = 0; i < nTimes; ++i) {
            double expiry = capletExpiries.get(i);
            timeCapletList.add(expiry);
            volCapletList.add(capVolCurve.yValue(expiry));
        }
        return Pair.of((Object)DoubleArray.copyOf(timeCapletList), (Object)DoubleArray.copyOf(volCapletList));
    }

    private Function<DoubleArray, DoubleArray> getPriceFunction(final List<ResolvedIborCapFloorLeg> capList, final RatesProvider ratesProvider, final Function<Curve, IborCapletFloorletVolatilities> volatilitiesFunction, final InterpolatedNodalCurve baseCurve) {
        final int nCaps = capList.size();
        Function<DoubleArray, DoubleArray> priceFunction = new Function<DoubleArray, DoubleArray>(){

            @Override
            public DoubleArray apply(DoubleArray capletVols) {
                IborCapletFloorletVolatilities newVols = (IborCapletFloorletVolatilities)volatilitiesFunction.apply(baseCurve.withYValues(capletVols));
                return DoubleArray.of((int)nCaps, n -> DirectIborCapletFloorletFlatVolatilityCalibrator.this.getLegPricer().presentValue((ResolvedIborCapFloorLeg)capList.get(n), ratesProvider, newVols).getAmount());
            }
        };
        return priceFunction;
    }

    private Function<DoubleArray, DoubleMatrix> getJacobianFunction(final List<ResolvedIborCapFloorLeg> capList, final RatesProvider ratesProvider, final Function<Curve, IborCapletFloorletVolatilities> volatilitiesFunction, final InterpolatedNodalCurve baseCurve) {
        final int nCaps = capList.size();
        final int nNodes = baseCurve.getParameterCount();
        Function<DoubleArray, DoubleMatrix> jacobianFunction = new Function<DoubleArray, DoubleMatrix>(){

            @Override
            public DoubleMatrix apply(DoubleArray capletVols) {
                IborCapletFloorletVolatilities newVols = (IborCapletFloorletVolatilities)volatilitiesFunction.apply(baseCurve.withYValues(capletVols));
                return DoubleMatrix.ofArrayObjects((int)nCaps, (int)nNodes, n -> ((CurrencyParameterSensitivity)newVols.parameterSensitivity(DirectIborCapletFloorletFlatVolatilityCalibrator.this.getLegPricer().presentValueSensitivityModelParamsVolatility((ResolvedIborCapFloorLeg)capList.get(n), ratesProvider, newVols).build()).getSensitivities().get(0)).getSensitivity());
            }
        };
        return jacobianFunction;
    }

    private void reduceRawData(IborCapletFloorletVolatilityDefinition definition, RatesProvider ratesProvider, double strike, double volatility, double error, LocalDate startDate, LocalDate endDate, CurveMetadata metadata, Function<Curve, IborCapletFloorletVolatilities> volatilityFunction, List<Double> timeList, List<Double> volList, List<ResolvedIborCapFloorLeg> capList, List<Double> priceList, List<Double> errorList) {
        ResolvedIborCapFloorLeg capFloor = definition.createCap(startDate, endDate, strike).resolve(this.getReferenceData());
        capList.add(capFloor);
        volList.add(volatility);
        ConstantCurve constVolCurve = ConstantCurve.of((CurveMetadata)metadata, (double)volatility);
        IborCapletFloorletVolatilities vols = volatilityFunction.apply((Curve)constVolCurve);
        timeList.add(vols.relativeTime(capFloor.getFinalFixingDateTime()));
        priceList.add(this.getLegPricer().presentValue(capFloor, ratesProvider, vols).getAmount());
        errorList.add(error);
    }

    private Function<Curve, IborCapletFloorletVolatilities> flatVolatilitiesFunction(IborCapletFloorletVolatilityDefinition definition, ZonedDateTime calibrationDateTime, RawOptionData capFloorData) {
        IborIndex index = definition.getIndex();
        if (capFloorData.getDataType().equals((Object)ValueType.BLACK_VOLATILITY)) {
            return this.blackVolatilitiesFunction(index, calibrationDateTime);
        }
        if (capFloorData.getDataType().equals((Object)ValueType.NORMAL_VOLATILITY)) {
            return this.normalVolatilitiesFunction(index, calibrationDateTime);
        }
        throw new IllegalArgumentException("Data type not supported");
    }

    private Function<Curve, IborCapletFloorletVolatilities> blackVolatilitiesFunction(final IborIndex index, final ZonedDateTime calibrationDateTime) {
        Function<Curve, IborCapletFloorletVolatilities> func = new Function<Curve, IborCapletFloorletVolatilities>(){

            @Override
            public IborCapletFloorletVolatilities apply(Curve s) {
                return BlackIborCapletFloorletExpiryFlatVolatilities.of(index, calibrationDateTime, s);
            }
        };
        return func;
    }

    private Function<Curve, IborCapletFloorletVolatilities> normalVolatilitiesFunction(final IborIndex index, final ZonedDateTime calibrationDateTime) {
        Function<Curve, IborCapletFloorletVolatilities> func = new Function<Curve, IborCapletFloorletVolatilities>(){

            @Override
            public IborCapletFloorletVolatilities apply(Curve s) {
                return NormalIborCapletFloorletExpiryFlatVolatilities.of(index, calibrationDateTime, s);
            }
        };
        return func;
    }
}

