/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml;

import com.google.common.collect.Iterables;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.reflect.Constructor;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import javax.xml.bind.JAXBException;
import org.apache.commons.io.output.ByteArrayOutputStream;
import org.apache.log4j.LogManager;
import org.apache.log4j.Logger;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.Transformer;
import org.apache.spark.sql.types.StructType;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.model.MetroJAXBUtil;
import org.jpmml.sparkml.FeatureConverter;
import org.jpmml.sparkml.ModelConverter;
import org.jpmml.sparkml.SparkMLEncoder;
import org.jpmml.sparkml.TransformerConverter;

public class ConverterUtil {
    private static final Map<Class<? extends Transformer>, Class<? extends TransformerConverter>> converters = new LinkedHashMap<Class<? extends Transformer>, Class<? extends TransformerConverter>>();
    private static final Logger logger = LogManager.getLogger(ConverterUtil.class);

    private ConverterUtil() {
    }

    public static PMML toPMML(StructType schema, PipelineModel pipelineModel) {
        Model rootModel;
        SparkMLEncoder encoder = new SparkMLEncoder(schema);
        ArrayList<Model> models = new ArrayList<Model>();
        List<Transformer> transformers = ConverterUtil.getTransformers(pipelineModel);
        for (Transformer transformer : transformers) {
            TransformerConverter<Transformer> converter = ConverterUtil.createConverter(transformer);
            if (converter instanceof FeatureConverter) {
                FeatureConverter featureConverter = (FeatureConverter)converter;
                featureConverter.registerFeatures(encoder);
                continue;
            }
            if (converter instanceof ModelConverter) {
                ModelConverter modelConverter = (ModelConverter)converter;
                Model model = modelConverter.registerModel(encoder);
                models.add(model);
                continue;
            }
            throw new IllegalArgumentException("Expected a " + FeatureConverter.class.getName() + " or " + ModelConverter.class.getName() + " instance, got " + converter);
        }
        if (models.size() == 1) {
            rootModel = (Model)Iterables.getOnlyElement(models);
        } else if (models.size() > 1) {
            ArrayList<MiningField> targetMiningFields = new ArrayList<MiningField>();
            for (Model model : models) {
                MiningSchema miningSchema = model.getMiningSchema();
                List miningFields = miningSchema.getMiningFields();
                for (MiningField miningField : miningFields) {
                    MiningField.UsageType usageType = miningField.getUsageType();
                    switch (usageType) {
                        case PREDICTED: 
                        case TARGET: {
                            targetMiningFields.add(miningField);
                            break;
                        }
                    }
                }
            }
            MiningSchema miningSchema = new MiningSchema(targetMiningFields);
            MiningModel miningModel = MiningModelUtil.createModelChain(models, (Schema)new Schema(null, Collections.emptyList())).setMiningSchema(miningSchema);
            rootModel = miningModel;
        } else {
            throw new IllegalArgumentException("Expected a pipeline with one or more models, got a pipeline with zero models");
        }
        PMML pmml = encoder.encodePMML(rootModel);
        return pmml;
    }

    public static byte[] toPMMLByteArray(StructType schema, PipelineModel pipelineModel) {
        PMML pmml = ConverterUtil.toPMML(schema, pipelineModel);
        ByteArrayOutputStream os = new ByteArrayOutputStream(0x100000);
        try {
            MetroJAXBUtil.marshalPMML((PMML)pmml, (OutputStream)os);
        }
        catch (JAXBException je) {
            throw new RuntimeException(je);
        }
        return os.toByteArray();
    }

    public static FeatureConverter<?> createFeatureConverter(Transformer transformer) {
        return (FeatureConverter)ConverterUtil.createConverter(transformer);
    }

    public static ModelConverter<?> createModelConverter(Transformer transformer) {
        return (ModelConverter)ConverterUtil.createConverter(transformer);
    }

    public static <T extends Transformer> TransformerConverter<T> createConverter(T transformer) {
        Class<?> clazz = transformer.getClass();
        Class<? extends TransformerConverter> converterClazz = ConverterUtil.getConverterClazz(clazz);
        if (converterClazz == null) {
            throw new IllegalArgumentException("Transformer class " + clazz.getName() + " is not supported");
        }
        try {
            Constructor<? extends TransformerConverter> constructor = converterClazz.getDeclaredConstructor(clazz);
            return constructor.newInstance(transformer);
        }
        catch (ReflectiveOperationException roe) {
            throw new IllegalArgumentException(roe);
        }
    }

    public static Class<? extends TransformerConverter> getConverterClazz(Class<? extends Transformer> clazz) {
        return converters.get(clazz);
    }

    public static void putConverterClazz(Class<? extends Transformer> clazz, Class<? extends TransformerConverter<?>> converterClazz) {
        if (clazz == null || !Transformer.class.isAssignableFrom(clazz)) {
            throw new IllegalArgumentException("Expected " + Transformer.class.getName() + " subclass, got " + (clazz != null ? clazz.getName() : null));
        }
        if (converterClazz == null || !TransformerConverter.class.isAssignableFrom(converterClazz)) {
            throw new IllegalArgumentException("Expected " + TransformerConverter.class.getName() + " subclass, got " + (converterClazz != null ? converterClazz.getName() : null));
        }
        converters.put(clazz, converterClazz);
    }

    private static List<Transformer> getTransformers(PipelineModel pipelineModel) {
        Transformer[] stages;
        ArrayList<Transformer> result = new ArrayList<Transformer>();
        for (Transformer stage : stages = pipelineModel.stages()) {
            if (stage instanceof PipelineModel) {
                PipelineModel nestedPipelineModel = (PipelineModel)stage;
                result.addAll(ConverterUtil.getTransformers(nestedPipelineModel));
                continue;
            }
            result.add(stage);
        }
        return result;
    }

    private static void init() {
        Enumeration<URL> urls;
        Thread thread = Thread.currentThread();
        ClassLoader classLoader = thread.getContextClassLoader();
        if (classLoader == null) {
            classLoader = ClassLoader.getSystemClassLoader();
        }
        try {
            urls = classLoader.getResources("META-INF/sparkml2pmml.properties");
        }
        catch (IOException ioe) {
            logger.warn((Object)"Failed to find resources", (Throwable)ioe);
            return;
        }
        while (urls.hasMoreElements()) {
            URL url = urls.nextElement();
            logger.trace((Object)("Loading resource " + url));
            try {
                InputStream is = url.openStream();
                Throwable throwable = null;
                try {
                    Properties properties = new Properties();
                    properties.load(is);
                    ConverterUtil.init(classLoader, properties);
                }
                catch (Throwable throwable2) {
                    throwable = throwable2;
                    throw throwable2;
                }
                finally {
                    if (is == null) continue;
                    if (throwable != null) {
                        try {
                            is.close();
                        }
                        catch (Throwable x2) {
                            throwable.addSuppressed(x2);
                        }
                        continue;
                    }
                    is.close();
                }
            }
            catch (IOException ioe) {
                logger.warn((Object)"Failed to load resource", (Throwable)ioe);
            }
        }
    }

    private static void init(ClassLoader classLoader, Properties properties) {
        if (properties.isEmpty()) {
            return;
        }
        Set<String> keys = properties.stringPropertyNames();
        for (String key : keys) {
            Class<?> converterClazz;
            Class<?> clazz;
            String value = properties.getProperty(key);
            logger.trace((Object)("Mapping transformer class " + key + " to transformer converter class " + value));
            try {
                clazz = classLoader.loadClass(key);
            }
            catch (ClassNotFoundException cnfe) {
                logger.warn((Object)"Failed to load transformer class", (Throwable)cnfe);
                continue;
            }
            try {
                converterClazz = classLoader.loadClass(value);
            }
            catch (ClassNotFoundException cnfe) {
                logger.warn((Object)"Failed to load transformer converter class", (Throwable)cnfe);
                continue;
            }
            ConverterUtil.putConverterClazz(clazz, converterClazz);
        }
    }

    static {
        ConverterUtil.init();
    }
}

