/*
 * Decompiled with CFR 0.152.
 */
package net.finmath.singleswaprate.calibration;

import java.time.LocalDate;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Map;
import java.util.TreeMap;
import java.util.TreeSet;
import net.finmath.functions.AnalyticFormulas;
import net.finmath.marketdata.model.AnalyticModel;
import net.finmath.marketdata.model.volatilities.SwaptionDataLattice;
import net.finmath.marketdata.products.Swap;
import net.finmath.marketdata.products.SwapAnnuity;
import net.finmath.optimizer.LevenbergMarquardt;
import net.finmath.optimizer.SolverException;
import net.finmath.singleswaprate.data.DataTable;
import net.finmath.singleswaprate.data.DataTableExtrapolated;
import net.finmath.singleswaprate.data.DataTableLight;
import net.finmath.singleswaprate.data.DataTableLinear;
import net.finmath.singleswaprate.model.volatilities.SABRVolatilityCube;
import net.finmath.time.Schedule;
import net.finmath.time.SchedulePrototype;

public class SABRShiftedSmileCalibration {
    private final LocalDate referenceDate;
    private final AnalyticModel model;
    private final SwaptionDataLattice cashPayerPremiums;
    private final SwaptionDataLattice cashReceiverPremiums;
    private final SwaptionDataLattice physicalPremiumsATM;
    private final SchedulePrototype fixMetaSchedule;
    private final SchedulePrototype floatMetaSchedule;
    private final String discountCurveName;
    private final String forwardCurveName;
    private final double sabrDisplacement;
    private final double sabrBeta;
    private final double correlationDecay;
    private final double iborOisDecorrelation;
    private Map<Integer, DataTableLight> physicalVolatilities;
    private Map<Integer, DataTableLight> cashPayerVolatilities;
    private Map<Integer, DataTableLight> cashReceiverVolatilities;
    private DataTableLight interpolationNodes;
    private boolean useLinearInterpolation = true;
    private DataTable swapRateTable;
    private DataTable rhoTable;
    private DataTable baseVolTable;
    private DataTable volvolTable;
    private int maxIterations = 500;
    private int numberOfThreads = Runtime.getRuntime().availableProcessors();

    public static SABRVolatilityCube createSABRVolatilityCube(String name, LocalDate referenceDate, SwaptionDataLattice cashPayerPremiums, SwaptionDataLattice cashReceiverPremiums, SwaptionDataLattice physicalPremiumsATM, AnalyticModel model, double sabrDisplacement, double sabrBeta, double correlationDecay, double iborOisDecorrelation) throws SolverException {
        SABRShiftedSmileCalibration factory = new SABRShiftedSmileCalibration(referenceDate, cashPayerPremiums, cashReceiverPremiums, physicalPremiumsATM, model, sabrDisplacement, sabrBeta, correlationDecay, iborOisDecorrelation);
        return factory.build(name);
    }

    public static Map<Integer, DataTable> createVolatilityCubeLattice(String name, LocalDate referenceDate, SwaptionDataLattice cashPayerPremiums, SwaptionDataLattice cashReceiverPremiums, SwaptionDataLattice physicalPremiumsATM, AnalyticModel model) {
        SABRShiftedSmileCalibration factory = new SABRShiftedSmileCalibration(referenceDate, cashPayerPremiums, cashReceiverPremiums, physicalPremiumsATM, model, 0.0, 0.0, 0.0, 0.0);
        try {
            factory.build(name);
        }
        catch (SolverException solverException) {
            // empty catch block
        }
        TreeMap<Integer, DataTable> returnMap = new TreeMap<Integer, DataTable>();
        for (Map.Entry<Integer, DataTableLight> entry : factory.physicalVolatilities.entrySet()) {
            returnMap.put(entry.getKey(), entry.getValue().clone());
        }
        return returnMap;
    }

    public SABRShiftedSmileCalibration(LocalDate referenceDate, SwaptionDataLattice cashPayerPremiums, SwaptionDataLattice cashReceiverPremiums, SwaptionDataLattice physicalPremiumsATM, AnalyticModel model, double sabrDisplacement, double sabrBeta, double correlationDecay, double iborOisDecorrelation) {
        this.referenceDate = referenceDate;
        this.physicalPremiumsATM = physicalPremiumsATM;
        this.cashPayerPremiums = cashPayerPremiums;
        this.cashReceiverPremiums = cashReceiverPremiums;
        this.model = model;
        this.sabrDisplacement = sabrDisplacement;
        this.sabrBeta = sabrBeta;
        this.correlationDecay = correlationDecay;
        this.iborOisDecorrelation = iborOisDecorrelation;
        this.fixMetaSchedule = cashPayerPremiums.getFixMetaSchedule();
        this.floatMetaSchedule = cashPayerPremiums.getFloatMetaSchedule();
        this.discountCurveName = cashPayerPremiums.getDiscountCurveName();
        this.forwardCurveName = cashPayerPremiums.getForwardCurveName();
    }

    public SABRVolatilityCube build(String name) throws SolverException {
        this.findInterpolationNodes();
        this.makeSwapRateTable();
        this.findPayerVolatilities();
        this.findReceiverVolatilities();
        this.makePhysicalVolatilities();
        this.calibrateSmilesOnNodes();
        return new SABRVolatilityCube(name, this.referenceDate, this.swapRateTable, this.sabrDisplacement, this.sabrBeta, this.rhoTable, this.baseVolTable, this.volvolTable, this.correlationDecay, this.iborOisDecorrelation);
    }

    private void calibrateSmilesOnNodes() throws SolverException {
        ArrayList<Integer> maturities = new ArrayList<Integer>();
        ArrayList<Integer> terminations = new ArrayList<Integer>();
        ArrayList<Double> sabrRhos = new ArrayList<Double>();
        ArrayList<Double> sabrBaseVols = new ArrayList<Double>();
        ArrayList<Double> sabrVolvols = new ArrayList<Double>();
        double[] initialParameters = new double[]{0.01, 0.15, 0.3};
        int[] maturityArray = new int[this.interpolationNodes.getMaturities().size()];
        int[] terminationArray = new int[this.interpolationNodes.getTerminations().size()];
        int index = maturityArray.length - 1;
        Object object = this.interpolationNodes.getMaturities().iterator();
        while (object.hasNext()) {
            int maturity = object.next();
            maturityArray[index--] = maturity;
        }
        index = terminationArray.length - 1;
        object = this.interpolationNodes.getTerminations().iterator();
        while (object.hasNext()) {
            int termination = object.next();
            terminationArray[index--] = termination;
        }
        for (Object maturity : (Object)maturityArray) {
            for (int termination : terminationArray) {
                final double parSwapRate = this.swapRateTable.getValue((int)maturity, termination);
                final double sabrMaturity = this.floatMetaSchedule.generateSchedule(this.referenceDate, this.referenceDate.plusMonths((long)maturity), this.referenceDate.plusMonths((long)(maturity + termination))).getFixing(0);
                int count = 0;
                for (int moneyness : this.physicalVolatilities.keySet()) {
                    if (!this.physicalVolatilities.get(moneyness).containsEntryFor((int)maturity, termination)) continue;
                    ++count;
                }
                final double[] marketStrikes = new double[count];
                double[] marketVolatilities = new double[count];
                index = 0;
                for (int moneyness : this.physicalVolatilities.keySet()) {
                    if (!this.physicalVolatilities.get(moneyness).containsEntryFor((int)maturity, termination)) continue;
                    marketStrikes[index] = parSwapRate + (double)moneyness / 10000.0;
                    marketVolatilities[index++] = this.physicalVolatilities.get(moneyness).getValue((int)maturity, termination);
                }
                LevenbergMarquardt optimizer = new LevenbergMarquardt(initialParameters, marketVolatilities, this.maxIterations, this.numberOfThreads){
                    private static final long serialVersionUID = -7551690451877166912L;

                    @Override
                    public void setValues(double[] parameters, double[] values) {
                        parameters[0] = Math.max(parameters[0], 0.0);
                        parameters[1] = Math.max(parameters[1], 0.0);
                        parameters[2] = Math.max(Math.min(parameters[2], 1.0), -1.0);
                        for (int i = 0; i < marketStrikes.length; ++i) {
                            values[i] = AnalyticFormulas.sabrBerestyckiNormalVolatilityApproximation(parameters[0], SABRShiftedSmileCalibration.this.sabrBeta, parameters[2], parameters[1], SABRShiftedSmileCalibration.this.sabrDisplacement, parSwapRate, marketStrikes[i], sabrMaturity);
                        }
                    }
                };
                optimizer.run();
                double[] parameters = optimizer.getBestFitParameters();
                maturities.add((int)maturity);
                terminations.add(termination);
                sabrBaseVols.add(parameters[0]);
                sabrVolvols.add(parameters[1]);
                sabrRhos.add(parameters[2]);
                initialParameters = parameters;
            }
        }
        if (this.useLinearInterpolation) {
            this.baseVolTable = new DataTableLinear("MarketBaseVolatilityTable", DataTable.TableConvention.MONTHS, this.referenceDate, this.floatMetaSchedule, maturities, terminations, sabrBaseVols);
            this.volvolTable = new DataTableLinear("MarketVolVolTable", DataTable.TableConvention.MONTHS, this.referenceDate, this.floatMetaSchedule, maturities, terminations, sabrVolvols);
            this.rhoTable = new DataTableLinear("MarketRhoTable", DataTable.TableConvention.MONTHS, this.referenceDate, this.floatMetaSchedule, maturities, terminations, sabrRhos);
        } else {
            this.baseVolTable = new DataTableExtrapolated("MarketBaseVolatilityTable", DataTable.TableConvention.MONTHS, this.referenceDate, this.floatMetaSchedule, maturities, terminations, sabrBaseVols);
            this.volvolTable = new DataTableExtrapolated("MarketVolVolTable", DataTable.TableConvention.MONTHS, this.referenceDate, this.floatMetaSchedule, maturities, terminations, sabrVolvols);
            this.rhoTable = new DataTableExtrapolated("MarketRhoTable", DataTable.TableConvention.MONTHS, this.referenceDate, this.floatMetaSchedule, maturities, terminations, sabrRhos);
        }
    }

    private void findInterpolationNodes() {
        ArrayList<Integer> nodeMaturities = new ArrayList<Integer>();
        ArrayList<Integer> nodeTerminations = new ArrayList<Integer>();
        ArrayList<Double> nodeCardinalities = new ArrayList<Double>();
        TreeSet<Integer> payerStrikes = new TreeSet<Integer>(this.cashPayerPremiums.getGridNodesPerMoneyness().keySet());
        payerStrikes.remove(0);
        TreeSet<Integer> receiverStrikes = new TreeSet<Integer>(this.cashReceiverPremiums.getGridNodesPerMoneyness().keySet());
        receiverStrikes.remove(0);
        for (int maturity : this.physicalPremiumsATM.getMaturities()) {
            for (int termination : this.physicalPremiumsATM.getTenors(0, maturity)) {
                int strike;
                int count = 1;
                Iterator iterator = payerStrikes.iterator();
                while (iterator.hasNext()) {
                    strike = (Integer)iterator.next();
                    if (!this.cashPayerPremiums.containsEntryFor(maturity, termination, strike)) continue;
                    ++count;
                }
                iterator = receiverStrikes.iterator();
                while (iterator.hasNext()) {
                    strike = (Integer)iterator.next();
                    if (!this.cashReceiverPremiums.containsEntryFor(maturity, termination, strike)) continue;
                    ++count;
                }
                if (count <= true) continue;
                nodeMaturities.add(maturity);
                nodeTerminations.add(termination);
                nodeCardinalities.add(Double.valueOf(count));
            }
        }
        this.interpolationNodes = new DataTableLight("NodesWithCardinality", DataTable.TableConvention.MONTHS, nodeMaturities, nodeTerminations, nodeCardinalities);
        if (this.interpolationNodes.size() != this.interpolationNodes.getMaturities().size() * this.interpolationNodes.getTerminations().size()) {
            Object object = this.interpolationNodes.getMaturities().iterator();
            while (object.hasNext()) {
                int maturity = (Integer)object.next();
                for (int termination : this.interpolationNodes.getTerminations()) {
                    if (this.interpolationNodes.containsEntryFor(maturity, termination)) continue;
                    this.interpolationNodes = this.interpolationNodes.addPoint(maturity, termination, 1.0);
                }
            }
        }
    }

    private void makePhysicalVolatilities() {
        int[] maturitiesArray = new int[this.interpolationNodes.size()];
        int[] terminationsArray = new int[this.interpolationNodes.size()];
        double[] volatilitiesArray = new double[this.interpolationNodes.size()];
        int index = 0;
        for (int maturity : this.interpolationNodes.getMaturities()) {
            for (int termination : this.interpolationNodes.getTerminationsForMaturity(maturity)) {
                maturitiesArray[index] = maturity;
                terminationsArray[index] = termination;
                LocalDate maturityDate = this.referenceDate.plusMonths(maturity);
                LocalDate terminationDate = maturityDate.plusMonths(termination);
                Schedule fixSchedule = this.fixMetaSchedule.generateSchedule(this.referenceDate, maturityDate, terminationDate);
                double annuity = SwapAnnuity.getSwapAnnuity(fixSchedule.getFixing(0), fixSchedule, this.model.getDiscountCurve(this.discountCurveName), this.model);
                double swapRate = this.swapRateTable.getValue(maturity, termination);
                volatilitiesArray[index++] = AnalyticFormulas.bachelierOptionImpliedVolatility(swapRate, fixSchedule.getFixing(0), swapRate, annuity, this.physicalPremiumsATM.getValue(maturity, termination, 0));
            }
        }
        DataTableLight physicalATMTable = new DataTableLight("VolatilitiesPhysicalATM", DataTable.TableConvention.MONTHS, maturitiesArray, terminationsArray, volatilitiesArray);
        this.physicalVolatilities = new TreeMap<Integer, DataTableLight>();
        this.physicalVolatilities.put(0, physicalATMTable);
        DataTableLight payerATMTable = this.cashPayerVolatilities.get(0);
        DataTableLight receiverATMTable = this.cashReceiverVolatilities.get(0);
        TreeSet<Integer> strikes = new TreeSet<Integer>(this.cashPayerVolatilities.keySet());
        strikes.addAll(this.cashReceiverVolatilities.keySet());
        strikes.remove(0);
        Iterator iterator = strikes.iterator();
        while (iterator.hasNext()) {
            int strike = (Integer)iterator.next();
            ArrayList<Integer> maturitiesPositive = new ArrayList<Integer>();
            ArrayList<Integer> terminationsPositive = new ArrayList<Integer>();
            ArrayList<Double> physicalVolatilitiesPositive = new ArrayList<Double>();
            ArrayList<Integer> maturitiesNegative = new ArrayList<Integer>();
            ArrayList<Integer> terminationsNegative = new ArrayList<Integer>();
            ArrayList<Double> physicalVolatilitiesNegative = new ArrayList<Double>();
            for (int maturity : this.interpolationNodes.getMaturities()) {
                for (int termination : this.interpolationNodes.getTerminationsForMaturity(maturity)) {
                    double physicalATM = physicalATMTable.getValue(maturity, termination);
                    if (this.cashPayerVolatilities.containsKey(strike) && this.cashPayerVolatilities.get(strike).containsEntryFor(maturity, termination)) {
                        double payerATM = payerATMTable.getValue(maturity, termination);
                        double payerSmile = this.cashPayerVolatilities.get(strike).getValue(maturity, termination);
                        maturitiesPositive.add(maturity);
                        terminationsPositive.add(termination);
                        physicalVolatilitiesPositive.add(payerSmile - payerATM + physicalATM);
                    }
                    if (!this.cashReceiverVolatilities.containsKey(strike) || !this.cashReceiverVolatilities.get(strike).containsEntryFor(maturity, termination)) continue;
                    double receiverATM = receiverATMTable.getValue(maturity, termination);
                    double receiverSmile = this.cashReceiverVolatilities.get(strike).getValue(maturity, termination);
                    maturitiesNegative.add(maturity);
                    terminationsNegative.add(termination);
                    physicalVolatilitiesNegative.add(receiverSmile - receiverATM + physicalATM);
                }
            }
            DataTableLight physicalPositiveSmileTable = new DataTableLight("VolatilitiesPhysical" + strike, DataTable.TableConvention.MONTHS, maturitiesPositive, terminationsPositive, physicalVolatilitiesPositive);
            DataTableLight physicalNegativeSmileTable = new DataTableLight("VolatilitiesPhysical" + -strike, DataTable.TableConvention.MONTHS, maturitiesNegative, terminationsNegative, physicalVolatilitiesNegative);
            this.physicalVolatilities.put(strike, physicalPositiveSmileTable);
            this.physicalVolatilities.put(-strike, physicalNegativeSmileTable);
        }
    }

    private void findPayerVolatilities() {
        this.cashPayerVolatilities = new TreeMap<Integer, DataTableLight>();
        for (int moneyness : this.cashPayerPremiums.getGridNodesPerMoneyness().keySet()) {
            ArrayList<Integer> maturities = new ArrayList<Integer>();
            ArrayList<Integer> terminations = new ArrayList<Integer>();
            ArrayList<Double> values = new ArrayList<Double>();
            for (int maturity : this.interpolationNodes.getMaturities()) {
                for (int termination : this.interpolationNodes.getTerminationsForMaturity(maturity)) {
                    if (!this.cashPayerPremiums.containsEntryFor(maturity, termination, moneyness)) continue;
                    LocalDate maturityDate = this.referenceDate.plusMonths(maturity);
                    LocalDate terminationDate = maturityDate.plusMonths(termination);
                    Schedule fixSchedule = this.fixMetaSchedule.generateSchedule(this.referenceDate, maturityDate, terminationDate);
                    double swapRate = this.swapRateTable.getValue(maturity, termination);
                    double cashAnnuity = SABRShiftedSmileCalibration.cashFunction(swapRate, fixSchedule);
                    maturities.add(maturity);
                    terminations.add(termination);
                    values.add(AnalyticFormulas.bachelierOptionImpliedVolatility(swapRate, fixSchedule.getFixing(0), swapRate + (double)moneyness / 10000.0, cashAnnuity, this.cashPayerPremiums.getValue(maturity, termination, moneyness)));
                }
            }
            DataTableLight volatilityTable = new DataTableLight("VolatilitiesPayer" + moneyness, DataTable.TableConvention.MONTHS, maturities, terminations, values);
            this.cashPayerVolatilities.put(moneyness, volatilityTable);
        }
    }

    private void findReceiverVolatilities() {
        this.cashReceiverVolatilities = new TreeMap<Integer, DataTableLight>();
        for (int moneyness : this.cashReceiverPremiums.getGridNodesPerMoneyness().keySet()) {
            ArrayList<Integer> maturities = new ArrayList<Integer>();
            ArrayList<Integer> terminations = new ArrayList<Integer>();
            ArrayList<Double> values = new ArrayList<Double>();
            for (int maturity : this.interpolationNodes.getMaturities()) {
                for (int termination : this.interpolationNodes.getTerminationsForMaturity(maturity)) {
                    if (!this.cashReceiverPremiums.containsEntryFor(maturity, termination, moneyness)) continue;
                    LocalDate maturityDate = this.referenceDate.plusMonths(maturity);
                    LocalDate terminationDate = maturityDate.plusMonths(termination);
                    Schedule fixSchedule = this.fixMetaSchedule.generateSchedule(this.referenceDate, maturityDate, terminationDate);
                    double swapRate = this.swapRateTable.getValue(maturity, termination);
                    double cashAnnuity = SABRShiftedSmileCalibration.cashFunction(swapRate, fixSchedule);
                    maturities.add(maturity);
                    terminations.add(termination);
                    values.add(AnalyticFormulas.bachelierOptionImpliedVolatility(swapRate, fixSchedule.getFixing(0), swapRate - (double)moneyness / 10000.0, cashAnnuity, this.cashReceiverPremiums.getValue(maturity, termination, moneyness) + (double)moneyness / 10000.0 * cashAnnuity));
                }
            }
            DataTableLight volatilityTable = new DataTableLight("VolatilitiesReceiver" + moneyness, DataTable.TableConvention.MONTHS, maturities, terminations, values);
            this.cashReceiverVolatilities.put(moneyness, volatilityTable);
        }
    }

    private void makeSwapRateTable() {
        int[] maturitiesArray = new int[this.interpolationNodes.size()];
        int[] terminationsArray = new int[this.interpolationNodes.size()];
        double[] swapRateArray = new double[this.interpolationNodes.size()];
        int index = 0;
        for (int maturity : this.interpolationNodes.getMaturities()) {
            for (int termination : this.interpolationNodes.getTerminationsForMaturity(maturity)) {
                maturitiesArray[index] = maturity;
                terminationsArray[index] = termination;
                LocalDate maturityDate = this.referenceDate.plusMonths(maturity);
                LocalDate terminationDate = maturityDate.plusMonths(termination);
                Schedule floatSchedule = this.floatMetaSchedule.generateSchedule(this.referenceDate, maturityDate, terminationDate);
                Schedule fixSchedule = this.fixMetaSchedule.generateSchedule(this.referenceDate, maturityDate, terminationDate);
                double swapRate = Swap.getForwardSwapRate(fixSchedule, floatSchedule, this.model.getForwardCurve(this.forwardCurveName), this.model);
                swapRateArray[index++] = swapRate;
            }
        }
        this.swapRateTable = this.useLinearInterpolation ? new DataTableLinear("MarketParSwapRates", DataTable.TableConvention.MONTHS, this.referenceDate, this.floatMetaSchedule, maturitiesArray, terminationsArray, swapRateArray) : new DataTableExtrapolated("MarketParSwapRates", DataTable.TableConvention.MONTHS, this.referenceDate, this.floatMetaSchedule, maturitiesArray, terminationsArray, swapRateArray);
    }

    public void setCalibrationParameters(int maxIterations, int numberOfThreads) {
        this.maxIterations = maxIterations;
        this.numberOfThreads = numberOfThreads;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public int getNumberOfThreads() {
        return this.numberOfThreads;
    }

    public boolean isUseLinearInterpolation() {
        return this.useLinearInterpolation;
    }

    public void setUseLinearInterpolation(boolean useLinearInterpolation) {
        this.useLinearInterpolation = useLinearInterpolation;
    }

    private static double cashFunction(double swapRate, Schedule schedule) {
        int numberOfPeriods = schedule.getNumberOfPeriods();
        double periodLength = 0.0;
        for (int index = 0; index < numberOfPeriods; ++index) {
            periodLength += schedule.getPeriodLength(index);
        }
        periodLength /= (double)schedule.getNumberOfPeriods();
        if (swapRate == 0.0) {
            return (double)numberOfPeriods * periodLength;
        }
        return (1.0 - Math.pow(1.0 + periodLength * swapRate, -numberOfPeriods)) / swapRate;
    }
}

