/*
 * Decompiled with CFR 0.152.
 */
package com.opengamma.strata.pricer.sensitivity;

import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Doubles;
import com.opengamma.strata.basics.currency.Currency;
import com.opengamma.strata.basics.index.Index;
import com.opengamma.strata.basics.index.PriceIndex;
import com.opengamma.strata.basics.index.RateIndex;
import com.opengamma.strata.collect.Guavate;
import com.opengamma.strata.collect.array.DoubleArray;
import com.opengamma.strata.collect.array.DoubleMatrix;
import com.opengamma.strata.collect.tuple.Pair;
import com.opengamma.strata.data.MarketDataName;
import com.opengamma.strata.market.curve.Curve;
import com.opengamma.strata.market.curve.CurveName;
import com.opengamma.strata.market.curve.LegalEntityGroup;
import com.opengamma.strata.market.curve.ParallelShiftedCurve;
import com.opengamma.strata.market.curve.RepoGroup;
import com.opengamma.strata.market.param.CrossGammaParameterSensitivities;
import com.opengamma.strata.market.param.CrossGammaParameterSensitivity;
import com.opengamma.strata.market.param.CurrencyParameterSensitivities;
import com.opengamma.strata.market.param.CurrencyParameterSensitivity;
import com.opengamma.strata.math.impl.differentiation.FiniteDifferenceType;
import com.opengamma.strata.math.impl.differentiation.VectorFieldFirstOrderDifferentiator;
import com.opengamma.strata.pricer.DiscountFactors;
import com.opengamma.strata.pricer.SimpleDiscountFactors;
import com.opengamma.strata.pricer.ZeroRateDiscountFactors;
import com.opengamma.strata.pricer.ZeroRatePeriodicDiscountFactors;
import com.opengamma.strata.pricer.bond.ImmutableLegalEntityDiscountingProvider;
import com.opengamma.strata.pricer.bond.LegalEntityDiscountingProvider;
import com.opengamma.strata.pricer.rate.ImmutableRatesProvider;
import com.opengamma.strata.pricer.rate.RatesProvider;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.IntStream;

public final class CurveGammaCalculator {
    public static final CurveGammaCalculator DEFAULT = new CurveGammaCalculator(FiniteDifferenceType.FORWARD, 1.0E-4);
    private final VectorFieldFirstOrderDifferentiator fd;

    public static CurveGammaCalculator ofForwardDifference(double shift) {
        return new CurveGammaCalculator(FiniteDifferenceType.FORWARD, shift);
    }

    public static CurveGammaCalculator ofCentralDifference(double shift) {
        return new CurveGammaCalculator(FiniteDifferenceType.CENTRAL, shift);
    }

    public static CurveGammaCalculator ofBackwardDifference(double shift) {
        return new CurveGammaCalculator(FiniteDifferenceType.BACKWARD, shift);
    }

    private CurveGammaCalculator(FiniteDifferenceType fdType, double shift) {
        this.fd = new VectorFieldFirstOrderDifferentiator(fdType, shift);
    }

    public CrossGammaParameterSensitivities calculateCrossGammaIntraCurve(RatesProvider ratesProvider, Function<ImmutableRatesProvider, CurrencyParameterSensitivities> sensitivitiesFn) {
        ImmutableRatesProvider immProv = ratesProvider.toImmutableRatesProvider();
        CurrencyParameterSensitivities baseDelta = sensitivitiesFn.apply(immProv);
        CrossGammaParameterSensitivities result = CrossGammaParameterSensitivities.empty();
        for (Map.Entry entry : immProv.getDiscountCurves().entrySet()) {
            Currency currency = (Currency)entry.getKey();
            Curve curve = (Curve)entry.getValue();
            if (baseDelta.findSensitivity((MarketDataName)curve.getName(), currency).isPresent()) {
                CrossGammaParameterSensitivity gammaSingle = this.computeGammaForCurve(curve, currency, (Curve c) -> immProv.toBuilder().discountCurve(currency, (Curve)c).build(), sensitivitiesFn);
                result = result.combinedWith(gammaSingle);
                continue;
            }
            if (curve.split().size() <= 1) continue;
            ImmutableList curves = curve.split();
            int nCurves = curves.size();
            for (int i = 0; i < nCurves; ++i) {
                int currentIndex = i;
                Curve underlyingCurve = (Curve)curves.get(currentIndex);
                if (!baseDelta.findSensitivity((MarketDataName)underlyingCurve.getName(), currency).isPresent()) continue;
                CrossGammaParameterSensitivity gammaSingle = this.computeGammaForCurve(underlyingCurve, currency, (Curve c) -> immProv.toBuilder().discountCurve(currency, curve.withUnderlyingCurve(currentIndex, c)).build(), sensitivitiesFn);
                result = result.combinedWith(gammaSingle);
            }
        }
        for (Map.Entry entry : immProv.getIndexCurves().entrySet()) {
            Index index = (Index)entry.getKey();
            if (!(index instanceof RateIndex) && !(index instanceof PriceIndex)) continue;
            Currency currency = this.getCurrency(index);
            Curve curve = (Curve)entry.getValue();
            if (baseDelta.findSensitivity((MarketDataName)curve.getName(), currency).isPresent()) {
                CrossGammaParameterSensitivity gammaSingle = this.computeGammaForCurve(curve, currency, (Curve c) -> immProv.toBuilder().indexCurve(index, (Curve)c).build(), sensitivitiesFn);
                result = result.combinedWith(gammaSingle);
                continue;
            }
            if (curve.split().size() <= 1) continue;
            ImmutableList curves = curve.split();
            int nCurves = curves.size();
            for (int i = 0; i < nCurves; ++i) {
                int currentIndex = i;
                Curve underlyingCurve = (Curve)curves.get(currentIndex);
                if (!baseDelta.findSensitivity((MarketDataName)underlyingCurve.getName(), currency).isPresent()) continue;
                CrossGammaParameterSensitivity gammaSingle = this.computeGammaForCurve(underlyingCurve, currency, (Curve c) -> immProv.toBuilder().indexCurve(index, curve.withUnderlyingCurve(currentIndex, c)).build(), sensitivitiesFn);
                result = result.combinedWith(gammaSingle);
            }
        }
        return result;
    }

    public CrossGammaParameterSensitivities calculateCrossGammaIntraCurve(LegalEntityDiscountingProvider ratesProvider, Function<ImmutableLegalEntityDiscountingProvider, CurrencyParameterSensitivities> sensitivitiesFn) {
        CrossGammaParameterSensitivity gammaSingle;
        CurveName underlyingCurveName;
        Curve underlyingCurve;
        int currentIndex;
        int i;
        int nCurves;
        ImmutableList curves;
        CrossGammaParameterSensitivity gammaSingle2;
        CurveName curveName;
        Curve curve;
        Currency currency;
        LocalDate valuationDate = ratesProvider.getValuationDate();
        ImmutableLegalEntityDiscountingProvider immProv = ratesProvider.toImmutableLegalEntityDiscountingProvider();
        CurrencyParameterSensitivities baseDelta = sensitivitiesFn.apply(immProv);
        CrossGammaParameterSensitivities result = CrossGammaParameterSensitivities.empty();
        for (Map.Entry entry : immProv.getIssuerCurves().entrySet()) {
            Pair legCcy = (Pair)entry.getKey();
            currency = (Currency)legCcy.getSecond();
            curve = this.getCurve((DiscountFactors)entry.getValue());
            curveName = curve.getName();
            if (baseDelta.findSensitivity((MarketDataName)curveName, currency).isPresent()) {
                gammaSingle2 = this.computeGammaForCurve(curveName, curve, currency, c -> this.replaceIssuerCurve(immProv, (Pair<LegalEntityGroup, Currency>)legCcy, DiscountFactors.of(currency, valuationDate, c)), sensitivitiesFn);
                result = result.combinedWith(gammaSingle2);
                continue;
            }
            curves = curve.split();
            nCurves = curves.size();
            if (nCurves <= 1) continue;
            for (i = 0; i < nCurves; ++i) {
                currentIndex = i;
                underlyingCurve = (Curve)curves.get(currentIndex);
                underlyingCurveName = underlyingCurve.getName();
                if (!baseDelta.findSensitivity((MarketDataName)underlyingCurveName, currency).isPresent()) continue;
                gammaSingle = this.computeGammaForCurve(underlyingCurveName, underlyingCurve, currency, c -> this.replaceIssuerCurve(immProv, (Pair<LegalEntityGroup, Currency>)legCcy, DiscountFactors.of(currency, valuationDate, curve.withUnderlyingCurve(currentIndex, c))), sensitivitiesFn);
                result = result.combinedWith(gammaSingle);
            }
        }
        for (Map.Entry entry : immProv.getRepoCurves().entrySet()) {
            Pair rgCcy = (Pair)entry.getKey();
            currency = (Currency)rgCcy.getSecond();
            curve = this.getCurve((DiscountFactors)entry.getValue());
            curveName = curve.getName();
            if (baseDelta.findSensitivity((MarketDataName)curveName, currency).isPresent()) {
                gammaSingle2 = this.computeGammaForCurve(curveName, curve, currency, c -> this.replaceRepoCurve(immProv, (Pair<RepoGroup, Currency>)rgCcy, DiscountFactors.of(currency, valuationDate, c)), sensitivitiesFn);
                result = result.combinedWith(gammaSingle2);
                continue;
            }
            curves = curve.split();
            nCurves = curves.size();
            if (nCurves <= 1) continue;
            for (i = 0; i < nCurves; ++i) {
                currentIndex = i;
                underlyingCurve = (Curve)curves.get(currentIndex);
                underlyingCurveName = underlyingCurve.getName();
                if (!baseDelta.findSensitivity((MarketDataName)underlyingCurveName, (Currency)rgCcy.getSecond()).isPresent()) continue;
                gammaSingle = this.computeGammaForCurve(underlyingCurveName, underlyingCurve, currency, c -> this.replaceRepoCurve(immProv, (Pair<RepoGroup, Currency>)rgCcy, DiscountFactors.of(currency, valuationDate, curve.withUnderlyingCurve(currentIndex, c))), sensitivitiesFn);
                result = result.combinedWith(gammaSingle);
            }
        }
        return result;
    }

    public CrossGammaParameterSensitivities calculateCrossGammaCrossCurve(RatesProvider ratesProvider, Function<ImmutableRatesProvider, CurrencyParameterSensitivities> sensitivitiesFn) {
        ImmutableRatesProvider immProv = ratesProvider.toImmutableRatesProvider();
        CurrencyParameterSensitivities baseDelta = sensitivitiesFn.apply(immProv);
        CrossGammaParameterSensitivities result = CrossGammaParameterSensitivities.empty();
        for (CurrencyParameterSensitivity baseDeltaSingle : baseDelta.getSensitivities()) {
            CrossGammaParameterSensitivities resultInner = CrossGammaParameterSensitivities.empty();
            for (Map.Entry entry : immProv.getDiscountCurves().entrySet()) {
                Currency currency = (Currency)entry.getKey();
                Curve curve = (Curve)entry.getValue();
                if (baseDelta.findSensitivity((MarketDataName)curve.getName(), currency).isPresent()) {
                    CrossGammaParameterSensitivity gammaSingle = this.computeGammaForCurve(baseDeltaSingle, curve, (Curve c) -> immProv.toBuilder().discountCurve(currency, (Curve)c).build(), sensitivitiesFn);
                    resultInner = resultInner.combinedWith(gammaSingle);
                    continue;
                }
                if (curve.split().size() <= 1) continue;
                ImmutableList curves = curve.split();
                int nCurves = curves.size();
                for (int i = 0; i < nCurves; ++i) {
                    int currentIndex = i;
                    Curve underlyingCurve = (Curve)curves.get(currentIndex);
                    if (!baseDelta.findSensitivity((MarketDataName)underlyingCurve.getName(), currency).isPresent()) continue;
                    CrossGammaParameterSensitivity gammaSingle = this.computeGammaForCurve(baseDeltaSingle, underlyingCurve, (Curve c) -> immProv.toBuilder().discountCurve(currency, curve.withUnderlyingCurve(currentIndex, c)).build(), sensitivitiesFn);
                    resultInner = resultInner.combinedWith(gammaSingle);
                }
            }
            for (Map.Entry entry : immProv.getIndexCurves().entrySet()) {
                Index index = (Index)entry.getKey();
                if (!(index instanceof RateIndex) && !(index instanceof PriceIndex)) continue;
                Currency currency = this.getCurrency(index);
                Curve curve = (Curve)entry.getValue();
                if (baseDelta.findSensitivity((MarketDataName)curve.getName(), currency).isPresent()) {
                    CrossGammaParameterSensitivity gammaSingle = this.computeGammaForCurve(baseDeltaSingle, curve, (Curve c) -> immProv.toBuilder().indexCurve(index, (Curve)c).build(), sensitivitiesFn);
                    resultInner = resultInner.combinedWith(gammaSingle);
                    continue;
                }
                if (curve.split().size() <= 1) continue;
                ImmutableList curves = curve.split();
                int nCurves = curves.size();
                for (int i = 0; i < nCurves; ++i) {
                    int currentIndex = i;
                    Curve underlyingCurve = (Curve)curves.get(currentIndex);
                    if (!baseDelta.findSensitivity((MarketDataName)underlyingCurve.getName(), currency).isPresent()) continue;
                    CrossGammaParameterSensitivity gammaSingle = this.computeGammaForCurve(baseDeltaSingle, underlyingCurve, (Curve c) -> immProv.toBuilder().indexCurve(index, curve.withUnderlyingCurve(currentIndex, c)).build(), sensitivitiesFn);
                    resultInner = resultInner.combinedWith(gammaSingle);
                }
            }
            result = result.combinedWith(this.combineSensitivities(baseDeltaSingle, resultInner));
        }
        return result;
    }

    private Currency getCurrency(Index index) {
        if (index instanceof RateIndex) {
            return ((RateIndex)index).getCurrency();
        }
        if (index instanceof PriceIndex) {
            return ((PriceIndex)index).getCurrency();
        }
        throw new IllegalArgumentException("unsupported index");
    }

    CrossGammaParameterSensitivity computeGammaForCurve(final Curve curve, final Currency sensitivityCurrency, final Function<Curve, ImmutableRatesProvider> ratesProviderFn, final Function<ImmutableRatesProvider, CurrencyParameterSensitivities> sensitivitiesFn) {
        Function<DoubleArray, DoubleArray> function = new Function<DoubleArray, DoubleArray>(){

            @Override
            public DoubleArray apply(DoubleArray t) {
                Curve newCurve = CurveGammaCalculator.this.replaceParameters(curve, t);
                ImmutableRatesProvider newRates = (ImmutableRatesProvider)ratesProviderFn.apply(newCurve);
                CurrencyParameterSensitivities sensiMulti = (CurrencyParameterSensitivities)sensitivitiesFn.apply(newRates);
                return sensiMulti.getSensitivity((MarketDataName)newCurve.getName(), sensitivityCurrency).getSensitivity();
            }
        };
        int nParams = curve.getParameterCount();
        DoubleMatrix sensi = (DoubleMatrix)this.fd.differentiate((Function)function).apply(DoubleArray.of((int)nParams, n -> curve.getParameter(n)));
        List metadata = (List)IntStream.range(0, nParams).mapToObj(i -> curve.getParameterMetadata(i)).collect(Guavate.toImmutableList());
        return CrossGammaParameterSensitivity.of((MarketDataName)curve.getName(), (List)metadata, (Currency)sensitivityCurrency, (DoubleMatrix)sensi);
    }

    CrossGammaParameterSensitivity computeGammaForCurve(final CurrencyParameterSensitivity baseDeltaSingle, final Curve curve, final Function<Curve, ImmutableRatesProvider> ratesProviderFn, final Function<ImmutableRatesProvider, CurrencyParameterSensitivities> sensitivitiesFn) {
        Function<DoubleArray, DoubleArray> function = new Function<DoubleArray, DoubleArray>(){

            @Override
            public DoubleArray apply(DoubleArray t) {
                Curve newCurve = CurveGammaCalculator.this.replaceParameters(curve, t);
                ImmutableRatesProvider newRates = (ImmutableRatesProvider)ratesProviderFn.apply(newCurve);
                CurrencyParameterSensitivities sensiMulti = (CurrencyParameterSensitivities)sensitivitiesFn.apply(newRates);
                return sensiMulti.getSensitivity(baseDeltaSingle.getMarketDataName(), baseDeltaSingle.getCurrency()).getSensitivity();
            }
        };
        int nParams = curve.getParameterCount();
        DoubleMatrix sensi = (DoubleMatrix)this.fd.differentiate((Function)function).apply(DoubleArray.of((int)nParams, n -> curve.getParameter(n)));
        List metadata = (List)IntStream.range(0, nParams).mapToObj(i -> curve.getParameterMetadata(i)).collect(Guavate.toImmutableList());
        return CrossGammaParameterSensitivity.of((MarketDataName)baseDeltaSingle.getMarketDataName(), (List)baseDeltaSingle.getParameterMetadata(), (MarketDataName)curve.getName(), (List)metadata, (Currency)baseDeltaSingle.getCurrency(), (DoubleMatrix)sensi);
    }

    private CrossGammaParameterSensitivity combineSensitivities(CurrencyParameterSensitivity baseDeltaSingle, CrossGammaParameterSensitivities blockCrossGamma) {
        double[][] valuesTotal = new double[baseDeltaSingle.getParameterCount()][];
        ArrayList<Object> order = new ArrayList<Object>();
        for (int i = 0; i < baseDeltaSingle.getParameterCount(); ++i) {
            ArrayList innerList = new ArrayList();
            for (CrossGammaParameterSensitivity gammaSingle : blockCrossGamma.getSensitivities()) {
                innerList.addAll(gammaSingle.getSensitivity().row(i).toList());
                if (i != 0) continue;
                order.add(gammaSingle.getOrder().get(0));
            }
            valuesTotal[i] = Doubles.toArray(innerList);
        }
        return CrossGammaParameterSensitivity.of((MarketDataName)baseDeltaSingle.getMarketDataName(), (List)baseDeltaSingle.getParameterMetadata(), order, (Currency)baseDeltaSingle.getCurrency(), (DoubleMatrix)DoubleMatrix.ofUnsafe((double[][])valuesTotal));
    }

    public CurrencyParameterSensitivity calculateSemiParallelGamma(Curve curve, Currency curveCurrency, Function<Curve, CurrencyParameterSensitivity> sensitivitiesFn) {
        Delta deltaShift = new Delta(curve, sensitivitiesFn);
        Function gammaFn = this.fd.differentiate((Function)deltaShift);
        DoubleArray gamma = ((DoubleMatrix)gammaFn.apply(DoubleArray.filled((int)1))).column(0);
        return curve.createParameterSensitivity(curveCurrency, gamma);
    }

    private Curve replaceParameters(Curve curve, DoubleArray newParameters) {
        return curve.withPerturbation((i, v, m) -> newParameters.get(i));
    }

    private Curve getCurve(DiscountFactors discountFactors) {
        if (discountFactors instanceof SimpleDiscountFactors) {
            return ((SimpleDiscountFactors)discountFactors).getCurve();
        }
        if (discountFactors instanceof ZeroRateDiscountFactors) {
            return ((ZeroRateDiscountFactors)discountFactors).getCurve();
        }
        if (discountFactors instanceof ZeroRatePeriodicDiscountFactors) {
            return ((ZeroRatePeriodicDiscountFactors)discountFactors).getCurve();
        }
        throw new IllegalArgumentException("Unsupported DiscountFactors type");
    }

    private CrossGammaParameterSensitivity computeGammaForCurve(final CurveName curveName, final Curve curve, final Currency sensitivityCurrency, final Function<Curve, ImmutableLegalEntityDiscountingProvider> ratesProviderFn, final Function<ImmutableLegalEntityDiscountingProvider, CurrencyParameterSensitivities> sensitivitiesFn) {
        Function<DoubleArray, DoubleArray> function = new Function<DoubleArray, DoubleArray>(){

            @Override
            public DoubleArray apply(DoubleArray t) {
                Curve newCurve = curve.withPerturbation((i, v, m) -> t.get(i));
                ImmutableLegalEntityDiscountingProvider newRates = (ImmutableLegalEntityDiscountingProvider)ratesProviderFn.apply(newCurve);
                CurrencyParameterSensitivities sensiMulti = (CurrencyParameterSensitivities)sensitivitiesFn.apply(newRates);
                return sensiMulti.getSensitivity((MarketDataName)curveName, sensitivityCurrency).getSensitivity();
            }
        };
        int nParams = curve.getParameterCount();
        DoubleMatrix sensi = (DoubleMatrix)this.fd.differentiate((Function)function).apply(DoubleArray.of((int)nParams, n -> curve.getParameter(n)));
        List metadata = (List)IntStream.range(0, nParams).mapToObj(i -> curve.getParameterMetadata(i)).collect(Guavate.toImmutableList());
        return CrossGammaParameterSensitivity.of((MarketDataName)curveName, (List)metadata, (Currency)sensitivityCurrency, (DoubleMatrix)sensi);
    }

    private ImmutableLegalEntityDiscountingProvider replaceIssuerCurve(ImmutableLegalEntityDiscountingProvider ratesProvider, Pair<LegalEntityGroup, Currency> legCcy, DiscountFactors discountFactors) {
        HashMap<Pair<LegalEntityGroup, Currency>, DiscountFactors> curves = new HashMap<Pair<LegalEntityGroup, Currency>, DiscountFactors>();
        curves.putAll((Map<Pair<LegalEntityGroup, Currency>, DiscountFactors>)ratesProvider.getIssuerCurves());
        curves.put(legCcy, discountFactors);
        return ratesProvider.toBuilder().issuerCurves(curves).build();
    }

    private ImmutableLegalEntityDiscountingProvider replaceRepoCurve(ImmutableLegalEntityDiscountingProvider ratesProvider, Pair<RepoGroup, Currency> rgCcy, DiscountFactors discountFactors) {
        HashMap<Pair<RepoGroup, Currency>, DiscountFactors> curves = new HashMap<Pair<RepoGroup, Currency>, DiscountFactors>();
        curves.putAll((Map<Pair<RepoGroup, Currency>, DiscountFactors>)ratesProvider.getRepoCurves());
        curves.put(rgCcy, discountFactors);
        return ratesProvider.toBuilder().repoCurves(curves).build();
    }

    static class Delta
    implements Function<DoubleArray, DoubleArray> {
        private final Curve curve;
        private final Function<Curve, CurrencyParameterSensitivity> sensitivitiesFn;

        Delta(Curve curve, Function<Curve, CurrencyParameterSensitivity> sensitivitiesFn) {
            this.curve = curve;
            this.sensitivitiesFn = sensitivitiesFn;
        }

        @Override
        public DoubleArray apply(DoubleArray s) {
            double shift = s.get(0);
            ParallelShiftedCurve curveBumped = ParallelShiftedCurve.absolute((Curve)this.curve, (double)shift);
            CurrencyParameterSensitivity pts = this.sensitivitiesFn.apply((Curve)curveBumped);
            return pts.getSensitivity();
        }
    }
}

