/*
 * Decompiled with CFR 0.152.
 */
package net.finmath.equities.models;

import java.time.LocalDate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import net.finmath.equities.marketdata.VolatilityPoint;
import net.finmath.equities.models.EquityForwardStructure;
import net.finmath.equities.models.ShiftedVolatilitySurface;
import net.finmath.equities.models.SviVolatilitySmile;
import net.finmath.equities.models.VolatilitySurface;
import net.finmath.interpolation.RationalFunctionInterpolation;
import net.finmath.optimizer.LevenbergMarquardt;
import net.finmath.optimizer.SolverException;
import net.finmath.time.daycount.DayCountConvention;

public class SviVolatilitySurface
implements VolatilitySurface,
ShiftedVolatilitySurface {
    private final DayCountConvention dayCounter;
    private final boolean useStickyStrike;
    private LocalDate valuationDate;
    private EquityForwardStructure forwardStructure;
    private SviVolatilitySmile[] smiles = new SviVolatilitySmile[0];
    private double[] smileTimes = new double[0];
    private boolean isCalibrated = false;
    private final double volShift;

    public SviVolatilitySurface(DayCountConvention dayCounter, boolean useStickyStrike) {
        this.dayCounter = dayCounter;
        this.useStickyStrike = useStickyStrike;
        this.volShift = 0.0;
    }

    public SviVolatilitySurface(LocalDate valuationDate, DayCountConvention dayCounter, EquityForwardStructure forwardStructure, SviVolatilitySmile[] smiles, boolean useStickyStrike) {
        this(valuationDate, dayCounter, forwardStructure, smiles, useStickyStrike, 0.0);
    }

    private SviVolatilitySurface(LocalDate valuationDate, DayCountConvention dayCounter, EquityForwardStructure forwardStructure, SviVolatilitySmile[] smiles, boolean useStickyStrike, double volShift) {
        this.dayCounter = dayCounter;
        this.setForwardStructure(forwardStructure);
        this.smiles = smiles;
        this.useStickyStrike = useStickyStrike;
        this.volShift = volShift;
        List<SviVolatilitySmile> sortedSmiles = Arrays.asList(smiles);
        sortedSmiles.sort(Comparator.comparing(pt -> pt.getSmileDate()));
        this.smileTimes = new double[sortedSmiles.size() + 1];
        this.smileTimes[0] = 0.0;
        for (int i = 0; i < sortedSmiles.size(); ++i) {
            this.smileTimes[i + 1] = dayCounter.getDaycountFraction(valuationDate, sortedSmiles.get(i).getSmileDate());
        }
        this.isCalibrated = true;
    }

    @Override
    public SviVolatilitySurface getShiftedSurface(double shift) {
        assert (this.volShift == 0.0) : "Surface is already shifted";
        return new SviVolatilitySurface(this.valuationDate, this.dayCounter, this.forwardStructure, this.smiles, this.useStickyStrike, this.volShift);
    }

    @Override
    public double getShift() {
        return this.volShift;
    }

    public SviVolatilitySmile[] getSmiles() {
        return this.smiles;
    }

    private void setForwardStructure(EquityForwardStructure forwardStructure) {
        this.forwardStructure = forwardStructure;
        this.valuationDate = forwardStructure.getValuationDate();
    }

    @Override
    public double getVolatility(double strike, LocalDate expiryDate, EquityForwardStructure currentForwardStructure) {
        double timeToMaturity = this.dayCounter.getDaycountFraction(this.valuationDate, expiryDate);
        return this.getVolatility(strike, timeToMaturity, currentForwardStructure);
    }

    @Override
    public double getVolatility(double strike, double timeToMaturity, EquityForwardStructure currentForwardStructure) {
        assert (this.isCalibrated) : "Surface is not calibrated yet";
        double logStrike = this.useStickyStrike ? this.forwardStructure.getLogMoneyness(strike, timeToMaturity) : currentForwardStructure.getLogMoneyness(strike, timeToMaturity);
        return this.interpolateVolatility(logStrike, timeToMaturity);
    }

    @Override
    public double getLocalVolatility(double strike, LocalDate expiryDate, EquityForwardStructure currentForwardStructure, double strikeShift, double timeShift) {
        assert (this.isCalibrated) : "Surface is not calibrated yet";
        double logStrike = currentForwardStructure.getLogMoneyness(strike, expiryDate);
        double timeToMaturity = this.dayCounter.getDaycountFraction(this.valuationDate, expiryDate);
        return this.getLocalVolatility(logStrike, timeToMaturity, currentForwardStructure, strikeShift, timeShift);
    }

    @Override
    public double getLocalVolatility(double logStrike, double timeToMaturity, EquityForwardStructure currentForwardStructure, double strikeShift, double timeShift) {
        assert (this.isCalibrated) : "Surface is not calibrated yet";
        if (this.useStickyStrike) {
            double expiryTimeAsofCalib = timeToMaturity + this.dayCounter.getDaycountFraction(this.valuationDate, currentForwardStructure.getValuationDate());
            logStrike += Math.log(currentForwardStructure.getForward(timeToMaturity) / this.forwardStructure.getForward(expiryTimeAsofCalib));
        }
        if (timeToMaturity >= 1.0E-16) {
            double f = this.interpolateTotalVariance(logStrike, timeToMaturity);
            double f_t = this.interpolateTotalVariance(logStrike, timeToMaturity + timeShift);
            f_t = (f_t - f) / timeShift;
            double f_plu = this.interpolateTotalVariance(logStrike + strikeShift, timeToMaturity);
            double f_min = this.interpolateTotalVariance(logStrike - strikeShift, timeToMaturity);
            double f_x = 0.5 * (f_plu - f_min) / strikeShift;
            double f_xx = (f_plu + f_min - 2.0 * f) / strikeShift / strikeShift;
            double lv = 0.5 * f_x * logStrike / f - 1.0;
            lv *= lv;
            return Math.sqrt(f_t / (lv += 0.5 * f_xx - 0.25 * (0.25 + 1.0 / f) * f_x * f_x));
        }
        if (timeToMaturity >= 0.0) {
            return this.getLocalVolatility(logStrike, 1.0E-16, currentForwardStructure, strikeShift, timeShift);
        }
        return 0.0;
    }

    private double interpolateVolatility(double logStrike, double timeToMaturity) {
        if (timeToMaturity >= 1.0E-16) {
            return Math.sqrt(this.interpolateTotalVariance(logStrike, timeToMaturity) / timeToMaturity);
        }
        if (timeToMaturity >= 0.0) {
            return this.interpolateVolatility(logStrike, 1.0E-16);
        }
        return 0.0;
    }

    private double interpolateTotalVariance(double logStrike, double timeToMaturity) {
        int len = this.smileTimes.length;
        double[] totalVariances = new double[len];
        totalVariances[0] = 0.0;
        for (int i = 1; i < len; ++i) {
            totalVariances[i] = this.smiles[i - 1].getTotalVariance(logStrike);
        }
        RationalFunctionInterpolation interpolator = new RationalFunctionInterpolation(this.smileTimes, totalVariances, RationalFunctionInterpolation.InterpolationMethod.LINEAR, RationalFunctionInterpolation.ExtrapolationMethod.LINEAR);
        double totalVariance = interpolator.getValue(timeToMaturity);
        if (this.volShift == 0.0) {
            return totalVariance;
        }
        return totalVariance + this.volShift * (2.0 * Math.sqrt(totalVariance * timeToMaturity) + this.volShift * timeToMaturity);
    }

    @Override
    public void calibrate(EquityForwardStructure forwardStructure, ArrayList<VolatilityPoint> volaPoints) {
        assert (this.volShift == 0.0) : "A shifted SVI surface cannot be calibrated";
        this.setForwardStructure(forwardStructure);
        Map<LocalDate, List<VolatilityPoint>> groupedPoints = volaPoints.stream().collect(Collectors.groupingBy(VolatilityPoint::getDate));
        List<LocalDate> sortedSmileDates = Arrays.asList(groupedPoints.keySet().toArray(new LocalDate[0]));
        sortedSmileDates.sort(Comparator.comparing(pt -> pt));
        this.smileTimes = new double[sortedSmileDates.size() + 1];
        this.smileTimes[0] = 0.0;
        this.smiles = new SviVolatilitySmile[sortedSmileDates.size()];
        for (int i = 0; i < sortedSmileDates.size(); ++i) {
            double[] thisSviParams;
            LocalDate date = sortedSmileDates.get(i);
            List<VolatilityPoint> thisPoints = groupedPoints.get(date);
            thisPoints.sort(Comparator.comparing(pt -> pt.getStrike()));
            double forward = forwardStructure.getDividendAdjustedStrike(forwardStructure.getForward(date), date);
            double ttm = this.dayCounter.getDaycountFraction(this.valuationDate, date);
            ArrayList<Double> logStrikes = new ArrayList<Double>();
            ArrayList<Double> totalVariances = new ArrayList<Double>();
            for (VolatilityPoint pt2 : thisPoints) {
                totalVariances.add(ttm * pt2.getVolatility() * pt2.getVolatility());
                logStrikes.add(Math.log(forwardStructure.getDividendAdjustedStrike(pt2.getStrike(), date) / forward));
            }
            try {
                thisSviParams = SviVolatilitySurface.calibrateSviSmile(ttm, logStrikes, totalVariances);
            }
            catch (SolverException se) {
                continue;
            }
            this.smileTimes[i + 1] = ttm;
            this.smiles[i] = new SviVolatilitySmile(date, thisSviParams[0], thisSviParams[1], thisSviParams[2], thisSviParams[3], thisSviParams[4]);
        }
        this.isCalibrated = true;
    }

    private static double[] calibrateSviSmile(double ttm, final ArrayList<Double> logStrikes, ArrayList<Double> totalVariances) throws SolverException {
        LevenbergMarquardt optimizer = new LevenbergMarquardt(){
            private static final long serialVersionUID = -2542034123359128169L;

            @Override
            public void setValues(double[] parameters, double[] values) {
                for (int i = 0; i < logStrikes.size(); ++i) {
                    values[i] = SviVolatilitySmile.sviTotalVariance((Double)logStrikes.get(i), parameters[0], parameters[1], parameters[2], parameters[3], parameters[4]);
                }
            }
        };
        double[] initialGuess = SviVolatilitySmile.sviInitialGuess(logStrikes, totalVariances);
        double[] weights = new double[logStrikes.size()];
        double[] targetValues = new double[logStrikes.size()];
        for (int i = 0; i < logStrikes.size(); ++i) {
            weights[i] = 1.0;
            targetValues[i] = totalVariances.get(i);
        }
        optimizer.setInitialParameters(initialGuess);
        optimizer.setWeights(weights);
        optimizer.setMaxIteration(100);
        optimizer.setTargetValues(targetValues);
        optimizer.run();
        return optimizer.getBestFitParameters();
    }
}

