Package ai.djl.training.optimizer
Class Optimizer
- java.lang.Object
-
- ai.djl.training.optimizer.Optimizer
-
public abstract class Optimizer extends java.lang.ObjectAnOptimizerupdates the weight parameters to minimize the loss function.Optimizeris an abstract class that provides the base implementation for optimizers.
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static classOptimizer.OptimizerBuilder<T extends Optimizer.OptimizerBuilder>The Builder to construct anOptimizer.
-
Field Summary
Fields Modifier and Type Field Description protected floatclipGradprotected floatrescaleGrad
-
Constructor Summary
Constructors Constructor Description Optimizer(Optimizer.OptimizerBuilder<?> builder)Creates a new instance ofOptimizer.
-
Method Summary
All Methods Static Methods Instance Methods Abstract Methods Concrete Methods Modifier and Type Method Description static Adadelta.Builderadadelta()Returns a new instance ofAdadelta.Builderthat can build anAdadeltaoptimizer.static Adagrad.Builderadagrad()Returns a new instance ofAdagrad.Builderthat can build anAdagradoptimizer.static Adam.Builderadam()Returns a new instance ofAdam.Builderthat can build anAdamoptimizer.static AdamW.BuilderadamW()Returns a new instance ofAdamW.Builderthat can build anAdamWoptimizer.protected floatgetWeightDecay()Gets the value of weight decay.static Nag.Buildernag()Returns a new instance ofNag.Builderthat can build anNagoptimizer.static RmsProp.Builderrmsprop()Returns a new instance ofRmsProp.Builderthat can build anRmsPropoptimizer.static Sgd.Buildersgd()Returns a new instance ofSgd.Builderthat can build anSgdoptimizer.abstract voidupdate(java.lang.String parameterId, NDArray weight, NDArray grad)Updates the parameters according to the gradients.protected intupdateCount(java.lang.String parameterId)protected NDArraywithDefaultState(java.util.Map<java.lang.String,java.util.Map<Device,NDArray>> state, java.lang.String key, Device device, java.util.function.Function<java.lang.String,NDArray> defaultFunction)
-
-
-
Constructor Detail
-
Optimizer
public Optimizer(Optimizer.OptimizerBuilder<?> builder)
Creates a new instance ofOptimizer.- Parameters:
builder- the builder used to create an instance ofOptimizer
-
-
Method Detail
-
sgd
public static Sgd.Builder sgd()
Returns a new instance ofSgd.Builderthat can build anSgdoptimizer.- Returns:
- the
SgdSgd.Builder
-
nag
public static Nag.Builder nag()
Returns a new instance ofNag.Builderthat can build anNagoptimizer.- Returns:
- the
NagNag.Builder
-
adam
public static Adam.Builder adam()
Returns a new instance ofAdam.Builderthat can build anAdamoptimizer.- Returns:
- the
AdamAdam.Builder
-
adamW
public static AdamW.Builder adamW()
Returns a new instance ofAdamW.Builderthat can build anAdamWoptimizer.- Returns:
- the
AdamWAdamW.Builder
-
rmsprop
public static RmsProp.Builder rmsprop()
Returns a new instance ofRmsProp.Builderthat can build anRmsPropoptimizer.- Returns:
- the
RmsPropRmsProp.Builder
-
adagrad
public static Adagrad.Builder adagrad()
Returns a new instance ofAdagrad.Builderthat can build anAdagradoptimizer.- Returns:
- the
AdagradAdagrad.Builder
-
adadelta
public static Adadelta.Builder adadelta()
Returns a new instance ofAdadelta.Builderthat can build anAdadeltaoptimizer.- Returns:
- the
AdadeltaAdadelta.Builder
-
getWeightDecay
protected float getWeightDecay()
Gets the value of weight decay.- Returns:
- the value of weight decay
-
updateCount
protected int updateCount(java.lang.String parameterId)
-
update
public abstract void update(java.lang.String parameterId, NDArray weight, NDArray grad)Updates the parameters according to the gradients.- Parameters:
parameterId- the parameter to be updatedweight- the weights of the parametergrad- the gradients
-
-