/*
 * Decompiled with CFR 0.152.
 */
package eu.verdelhan.ta4j.analysis.criteria;

import eu.verdelhan.ta4j.Order;
import eu.verdelhan.ta4j.TimeSeries;
import eu.verdelhan.ta4j.Trade;
import eu.verdelhan.ta4j.TradingRecord;
import eu.verdelhan.ta4j.analysis.criteria.AbstractAnalysisCriterion;
import eu.verdelhan.ta4j.analysis.criteria.TotalProfitCriterion;

public class LinearTransactionCostCriterion
extends AbstractAnalysisCriterion {
    private double initialAmount;
    private double a;
    private double b;
    private TotalProfitCriterion profit;

    public LinearTransactionCostCriterion(double initialAmount, double a) {
        this(initialAmount, a, 0.0);
    }

    public LinearTransactionCostCriterion(double initialAmount, double a, double b) {
        this.initialAmount = initialAmount;
        this.a = a;
        this.b = b;
        this.profit = new TotalProfitCriterion();
    }

    @Override
    public double calculate(TimeSeries series, Trade trade) {
        return this.getTradeCost(series, trade, this.initialAmount);
    }

    @Override
    public double calculate(TimeSeries series, TradingRecord tradingRecord) {
        double totalCosts = 0.0;
        double tradedAmount = this.initialAmount;
        for (Trade trade : tradingRecord.getTrades()) {
            double tradeCost = this.getTradeCost(series, trade, tradedAmount);
            totalCosts += tradeCost;
            tradedAmount = (tradedAmount - tradeCost) * this.profit.calculate(series, trade);
        }
        Trade currentTrade = tradingRecord.getCurrentTrade();
        if (currentTrade.isOpened()) {
            totalCosts += this.getOrderCost(currentTrade.getEntry(), tradedAmount);
        }
        return totalCosts;
    }

    @Override
    public boolean betterThan(double criterionValue1, double criterionValue2) {
        return criterionValue1 < criterionValue2;
    }

    private double getOrderCost(Order order, double tradedAmount) {
        double orderCost = 0.0;
        if (order != null) {
            return this.a * tradedAmount + this.b;
        }
        return orderCost;
    }

    private double getTradeCost(TimeSeries series, Trade trade, double initialAmount) {
        double totalTradeCost = 0.0;
        if (trade != null && trade.getEntry() != null) {
            totalTradeCost = this.getOrderCost(trade.getEntry(), initialAmount);
            if (trade.getExit() != null) {
                double newTradedAmount = (initialAmount - totalTradeCost) * this.profit.calculate(series, trade);
                totalTradeCost += this.getOrderCost(trade.getExit(), newTradedAmount);
            }
        }
        return totalTradeCost;
    }
}

