Package ai.libs.jaicore.ml.scikitwrapper
Class ScikitLearnWrapper
- java.lang.Object
-
- ai.libs.jaicore.ml.core.learner.ASupervisedLearner<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>,org.api4.java.ai.ml.classification.singlelabel.evaluation.ISingleLabelClassification,org.api4.java.ai.ml.classification.singlelabel.evaluation.ISingleLabelClassificationPredictionBatch>
-
- ai.libs.jaicore.ml.scikitwrapper.ScikitLearnWrapper
-
- All Implemented Interfaces:
org.api4.java.ai.ml.core.learner.IFittable<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>>,org.api4.java.ai.ml.core.learner.IFittablePredictor<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>>,org.api4.java.ai.ml.core.learner.ILearnerConfigHandler,org.api4.java.ai.ml.core.learner.IPredictor<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>>,org.api4.java.ai.ml.core.learner.ISupervisedLearner<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>>
public class ScikitLearnWrapper extends ASupervisedLearner<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>,org.api4.java.ai.ml.classification.singlelabel.evaluation.ISingleLabelClassification,org.api4.java.ai.ml.classification.singlelabel.evaluation.ISingleLabelClassificationPredictionBatch> implements org.api4.java.ai.ml.core.learner.ISupervisedLearner<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>>
Wraps a Scikit-Learn Python process by utilizing a template to start a classifier in Scikit with the given classifier. Usage: Set the constructInstruction to exactly the command how the classifier should be instantiated. E.g. "LinearRegression()" or "MLPRegressor(solver = 'lbfg')". Set the imports to exactly what the additional imports lines that are necessary to run the construction command must look like. It is up to the user to decide whether fully qualified names or only the class name themself are used as long as the import is on par with the construct call. E.g (without namespace in construct call) "from sklearn.linear_model import LinearRegression" or (without namespace) "import sklearn.linear_model" createImportStatementFromImportFolder might help to import an own folder of modules. It initializes the folder to be utilizable as a source of modules. Depending on the shape of the construct call the keepNamespace flag must be set (as described above). Before starting the classification it must be set whether the given dataset is a categorical or a regression task (setIsRegression). If the task is a multi target prediction, setTargets must be used to define which columns of the dataset are the targets. If no targets are defined it is assumed that only the last column is the target vector. Moreover, the outputFolder might be set to something else but the default (setOutputFolder). Now buildClassifier can be run. If classifyInstances is run with the same ScikitLearnWrapper instance after training, the previously trained model is used for testing. If another model shall be used or there was no training prior to classifyInstances, the model must be set with setModelPath. After a multi target prediction the results might be more accessible with the unflattened representation that can be obtained with getRawLastClassificationResults. For debug purposes the wrapper might be set to be verbose with setIsVerbose.
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static classScikitLearnWrapper.ProblemType
-
Constructor Summary
Constructors Constructor Description ScikitLearnWrapper(java.lang.String constructInstruction, java.lang.String imports)Starts a new wrapper and creates its underlying script with the given parameters.ScikitLearnWrapper(java.lang.String constructInstruction, java.lang.String imports, boolean withoutModelDump)Starts a new wrapper and creates its underlying script with the given parameters.ScikitLearnWrapper(java.lang.String constructInstruction, java.lang.String imports, java.io.File trainedModelPath)
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description static java.lang.StringcreateImportStatementFromImportFolder(java.io.File importsFolder, boolean keepNamespace)Makes the given folder a module to be usable as an import for python and creates a string that adds the folder to the python environment and then imports the folder itself as a module.double[]distributionForInstance(org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance instance)voidfit(org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance> data)static java.lang.StringgetImportString(java.util.Collection<java.lang.String> imports)java.io.FilegetModelPath()java.util.List<java.util.List<java.lang.Double>>getRawLastClassificationResults()org.api4.java.ai.ml.classification.singlelabel.evaluation.ISingleLabelClassificationpredict(org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance instance)org.api4.java.ai.ml.classification.singlelabel.evaluation.ISingleLabelClassificationPredictionBatchpredict(org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance[] dTest)voidsetModelPath(java.io.File modelFile)voidsetProblemType(ScikitLearnWrapper.ProblemType problemType)voidsetTargets(int... targetColumns)java.lang.StringtoString()-
Methods inherited from class ai.libs.jaicore.ml.core.learner.ASupervisedLearner
fitAndPredict, fitAndPredict, fitAndPredict, getConfig, predict, setConfig
-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface org.api4.java.ai.ml.core.learner.IFittablePredictor
fitAndPredict, fitAndPredict, fitAndPredict
-
-
-
-
Constructor Detail
-
ScikitLearnWrapper
public ScikitLearnWrapper(java.lang.String constructInstruction, java.lang.String imports, boolean withoutModelDump) throws java.io.IOExceptionStarts a new wrapper and creates its underlying script with the given parameters.- Parameters:
constructInstruction- String that defines what constructor to call for the classifier and with which parameters to call it.imports- Imports that are appended to the beginning of the script. Normally only the necessary imports for the constructor instruction must be added here.- Throws:
java.io.IOException- The script could not be created.
-
ScikitLearnWrapper
public ScikitLearnWrapper(java.lang.String constructInstruction, java.lang.String imports) throws java.io.IOExceptionStarts a new wrapper and creates its underlying script with the given parameters.- Parameters:
constructInstruction- String that defines what constructor to call for the classifier and with which parameters to call it.imports- Imports that are appended to the beginning of the script. Normally only the necessary imports for the constructor instruction must be added here.- Throws:
java.io.IOException- The script could not be created.
-
ScikitLearnWrapper
public ScikitLearnWrapper(java.lang.String constructInstruction, java.lang.String imports, java.io.File trainedModelPath) throws java.io.IOException- Throws:
java.io.IOException
-
-
Method Detail
-
fit
public void fit(org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance> data) throws org.api4.java.ai.ml.core.exception.TrainingException, java.lang.InterruptedException- Specified by:
fitin interfaceorg.api4.java.ai.ml.core.learner.IFittable<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>>- Throws:
org.api4.java.ai.ml.core.exception.TrainingExceptionjava.lang.InterruptedException
-
predict
public org.api4.java.ai.ml.classification.singlelabel.evaluation.ISingleLabelClassification predict(org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance instance) throws org.api4.java.ai.ml.core.exception.PredictionException, java.lang.InterruptedException- Specified by:
predictin interfaceorg.api4.java.ai.ml.core.learner.IPredictor<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>>- Specified by:
predictin classASupervisedLearner<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>,org.api4.java.ai.ml.classification.singlelabel.evaluation.ISingleLabelClassification,org.api4.java.ai.ml.classification.singlelabel.evaluation.ISingleLabelClassificationPredictionBatch>- Throws:
org.api4.java.ai.ml.core.exception.PredictionExceptionjava.lang.InterruptedException
-
predict
public org.api4.java.ai.ml.classification.singlelabel.evaluation.ISingleLabelClassificationPredictionBatch predict(org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance[] dTest) throws org.api4.java.ai.ml.core.exception.PredictionException, java.lang.InterruptedException- Specified by:
predictin interfaceorg.api4.java.ai.ml.core.learner.IPredictor<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>>- Specified by:
predictin classASupervisedLearner<org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance,org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset<? extends org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance>,org.api4.java.ai.ml.classification.singlelabel.evaluation.ISingleLabelClassification,org.api4.java.ai.ml.classification.singlelabel.evaluation.ISingleLabelClassificationPredictionBatch>- Throws:
org.api4.java.ai.ml.core.exception.PredictionExceptionjava.lang.InterruptedException
-
createImportStatementFromImportFolder
public static java.lang.String createImportStatementFromImportFolder(java.io.File importsFolder, boolean keepNamespace) throws java.io.IOExceptionMakes the given folder a module to be usable as an import for python and creates a string that adds the folder to the python environment and then imports the folder itself as a module.- Parameters:
importsFolder- Folder to be added as a module.keepNamespace- If true, a class must be called by the modules' name plus the class name. This is only important if multiple modules are imported and the classes' names are ambiguous. Keep in mind that the constructor call for the classifier must be created accordingly.- Returns:
- String which can be appended to other imports to care for the folder to be added as a module.
- Throws:
java.io.IOException- The __init__.py couldn't be created in the given folder (which is necessary to declare it as a module).
-
getImportString
public static java.lang.String getImportString(java.util.Collection<java.lang.String> imports)
-
getRawLastClassificationResults
public java.util.List<java.util.List<java.lang.Double>> getRawLastClassificationResults()
-
setProblemType
public void setProblemType(ScikitLearnWrapper.ProblemType problemType)
-
setTargets
public void setTargets(int... targetColumns)
-
setModelPath
public void setModelPath(java.io.File modelFile)
-
getModelPath
public java.io.File getModelPath()
-
distributionForInstance
public double[] distributionForInstance(org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance instance)
-
toString
public java.lang.String toString()
- Overrides:
toStringin classjava.lang.Object
-
-