/*
 * Decompiled with CFR 0.152.
 */
package smile.data.formula;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.data.AbstractTuple;
import smile.data.CategoricalEncoder;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Delete;
import smile.data.formula.Dot;
import smile.data.formula.FactorCrossing;
import smile.data.formula.Feature;
import smile.data.formula.Intercept;
import smile.data.formula.Term;
import smile.data.formula.Terms;
import smile.data.formula.Variable;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.ValueVector;
import smile.math.matrix.Matrix;

public class Formula
implements AutoCloseable,
Serializable {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(Formula.class);
    private final Term response;
    private final Term[] predictors;
    private transient ThreadLocal<Binding> binding;

    public Formula(Term response, Term ... predictors) {
        if (response instanceof Dot || response instanceof FactorCrossing) {
            throw new IllegalArgumentException("The response variable cannot be '.' or FactorCrossing.");
        }
        this.response = response;
        this.predictors = predictors;
    }

    public Term[] predictors() {
        return this.predictors;
    }

    public Term response() {
        return this.response;
    }

    @Override
    public void close() {
        if (this.binding != null) {
            this.binding.remove();
            this.binding = null;
        }
    }

    public String toString() {
        String r = this.response == null ? "" : this.response.toString();
        String p = Arrays.stream(this.predictors).map(predictor -> {
            Object s = predictor.toString();
            if (!((String)s).startsWith("- ")) {
                s = "+ " + (String)s;
            }
            return s;
        }).collect(Collectors.joining(" "));
        if (p.startsWith("+ ")) {
            p = p.substring(2);
        }
        return String.format("%s ~ %s", r, p);
    }

    public boolean equals(Object o) {
        if (!(o instanceof Formula)) {
            return false;
        }
        Formula f = (Formula)o;
        if (this.predictors.length != f.predictors.length) {
            return false;
        }
        if (!String.valueOf(this.response).equals(String.valueOf(f.response))) {
            return false;
        }
        for (int i = 0; i < this.predictors.length; ++i) {
            if (String.valueOf(this.predictors[i]).equals(String.valueOf(f.predictors[i]))) continue;
            return false;
        }
        return true;
    }

    public static Formula lhs(String lhs) {
        return Formula.lhs(new Variable(lhs));
    }

    public static Formula lhs(Term lhs) {
        return new Formula(lhs, new Dot());
    }

    public static Formula rhs(String ... predictors) {
        return Formula.of(null, predictors);
    }

    public static Formula rhs(Term ... predictors) {
        return new Formula(null, predictors);
    }

    public static Formula of(String response, String ... predictors) {
        return new Formula(new Variable(response), (Term[])Arrays.stream(predictors).map(predictor -> switch (predictor) {
            case "." -> new Dot();
            case "1" -> new Intercept(true);
            case "0" -> new Intercept(false);
            default -> new Variable((String)predictor);
        }).toArray(Term[]::new));
    }

    public static Formula of(String response, Term ... predictors) {
        return new Formula(new Variable(response), predictors);
    }

    public static Formula of(Term response, Term ... predictors) {
        return new Formula(response, predictors);
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public static Formula of(String s) {
        String[] tokens = s.split("~");
        if (tokens.length != 2) {
            throw new IllegalArgumentException("Invalid formula: " + s);
        }
        String lhs = tokens[0].trim();
        Term response = lhs.isEmpty() ? null : Terms.$(lhs);
        Object rhs = tokens[1].trim();
        if (((String)rhs).isEmpty()) {
            if (response != null) return Formula.lhs(response);
            throw new IllegalArgumentException("Invalid formula: " + s);
        }
        Pattern regex = Pattern.compile("\\)\\d*");
        ArrayList<Term> predictors = new ArrayList<Term>();
        if (!((String)rhs).startsWith("+") && !((String)rhs).startsWith("-")) {
            rhs = "+ " + (String)rhs;
        }
        while (!((String)rhs).isEmpty()) {
            Object item;
            boolean delete = false;
            if (((String)rhs).startsWith("+")) {
                rhs = ((String)rhs).substring(1).trim();
            } else {
                if (!((String)rhs).startsWith("-")) throw new IllegalArgumentException("Invalid formula: " + s);
                delete = true;
                rhs = ((String)rhs).substring(1).trim();
            }
            if (((String)rhs).startsWith("(")) {
                Matcher matcher = regex.matcher((CharSequence)rhs);
                if (!matcher.find()) throw new IllegalArgumentException("Invalid formula: " + s);
                if (matcher.end() < ((String)rhs).length()) {
                    item = ((String)rhs).substring(0, matcher.end());
                    rhs = ((String)rhs).substring(matcher.end()).trim();
                } else {
                    item = rhs;
                    rhs = "";
                }
            } else {
                int end = ((String)rhs).indexOf(32, 1);
                if (end > 0) {
                    item = ((String)rhs).substring(0, end);
                    rhs = ((String)rhs).substring(end).trim();
                } else {
                    item = rhs;
                    rhs = "";
                }
            }
            Term term = Terms.$((String)item);
            if (delete) {
                term = Terms.delete(term);
            }
            predictors.add(term);
        }
        return new Formula(response, predictors.toArray(new Term[0]));
    }

    public Formula expand(StructType inputSchema) {
        HashSet<String> columns = new HashSet<String>();
        if (this.response != null) {
            columns.addAll(this.response.variables());
        }
        Arrays.stream(this.predictors).filter(term -> term instanceof FactorCrossing || term instanceof Variable).forEach(term -> columns.addAll(term.variables()));
        List<Variable> rest = inputSchema.fields().stream().filter(field -> !columns.contains(field.name())).map(field -> new Variable(field.name())).toList();
        ArrayList<Term> expanded = new ArrayList<Term>();
        for (Term predictor2 : this.predictors) {
            if (predictor2 instanceof Dot) {
                expanded.addAll(rest);
                continue;
            }
            if (predictor2 instanceof Delete) continue;
            expanded.addAll(predictor2.expand());
        }
        Set deletes = Arrays.stream(this.predictors).filter(predictor -> predictor instanceof Delete).flatMap(predictor -> predictor.expand().stream()).map(term -> term.toString().substring(2)).collect(Collectors.toSet());
        expanded.removeIf(term -> deletes.contains(term.toString()));
        return new Formula(this.response, expanded.toArray(new Term[0]));
    }

    public StructType bind(StructType inputSchema) {
        if (this.binding != null && this.binding.get().inputSchema == inputSchema) {
            return this.binding.get().xschema;
        }
        Formula formula = this.expand(inputSchema);
        final Binding binding = new Binding();
        binding.inputSchema = inputSchema;
        List features = Arrays.stream(formula.predictors).filter(predictor -> !(predictor instanceof Delete) && !(predictor instanceof Intercept)).flatMap(predictor -> predictor.bind(inputSchema).stream()).collect(Collectors.toCollection(ArrayList::new));
        binding.x = features.toArray(new Feature[0]);
        binding.xschema = new StructType((StructField[])features.stream().map(Feature::field).toArray(StructField[]::new));
        if (this.response != null) {
            try {
                features.addAll(0, this.response.bind(inputSchema));
                binding.yx = features.toArray(new Feature[0]);
                binding.yxschema = new StructType((StructField[])features.stream().map(Feature::field).toArray(StructField[]::new));
            }
            catch (RuntimeException ex) {
                logger.debug("The response variable {} doesn't exist in the schema {}", (Object)this.response, (Object)inputSchema);
            }
        }
        if (this.binding != null) {
            this.binding.remove();
        }
        this.binding = new ThreadLocal<Binding>(this){

            @Override
            protected synchronized Binding initialValue() {
                return binding;
            }
        };
        return binding.xschema;
    }

    public Tuple apply(final Tuple tuple) {
        this.bind(tuple.schema());
        final Binding binding = this.binding.get();
        return new AbstractTuple(binding.yxschema){

            @Override
            public Object get(int i) {
                return binding.yx[i].apply(tuple);
            }

            @Override
            public int getInt(int i) {
                return binding.yx[i].applyAsInt(tuple);
            }

            @Override
            public long getLong(int i) {
                return binding.yx[i].applyAsLong(tuple);
            }

            @Override
            public float getFloat(int i) {
                return binding.yx[i].applyAsFloat(tuple);
            }

            @Override
            public double getDouble(int i) {
                return binding.yx[i].applyAsDouble(tuple);
            }

            @Override
            public String toString() {
                return binding.yxschema.toString(this);
            }
        };
    }

    public Tuple x(final Tuple tuple) {
        this.bind(tuple.schema());
        final Binding binding = this.binding.get();
        return new AbstractTuple(binding.xschema){

            @Override
            public Object get(int i) {
                return binding.x[i].apply(tuple);
            }

            @Override
            public int getInt(int i) {
                return binding.x[i].applyAsInt(tuple);
            }

            @Override
            public long getLong(int i) {
                return binding.x[i].applyAsLong(tuple);
            }

            @Override
            public float getFloat(int i) {
                return binding.x[i].applyAsFloat(tuple);
            }

            @Override
            public double getDouble(int i) {
                return binding.x[i].applyAsDouble(tuple);
            }

            @Override
            public String toString() {
                return binding.xschema.toString(this);
            }
        };
    }

    public DataFrame frame(DataFrame data) {
        this.bind(data.schema());
        Binding binding = this.binding.get();
        ValueVector[] vectors = (ValueVector[])Arrays.stream(binding.yx != null ? binding.yx : binding.x).map(term -> term.apply(data)).toArray(ValueVector[]::new);
        return new DataFrame(vectors);
    }

    public DataFrame x(DataFrame data) {
        this.bind(data.schema());
        Binding binding = this.binding.get();
        ValueVector[] vectors = (ValueVector[])Arrays.stream(binding.x).map(term -> term.apply(data)).toArray(ValueVector[]::new);
        return new DataFrame(vectors);
    }

    private boolean hasBias() {
        boolean bias = true;
        Optional<Intercept> intercept = Arrays.stream(this.predictors).filter(term -> term instanceof Intercept).map(term -> (Intercept)term).findAny();
        if (intercept.isPresent()) {
            bias = intercept.get().bias();
        }
        return bias;
    }

    public Matrix matrix(DataFrame data) {
        return this.matrix(data, this.hasBias());
    }

    public Matrix matrix(DataFrame data, boolean bias) {
        return this.x(data).toMatrix(bias, CategoricalEncoder.DUMMY, null);
    }

    public ValueVector y(DataFrame data) {
        if (this.response == null) {
            throw new UnsupportedOperationException("The formula has no response variable.");
        }
        this.bind(data.schema());
        Binding binding = this.binding.get();
        if (binding.yx == null) {
            throw new UnsupportedOperationException("The data has no response variable.");
        }
        return binding.yx[0].apply(data);
    }

    public double y(Tuple tuple) {
        if (this.response == null) {
            throw new UnsupportedOperationException("The formula has no response variable.");
        }
        this.bind(tuple.schema());
        Binding binding = this.binding.get();
        if (binding.yx == null) {
            throw new UnsupportedOperationException("The data has no response variable.");
        }
        return binding.yx[0].applyAsDouble(tuple);
    }

    public int yint(Tuple tuple) {
        if (this.response == null) {
            throw new UnsupportedOperationException("The formula has no response variable.");
        }
        this.bind(tuple.schema());
        Binding binding = this.binding.get();
        if (binding.yx == null) {
            throw new UnsupportedOperationException("The data has no response variable.");
        }
        return binding.yx[0].applyAsInt(tuple);
    }

    private static class Binding {
        StructType inputSchema;
        StructType yxschema;
        StructType xschema;
        Feature[] yx;
        Feature[] x;

        private Binding() {
        }
    }
}

