/*
 * Decompiled with CFR 0.152.
 */
package smile.math;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Locale;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public interface TimeFunction
extends Serializable {
    public double apply(int var1);

    public static TimeFunction constant(final double alpha) {
        return new TimeFunction(){

            @Override
            public double apply(int t2) {
                return alpha;
            }

            public String toString() {
                return String.format("%f", alpha);
            }
        };
    }

    public static TimeFunction piecewise(int[] milestones, double[] values) {
        TimeFunction[] schedules = new TimeFunction[values.length];
        for (int i = 0; i < values.length; ++i) {
            schedules[i] = TimeFunction.constant(values[i]);
        }
        return TimeFunction.piecewise(milestones, schedules);
    }

    public static TimeFunction piecewise(final int[] milestones, final TimeFunction ... schedules) {
        if (schedules.length != milestones.length + 1) {
            throw new IllegalArgumentException("values should have one more element than milestones");
        }
        return new TimeFunction(){

            @Override
            public double apply(int t2) {
                int i = Arrays.binarySearch(milestones, t2);
                if (i < 0) {
                    i = -i - 1;
                }
                return schedules[i].apply(t2);
            }

            public String toString() {
                return String.format("Piecewise(%s, %s)", Arrays.toString(milestones), Arrays.toString(schedules));
            }
        };
    }

    public static TimeFunction linear(double initLearningRate, double decaySteps, double endLearningRate) {
        return TimeFunction.polynomial(1.0, initLearningRate, decaySteps, endLearningRate, false);
    }

    public static TimeFunction polynomial(double degree, double initLearningRate, double decaySteps, double endLearningRate) {
        return TimeFunction.polynomial(degree, initLearningRate, decaySteps, endLearningRate, false);
    }

    public static TimeFunction polynomial(final double degree, final double initLearningRate, final double decaySteps, final double endLearningRate, final boolean cycle) {
        return new TimeFunction(){

            @Override
            public double apply(int t2) {
                if (cycle) {
                    double T = decaySteps * Math.max(1.0, Math.ceil((double)t2 / decaySteps));
                    return (initLearningRate - endLearningRate) * Math.pow(1.0 - (double)t2 / T, degree) + endLearningRate;
                }
                double steps = Math.min((double)t2, decaySteps);
                return (initLearningRate - endLearningRate) * Math.pow(1.0 - steps / decaySteps, degree) + endLearningRate;
            }

            public String toString() {
                if (degree == 1.0) {
                    return String.format("LinearDecay(%f, %.0f, %f)", initLearningRate, decaySteps, endLearningRate);
                }
                return String.format("PolynomialDecay(%f, %f, %.0f, %f, %s)", degree, initLearningRate, decaySteps, endLearningRate, cycle);
            }
        };
    }

    public static TimeFunction inverse(final double initLearningRate, final double decaySteps) {
        return new TimeFunction(){

            @Override
            public double apply(int t2) {
                return initLearningRate * decaySteps / (decaySteps + (double)t2);
            }

            public String toString() {
                return String.format("InverseTimeDecay(%f, %.0f)", initLearningRate, decaySteps);
            }
        };
    }

    public static TimeFunction inverse(double initLearningRate, double decaySteps, double decayRate) {
        return TimeFunction.inverse(initLearningRate, decaySteps, decayRate, false);
    }

    public static TimeFunction inverse(final double initLearningRate, final double decaySteps, final double decayRate, final boolean staircase) {
        return new TimeFunction(){

            @Override
            public double apply(int t2) {
                if (staircase) {
                    return initLearningRate / (1.0 + decayRate * Math.floor((double)t2 / decaySteps));
                }
                return initLearningRate / (1.0 + decayRate * (double)t2 / decaySteps);
            }

            public String toString() {
                return String.format("InverseTimeDecay(%f, %.0f, %f, %s)", initLearningRate, decaySteps, decayRate, staircase);
            }
        };
    }

    public static TimeFunction exp(final double initLearningRate, final double decaySteps) {
        return new TimeFunction(){

            @Override
            public double apply(int t2) {
                return initLearningRate * Math.exp((double)(-t2) / decaySteps);
            }

            public String toString() {
                return String.format("ExponentialDecay(%f, %.0f)", initLearningRate, decaySteps);
            }
        };
    }

    public static TimeFunction exp(final double initLearningRate, final double decaySteps, final double endLearningRate) {
        final double decayRate = endLearningRate / initLearningRate;
        return new TimeFunction(){

            @Override
            public double apply(int t2) {
                return initLearningRate * Math.pow(decayRate, Math.min((double)t2, decaySteps) / decaySteps);
            }

            public String toString() {
                return String.format("ExponentialDecay(%f, %.0f, %f)", initLearningRate, decaySteps, endLearningRate);
            }
        };
    }

    public static TimeFunction exp(final double initLearningRate, final double decaySteps, final double decayRate, final boolean staircase) {
        return new TimeFunction(){

            @Override
            public double apply(int t2) {
                if (staircase) {
                    return initLearningRate * Math.pow(decayRate, Math.floor((double)t2 / decaySteps));
                }
                return initLearningRate * Math.pow(decayRate, (double)t2 / decaySteps);
            }

            public String toString() {
                return String.format("ExponentialDecay(%f, %.0f, %f, %s)", initLearningRate, decaySteps, decayRate, staircase);
            }
        };
    }

    public static TimeFunction cosine(final double minLearningRate, final double decaySteps, final double maxLearningRate) {
        return new TimeFunction(){

            @Override
            public double apply(int t2) {
                return minLearningRate + 0.5 * (maxLearningRate - minLearningRate) * (1.0 + Math.cos((double)t2 / decaySteps * Math.PI));
            }

            public String toString() {
                return String.format("CosineDecay(%f, %.0f, %f)", minLearningRate, decaySteps, maxLearningRate);
            }
        };
    }

    public static TimeFunction of(String time) {
        String[] tokens;
        time = time.trim().toLowerCase(Locale.ROOT);
        Pattern linear = Pattern.compile(String.format("linear(?:decay)?\\((%s),\\s*(%s),\\s*(%s)\\)", "[-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?", "[-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?", "[-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?"));
        Matcher m = linear.matcher(time);
        if (m.matches()) {
            double initLearningRate = Double.parseDouble(m.group(1));
            double decaySteps = Double.parseDouble(m.group(2));
            double endLearningRate = Double.parseDouble(m.group(3));
            return TimeFunction.linear(initLearningRate, decaySteps, endLearningRate);
        }
        Pattern polynomial = Pattern.compile(String.format("polynomial(?:decay)?\\((%s),\\s*(%s),\\s*(%s),\\s*(%s)(,\\s*(%s))?\\)", "[-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?", "[-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?", "[-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?", "[-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?", "(true|false)"));
        m = polynomial.matcher(time);
        if (m.matches()) {
            double degree = Double.parseDouble(m.group(1));
            double initLearningRate = Double.parseDouble(m.group(2));
            double decaySteps = Double.parseDouble(m.group(3));
            double endLearningRate = Double.parseDouble(m.group(4));
            boolean cycle = m.group(5) != null && m.group(6).equals("true");
            return TimeFunction.polynomial(degree, initLearningRate, decaySteps, endLearningRate, cycle);
        }
        if (time.startsWith("piecewise([") && time.endsWith("])") && (tokens = time.substring(11, time.length() - 2).split("\\],\\s*\\[")).length == 2) {
            int[] milestones = Arrays.stream(tokens[0].split(",\\s*")).mapToInt(Integer::parseInt).toArray();
            double[] values = Arrays.stream(tokens[1].split(",\\s*")).mapToDouble(Double::parseDouble).toArray();
            return TimeFunction.piecewise(milestones, values);
        }
        Pattern inverse = Pattern.compile(String.format("inverse(?:timedecay)?\\((%s),\\s*(%s)(,\\s*(%s))?(,\\s*(%s))?\\)", "[-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?", "[-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?", "[-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?", "(true|false)"));
        m = inverse.matcher(time);
        if (m.matches()) {
            double initLearningRate = Double.parseDouble(m.group(1));
            double decaySteps = Double.parseDouble(m.group(2));
            if (m.group(3) == null) {
                return TimeFunction.inverse(initLearningRate, decaySteps);
            }
            double endLearningRate = Double.parseDouble(m.group(4));
            boolean staircase = m.group(5) != null && m.group(6).equals("true");
            return TimeFunction.inverse(initLearningRate, decaySteps, endLearningRate, staircase);
        }
        Pattern exp = Pattern.compile(String.format("exp(?:onentialdecay)?\\((%s),\\s*(%s)(,\\s*(%s))?(,\\s*(%s))?\\)", "[-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?", "[-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?", "[-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?", "(true|false)"));
        m = exp.matcher(time);
        if (m.matches()) {
            double initLearningRate = Double.parseDouble(m.group(1));
            double decaySteps = Double.parseDouble(m.group(2));
            if (m.group(3) == null) {
                return TimeFunction.exp(initLearningRate, decaySteps);
            }
            double endLearningRate = Double.parseDouble(m.group(4));
            boolean staircase = m.group(5) != null && m.group(6).equals("true");
            return TimeFunction.exp(initLearningRate, decaySteps, endLearningRate, staircase);
        }
        Pattern cosine = Pattern.compile(String.format("cosine(?:decay)?\\((%s),\\s*(%s),\\s*(%s)\\)", "[-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?", "[-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?", "[-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?"));
        m = cosine.matcher(time);
        if (m.matches()) {
            double minLearningRate = Double.parseDouble(m.group(1));
            double decaySteps = Double.parseDouble(m.group(2));
            double maxLearningRate = Double.parseDouble(m.group(3));
            return TimeFunction.cosine(minLearningRate, decaySteps, maxLearningRate);
        }
        try {
            double alpha = Double.parseDouble(time);
            return TimeFunction.constant(alpha);
        }
        catch (Exception ex) {
            throw new IllegalArgumentException("Unknown time function: " + time);
        }
    }
}

