package org.deeplearning4j.datasets.iterator.callbacks;

import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/datasets/iterator/callbacks/DefaultCallback.class */
public class DefaultCallback implements DataSetCallback {
    @Override // org.deeplearning4j.datasets.iterator.callbacks.DataSetCallback
    public void call(DataSet dataSet) {
        if (dataSet != null) {
            if (dataSet.getFeatures() != null) {
                Nd4j.getAffinityManager().ensureLocation(dataSet.getFeatures(), AffinityManager.Location.DEVICE);
            }
            if (dataSet.getLabels() != null) {
                Nd4j.getAffinityManager().ensureLocation(dataSet.getLabels(), AffinityManager.Location.DEVICE);
            }
            if (dataSet.getFeaturesMaskArray() != null) {
                Nd4j.getAffinityManager().ensureLocation(dataSet.getFeaturesMaskArray(), AffinityManager.Location.DEVICE);
            }
            if (dataSet.getLabelsMaskArray() != null) {
                Nd4j.getAffinityManager().ensureLocation(dataSet.getLabelsMaskArray(), AffinityManager.Location.DEVICE);
            }
        }
    }

    @Override // org.deeplearning4j.datasets.iterator.callbacks.DataSetCallback
    public void call(MultiDataSet multiDataSet) {
        if (multiDataSet != null) {
            if (multiDataSet.getFeatures() != null) {
                for (int i = 0; i < multiDataSet.getFeatures().length; i++) {
                    Nd4j.getAffinityManager().ensureLocation(multiDataSet.getFeatures()[i], AffinityManager.Location.DEVICE);
                }
            }
            if (multiDataSet.getLabels() != null) {
                for (int i2 = 0; i2 < multiDataSet.getLabels().length; i2++) {
                    Nd4j.getAffinityManager().ensureLocation(multiDataSet.getLabels()[i2], AffinityManager.Location.DEVICE);
                }
            }
            if (multiDataSet.getFeaturesMaskArrays() != null) {
                for (int i3 = 0; i3 < multiDataSet.getFeaturesMaskArrays().length; i3++) {
                    Nd4j.getAffinityManager().ensureLocation(multiDataSet.getFeaturesMaskArrays()[i3], AffinityManager.Location.DEVICE);
                }
            }
            if (multiDataSet.getLabelsMaskArrays() != null) {
                for (int i4 = 0; i4 < multiDataSet.getLabelsMaskArrays().length; i4++) {
                    Nd4j.getAffinityManager().ensureLocation(multiDataSet.getLabelsMaskArrays()[i4], AffinityManager.Location.DEVICE);
                }
            }
        }
    }

    @Override // org.deeplearning4j.datasets.iterator.callbacks.DataSetCallback
    public void reset() {
    }
}
