Package ai.djl.training.dataset
Class ArrayDataset
- java.lang.Object
-
- ai.djl.training.dataset.RandomAccessDataset
-
- ai.djl.training.dataset.ArrayDataset
-
- All Implemented Interfaces:
Dataset
public class ArrayDataset extends RandomAccessDataset
ArrayDatasetis an implementation ofRandomAccessDatasetthat consist entirely of largeNDArrays. It is recommended only for datasets small enough to fit in memory that come in array formats. Otherwise, consider directly using theRandomAccessDatasetinstead.There can be multiple data and label
NDArrays within the dataset. Each sample will be retrieved by indexing eachNDArrayalong the first dimension.The following is an example of how to use ArrayDataset:
ArrayDataset dataset = new ArrayDataset.Builder() .setData(data1, data2) .optLabels(labels1, labels2, labels3) .setSampling(20, false) .build();Suppose you get a
Batchfromtrainer.iterateDataset(dataset)ordataset.getData(manager). In the data of this batch, it will be an NDList with one NDArray for each data input. In this case, it would be 2 arrays. Similarly, the labels would have 3 arrays.- See Also:
Dataset
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static classArrayDataset.BuilderThe Builder to construct anArrayDataset.-
Nested classes/interfaces inherited from class ai.djl.training.dataset.RandomAccessDataset
RandomAccessDataset.BaseBuilder<T extends RandomAccessDataset.BaseBuilder<T>>
-
Nested classes/interfaces inherited from interface ai.djl.training.dataset.Dataset
Dataset.Usage
-
-
Field Summary
Fields Modifier and Type Field Description protected NDArray[]dataprotected NDArray[]labels-
Fields inherited from class ai.djl.training.dataset.RandomAccessDataset
dataBatchifier, device, labelBatchifier, limit, pipeline, prefetchNumber, sampler, targetPipeline
-
-
Constructor Summary
Constructors Constructor Description ArrayDataset(RandomAccessDataset.BaseBuilder<?> builder)Creates a new instance ofArrayDatasetwith the arguments inArrayDataset.Builder.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description protected longavailableSize()Returns the number of records available to be read in thisDataset.Recordget(NDManager manager, long index)Gets theRecordfor the given index from the dataset.BatchgetByIndices(NDManager manager, long... indices)Gets theBatchfor the given indices from the dataset.BatchgetByRange(NDManager manager, long fromIndex, long toIndex)Gets theBatchfor the given range from the dataset.java.lang.Iterable<Batch>getData(NDManager manager, Sampler sampler, java.util.concurrent.ExecutorService executorService)Fetches an iterator that can iterate through theDatasetwith a custom sampler multi-threaded.protected RandomAccessDatasetnewSubDataset(int[] indices, int from, int to)protected RandomAccessDatasetnewSubDataset(java.util.List<java.lang.Long> subIndices)voidprepare(ai.djl.util.Progress progress)Prepares the dataset for use with tracked progress.-
Methods inherited from class ai.djl.training.dataset.RandomAccessDataset
getData, getData, getData, randomSplit, size, subDataset, subDataset, subDataset, subDataset, toArray
-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
-
Methods inherited from interface ai.djl.training.dataset.Dataset
matchingTranslatorOptions, prepare
-
-
-
-
Constructor Detail
-
ArrayDataset
public ArrayDataset(RandomAccessDataset.BaseBuilder<?> builder)
Creates a new instance ofArrayDatasetwith the arguments inArrayDataset.Builder.- Parameters:
builder- a builder with the required arguments
-
-
Method Detail
-
availableSize
protected long availableSize()
Returns the number of records available to be read in thisDataset.- Specified by:
availableSizein classRandomAccessDataset- Returns:
- the number of records available to be read in this
Dataset
-
get
public Record get(NDManager manager, long index)
Gets theRecordfor the given index from the dataset.- Specified by:
getin classRandomAccessDataset- Parameters:
manager- the manager used to create the arraysindex- the index of the requested data item- Returns:
- a
Recordthat contains the data and label of the requested data item
-
getByIndices
public Batch getByIndices(NDManager manager, long... indices)
Gets theBatchfor the given indices from the dataset.- Parameters:
manager- the manager used to create the arraysindices- indices of the requested data items- Returns:
- a
Batchthat contains the data and label of the requested data items
-
getByRange
public Batch getByRange(NDManager manager, long fromIndex, long toIndex)
Gets theBatchfor the given range from the dataset.- Parameters:
manager- the manager used to create the arraysfromIndex- low endpoint (inclusive) of the datasettoIndex- high endpoint (exclusive) of the dataset- Returns:
- a
Batchthat contains the data and label of the requested data items
-
newSubDataset
protected RandomAccessDataset newSubDataset(int[] indices, int from, int to)
- Overrides:
newSubDatasetin classRandomAccessDataset
-
newSubDataset
protected RandomAccessDataset newSubDataset(java.util.List<java.lang.Long> subIndices)
- Overrides:
newSubDatasetin classRandomAccessDataset
-
getData
public java.lang.Iterable<Batch> getData(NDManager manager, Sampler sampler, java.util.concurrent.ExecutorService executorService) throws java.io.IOException, TranslateException
Fetches an iterator that can iterate through theDatasetwith a custom sampler multi-threaded.- Overrides:
getDatain classRandomAccessDataset- Parameters:
manager- the manager to create the arrayssampler- the sampler to use to iterate through the datasetexecutorService- the executorService to multi-thread with- Returns:
- an
IterableofBatchthat contains batches of data from the dataset - Throws:
java.io.IOException- for various exceptions depending on the datasetTranslateException- if there is an error while processing input
-
prepare
public void prepare(ai.djl.util.Progress progress) throws java.io.IOExceptionPrepares the dataset for use with tracked progress.- Parameters:
progress- the progress tracker- Throws:
java.io.IOException- for various exceptions depending on the dataset
-
-