/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset.callbacks;

import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.callbacks.DataSetCallback;
import org.nd4j.linalg.factory.Nd4j;

public class DefaultCallback
implements DataSetCallback {
    @Override
    public void call(DataSet dataSet) {
        if (dataSet != null) {
            if (dataSet.getFeatures() != null) {
                Nd4j.getAffinityManager().ensureLocation(dataSet.getFeatures(), AffinityManager.Location.DEVICE);
            }
            if (dataSet.getLabels() != null) {
                Nd4j.getAffinityManager().ensureLocation(dataSet.getLabels(), AffinityManager.Location.DEVICE);
            }
            if (dataSet.getFeaturesMaskArray() != null) {
                Nd4j.getAffinityManager().ensureLocation(dataSet.getFeaturesMaskArray(), AffinityManager.Location.DEVICE);
            }
            if (dataSet.getLabelsMaskArray() != null) {
                Nd4j.getAffinityManager().ensureLocation(dataSet.getLabelsMaskArray(), AffinityManager.Location.DEVICE);
            }
        }
    }

    @Override
    public void call(MultiDataSet multiDataSet) {
        if (multiDataSet != null) {
            int i;
            if (multiDataSet.getFeatures() != null) {
                for (i = 0; i < multiDataSet.getFeatures().length; ++i) {
                    Nd4j.getAffinityManager().ensureLocation(multiDataSet.getFeatures()[i], AffinityManager.Location.DEVICE);
                }
            }
            if (multiDataSet.getLabels() != null) {
                for (i = 0; i < multiDataSet.getLabels().length; ++i) {
                    Nd4j.getAffinityManager().ensureLocation(multiDataSet.getLabels()[i], AffinityManager.Location.DEVICE);
                }
            }
            if (multiDataSet.getFeaturesMaskArrays() != null) {
                for (i = 0; i < multiDataSet.getFeaturesMaskArrays().length; ++i) {
                    Nd4j.getAffinityManager().ensureLocation(multiDataSet.getFeaturesMaskArrays()[i], AffinityManager.Location.DEVICE);
                }
            }
            if (multiDataSet.getLabelsMaskArrays() != null) {
                for (i = 0; i < multiDataSet.getLabelsMaskArrays().length; ++i) {
                    Nd4j.getAffinityManager().ensureLocation(multiDataSet.getLabelsMaskArrays()[i], AffinityManager.Location.DEVICE);
                }
            }
        }
    }

    @Override
    public void reset() {
    }
}

