/*
 * Decompiled with CFR 0.152.
 */
package com.opengamma.strata.pricer.credit;

import com.google.common.collect.ImmutableList;
import com.opengamma.strata.basics.ReferenceData;
import com.opengamma.strata.basics.currency.Currency;
import com.opengamma.strata.basics.date.BusinessDayAdjustment;
import com.opengamma.strata.basics.date.DayCount;
import com.opengamma.strata.collect.ArgChecker;
import com.opengamma.strata.collect.array.DoubleArray;
import com.opengamma.strata.collect.array.DoubleMatrix;
import com.opengamma.strata.collect.array.Matrix;
import com.opengamma.strata.data.MarketData;
import com.opengamma.strata.data.MarketDataId;
import com.opengamma.strata.market.curve.ConstantNodalCurve;
import com.opengamma.strata.market.curve.CurveInfoType;
import com.opengamma.strata.market.curve.CurveMetadata;
import com.opengamma.strata.market.curve.CurveName;
import com.opengamma.strata.market.curve.CurveParameterSize;
import com.opengamma.strata.market.curve.DepositIsdaCreditCurveNode;
import com.opengamma.strata.market.curve.InterpolatedNodalCurve;
import com.opengamma.strata.market.curve.IsdaCreditCurveDefinition;
import com.opengamma.strata.market.curve.IsdaCreditCurveNode;
import com.opengamma.strata.market.curve.JacobianCalibrationMatrix;
import com.opengamma.strata.market.curve.NodalCurve;
import com.opengamma.strata.market.curve.SwapIsdaCreditCurveNode;
import com.opengamma.strata.market.param.ParameterMetadata;
import com.opengamma.strata.market.param.UnitParameterSensitivities;
import com.opengamma.strata.market.param.UnitParameterSensitivity;
import com.opengamma.strata.math.impl.matrix.CommonsMatrixAlgebra;
import com.opengamma.strata.math.impl.matrix.MatrixAlgebra;
import com.opengamma.strata.math.impl.rootfinding.BracketRoot;
import com.opengamma.strata.math.impl.rootfinding.NewtonRaphsonSingleRootFinder;
import com.opengamma.strata.pricer.credit.IsdaCreditDiscountFactors;
import java.time.LocalDate;
import java.time.Period;
import java.time.temporal.TemporalAmount;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Function;

public final class IsdaCompliantDiscountCurveCalibrator {
    public static final IsdaCompliantDiscountCurveCalibrator STANDARD = IsdaCompliantDiscountCurveCalibrator.of(1.0E-12);
    private static final MatrixAlgebra MATRIX_ALGEBRA = new CommonsMatrixAlgebra();
    private static final BracketRoot BRACKETER = new BracketRoot();
    private final NewtonRaphsonSingleRootFinder rootFinder;

    public static IsdaCompliantDiscountCurveCalibrator standard() {
        return STANDARD;
    }

    public static IsdaCompliantDiscountCurveCalibrator of(double accuracy) {
        return new IsdaCompliantDiscountCurveCalibrator(accuracy);
    }

    private IsdaCompliantDiscountCurveCalibrator(double accuracy) {
        this.rootFinder = new NewtonRaphsonSingleRootFinder(accuracy);
    }

    public IsdaCreditDiscountFactors calibrate(IsdaCreditCurveDefinition curveDefinition, MarketData marketData, ReferenceData refData) {
        ImmutableList curveNodes = curveDefinition.getCurveNodes();
        int nNodes = curveNodes.size();
        ArgChecker.isTrue((nNodes > 1 ? 1 : 0) != 0, (String)"the number of curve nodes must be greater than 1");
        LocalDate curveSnapDate = marketData.getValuationDate();
        LocalDate curveValuationDate = curveDefinition.getCurveValuationDate();
        DayCount curveDayCount = curveDefinition.getDayCount();
        BasicFixedLeg[] swapLeg = new BasicFixedLeg[nNodes];
        double[] termDepositYearFraction = new double[nNodes];
        double[] curveNodeTime = new double[nNodes];
        double[] rates = new double[nNodes];
        ImmutableList.Builder paramMetadata = ImmutableList.builder();
        int nTermDeposit = 0;
        LocalDate curveSpotDate = null;
        for (int i = 0; i < nNodes; ++i) {
            LocalDate cvDateTmp;
            IsdaCreditCurveNode node = (IsdaCreditCurveNode)curveNodes.get(i);
            rates[i] = (Double)marketData.getValue((MarketDataId)node.getObservableId());
            LocalDate adjMatDate = node.date(curveSnapDate, refData);
            paramMetadata.add((Object)node.metadata(adjMatDate));
            if (node instanceof DepositIsdaCreditCurveNode) {
                DepositIsdaCreditCurveNode termDeposit = (DepositIsdaCreditCurveNode)node;
                cvDateTmp = termDeposit.getSpotDateOffset().adjust(curveSnapDate, refData);
                curveNodeTime[i] = curveDayCount.relativeYearFraction(cvDateTmp, adjMatDate);
                termDepositYearFraction[i] = termDeposit.getDayCount().relativeYearFraction(cvDateTmp, adjMatDate);
                ArgChecker.isTrue((nTermDeposit == i ? 1 : 0) != 0, (String)"TermDepositCurveNode should not be after FixedIborSwapCurveNode");
                ++nTermDeposit;
            } else if (node instanceof SwapIsdaCreditCurveNode) {
                SwapIsdaCreditCurveNode swap = (SwapIsdaCreditCurveNode)node;
                cvDateTmp = swap.getSpotDateOffset().adjust(curveSnapDate, refData);
                curveNodeTime[i] = curveDayCount.relativeYearFraction(cvDateTmp, adjMatDate);
                BusinessDayAdjustment busAdj = swap.getBusinessDayAdjustment();
                swapLeg[i] = new BasicFixedLeg(cvDateTmp, cvDateTmp.plus((TemporalAmount)swap.getTenor()), swap.getPaymentFrequency().getPeriod(), swap.getDayCount(), curveDayCount, busAdj, refData);
            } else {
                throw new IllegalArgumentException("unsupported cuve node type");
            }
            if (i > 0) {
                ArgChecker.isTrue((curveNodeTime[i] - curveNodeTime[i - 1] > 0.0 ? 1 : 0) != 0, (String)"curve nodes should be ascending in terms of tenor");
                ArgChecker.isTrue((boolean)cvDateTmp.equals(curveSpotDate), (String)"spot lag should be common for all of the curve nodes");
                continue;
            }
            ArgChecker.isTrue((curveNodeTime[i] >= 0.0 ? 1 : 0) != 0, (String)"the first node should be after curve spot date");
            curveSpotDate = cvDateTmp;
        }
        ImmutableList parameterMetadata = paramMetadata.build();
        double[] ratesMod = Arrays.copyOf(rates, nNodes);
        for (int i = 0; i < nTermDeposit; ++i) {
            double dfInv = 1.0 + ratesMod[i] * termDepositYearFraction[i];
            ratesMod[i] = Math.log(dfInv) / curveNodeTime[i];
        }
        InterpolatedNodalCurve curve = curveDefinition.curve(DoubleArray.ofUnsafe((double[])curveNodeTime), DoubleArray.ofUnsafe((double[])ratesMod));
        for (int i = nTermDeposit; i < nNodes; ++i) {
            curve = this.fitSwap(i, swapLeg[i], curve, rates[i]);
        }
        Currency currency = curveDefinition.getCurrency();
        DoubleMatrix sensi = this.quoteValueSensitivity(nTermDeposit, termDepositYearFraction, swapLeg, ratesMod, curve, curveDefinition.isComputeJacobian());
        if (curveValuationDate.isEqual(curveSpotDate)) {
            if (curveDefinition.isComputeJacobian()) {
                JacobianCalibrationMatrix jacobian = JacobianCalibrationMatrix.of((List)ImmutableList.of((Object)CurveParameterSize.of((CurveName)curveDefinition.getName(), (int)nNodes)), (DoubleMatrix)MATRIX_ALGEBRA.getInverse((Matrix)sensi));
                InterpolatedNodalCurve curveWithParamMetadata = curve.withMetadata((CurveMetadata)curve.getMetadata().withInfo(CurveInfoType.JACOBIAN, (Object)jacobian).withParameterMetadata((List)parameterMetadata));
                return IsdaCreditDiscountFactors.of(currency, curveValuationDate, (NodalCurve)curveWithParamMetadata);
            }
            InterpolatedNodalCurve curveWithParamMetadata = curve.withMetadata(curve.getMetadata().withParameterMetadata((List)parameterMetadata));
            return IsdaCreditDiscountFactors.of(currency, curveValuationDate, (NodalCurve)curveWithParamMetadata);
        }
        double offset = curveDayCount.relativeYearFraction(curveSpotDate, curveValuationDate);
        return IsdaCreditDiscountFactors.of(currency, curveValuationDate, this.withShift(curve, (List<ParameterMetadata>)parameterMetadata, sensi, curveDefinition.isComputeJacobian(), offset));
    }

    private InterpolatedNodalCurve fitSwap(final int curveIndex, final BasicFixedLeg swap, final InterpolatedNodalCurve curve, final double swapRate) {
        int nPayments = swap.getNumPayments();
        int nNodes = curve.getParameterCount();
        double t1 = curveIndex == 0 ? 0.0 : curve.getXValues().get(curveIndex - 1);
        double t2 = curveIndex == nNodes - 1 ? Double.POSITIVE_INFINITY : curve.getXValues().get(curveIndex + 1);
        double temp = 0.0;
        double temp2 = 0.0;
        int i1 = 0;
        int i2 = nPayments;
        final double[] paymentAmounts = new double[nPayments];
        for (int i = 0; i < nPayments; ++i) {
            double df;
            double t = swap.getPaymentTime(i);
            paymentAmounts[i] = swap.getPaymentAmounts(i, swapRate);
            if (t <= t1) {
                df = Math.exp(-curve.yValue(t) * t);
                temp += paymentAmounts[i] * df;
                temp2 += paymentAmounts[i] * t * df * curve.yValueParameterSensitivity(t).getSensitivity().get(curveIndex);
                ++i1;
                continue;
            }
            if (!(t >= t2)) continue;
            df = Math.exp(-curve.yValue(t) * t);
            temp += paymentAmounts[i] * df;
            temp2 -= paymentAmounts[i] * t * df * curve.yValueParameterSensitivity(t).getSensitivity().get(curveIndex);
            --i2;
        }
        final double cachedValues = temp;
        final double cachedSense = temp2;
        final int index1 = i1;
        final int index2 = i2;
        Function<Double, Double> func = new Function<Double, Double>(){

            @Override
            public Double apply(Double x) {
                InterpolatedNodalCurve tempCurve = curve.withParameter(curveIndex, x.doubleValue());
                double sum = 1.0 - cachedValues;
                for (int i = index1; i < index2; ++i) {
                    double t = swap.getPaymentTime(i);
                    sum -= paymentAmounts[i] * Math.exp(-tempCurve.yValue(t) * t);
                }
                return sum;
            }
        };
        Function<Double, Double> grad = new Function<Double, Double>(){

            @Override
            public Double apply(Double x) {
                InterpolatedNodalCurve tempCurve = curve.withParameter(curveIndex, x.doubleValue());
                double sum = cachedSense;
                for (int i = index1; i < index2; ++i) {
                    double t = swap.getPaymentTime(i);
                    sum += swap.getPaymentAmounts(i, swapRate) * t * Math.exp(-tempCurve.yValue(t) * t) * tempCurve.yValueParameterSensitivity(t).getSensitivity().get(curveIndex);
                }
                return sum;
            }
        };
        double guess = curve.getParameter(curveIndex);
        if (guess == 0.0 && (Double)func.apply(guess) == 0.0) {
            return curve;
        }
        double[] bracket = guess > 0.0 ? BRACKETER.getBracketedPoints((Function)func, 0.8 * guess, 1.25 * guess, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY) : BRACKETER.getBracketedPoints((Function)func, 1.25 * guess, 0.8 * guess, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY);
        double r = this.rootFinder.getRoot((Function)func, (Function)grad, Double.valueOf(bracket[0]), Double.valueOf(bracket[1]));
        return curve.withParameter(curveIndex, r);
    }

    private DoubleMatrix quoteValueSensitivity(int nTermDeposit, double[] termDepositYearFraction, BasicFixedLeg[] swapLeg, double[] rates, InterpolatedNodalCurve curve, boolean computejacobian) {
        if (computejacobian) {
            int i2;
            int nNode = curve.getParameterCount();
            DoubleMatrix sensiDeposit = DoubleMatrix.ofArrayObjects((int)nTermDeposit, (int)nNode, i -> this.sensitivityDeposit(curve, termDepositYearFraction[i], i, rates[i]));
            DoubleMatrix sensiSwap = DoubleMatrix.ofArrayObjects((int)(nNode - nTermDeposit), (int)nNode, i -> this.sensitivitySwap(swapLeg[i + nTermDeposit], (NodalCurve)curve, rates[i + nTermDeposit]));
            double[][] sensiTotal = new double[nNode][];
            for (i2 = 0; i2 < nTermDeposit; ++i2) {
                sensiTotal[i2] = sensiDeposit.rowArray(i2);
            }
            for (i2 = nTermDeposit; i2 < nNode; ++i2) {
                sensiTotal[i2] = sensiSwap.rowArray(i2 - nTermDeposit);
            }
            return DoubleMatrix.ofUnsafe((double[][])sensiTotal);
        }
        return DoubleMatrix.EMPTY;
    }

    private DoubleArray sensitivityDeposit(InterpolatedNodalCurve curve, double termDepositYearFraction, int index, double fixedRate) {
        int nNode = curve.getParameterCount();
        double[] sensi = new double[nNode];
        sensi[index] = curve.getXValues().get(index) * (1.0 + fixedRate * termDepositYearFraction) / termDepositYearFraction;
        return DoubleArray.ofUnsafe((double[])sensi);
    }

    private DoubleArray sensitivitySwap(BasicFixedLeg swap, NodalCurve curve, double swapRate) {
        double df;
        double t;
        int nPayments = swap.getNumPayments();
        double annuity = 0.0;
        UnitParameterSensitivities sensi = UnitParameterSensitivities.empty();
        for (int i = 0; i < nPayments - 1; ++i) {
            t = swap.getPaymentTime(i);
            df = Math.exp(-curve.yValue(t) * t);
            annuity += swap.getYearFraction(i) * df;
            sensi = sensi.combinedWith(curve.yValueParameterSensitivity(t).multipliedBy(-df * t * swap.getYearFraction(i) * swapRate));
        }
        int lastIndex = nPayments - 1;
        t = swap.getPaymentTime(lastIndex);
        df = Math.exp(-curve.yValue(t) * t);
        sensi = sensi.combinedWith(curve.yValueParameterSensitivity(t).multipliedBy(-df * t * (1.0 + swap.getYearFraction(lastIndex) * swapRate)));
        ArgChecker.isTrue(((sensi = sensi.multipliedBy(-1.0 / (annuity += swap.getYearFraction(lastIndex) * df))).size() == 1 ? 1 : 0) != 0);
        return ((UnitParameterSensitivity)sensi.getSensitivities().get(0)).getSensitivity();
    }

    private NodalCurve withShift(InterpolatedNodalCurve curve, List<ParameterMetadata> parameterMetadata, DoubleMatrix sensitivity, boolean computeJacobian, double shift) {
        int nNode = curve.getParameterCount();
        if (shift < curve.getXValues().get(0)) {
            double eta = curve.getYValues().get(0) * shift;
            DoubleArray time = DoubleArray.of((int)nNode, i -> curve.getXValues().get(i) - shift);
            DoubleArray rate = DoubleArray.of((int)nNode, i -> (curve.getYValues().get(i) * curve.getXValues().get(i) - eta) / time.get(i));
            CurveMetadata metadata = curve.getMetadata().withParameterMetadata(parameterMetadata);
            if (computeJacobian) {
                double[][] transf = new double[nNode][nNode];
                for (int i2 = 0; i2 < nNode; ++i2) {
                    transf[i2][0] = -shift / time.get(i2);
                    double[] dArray = transf[i2];
                    int n = i2;
                    dArray[n] = dArray[n] + curve.getXValues().get(i2) / time.get(i2);
                }
                DoubleMatrix jacobianMatrix = (DoubleMatrix)MATRIX_ALGEBRA.multiply((Matrix)DoubleMatrix.ofUnsafe((double[][])transf), (Matrix)MATRIX_ALGEBRA.getInverse((Matrix)sensitivity));
                JacobianCalibrationMatrix jacobian = JacobianCalibrationMatrix.of((List)ImmutableList.of((Object)CurveParameterSize.of((CurveName)curve.getName(), (int)nNode)), (DoubleMatrix)jacobianMatrix);
                return curve.withValues(time, rate).withMetadata((CurveMetadata)metadata.withInfo(CurveInfoType.JACOBIAN, (Object)jacobian));
            }
            return curve.withValues(time, rate).withMetadata(metadata);
        }
        if (shift >= curve.getXValues().get(nNode - 1)) {
            double time = 1.0;
            double interval = curve.getXValues().get(nNode - 1) - curve.getXValues().get(nNode - 2);
            double rate = (curve.getYValues().get(nNode - 1) * curve.getXValues().get(nNode - 1) - curve.getYValues().get(nNode - 2) * curve.getXValues().get(nNode - 2)) / interval;
            if (computeJacobian) {
                double[][] transf = new double[1][nNode];
                transf[0][nNode - 2] = -curve.getXValues().get(nNode - 2) / interval;
                transf[0][nNode - 1] = curve.getXValues().get(nNode - 1) / interval;
                DoubleMatrix jacobianMatrix = (DoubleMatrix)MATRIX_ALGEBRA.multiply((Matrix)DoubleMatrix.ofUnsafe((double[][])transf), (Matrix)MATRIX_ALGEBRA.getInverse((Matrix)sensitivity));
                JacobianCalibrationMatrix jacobian = JacobianCalibrationMatrix.of((List)ImmutableList.of((Object)CurveParameterSize.of((CurveName)curve.getName(), (int)nNode)), (DoubleMatrix)jacobianMatrix);
                return ConstantNodalCurve.of((CurveMetadata)curve.getMetadata().withInfo(CurveInfoType.JACOBIAN, (Object)jacobian), (double)time, (double)rate);
            }
            return ConstantNodalCurve.of((CurveMetadata)curve.getMetadata(), (double)time, (double)rate);
        }
        int index = Arrays.binarySearch(curve.getXValues().toArray(), shift);
        index = index < 0 ? -(index + 1) : ++index;
        double interval = curve.getXValues().get(index) - curve.getXValues().get(index - 1);
        double tt1 = curve.getXValues().get(index - 1) * (curve.getXValues().get(index) - shift);
        double tt2 = curve.getXValues().get(index) * (shift - curve.getXValues().get(index - 1));
        double eta = (curve.getYValues().get(index - 1) * tt1 + curve.getYValues().get(index) * tt2) / interval;
        int m = nNode - index;
        CurveMetadata metadata = curve.getMetadata().withParameterMetadata(parameterMetadata.subList(index, nNode));
        int indexFinal = index;
        DoubleArray time = DoubleArray.of((int)m, i -> curve.getXValues().get(i + indexFinal) - shift);
        DoubleArray rate = DoubleArray.of((int)m, i -> (curve.getYValues().get(i + indexFinal) * curve.getXValues().get(i + indexFinal) - eta) / time.get(i));
        if (computeJacobian) {
            double[][] transf = new double[m][nNode];
            for (int i3 = 0; i3 < m; ++i3) {
                double[] dArray = transf[i3];
                int n = index - 1;
                dArray[n] = dArray[n] - tt1 / (time.get(i3) * interval);
                double[] dArray2 = transf[i3];
                int n2 = index;
                dArray2[n2] = dArray2[n2] - tt2 / (time.get(i3) * interval);
                double[] dArray3 = transf[i3];
                int n3 = i3 + index;
                dArray3[n3] = dArray3[n3] + curve.getXValues().get(i3 + index) / time.get(i3);
            }
            DoubleMatrix jacobianMatrix = (DoubleMatrix)MATRIX_ALGEBRA.multiply((Matrix)DoubleMatrix.ofUnsafe((double[][])transf), (Matrix)MATRIX_ALGEBRA.getInverse((Matrix)sensitivity));
            JacobianCalibrationMatrix jacobian = JacobianCalibrationMatrix.of((List)ImmutableList.of((Object)CurveParameterSize.of((CurveName)curve.getName(), (int)nNode)), (DoubleMatrix)jacobianMatrix);
            return curve.withValues(time, rate).withMetadata((CurveMetadata)metadata.withInfo(CurveInfoType.JACOBIAN, (Object)jacobian));
        }
        return curve.withValues(time, rate).withMetadata(metadata);
    }

    private final class BasicFixedLeg {
        private final int nPayment;
        private final double[] swapPaymentTime;
        private final double[] yearFraction;

        public BasicFixedLeg(LocalDate curveSpotDate, LocalDate maturityDate, Period swapInterval, DayCount swapDCC, DayCount curveDcc, BusinessDayAdjustment busAdj, ReferenceData refData) {
            ArrayList<LocalDate> list = new ArrayList<LocalDate>();
            LocalDate tDate = maturityDate;
            int step = 1;
            while (tDate.isAfter(curveSpotDate)) {
                list.add(tDate);
                tDate = maturityDate.minus(swapInterval.multipliedBy(step++));
            }
            list.remove(curveSpotDate);
            this.nPayment = list.size();
            this.swapPaymentTime = new double[this.nPayment];
            this.yearFraction = new double[this.nPayment];
            LocalDate prev = curveSpotDate;
            int j = this.nPayment - 1;
            int i = 0;
            while (i < this.nPayment) {
                LocalDate current = (LocalDate)list.get(j);
                LocalDate adjCurr = busAdj.adjust(current, refData);
                this.yearFraction[i] = swapDCC.relativeYearFraction(prev, adjCurr);
                this.swapPaymentTime[i] = curveDcc.relativeYearFraction(curveSpotDate, adjCurr);
                prev = adjCurr;
                ++i;
                --j;
            }
        }

        public int getNumPayments() {
            return this.nPayment;
        }

        public double getPaymentAmounts(int index, double rate) {
            return index == this.nPayment - 1 ? 1.0 + rate * this.yearFraction[index] : rate * this.yearFraction[index];
        }

        public double getPaymentTime(int index) {
            return this.swapPaymentTime[index];
        }

        public double getYearFraction(int index) {
            return this.yearFraction[index];
        }
    }
}

