public class SGD extends RandomizableClassifier implements UpdateableClassifier, OptionHandler, Aggregateable<SGD>
-F Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression), 2 = squared loss (regression), 3 = epsilon insensitive loss (regression), 4 = Huber loss (regression). (default = 0)
-L The learning rate. If normalization is turned off (as it is automatically for streaming data), then the default learning rate will need to be reduced (try 0.0001). (default = 0.01).
-R <double> The lambda regularization constant (default = 0.0001)
-E <integer> The number of epochs to perform (batch learning only, default = 500)
-C <double> The epsilon threshold (epsilon-insenstive and Huber loss only, default = 1e-3)
-N Don't normalize the data
-M Don't replace missing values
-S <num> Random number seed. (default 1)
-output-debug-info If set, classifier is run in debug mode and may output additional info to the console
-do-not-check-capabilities If set, classifier capabilities are not checked before classifier is built (use with caution).
| Modifier and Type | Field and Description |
|---|---|
static int |
EPSILON_INSENSITIVE
The epsilon insensitive loss function
|
static int |
HINGE
the hinge loss function.
|
static int |
HUBER
The Huber loss function
|
static int |
LOGLOSS
the log loss function.
|
protected Instances |
m_data
Holds the header of the training data
|
protected boolean |
m_dontNormalize
Turn off normalization of the input data.
|
protected boolean |
m_dontReplaceMissing
Turn off global replacement of missing values.
|
protected int |
m_epochs
The number of epochs to perform (batch learning).
|
protected double |
m_epsilon
The epsilon parameter for epsilon insensitive and Huber loss
|
protected double |
m_lambda
The regularization parameter
|
protected double |
m_learningRate
The learning rate
|
protected int |
m_loss
The current loss function to minimize
|
protected Filter |
m_nominalToBinary
Convert nominal attributes to numerically coded binary ones.
|
protected Normalize |
m_normalize
Normalize the training data
|
protected double |
m_numInstances
The number of training instances
|
protected int |
m_numModels |
protected ReplaceMissingValues |
m_replaceMissing
Replace missing values
|
protected double |
m_t
Holds the current iteration number
|
protected double[] |
m_weights
Stores the weights (+ bias in the last element)
|
static int |
SQUAREDLOSS
the squared loss function.
|
static Tag[] |
TAGS_SELECTION
Loss functions to choose from
|
m_SeedBATCH_SIZE_DEFAULT, m_BatchSize, m_Debug, m_DoNotCheckCapabilities, m_numDecimalPlaces, NUM_DECIMAL_PLACES_DEFAULT| Constructor and Description |
|---|
SGD() |
| Modifier and Type | Method and Description |
|---|---|
SGD |
aggregate(SGD toAggregate)
Aggregate an object with this one
|
void |
buildClassifier(Instances data)
Method for building the classifier.
|
double[] |
distributionForInstance(Instance inst)
Computes the distribution for a given instance
|
protected double |
dloss(double z) |
java.lang.String |
dontNormalizeTipText()
Returns the tip text for this property
|
java.lang.String |
dontReplaceMissingTipText()
Returns the tip text for this property
|
protected static double |
dotProd(Instance inst1,
double[] weights,
int classIndex) |
java.lang.String |
epochsTipText()
Returns the tip text for this property
|
java.lang.String |
epsilonTipText()
Returns the tip text for this property
|
void |
finalizeAggregation()
Call to complete the aggregation process.
|
Capabilities |
getCapabilities()
Returns default capabilities of the classifier.
|
boolean |
getDontNormalize()
Get whether normalization has been turned off.
|
boolean |
getDontReplaceMissing()
Get whether global replacement of missing values has been disabled.
|
int |
getEpochs()
Get current number of epochs
|
double |
getEpsilon()
Get the epsilon threshold on the error for epsilon insensitive and Huber loss functions
|
double |
getLambda()
Get the current value of lambda
|
double |
getLearningRate()
Get the learning rate.
|
SelectedTag |
getLossFunction()
Get the current loss function.
|
java.lang.String[] |
getOptions()
Gets the current settings of the classifier.
|
java.lang.String |
getRevision()
Returns the revision string.
|
double[] |
getWeights() |
java.lang.String |
globalInfo()
Returns a string describing classifier
|
java.lang.String |
lambdaTipText()
Returns the tip text for this property
|
java.lang.String |
learningRateTipText()
Returns the tip text for this property
|
java.util.Enumeration<Option> |
listOptions()
Returns an enumeration describing the available options.
|
java.lang.String |
lossFunctionTipText()
Returns the tip text for this property
|
static void |
main(java.lang.String[] args)
Main method for testing this class.
|
void |
reset()
Reset the classifier.
|
void |
setDontNormalize(boolean m)
Turn normalization off/on.
|
void |
setDontReplaceMissing(boolean m)
Turn global replacement of missing values off/on.
|
void |
setEpochs(int e)
Set the number of epochs to use
|
void |
setEpsilon(double e)
Set the epsilon threshold on the error for epsilon insensitive and Huber loss functions
|
void |
setLambda(double lambda)
Set the value of lambda to use
|
void |
setLearningRate(double lr)
Set the learning rate.
|
void |
setLossFunction(SelectedTag function)
Set the loss function to use.
|
void |
setOptions(java.lang.String[] options)
Parses a given list of options.
|
java.lang.String |
toString()
Prints out the classifier.
|
void |
updateClassifier(Instance instance)
Updates the classifier with the given instance.
|
protected void |
updateClassifier(Instance instance,
boolean filter)
Updates the classifier with the given instance.
|
getSeed, seedTipText, setSeedbatchSizeTipText, classifyInstance, debugTipText, distributionsForInstances, doNotCheckCapabilitiesTipText, forName, getBatchSize, getDebug, getDoNotCheckCapabilities, getNumDecimalPlaces, implementsMoreEfficientBatchPrediction, makeCopies, makeCopy, numDecimalPlacesTipText, postExecution, preExecution, run, runClassifier, setBatchSize, setDebug, setDoNotCheckCapabilities, setNumDecimalPlacesprotected ReplaceMissingValues m_replaceMissing
protected Filter m_nominalToBinary
protected Normalize m_normalize
protected double m_lambda
protected double m_learningRate
protected double[] m_weights
protected double m_epsilon
protected double m_t
protected double m_numInstances
protected int m_epochs
protected boolean m_dontNormalize
protected boolean m_dontReplaceMissing
protected Instances m_data
public static final int HINGE
public static final int LOGLOSS
public static final int SQUAREDLOSS
public static final int EPSILON_INSENSITIVE
public static final int HUBER
protected int m_loss
public static final Tag[] TAGS_SELECTION
protected int m_numModels
public Capabilities getCapabilities()
getCapabilities in interface ClassifiergetCapabilities in interface CapabilitiesHandlergetCapabilities in class AbstractClassifierCapabilitiespublic java.lang.String epsilonTipText()
public void setEpsilon(double e)
e - the value of epsilon to usepublic double getEpsilon()
public java.lang.String lambdaTipText()
public void setLambda(double lambda)
lambda - the value of lambda to usepublic double getLambda()
public void setLearningRate(double lr)
lr - the learning rate to use.public double getLearningRate()
public java.lang.String learningRateTipText()
public java.lang.String epochsTipText()
public void setEpochs(int e)
e - the number of epochs to usepublic int getEpochs()
public void setDontNormalize(boolean m)
m - true if normalization is to be disabled.public boolean getDontNormalize()
public java.lang.String dontNormalizeTipText()
public void setDontReplaceMissing(boolean m)
m - true if global replacement of missing values is to be turned off.public boolean getDontReplaceMissing()
public java.lang.String dontReplaceMissingTipText()
public void setLossFunction(SelectedTag function)
function - the loss function to use.public SelectedTag getLossFunction()
public java.lang.String lossFunctionTipText()
public java.util.Enumeration<Option> listOptions()
listOptions in interface OptionHandlerlistOptions in class RandomizableClassifierpublic void setOptions(java.lang.String[] options)
throws java.lang.Exception
-F Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression), 2 = squared loss (regression), 3 = epsilon insensitive loss (regression), 4 = Huber loss (regression). (default = 0)
-L The learning rate. If normalization is turned off (as it is automatically for streaming data), then the default learning rate will need to be reduced (try 0.0001). (default = 0.01).
-R <double> The lambda regularization constant (default = 0.0001)
-E <integer> The number of epochs to perform (batch learning only, default = 500)
-C <double> The epsilon threshold (epsilon-insenstive and Huber loss only, default = 1e-3)
-N Don't normalize the data
-M Don't replace missing values
-S <num> Random number seed. (default 1)
-output-debug-info If set, classifier is run in debug mode and may output additional info to the console
-do-not-check-capabilities If set, classifier capabilities are not checked before classifier is built (use with caution).
setOptions in interface OptionHandlersetOptions in class RandomizableClassifieroptions - the list of options as an array of stringsjava.lang.Exception - if an option is not supportedpublic java.lang.String[] getOptions()
getOptions in interface OptionHandlergetOptions in class RandomizableClassifierpublic java.lang.String globalInfo()
public void reset()
public void buildClassifier(Instances data) throws java.lang.Exception
buildClassifier in interface Classifierdata - the set of training instances.java.lang.Exception - if the classifier can't be built successfully.protected double dloss(double z)
protected static double dotProd(Instance inst1, double[] weights, int classIndex) throws java.lang.InterruptedException
java.lang.InterruptedExceptionprotected void updateClassifier(Instance instance, boolean filter) throws java.lang.Exception
instance - the new training instance to include in the modelfilter - true if the instance should pass through any of the filters set up in buildClassifier().
When batch training buildClassifier() already batch filters all training instances so
don't need to filter them again here.java.lang.Exception - if the instance could not be incorporated in the model.public void updateClassifier(Instance instance) throws java.lang.Exception
updateClassifier in interface UpdateableClassifierinstance - the new training instance to include in the modeljava.lang.Exception - if the instance could not be incorporated in the model.public double[] distributionForInstance(Instance inst) throws java.lang.Exception
distributionForInstance in interface ClassifierdistributionForInstance in class AbstractClassifierinstance - the instance for which distribution is computedjava.lang.Exception - if the distribution can't be computed successfullypublic double[] getWeights()
public java.lang.String toString()
toString in class java.lang.Objectpublic java.lang.String getRevision()
getRevision in interface RevisionHandlergetRevision in class AbstractClassifierpublic SGD aggregate(SGD toAggregate) throws java.lang.Exception
aggregate in interface Aggregateable<SGD>toAggregate - the object to aggregatejava.lang.Exception - if the supplied object can't be aggregated for some reasonpublic void finalizeAggregation()
throws java.lang.Exception
finalizeAggregation in interface Aggregateable<SGD>java.lang.Exception - if the aggregation can't be finalized for some reasonpublic static void main(java.lang.String[] args)