public class Prelu extends AbstractBlock
Leaky ReLUs attempt to fix the 'dying ReLU' problem by allowing a small slope when the input is negative and has a slope of one when input is positive. This is defined by \(y= x \gt 0 ? x : slope * x\).
Parametric ReLU is a Leaky ReLU in which the slope is learnt during training.
children, inputNames, inputShapes, parameters, version| Constructor and Description |
|---|
Prelu()
Creates a Parametric ReLU Block.
|
| Modifier and Type | Method and Description |
|---|---|
protected NDList |
forwardInternal(ParameterStore parameterStore,
NDList inputs,
boolean training,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper for
Block.forward(ParameterStore, NDList, boolean, PairList) after
initialization. |
Shape[] |
getOutputShapes(Shape[] inputs)
Returns the expected output shapes of the block for the specified input shapes.
|
void |
loadMetadata(byte loadVersion,
java.io.DataInputStream is)
Overwrite this to load additional metadata with the parameter values.
|
static NDList |
prelu(NDArray input,
NDArray alpha)
Applies a Prelu activation on the input
NDArray. |
addChildBlock, addParameter, beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getChildren, getDirectParameters, getParameters, initialize, initializeChildBlocks, isInitialized, loadParameters, prepare, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toStringclone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, waitforward, validateLayoutprotected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
Block.forward(ParameterStore, NDList, boolean, PairList) after
initialization.forwardInternal in class AbstractBlockparameterStore - the parameter storeinputs - the input NDListtraining - true for a training forward passparams - optional parameterspublic Shape[] getOutputShapes(Shape[] inputs)
inputs - the shapes of the inputspublic void loadMetadata(byte loadVersion,
java.io.DataInputStream is)
throws java.io.IOException,
MalformedModelException
If you overwrite AbstractBlock.saveMetadata(DataOutputStream) or need to provide
backward compatibility to older binary formats, you prabably need to overwrite this. This
default implementation checks if the version number fits, if not it throws an MalformedModelException. After that it restores the input shapes.
loadMetadata in class AbstractBlockloadVersion - the version used for loading this metadata.is - the input stream we are loading fromjava.io.IOException - loading failedMalformedModelException - data can be loaded but has wrong format