Package ai.libs.jaicore.ml.scikitwrapper
Class AScikitLearnWrapper<P extends org.api4.java.ai.ml.core.evaluation.IPrediction,B extends org.api4.java.ai.ml.core.evaluation.IPredictionBatch>
- java.lang.Object
-
- ai.libs.jaicore.ml.core.learner.ASupervisedLearner<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>,P,B>
-
- ai.libs.jaicore.ml.scikitwrapper.AScikitLearnWrapper<P,B>
-
- All Implemented Interfaces:
IScikitLearnWrapper,org.api4.java.ai.ml.core.learner.IFittable<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>>,org.api4.java.ai.ml.core.learner.IFittablePredictor<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>>,org.api4.java.ai.ml.core.learner.ILearnerConfigHandler,org.api4.java.ai.ml.core.learner.IPredictor<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>>,org.api4.java.ai.ml.core.learner.ISupervisedLearner<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>>,org.api4.java.common.control.ILoggingCustomizable
- Direct Known Subclasses:
ScikitLearnClassificationWrapper,ScikitLearnMultiTargetRegressionWrapper,ScikitLearnRegressionWrapper,ScikitLearnTimeSeriesFeatureEngineeringWrapper
public abstract class AScikitLearnWrapper<P extends org.api4.java.ai.ml.core.evaluation.IPrediction,B extends org.api4.java.ai.ml.core.evaluation.IPredictionBatch> extends ASupervisedLearner<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>,P,B> implements IScikitLearnWrapper
-
-
Field Summary
Fields Modifier and Type Field Description protected java.lang.StringconfigurationUIDprotected org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>dataprotected org.slf4j.Loggerloggerprotected java.io.FilemodelFileprotected java.lang.Stringpipelineprotected EScikitLearnProblemTypeproblemTypestatic intPYTHON_MINIMUM_REQUIRED_VERSION_MAJstatic intPYTHON_MINIMUM_REQUIRED_VERSION_MINstatic intPYTHON_MINIMUM_REQUIRED_VERSION_RELprotected static java.lang.String[]PYTHON_OPTIONAL_MODULESprotected static java.lang.String[]PYTHON_REQUIRED_MODULESprotected ai.libs.python.IPythonConfigpythonConfigprotected IScikitLearnWrapperConfigscikitLearnWrapperConfigprotected longseedprotected int[]targetIndicesprotected org.api4.java.algorithm.Timeouttimeout
-
Constructor Summary
Constructors Modifier Constructor Description protectedAScikitLearnWrapper(EScikitLearnProblemType problemType, java.lang.String pipeline, java.lang.String imports)
-
Method Summary
All Methods Instance Methods Abstract Methods Concrete Methods Modifier and Type Method Description protected ScikitLearnWrapperCommandBuilderconstructCommandLineParametersForFitAndPredictMode(java.io.File trainingDataFile, java.io.File testingDataFile, java.io.File testingOutputFile)protected ScikitLearnWrapperCommandBuilderconstructCommandLineParametersForFitMode(java.io.File modelFile, java.io.File trainingDataFile)protected ScikitLearnWrapperCommandBuilderconstructCommandLineParametersForPredictMode(java.io.File modelFile, java.io.File testingDataFile, java.io.File outputFile)protected abstract booleandoLabelsFitToProblemType(org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance> data)voidfit(java.lang.String trainingDataName)voidfit(org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance> trainingData)BfitAndPredict(java.io.File trainingDataFile, java.lang.String trainingDataName, java.io.File testingDataFile, java.lang.String testingDataName)BfitAndPredict(org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance> trainingData, org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance> testingData)protected ScikitLearnWrapperCommandBuildergetCommandBuilder()protected ScikitLearnWrapperCommandBuildergetCommandBuilder(ScikitLearnWrapperCommandBuilder commandBuilder)java.lang.StringgetDataName(org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance> data)java.lang.StringgetLoggerName()java.io.FilegetModelFile()protected java.lang.StringgetModelFileName(java.lang.String dataFileName)java.io.FilegetModelPath()java.io.FilegetOutputFile(java.lang.String dataName)protected java.util.List<java.util.List<java.lang.Double>>getRawPredictionResults(java.io.File outputFile)java.io.FilegetSKLearnScriptFile()protected abstract BhandleOutput(java.io.File outputFile)Bpredict(java.lang.String testingDataName)Bpredict(org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance> testingData)Ppredict(org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance instance)Bpredict(org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance[] testingInstances)voidsetLoggerName(java.lang.String name)voidsetModelPath(java.lang.String modelPath)voidsetPythonConfig(ai.libs.python.IPythonConfig pythonConfig)voidsetPythonTemplate(java.lang.String pythonTemplatePath)voidsetScikitLearnWrapperConfig(IScikitLearnWrapperConfig scikitLearnWrapperConfig)voidsetSeed(long seed)voidsetTargetIndices(int... targetIndices)voidsetTimeout(org.api4.java.algorithm.Timeout timeout)java.lang.StringtoString()-
Methods inherited from class ai.libs.jaicore.ml.core.learner.ASupervisedLearner
fitAndPredict, fitAndPredict, getConfig, setConfig
-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
-
-
-
Field Detail
-
PYTHON_MINIMUM_REQUIRED_VERSION_REL
public static final int PYTHON_MINIMUM_REQUIRED_VERSION_REL
- See Also:
- Constant Field Values
-
PYTHON_MINIMUM_REQUIRED_VERSION_MAJ
public static final int PYTHON_MINIMUM_REQUIRED_VERSION_MAJ
- See Also:
- Constant Field Values
-
PYTHON_MINIMUM_REQUIRED_VERSION_MIN
public static final int PYTHON_MINIMUM_REQUIRED_VERSION_MIN
- See Also:
- Constant Field Values
-
PYTHON_REQUIRED_MODULES
protected static final java.lang.String[] PYTHON_REQUIRED_MODULES
-
PYTHON_OPTIONAL_MODULES
protected static final java.lang.String[] PYTHON_OPTIONAL_MODULES
-
logger
protected org.slf4j.Logger logger
-
scikitLearnWrapperConfig
protected IScikitLearnWrapperConfig scikitLearnWrapperConfig
-
pythonConfig
protected ai.libs.python.IPythonConfig pythonConfig
-
configurationUID
protected final java.lang.String configurationUID
-
problemType
protected EScikitLearnProblemType problemType
-
pipeline
protected java.lang.String pipeline
-
modelFile
protected java.io.File modelFile
-
data
protected org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance> data
-
targetIndices
protected int[] targetIndices
-
seed
protected long seed
-
timeout
protected org.api4.java.algorithm.Timeout timeout
-
-
Constructor Detail
-
AScikitLearnWrapper
protected AScikitLearnWrapper(EScikitLearnProblemType problemType, java.lang.String pipeline, java.lang.String imports) throws java.io.IOException, java.lang.InterruptedException
- Throws:
java.io.IOExceptionjava.lang.InterruptedException
-
-
Method Detail
-
setPythonTemplate
public void setPythonTemplate(java.lang.String pythonTemplatePath) throws java.io.IOException- Specified by:
setPythonTemplatein interfaceIScikitLearnWrapper- Throws:
java.io.IOException
-
setModelPath
public void setModelPath(java.lang.String modelPath) throws java.io.IOException- Specified by:
setModelPathin interfaceIScikitLearnWrapper- Throws:
java.io.IOException
-
getModelPath
public java.io.File getModelPath()
- Specified by:
getModelPathin interfaceIScikitLearnWrapper
-
setTargetIndices
public void setTargetIndices(int... targetIndices)
- Specified by:
setTargetIndicesin interfaceIScikitLearnWrapper
-
setSeed
public void setSeed(long seed)
- Specified by:
setSeedin interfaceIScikitLearnWrapper
-
setTimeout
public void setTimeout(org.api4.java.algorithm.Timeout timeout)
- Specified by:
setTimeoutin interfaceIScikitLearnWrapper
-
fit
public void fit(org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance> trainingData) throws org.api4.java.ai.ml.core.exception.TrainingException, java.lang.InterruptedException
-
fit
public void fit(java.lang.String trainingDataName) throws org.api4.java.ai.ml.core.exception.TrainingException, java.lang.InterruptedException- Specified by:
fitin interfaceIScikitLearnWrapper- Throws:
org.api4.java.ai.ml.core.exception.TrainingExceptionjava.lang.InterruptedException
-
predict
public B predict(org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance> testingData) throws org.api4.java.ai.ml.core.exception.PredictionException, java.lang.InterruptedException
- Specified by:
predictin interfaceorg.api4.java.ai.ml.core.learner.IPredictor<P extends org.api4.java.ai.ml.core.evaluation.IPrediction,B extends org.api4.java.ai.ml.core.evaluation.IPredictionBatch>- Overrides:
predictin classASupervisedLearner<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>,P extends org.api4.java.ai.ml.core.evaluation.IPrediction,B extends org.api4.java.ai.ml.core.evaluation.IPredictionBatch>- Throws:
org.api4.java.ai.ml.core.exception.PredictionExceptionjava.lang.InterruptedException
-
predict
public B predict(java.lang.String testingDataName) throws org.api4.java.ai.ml.core.exception.PredictionException, java.lang.InterruptedException
- Throws:
org.api4.java.ai.ml.core.exception.PredictionExceptionjava.lang.InterruptedException
-
predict
public B predict(org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance[] testingInstances) throws org.api4.java.ai.ml.core.exception.PredictionException, java.lang.InterruptedException
- Specified by:
predictin interfaceorg.api4.java.ai.ml.core.learner.IPredictor<P extends org.api4.java.ai.ml.core.evaluation.IPrediction,B extends org.api4.java.ai.ml.core.evaluation.IPredictionBatch>- Specified by:
predictin classASupervisedLearner<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>,P extends org.api4.java.ai.ml.core.evaluation.IPrediction,B extends org.api4.java.ai.ml.core.evaluation.IPredictionBatch>- Throws:
org.api4.java.ai.ml.core.exception.PredictionExceptionjava.lang.InterruptedException
-
predict
public P predict(org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance instance) throws org.api4.java.ai.ml.core.exception.PredictionException, java.lang.InterruptedException
- Specified by:
predictin interfaceorg.api4.java.ai.ml.core.learner.IPredictor<P extends org.api4.java.ai.ml.core.evaluation.IPrediction,B extends org.api4.java.ai.ml.core.evaluation.IPredictionBatch>- Specified by:
predictin classASupervisedLearner<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>,P extends org.api4.java.ai.ml.core.evaluation.IPrediction,B extends org.api4.java.ai.ml.core.evaluation.IPredictionBatch>- Throws:
org.api4.java.ai.ml.core.exception.PredictionExceptionjava.lang.InterruptedException
-
fitAndPredict
public B fitAndPredict(org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance> trainingData, org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance> testingData) throws org.api4.java.ai.ml.core.exception.TrainingException, org.api4.java.ai.ml.core.exception.PredictionException, java.lang.InterruptedException
- Specified by:
fitAndPredictin interfaceorg.api4.java.ai.ml.core.learner.IFittablePredictor<P extends org.api4.java.ai.ml.core.evaluation.IPrediction,B extends org.api4.java.ai.ml.core.evaluation.IPredictionBatch>- Overrides:
fitAndPredictin classASupervisedLearner<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>,P extends org.api4.java.ai.ml.core.evaluation.IPrediction,B extends org.api4.java.ai.ml.core.evaluation.IPredictionBatch>- Throws:
org.api4.java.ai.ml.core.exception.TrainingExceptionorg.api4.java.ai.ml.core.exception.PredictionExceptionjava.lang.InterruptedException
-
fitAndPredict
public B fitAndPredict(java.io.File trainingDataFile, java.lang.String trainingDataName, java.io.File testingDataFile, java.lang.String testingDataName) throws org.api4.java.ai.ml.core.exception.TrainingException, org.api4.java.ai.ml.core.exception.PredictionException, java.lang.InterruptedException
- Throws:
org.api4.java.ai.ml.core.exception.TrainingExceptionorg.api4.java.ai.ml.core.exception.PredictionExceptionjava.lang.InterruptedException
-
getModelFileName
protected java.lang.String getModelFileName(java.lang.String dataFileName)
-
getDataName
public java.lang.String getDataName(org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance> data)
- Specified by:
getDataNamein interfaceIScikitLearnWrapper
-
doLabelsFitToProblemType
protected abstract boolean doLabelsFitToProblemType(org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance> data)
-
getCommandBuilder
protected ScikitLearnWrapperCommandBuilder getCommandBuilder()
-
getCommandBuilder
protected ScikitLearnWrapperCommandBuilder getCommandBuilder(ScikitLearnWrapperCommandBuilder commandBuilder)
-
constructCommandLineParametersForFitMode
protected ScikitLearnWrapperCommandBuilder constructCommandLineParametersForFitMode(java.io.File modelFile, java.io.File trainingDataFile)
-
constructCommandLineParametersForPredictMode
protected ScikitLearnWrapperCommandBuilder constructCommandLineParametersForPredictMode(java.io.File modelFile, java.io.File testingDataFile, java.io.File outputFile)
-
constructCommandLineParametersForFitAndPredictMode
protected ScikitLearnWrapperCommandBuilder constructCommandLineParametersForFitAndPredictMode(java.io.File trainingDataFile, java.io.File testingDataFile, java.io.File testingOutputFile)
-
getOutputFile
public java.io.File getOutputFile(java.lang.String dataName)
- Specified by:
getOutputFilein interfaceIScikitLearnWrapper
-
handleOutput
protected abstract B handleOutput(java.io.File outputFile) throws org.api4.java.ai.ml.core.exception.PredictionException, org.api4.java.ai.ml.core.exception.TrainingException
- Throws:
org.api4.java.ai.ml.core.exception.PredictionExceptionorg.api4.java.ai.ml.core.exception.TrainingException
-
getRawPredictionResults
protected java.util.List<java.util.List<java.lang.Double>> getRawPredictionResults(java.io.File outputFile) throws org.api4.java.ai.ml.core.exception.PredictionException- Throws:
org.api4.java.ai.ml.core.exception.PredictionException
-
setPythonConfig
public void setPythonConfig(ai.libs.python.IPythonConfig pythonConfig)
- Specified by:
setPythonConfigin interfaceIScikitLearnWrapper
-
setScikitLearnWrapperConfig
public void setScikitLearnWrapperConfig(IScikitLearnWrapperConfig scikitLearnWrapperConfig)
- Specified by:
setScikitLearnWrapperConfigin interfaceIScikitLearnWrapper
-
getSKLearnScriptFile
public java.io.File getSKLearnScriptFile()
- Specified by:
getSKLearnScriptFilein interfaceIScikitLearnWrapper
-
getModelFile
public java.io.File getModelFile()
- Specified by:
getModelFilein interfaceIScikitLearnWrapper
-
getLoggerName
public java.lang.String getLoggerName()
- Specified by:
getLoggerNamein interfaceorg.api4.java.common.control.ILoggingCustomizable
-
setLoggerName
public void setLoggerName(java.lang.String name)
- Specified by:
setLoggerNamein interfaceorg.api4.java.common.control.ILoggingCustomizable
-
toString
public java.lang.String toString()
- Overrides:
toStringin classjava.lang.Object
-
-