package org.deeplearning4j.rl4j.observation.preprocessor;

import org.deeplearning4j.rl4j.observation.preprocessor.pooling.ChannelStackPoolContentAssembler;
import org.deeplearning4j.rl4j.observation.preprocessor.pooling.CircularFifoObservationPool;
import org.deeplearning4j.rl4j.observation.preprocessor.pooling.ObservationPool;
import org.deeplearning4j.rl4j.observation.preprocessor.pooling.PoolContentAssembler;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;

/* loaded from: input_file:org/deeplearning4j/rl4j/observation/preprocessor/PoolingDataSetPreProcessor.class */
public class PoolingDataSetPreProcessor extends ResettableDataSetPreProcessor {
    private final ObservationPool observationPool;
    private final PoolContentAssembler poolContentAssembler;

    /* loaded from: input_file:org/deeplearning4j/rl4j/observation/preprocessor/PoolingDataSetPreProcessor$Builder.class */
    public static class Builder {
        private ObservationPool observationPool;
        private PoolContentAssembler poolContentAssembler;

        public Builder observablePool(ObservationPool observationPool) {
            this.observationPool = observationPool;
            return this;
        }

        public Builder poolContentAssembler(PoolContentAssembler poolContentAssembler) {
            this.poolContentAssembler = poolContentAssembler;
            return this;
        }

        public PoolingDataSetPreProcessor build() {
            if (this.observationPool == null) {
                this.observationPool = new CircularFifoObservationPool();
            }
            if (this.poolContentAssembler == null) {
                this.poolContentAssembler = new ChannelStackPoolContentAssembler();
            }
            return new PoolingDataSetPreProcessor(this);
        }
    }

    protected PoolingDataSetPreProcessor(Builder builder) {
        this.observationPool = builder.observationPool;
        this.poolContentAssembler = builder.poolContentAssembler;
    }

    public void preProcess(DataSet dataSet) {
        Preconditions.checkNotNull(dataSet, "Encountered null dataSet");
        if (dataSet.isEmpty()) {
            return;
        }
        Preconditions.checkArgument(dataSet.numExamples() == 1, "Pooling datasets conatining more than one example is not supported");
        this.observationPool.add(dataSet.getFeatures().slice(0L, 0).dup());
        if (!this.observationPool.isAtFullCapacity()) {
            dataSet.setFeatures((INDArray) null);
            return;
        }
        INDArray assemble = this.poolContentAssembler.assemble(this.observationPool.get());
        long[] shape = assemble.shape();
        long[] jArr = new long[shape.length + 1];
        jArr[0] = 1;
        System.arraycopy(shape, 0, jArr, 1, shape.length);
        dataSet.setFeatures(assemble.reshape(jArr));
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override // org.deeplearning4j.rl4j.observation.preprocessor.ResettableDataSetPreProcessor
    public void reset() {
        this.observationPool.reset();
    }
}
