Class AScikitLearnWrapper<P extends org.api4.java.ai.ml.core.evaluation.IPrediction,​B extends org.api4.java.ai.ml.core.evaluation.IPredictionBatch>

  • All Implemented Interfaces:
    IScikitLearnWrapper, org.api4.java.ai.ml.core.learner.IFittable<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,​org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>>, org.api4.java.ai.ml.core.learner.IFittablePredictor<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,​org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>>, org.api4.java.ai.ml.core.learner.ILearnerConfigHandler, org.api4.java.ai.ml.core.learner.IPredictor<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,​org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>>, org.api4.java.ai.ml.core.learner.ISupervisedLearner<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,​org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>>, org.api4.java.common.control.ILoggingCustomizable
    Direct Known Subclasses:
    ScikitLearnClassificationWrapper, ScikitLearnMultiTargetRegressionWrapper, ScikitLearnRegressionWrapper, ScikitLearnTimeSeriesFeatureEngineeringWrapper

    public abstract class AScikitLearnWrapper<P extends org.api4.java.ai.ml.core.evaluation.IPrediction,​B extends org.api4.java.ai.ml.core.evaluation.IPredictionBatch>
    extends ASupervisedLearner<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,​org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>,​P,​B>
    implements IScikitLearnWrapper
    • Field Detail

      • PYTHON_MINIMUM_REQUIRED_VERSION_REL

        public static final int PYTHON_MINIMUM_REQUIRED_VERSION_REL
        See Also:
        Constant Field Values
      • PYTHON_MINIMUM_REQUIRED_VERSION_MAJ

        public static final int PYTHON_MINIMUM_REQUIRED_VERSION_MAJ
        See Also:
        Constant Field Values
      • PYTHON_MINIMUM_REQUIRED_VERSION_MIN

        public static final int PYTHON_MINIMUM_REQUIRED_VERSION_MIN
        See Also:
        Constant Field Values
      • PYTHON_REQUIRED_MODULES

        protected static final java.lang.String[] PYTHON_REQUIRED_MODULES
      • PYTHON_OPTIONAL_MODULES

        protected static final java.lang.String[] PYTHON_OPTIONAL_MODULES
      • logger

        protected org.slf4j.Logger logger
      • pythonConfig

        protected ai.libs.python.IPythonConfig pythonConfig
      • configurationUID

        protected final java.lang.String configurationUID
      • pipeline

        protected java.lang.String pipeline
      • modelFile

        protected java.io.File modelFile
      • data

        protected org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance> data
      • targetIndices

        protected int[] targetIndices
      • seed

        protected long seed
      • timeout

        protected org.api4.java.algorithm.Timeout timeout
    • Constructor Detail

      • AScikitLearnWrapper

        protected AScikitLearnWrapper​(EScikitLearnProblemType problemType,
                                      java.lang.String pipeline,
                                      java.lang.String imports)
                               throws java.io.IOException,
                                      java.lang.InterruptedException
        Throws:
        java.io.IOException
        java.lang.InterruptedException
    • Method Detail

      • setPythonTemplate

        public void setPythonTemplate​(java.lang.String pythonTemplatePath)
                               throws java.io.IOException
        Specified by:
        setPythonTemplate in interface IScikitLearnWrapper
        Throws:
        java.io.IOException
      • setModelPath

        public void setModelPath​(java.lang.String modelPath)
                          throws java.io.IOException
        Specified by:
        setModelPath in interface IScikitLearnWrapper
        Throws:
        java.io.IOException
      • fit

        public void fit​(org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance> trainingData)
                 throws org.api4.java.ai.ml.core.exception.TrainingException,
                        java.lang.InterruptedException
        Specified by:
        fit in interface org.api4.java.ai.ml.core.learner.IFittable<P extends org.api4.java.ai.ml.core.evaluation.IPrediction,​B extends org.api4.java.ai.ml.core.evaluation.IPredictionBatch>
        Throws:
        org.api4.java.ai.ml.core.exception.TrainingException
        java.lang.InterruptedException
      • fit

        public void fit​(java.lang.String trainingDataName)
                 throws org.api4.java.ai.ml.core.exception.TrainingException,
                        java.lang.InterruptedException
        Specified by:
        fit in interface IScikitLearnWrapper
        Throws:
        org.api4.java.ai.ml.core.exception.TrainingException
        java.lang.InterruptedException
      • predict

        public B predict​(org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance> testingData)
                  throws org.api4.java.ai.ml.core.exception.PredictionException,
                         java.lang.InterruptedException
        Specified by:
        predict in interface org.api4.java.ai.ml.core.learner.IPredictor<P extends org.api4.java.ai.ml.core.evaluation.IPrediction,​B extends org.api4.java.ai.ml.core.evaluation.IPredictionBatch>
        Overrides:
        predict in class ASupervisedLearner<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,​org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>,​P extends org.api4.java.ai.ml.core.evaluation.IPrediction,​B extends org.api4.java.ai.ml.core.evaluation.IPredictionBatch>
        Throws:
        org.api4.java.ai.ml.core.exception.PredictionException
        java.lang.InterruptedException
      • predict

        public B predict​(java.lang.String testingDataName)
                  throws org.api4.java.ai.ml.core.exception.PredictionException,
                         java.lang.InterruptedException
        Throws:
        org.api4.java.ai.ml.core.exception.PredictionException
        java.lang.InterruptedException
      • predict

        public B predict​(org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance[] testingInstances)
                  throws org.api4.java.ai.ml.core.exception.PredictionException,
                         java.lang.InterruptedException
        Specified by:
        predict in interface org.api4.java.ai.ml.core.learner.IPredictor<P extends org.api4.java.ai.ml.core.evaluation.IPrediction,​B extends org.api4.java.ai.ml.core.evaluation.IPredictionBatch>
        Specified by:
        predict in class ASupervisedLearner<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,​org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>,​P extends org.api4.java.ai.ml.core.evaluation.IPrediction,​B extends org.api4.java.ai.ml.core.evaluation.IPredictionBatch>
        Throws:
        org.api4.java.ai.ml.core.exception.PredictionException
        java.lang.InterruptedException
      • predict

        public P predict​(org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance instance)
                  throws org.api4.java.ai.ml.core.exception.PredictionException,
                         java.lang.InterruptedException
        Specified by:
        predict in interface org.api4.java.ai.ml.core.learner.IPredictor<P extends org.api4.java.ai.ml.core.evaluation.IPrediction,​B extends org.api4.java.ai.ml.core.evaluation.IPredictionBatch>
        Specified by:
        predict in class ASupervisedLearner<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,​org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>,​P extends org.api4.java.ai.ml.core.evaluation.IPrediction,​B extends org.api4.java.ai.ml.core.evaluation.IPredictionBatch>
        Throws:
        org.api4.java.ai.ml.core.exception.PredictionException
        java.lang.InterruptedException
      • fitAndPredict

        public B fitAndPredict​(org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance> trainingData,
                               org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance> testingData)
                        throws org.api4.java.ai.ml.core.exception.TrainingException,
                               org.api4.java.ai.ml.core.exception.PredictionException,
                               java.lang.InterruptedException
        Specified by:
        fitAndPredict in interface org.api4.java.ai.ml.core.learner.IFittablePredictor<P extends org.api4.java.ai.ml.core.evaluation.IPrediction,​B extends org.api4.java.ai.ml.core.evaluation.IPredictionBatch>
        Overrides:
        fitAndPredict in class ASupervisedLearner<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,​org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>,​P extends org.api4.java.ai.ml.core.evaluation.IPrediction,​B extends org.api4.java.ai.ml.core.evaluation.IPredictionBatch>
        Throws:
        org.api4.java.ai.ml.core.exception.TrainingException
        org.api4.java.ai.ml.core.exception.PredictionException
        java.lang.InterruptedException
      • fitAndPredict

        public B fitAndPredict​(java.io.File trainingDataFile,
                               java.lang.String trainingDataName,
                               java.io.File testingDataFile,
                               java.lang.String testingDataName)
                        throws org.api4.java.ai.ml.core.exception.TrainingException,
                               org.api4.java.ai.ml.core.exception.PredictionException,
                               java.lang.InterruptedException
        Throws:
        org.api4.java.ai.ml.core.exception.TrainingException
        org.api4.java.ai.ml.core.exception.PredictionException
        java.lang.InterruptedException
      • getModelFileName

        protected java.lang.String getModelFileName​(java.lang.String dataFileName)
      • getDataName

        public java.lang.String getDataName​(org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance> data)
        Specified by:
        getDataName in interface IScikitLearnWrapper
      • doLabelsFitToProblemType

        protected abstract boolean doLabelsFitToProblemType​(org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance> data)
      • constructCommandLineParametersForFitMode

        protected ScikitLearnWrapperCommandBuilder constructCommandLineParametersForFitMode​(java.io.File modelFile,
                                                                                            java.io.File trainingDataFile)
      • constructCommandLineParametersForPredictMode

        protected ScikitLearnWrapperCommandBuilder constructCommandLineParametersForPredictMode​(java.io.File modelFile,
                                                                                                java.io.File testingDataFile,
                                                                                                java.io.File outputFile)
      • constructCommandLineParametersForFitAndPredictMode

        protected ScikitLearnWrapperCommandBuilder constructCommandLineParametersForFitAndPredictMode​(java.io.File trainingDataFile,
                                                                                                      java.io.File testingDataFile,
                                                                                                      java.io.File testingOutputFile)
      • handleOutput

        protected abstract B handleOutput​(java.io.File outputFile)
                                   throws org.api4.java.ai.ml.core.exception.PredictionException,
                                          org.api4.java.ai.ml.core.exception.TrainingException
        Throws:
        org.api4.java.ai.ml.core.exception.PredictionException
        org.api4.java.ai.ml.core.exception.TrainingException
      • getRawPredictionResults

        protected java.util.List<java.util.List<java.lang.Double>> getRawPredictionResults​(java.io.File outputFile)
                                                                                    throws org.api4.java.ai.ml.core.exception.PredictionException
        Throws:
        org.api4.java.ai.ml.core.exception.PredictionException
      • getLoggerName

        public java.lang.String getLoggerName()
        Specified by:
        getLoggerName in interface org.api4.java.common.control.ILoggingCustomizable
      • setLoggerName

        public void setLoggerName​(java.lang.String name)
        Specified by:
        setLoggerName in interface org.api4.java.common.control.ILoggingCustomizable
      • toString

        public java.lang.String toString()
        Overrides:
        toString in class java.lang.Object