/*
 * Decompiled with CFR 0.152.
 */
package com.opengamma.strata.pricer.impl.volatility.local;

import com.google.common.collect.ImmutableList;
import com.opengamma.strata.collect.array.DoubleArray;
import com.opengamma.strata.collect.array.DoubleMatrix;
import com.opengamma.strata.collect.tuple.DoublesPair;
import com.opengamma.strata.collect.tuple.Pair;
import com.opengamma.strata.market.ValueType;
import com.opengamma.strata.market.curve.interpolator.CurveInterpolator;
import com.opengamma.strata.market.curve.interpolator.CurveInterpolators;
import com.opengamma.strata.market.surface.DefaultSurfaceMetadata;
import com.opengamma.strata.market.surface.InterpolatedNodalSurface;
import com.opengamma.strata.market.surface.Surface;
import com.opengamma.strata.market.surface.SurfaceMetadata;
import com.opengamma.strata.market.surface.SurfaceName;
import com.opengamma.strata.market.surface.interpolator.GridSurfaceInterpolator;
import com.opengamma.strata.market.surface.interpolator.SurfaceInterpolator;
import com.opengamma.strata.math.MathUtils;
import com.opengamma.strata.pricer.fxopt.RecombiningTrinomialTreeData;
import com.opengamma.strata.pricer.impl.option.BlackFormulaRepository;
import com.opengamma.strata.pricer.impl.option.BlackScholesFormulaRepository;
import com.opengamma.strata.pricer.impl.volatility.local.LocalVolatilityCalculator;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Function;

public class ImpliedTrinomialTreeLocalVolatilityCalculator
implements LocalVolatilityCalculator {
    private static final GridSurfaceInterpolator DEFAULT_INTERPOLATOR = GridSurfaceInterpolator.of((CurveInterpolator)CurveInterpolators.TIME_SQUARE, (CurveInterpolator)CurveInterpolators.LINEAR);
    private final int nSteps;
    private final double maxTime;
    private final SurfaceInterpolator interpolator;

    public ImpliedTrinomialTreeLocalVolatilityCalculator() {
        this(20, 3.0, (SurfaceInterpolator)DEFAULT_INTERPOLATOR);
    }

    public ImpliedTrinomialTreeLocalVolatilityCalculator(int nSteps, double maxTime) {
        this(nSteps, maxTime, (SurfaceInterpolator)DEFAULT_INTERPOLATOR);
    }

    public ImpliedTrinomialTreeLocalVolatilityCalculator(int nSteps, double maxTime, SurfaceInterpolator interpolator) {
        this.nSteps = nSteps;
        this.maxTime = maxTime;
        this.interpolator = interpolator;
    }

    public InterpolatedNodalSurface localVolatilityFromImpliedVolatility(final Surface impliedVolatilitySurface, double spot, Function<Double, Double> interestRate, Function<Double, Double> dividendRate) {
        Function<DoublesPair, Double> surface = new Function<DoublesPair, Double>(){

            @Override
            public Double apply(DoublesPair tk) {
                return impliedVolatilitySurface.zValue(tk);
            }
        };
        ImmutableList localVolData = (ImmutableList)this.calibrate(surface, spot, interestRate, dividendRate).getFirst();
        DefaultSurfaceMetadata metadata = DefaultSurfaceMetadata.builder().xValueType(ValueType.YEAR_FRACTION).yValueType(ValueType.STRIKE).zValueType(ValueType.LOCAL_VOLATILITY).surfaceName(SurfaceName.of((String)("localVol_" + impliedVolatilitySurface.getName()))).build();
        return InterpolatedNodalSurface.ofUnsorted((SurfaceMetadata)metadata, (DoubleArray)DoubleArray.ofUnsafe((double[])((double[])localVolData.get(0))), (DoubleArray)DoubleArray.ofUnsafe((double[])((double[])localVolData.get(1))), (DoubleArray)DoubleArray.ofUnsafe((double[])((double[])localVolData.get(2))), (SurfaceInterpolator)this.interpolator);
    }

    public RecombiningTrinomialTreeData calibrateImpliedVolatility(Function<DoublesPair, Double> impliedVolatilitySurface, double spot, Function<Double, Double> interestRate, Function<Double, Double> dividendRate) {
        return (RecombiningTrinomialTreeData)this.calibrate(impliedVolatilitySurface, spot, interestRate, dividendRate).getSecond();
    }

    public InterpolatedNodalSurface localVolatilityFromPrice(Surface callPriceSurface, double spot, Function<Double, Double> interestRate, Function<Double, Double> dividendRate) {
        double[][] stateValue = new double[this.nSteps + 1][];
        double[] df = new double[this.nSteps];
        ArrayList<DoubleMatrix> probability = new ArrayList<DoubleMatrix>(this.nSteps);
        int nTotal = (this.nSteps - 1) * (this.nSteps - 1) + 1;
        double[] timeRes = new double[nTotal];
        double[] spotRes = new double[nTotal];
        double[] volRes = new double[nTotal];
        double refPrice = callPriceSurface.zValue(this.maxTime, spot) * Math.exp(interestRate.apply(this.maxTime) * this.maxTime);
        double refForward = spot * Math.exp((interestRate.apply(this.maxTime) - dividendRate.apply(this.maxTime)) * this.maxTime);
        double refVolatility = BlackFormulaRepository.impliedVolatility(refPrice, refForward, spot, this.maxTime, true);
        double dt = this.maxTime / (double)this.nSteps;
        double dx = refVolatility * Math.sqrt(3.0 * dt);
        double upFactor = Math.exp(dx);
        double downFactor = Math.exp(-dx);
        double[] adSec = new double[2 * this.nSteps + 1];
        double[] assetPrice = new double[2 * this.nSteps + 1];
        for (int i = this.nSteps; i > -1; --i) {
            int j;
            if (i == 0) {
                this.resolveFirstLayer(interestRate, dividendRate, nTotal, dt, spot, adSec, assetPrice, timeRes, spotRes, volRes, df, stateValue, probability);
                continue;
            }
            double time = dt * (double)i;
            double zeroRate = interestRate.apply(time);
            double zeroDividendRate = dividendRate.apply(time);
            int nNodes = 2 * i + 1;
            double[] assetPriceLocal = new double[nNodes];
            double[] callOptionPrice = new double[nNodes];
            double[] putOptionPrice = new double[nNodes];
            int position = i - 1;
            double assetTmp = spot * Math.pow(upFactor, i);
            for (j = nNodes - 1; j > position - 1; --j) {
                assetPriceLocal[j] = assetTmp;
                callOptionPrice[j] = callPriceSurface.zValue(time, assetPriceLocal[j]);
                assetTmp *= downFactor;
            }
            assetTmp = spot * Math.pow(downFactor, i);
            for (j = 0; j < position + 2; ++j) {
                assetPriceLocal[j] = assetTmp;
                putOptionPrice[j] = callPriceSurface.zValue(time, assetPriceLocal[j]) - spot * Math.exp(-zeroDividendRate * time) + Math.exp(-zeroRate * time) * assetPriceLocal[j];
                assetTmp *= upFactor;
            }
            this.resolveLayer(interestRate, dividendRate, i, nTotal, position, dt, zeroRate, zeroDividendRate, callOptionPrice, putOptionPrice, adSec, assetPrice, assetPriceLocal, timeRes, spotRes, volRes, df, stateValue, probability);
        }
        DefaultSurfaceMetadata metadata = DefaultSurfaceMetadata.builder().xValueType(ValueType.YEAR_FRACTION).yValueType(ValueType.STRIKE).zValueType(ValueType.LOCAL_VOLATILITY).surfaceName(SurfaceName.of((String)("localVol_" + callPriceSurface.getName()))).build();
        return InterpolatedNodalSurface.ofUnsorted((SurfaceMetadata)metadata, (DoubleArray)DoubleArray.ofUnsafe((double[])timeRes), (DoubleArray)DoubleArray.ofUnsafe((double[])spotRes), (DoubleArray)DoubleArray.ofUnsafe((double[])volRes), (SurfaceInterpolator)this.interpolator);
    }

    private Pair<ImmutableList<double[]>, RecombiningTrinomialTreeData> calibrate(Function<DoublesPair, Double> impliedVolatilitySurface, double spot, Function<Double, Double> interestRate, Function<Double, Double> dividendRate) {
        double[][] stateValue = new double[this.nSteps + 1][];
        double[] df = new double[this.nSteps];
        double[] timePrim = new double[this.nSteps + 1];
        ArrayList<DoubleMatrix> probability = new ArrayList<DoubleMatrix>(this.nSteps);
        int nTotal = (this.nSteps - 1) * (this.nSteps - 1) + 1;
        double[] timeRes = new double[nTotal];
        double[] spotRes = new double[nTotal];
        double[] volRes = new double[nTotal];
        double volatility = impliedVolatilitySurface.apply(DoublesPair.of((double)this.maxTime, (double)spot));
        double dt = this.maxTime / (double)this.nSteps;
        double dx = volatility * Math.sqrt(3.0 * dt);
        double upFactor = Math.exp(dx);
        double downFactor = Math.exp(-dx);
        double[] adSec = new double[2 * this.nSteps + 1];
        double[] assetPrice = new double[2 * this.nSteps + 1];
        for (int i = this.nSteps; i > -1; --i) {
            double impliedVol;
            int j;
            timePrim[i] = dt * (double)i;
            if (i == 0) {
                this.resolveFirstLayer(interestRate, dividendRate, nTotal, dt, spot, adSec, assetPrice, timeRes, spotRes, volRes, df, stateValue, probability);
                continue;
            }
            double zeroRate = interestRate.apply(timePrim[i]);
            double zeroDividendRate = dividendRate.apply(timePrim[i]);
            double zeroCostRate = zeroRate - zeroDividendRate;
            int nNodes = 2 * i + 1;
            double[] assetPriceLocal = new double[nNodes];
            double[] callOptionPrice = new double[nNodes];
            double[] putOptionPrice = new double[nNodes];
            int position = i - 1;
            double assetTmp = spot * Math.pow(upFactor, i);
            for (j = nNodes - 1; j > position - 1; --j) {
                assetPriceLocal[j] = assetTmp;
                impliedVol = impliedVolatilitySurface.apply(DoublesPair.of((double)timePrim[i], (double)assetPriceLocal[j]));
                callOptionPrice[j] = BlackScholesFormulaRepository.price(spot, assetPriceLocal[j], timePrim[i], impliedVol, zeroRate, zeroCostRate, true);
                assetTmp *= downFactor;
            }
            assetTmp = spot * Math.pow(downFactor, i);
            for (j = 0; j < position + 2; ++j) {
                assetPriceLocal[j] = assetTmp;
                impliedVol = impliedVolatilitySurface.apply(DoublesPair.of((double)timePrim[i], (double)assetPriceLocal[j]));
                putOptionPrice[j] = BlackScholesFormulaRepository.price(spot, assetPriceLocal[j], timePrim[i], impliedVol, zeroRate, zeroCostRate, false);
                assetTmp *= upFactor;
            }
            this.resolveLayer(interestRate, dividendRate, i, nTotal, position, dt, zeroRate, zeroDividendRate, callOptionPrice, putOptionPrice, adSec, assetPrice, assetPriceLocal, timeRes, spotRes, volRes, df, stateValue, probability);
        }
        ImmutableList localVolData = ImmutableList.of((Object)timeRes, (Object)spotRes, (Object)volRes);
        RecombiningTrinomialTreeData treeData = RecombiningTrinomialTreeData.of(DoubleMatrix.ofUnsafe((double[][])stateValue), probability, DoubleArray.ofUnsafe((double[])df), DoubleArray.ofUnsafe((double[])timePrim));
        return Pair.of((Object)localVolData, (Object)treeData);
    }

    private void resolveFirstLayer(Function<Double, Double> interestRate, Function<Double, Double> dividendRate, int nTotal, double dt, double spot, double[] adSec, double[] assetPrice, double[] timeRes, double[] spotRes, double[] volRes, double[] df, double[][] stateValue, List<DoubleMatrix> probability) {
        double discountFactor = Math.exp(-interestRate.apply(dt).doubleValue() * dt);
        double fwdFactor = Math.exp((interestRate.apply(dt) - dividendRate.apply(dt)) * dt);
        double upProb = adSec[2] / discountFactor;
        double midProb = this.getMiddle(upProb, fwdFactor, spot, assetPrice[0], assetPrice[1], assetPrice[2]);
        double dwProb = 1.0 - upProb - midProb;
        double fwd = spot * fwdFactor;
        timeRes[nTotal - 1] = dt;
        spotRes[nTotal - 1] = spot;
        double var = (dwProb * MathUtils.pow2((double)(assetPrice[0] - fwd)) + midProb * MathUtils.pow2((double)(assetPrice[1] - fwd)) + upProb * MathUtils.pow2((double)(assetPrice[2] - fwd))) / (fwd * fwd * dt);
        volRes[nTotal - 1] = Math.sqrt(0.5 * (var + volRes[nTotal - 2] * volRes[nTotal - 2]));
        probability.add(0, DoubleMatrix.ofUnsafe((double[][])new double[][]{{dwProb, midProb, upProb}}));
        df[0] = discountFactor;
        stateValue[0] = new double[]{spot};
    }

    private void resolveLayer(Function<Double, Double> interestRate, Function<Double, Double> dividendRate, int i, int nTotal, int position, double dt, double zeroRate, double zeroDividendRate, double[] callOptionPrice, double[] putOptionPrice, double[] adSec, double[] assetPrice, double[] assetPriceLocal, double[] timeRes, double[] spotRes, double[] volRes, double[] df, double[][] stateValue, List<DoubleMatrix> probability) {
        int k;
        int j;
        int positionLocal = position;
        int nNodes = callOptionPrice.length;
        double[] adSecLocal = new double[nNodes];
        for (j = nNodes - 1; j > positionLocal; --j) {
            adSecLocal[j] = callOptionPrice[j - 1];
            for (k = j + 1; k < nNodes; ++k) {
                int n = j;
                adSecLocal[n] = adSecLocal[n] - (assetPriceLocal[k] - assetPriceLocal[j - 1]) * adSecLocal[k];
            }
            int n = j;
            adSecLocal[n] = adSecLocal[n] / (assetPriceLocal[j] - assetPriceLocal[j - 1]);
        }
        ++positionLocal;
        for (j = 0; j < positionLocal; ++j) {
            adSecLocal[j] = putOptionPrice[j + 1];
            for (k = 0; k < j; ++k) {
                int n = j;
                adSecLocal[n] = adSecLocal[n] - (assetPriceLocal[j + 1] - assetPriceLocal[k]) * adSecLocal[k];
            }
            int n = j;
            adSecLocal[n] = adSecLocal[n] / (assetPriceLocal[j + 1] - assetPriceLocal[j]);
        }
        if (i != this.nSteps) {
            int k2;
            double time = dt * (double)i;
            double timeNext = dt * (double)(i - 1);
            double rate = (zeroRate * time - interestRate.apply(timeNext) * timeNext) / dt;
            double dividend = (zeroDividendRate * time - dividendRate.apply(timeNext) * timeNext) / dt;
            double cost = rate - dividend;
            double discountFactor = Math.exp(-rate * dt);
            double fwdFactor = Math.exp(cost * dt);
            double[][] prob = new double[nNodes][3];
            prob[nNodes - 1][2] = adSec[nNodes + 1] / adSecLocal[nNodes - 1] / discountFactor;
            prob[nNodes - 1][1] = this.getMiddle(prob[nNodes - 1][2], fwdFactor, assetPriceLocal[nNodes - 1], assetPrice[nNodes - 1], assetPrice[nNodes], assetPrice[nNodes + 1]);
            prob[nNodes - 1][0] = 1.0 - prob[nNodes - 1][2] - prob[nNodes - 1][1];
            this.correctProbability(prob[nNodes - 1], fwdFactor, assetPriceLocal[nNodes - 1], assetPrice[nNodes - 1], assetPrice[nNodes], assetPrice[nNodes + 1]);
            prob[nNodes - 2][2] = (adSec[nNodes] / discountFactor - prob[nNodes - 1][1] * adSecLocal[nNodes - 1]) / adSecLocal[nNodes - 2];
            prob[nNodes - 2][1] = this.getMiddle(prob[nNodes - 2][2], fwdFactor, assetPriceLocal[nNodes - 2], assetPrice[nNodes - 2], assetPrice[nNodes - 1], assetPrice[nNodes]);
            prob[nNodes - 2][0] = 1.0 - prob[nNodes - 2][2] - prob[nNodes - 2][1];
            this.correctProbability(prob[nNodes - 2], fwdFactor, assetPriceLocal[nNodes - 2], assetPrice[nNodes - 2], assetPrice[nNodes - 1], assetPrice[nNodes]);
            for (int j2 = nNodes - 3; j2 > -1; --j2) {
                prob[j2][2] = (adSec[j2 + 2] / discountFactor - prob[j2 + 2][0] * adSecLocal[j2 + 2] - prob[j2 + 1][1] * adSecLocal[j2 + 1]) / adSecLocal[j2];
                prob[j2][1] = this.getMiddle(prob[j2][2], fwdFactor, assetPriceLocal[j2], assetPrice[j2], assetPrice[j2 + 1], assetPrice[j2 + 2]);
                prob[j2][0] = 1.0 - prob[j2][1] - prob[j2][2];
                this.correctProbability(prob[j2], fwdFactor, assetPriceLocal[j2], assetPrice[j2], assetPrice[j2 + 1], assetPrice[j2 + 2]);
            }
            int offset = nTotal - i * i - 1;
            double[] varBare = new double[nNodes];
            for (k2 = 0; k2 < nNodes; ++k2) {
                double fwd = assetPriceLocal[k2] * fwdFactor;
                varBare[k2] = (prob[k2][0] * MathUtils.pow2((double)(assetPrice[k2] - fwd)) + prob[k2][1] * MathUtils.pow2((double)(assetPrice[k2 + 1] - fwd)) + prob[k2][2] * MathUtils.pow2((double)(assetPrice[k2 + 2] - fwd))) / (fwd * fwd * dt);
                if (!(varBare[k2] < 0.0)) continue;
                throw new IllegalArgumentException("Negative variance");
            }
            for (k2 = 0; k2 < nNodes - 2; ++k2) {
                double var = k2 == 0 || k2 == nNodes - 3 ? (varBare[k2] + varBare[k2 + 1] + varBare[k2 + 2]) / 3.0 : (varBare[k2 - 1] + varBare[k2] + varBare[k2 + 1] + varBare[k2 + 2] + varBare[k2 + 3]) / 5.0;
                volRes[offset + k2] = i == this.nSteps - 1 ? Math.sqrt(var) : Math.sqrt(0.5 * (var + volRes[offset - (2 * i - k2)] * volRes[offset - (2 * i - k2)]));
                timeRes[offset + k2] = dt * ((double)i + 1.0);
                spotRes[offset + k2] = assetPriceLocal[k2 + 1];
            }
            probability.add(0, DoubleMatrix.ofUnsafe((double[][])prob));
            df[i] = discountFactor;
        }
        stateValue[i] = Arrays.copyOf(assetPriceLocal, nNodes);
        System.arraycopy(adSecLocal, 0, adSec, 0, nNodes);
        System.arraycopy(assetPriceLocal, 0, assetPrice, 0, nNodes);
    }

    private void correctProbability(double[] probability, double factor, double assetBase, double assertPriceLow, double assertPriceMid, double assetPriceHigh) {
        if (!(probability[2] > 0.0 && probability[1] > 0.0 && probability[0] > 0.0)) {
            double fwd = assetBase * factor;
            if (fwd <= assertPriceMid && fwd > assertPriceLow) {
                probability[0] = 0.5 * (fwd - assertPriceLow) / (assetPriceHigh - assertPriceLow);
                probability[2] = 0.5 * ((assetPriceHigh - fwd) / (assetPriceHigh - assertPriceLow) + (assertPriceMid - fwd) / (assertPriceMid - assertPriceLow));
            } else if (fwd < assetPriceHigh && fwd > assertPriceMid) {
                probability[0] = 0.5 * ((fwd - assertPriceMid) / (assetPriceHigh - assertPriceLow) + (fwd - assertPriceLow) / (assetPriceHigh - assertPriceLow));
                probability[2] = 0.5 * (assetPriceHigh - fwd) / assetPriceHigh;
            }
            probability[1] = 1.0 - probability[0] - probability[2];
        }
    }

    private double getMiddle(double upProbability, double factor, double assetBase, double assetPrevDw, double assetPrevMd, double assetPrevUp) {
        return (factor * assetBase - assetPrevDw - upProbability * (assetPrevUp - assetPrevDw)) / (assetPrevMd - assetPrevDw);
    }
}

