/*
 * Decompiled with CFR 0.152.
 */
package net.finmath.montecarlo.assetderivativevaluation.models;

import java.util.Map;
import net.finmath.marketdata.model.curves.DiscountCurve;
import net.finmath.modelling.descriptor.HestonModelDescriptor;
import net.finmath.montecarlo.RandomVariableFactory;
import net.finmath.montecarlo.RandomVariableFromArrayFactory;
import net.finmath.montecarlo.model.AbstractProcessModel;
import net.finmath.montecarlo.process.MonteCarloProcess;
import net.finmath.stochastic.RandomVariable;
import net.finmath.stochastic.Scalar;

public class HestonModel
extends AbstractProcessModel {
    private static final RandomVariable ZERO = new Scalar(0.0);
    private final RandomVariable initialValue;
    private final DiscountCurve discountCurveForForwardRate;
    private final RandomVariable riskFreeRate;
    private final RandomVariable volatility;
    private final DiscountCurve discountCurveForDiscountRate;
    private final RandomVariable discountRate;
    private final RandomVariable theta;
    private final RandomVariable kappa;
    private final RandomVariable xi;
    private final RandomVariable rho;
    private final RandomVariable rhoBar;
    private final Scheme scheme;
    private final RandomVariableFactory randomVariableFactory;
    private final RandomVariable[] initialValueVector = new RandomVariable[2];

    public HestonModel(HestonModelDescriptor descriptor, Scheme scheme, RandomVariableFactory randomVariableFactory) {
        this(randomVariableFactory.createRandomVariable(descriptor.getInitialValue()), descriptor.getDiscountCurveForForwardRate(), randomVariableFactory.createRandomVariable(descriptor.getVolatility()), descriptor.getDiscountCurveForDiscountRate(), randomVariableFactory.createRandomVariable(descriptor.getTheta()), randomVariableFactory.createRandomVariable(descriptor.getKappa()), randomVariableFactory.createRandomVariable(descriptor.getXi()), randomVariableFactory.createRandomVariable(descriptor.getRho()), scheme, randomVariableFactory);
    }

    public HestonModel(RandomVariable initialValue, DiscountCurve discountCurveForForwardRate, RandomVariable volatility, DiscountCurve discountCurveForDiscountRate, RandomVariable theta, RandomVariable kappa, RandomVariable xi, RandomVariable rho, Scheme scheme, RandomVariableFactory randomVariableFactory) {
        this.initialValue = initialValue;
        this.discountCurveForForwardRate = discountCurveForForwardRate;
        this.riskFreeRate = null;
        this.volatility = volatility;
        this.discountCurveForDiscountRate = discountCurveForDiscountRate;
        this.discountRate = null;
        this.theta = theta;
        this.kappa = kappa;
        this.xi = xi;
        this.rho = rho;
        this.rhoBar = rho.squared().sub(1.0).mult(-1.0).sqrt();
        this.scheme = scheme;
        this.randomVariableFactory = randomVariableFactory;
    }

    public HestonModel(RandomVariable initialValue, RandomVariable riskFreeRate, RandomVariable volatility, RandomVariable discountRate, RandomVariable theta, RandomVariable kappa, RandomVariable xi, RandomVariable rho, Scheme scheme, RandomVariableFactory randomVariableFactory) {
        this.initialValue = initialValue;
        this.discountCurveForForwardRate = null;
        this.riskFreeRate = riskFreeRate;
        this.volatility = volatility;
        this.discountRate = discountRate;
        this.discountCurveForDiscountRate = null;
        this.theta = theta;
        this.kappa = kappa;
        this.xi = xi;
        this.rho = rho;
        this.rhoBar = rho.squared().sub(1.0).mult(-1.0).sqrt();
        this.scheme = scheme;
        this.randomVariableFactory = randomVariableFactory;
    }

    public HestonModel(double initialValue, double riskFreeRate, double volatility, double discountRate, double theta, double kappa, double xi, double rho, Scheme scheme, RandomVariableFactory randomVariableFactory) {
        this(randomVariableFactory.createRandomVariable(initialValue), randomVariableFactory.createRandomVariable(riskFreeRate), randomVariableFactory.createRandomVariable(volatility), randomVariableFactory.createRandomVariable(discountRate), randomVariableFactory.createRandomVariable(theta), randomVariableFactory.createRandomVariable(kappa), randomVariableFactory.createRandomVariable(xi), randomVariableFactory.createRandomVariable(rho), scheme, randomVariableFactory);
    }

    public HestonModel(double initialValue, double riskFreeRate, double volatility, double discountRate, double theta, double kappa, double xi, double rho, Scheme scheme) {
        this(initialValue, riskFreeRate, volatility, discountRate, theta, kappa, xi, rho, scheme, (RandomVariableFactory)new RandomVariableFromArrayFactory());
    }

    public HestonModel(double initialValue, double riskFreeRate, double volatility, double theta, double kappa, double xi, double rho, Scheme scheme) {
        this(initialValue, riskFreeRate, volatility, riskFreeRate, theta, kappa, xi, rho, scheme, (RandomVariableFactory)new RandomVariableFromArrayFactory());
    }

    @Override
    public RandomVariable[] getInitialState(MonteCarloProcess process) {
        if (this.initialValueVector[0] == null) {
            this.initialValueVector[0] = this.initialValue.log();
            this.initialValueVector[1] = this.volatility.squared();
        }
        return this.initialValueVector;
    }

    @Override
    public RandomVariable[] getDrift(MonteCarloProcess process, int timeIndex, RandomVariable[] realizationAtTimeIndex, RandomVariable[] realizationPredictor) {
        RandomVariable stochasticVariance;
        if (this.scheme == Scheme.FULL_TRUNCATION) {
            stochasticVariance = realizationAtTimeIndex[1].floor(0.0);
        } else if (this.scheme == Scheme.REFLECTION) {
            stochasticVariance = realizationAtTimeIndex[1].abs();
        } else {
            throw new UnsupportedOperationException("Scheme " + this.scheme.name() + " not supported.");
        }
        RandomVariable[] drift = new RandomVariable[2];
        RandomVariable riskFreeRateAtTimeStep = null;
        if (this.discountCurveForForwardRate != null) {
            double time = process.getTime(timeIndex);
            double timeNext = process.getTime(timeIndex + 1);
            double rate = Math.log(this.discountCurveForForwardRate.getDiscountFactor(time) / this.discountCurveForForwardRate.getDiscountFactor(timeNext)) / (timeNext - time);
            riskFreeRateAtTimeStep = this.randomVariableFactory.createRandomVariable(rate);
        } else {
            riskFreeRateAtTimeStep = this.riskFreeRate;
        }
        drift[0] = riskFreeRateAtTimeStep.sub(stochasticVariance.div(2.0));
        drift[1] = this.theta.sub(stochasticVariance).mult(this.kappa);
        return drift;
    }

    @Override
    public RandomVariable[] getFactorLoading(MonteCarloProcess process, int timeIndex, int component, RandomVariable[] realizationAtTimeIndex) {
        RandomVariable stochasticVolatility;
        if (this.scheme == Scheme.FULL_TRUNCATION) {
            stochasticVolatility = realizationAtTimeIndex[1].floor(0.0).sqrt();
        } else if (this.scheme == Scheme.REFLECTION) {
            stochasticVolatility = realizationAtTimeIndex[1].abs().sqrt();
        } else {
            throw new UnsupportedOperationException("Scheme " + this.scheme.name() + " not supported.");
        }
        RandomVariable[] factorLoadings = new RandomVariable[2];
        if (component == 0) {
            factorLoadings[0] = stochasticVolatility;
            factorLoadings[1] = ZERO;
        } else if (component == 1) {
            RandomVariable volatility = stochasticVolatility.mult(this.xi);
            factorLoadings[0] = volatility.mult(this.rho);
            factorLoadings[1] = volatility.mult(this.rhoBar);
        } else {
            throw new UnsupportedOperationException("Component " + component + " does not exist.");
        }
        return factorLoadings;
    }

    @Override
    public RandomVariable applyStateSpaceTransform(MonteCarloProcess process, int timeIndex, int componentIndex, RandomVariable randomVariable) {
        if (componentIndex == 0) {
            return randomVariable.exp();
        }
        if (componentIndex == 1) {
            return randomVariable;
        }
        throw new UnsupportedOperationException("Component " + componentIndex + " does not exist.");
    }

    @Override
    public RandomVariable applyStateSpaceTransformInverse(MonteCarloProcess process, int timeIndex, int componentIndex, RandomVariable randomVariable) {
        if (componentIndex == 0) {
            return randomVariable.log();
        }
        if (componentIndex == 1) {
            return randomVariable;
        }
        throw new UnsupportedOperationException("Component " + componentIndex + " does not exist.");
    }

    @Override
    public RandomVariable getNumeraire(MonteCarloProcess process, double time) {
        if (this.discountCurveForDiscountRate != null) {
            return this.randomVariableFactory.createRandomVariable(1.0 / this.discountCurveForDiscountRate.getDiscountFactor(time));
        }
        return this.discountRate.mult(time).exp();
    }

    @Override
    public int getNumberOfComponents() {
        return 2;
    }

    @Override
    public int getNumberOfFactors() {
        return 1;
    }

    @Override
    public RandomVariable getRandomVariableForConstant(double value) {
        return this.randomVariableFactory.createRandomVariable(value);
    }

    @Override
    public HestonModel getCloneWithModifiedData(Map<String, Object> dataModified) {
        RandomVariableFactory newRandomVariableFactory = (RandomVariableFactory)dataModified.getOrDefault("randomVariableFactory", this.randomVariableFactory);
        RandomVariable newInitialValue = RandomVariableFactory.getRandomVariableOrDefault(newRandomVariableFactory, dataModified.get("initialValue"), this.initialValue);
        RandomVariable newRiskFreeRate = RandomVariableFactory.getRandomVariableOrDefault(newRandomVariableFactory, dataModified.get("riskFreeRate"), this.riskFreeRate);
        RandomVariable newVolatility = RandomVariableFactory.getRandomVariableOrDefault(newRandomVariableFactory, dataModified.get("volatility"), this.volatility);
        RandomVariable newDiscountRate = RandomVariableFactory.getRandomVariableOrDefault(newRandomVariableFactory, dataModified.get("discountRate"), this.discountRate);
        RandomVariable newTheta = RandomVariableFactory.getRandomVariableOrDefault(newRandomVariableFactory, dataModified.get("theta"), this.theta);
        RandomVariable newKappa = RandomVariableFactory.getRandomVariableOrDefault(newRandomVariableFactory, dataModified.get("kappa"), this.kappa);
        RandomVariable newXi = RandomVariableFactory.getRandomVariableOrDefault(newRandomVariableFactory, dataModified.get("xi"), this.xi);
        RandomVariable newRho = RandomVariableFactory.getRandomVariableOrDefault(newRandomVariableFactory, dataModified.get("rho"), this.rho);
        return new HestonModel(newInitialValue, newRiskFreeRate, newVolatility, newDiscountRate, newTheta, newKappa, newXi, newRho, this.scheme, this.randomVariableFactory);
    }

    public String toString() {
        return "HestonModel [initialValue=" + this.initialValue + ", riskFreeRate=" + this.riskFreeRate + ", volatility=" + this.volatility + ", theta=" + this.theta + ", kappa=" + this.kappa + ", xi=" + this.xi + ", rho=" + this.rho + ", scheme=" + this.scheme + "]";
    }

    public RandomVariable getInitialValue() {
        return this.initialValue;
    }

    public RandomVariable getRiskFreeRate() {
        return this.riskFreeRate;
    }

    public RandomVariable getVolatility() {
        return this.volatility;
    }

    public DiscountCurve getDiscountCurveForForwardRate() {
        return this.discountCurveForForwardRate;
    }

    public DiscountCurve getDiscountCurveForDiscountRate() {
        return this.discountCurveForDiscountRate;
    }

    public RandomVariable getTheta() {
        return this.theta;
    }

    public RandomVariable getKappa() {
        return this.kappa;
    }

    public RandomVariable getXi() {
        return this.xi;
    }

    public RandomVariable getRho() {
        return this.rho;
    }

    public Scheme getScheme() {
        return this.scheme;
    }

    public static enum Scheme {
        REFLECTION,
        FULL_TRUNCATION;

    }
}

