package org.deeplearning4j.rl4j.experience;

import java.util.List;
import org.deeplearning4j.rl4j.learning.sync.ExpReplay;
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandler.class */
public class ReplayMemoryExperienceHandler<A> implements ExperienceHandler<A, Transition<A>> {
    private static final int DEFAULT_MAX_REPLAY_MEMORY_SIZE = 150000;
    private static final int DEFAULT_BATCH_SIZE = 32;
    private IExpReplay<A> expReplay;
    private Transition<A> pendingTransition;

    /* loaded from: input_file:org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandler$Builder.class */
    public class Builder {
        private int maxReplayMemorySize = ReplayMemoryExperienceHandler.DEFAULT_MAX_REPLAY_MEMORY_SIZE;
        private int batchSize = ReplayMemoryExperienceHandler.DEFAULT_BATCH_SIZE;
        private Random random = Nd4j.getRandom();

        public Builder() {
        }

        public ReplayMemoryExperienceHandler<A>.Builder maxReplayMemorySize(int i) {
            this.maxReplayMemorySize = i;
            return this;
        }

        public ReplayMemoryExperienceHandler<A>.Builder batchSize(int i) {
            this.batchSize = i;
            return this;
        }

        public ReplayMemoryExperienceHandler<A>.Builder random(Random random) {
            this.random = random;
            return this;
        }

        public ReplayMemoryExperienceHandler<A> build() {
            return new ReplayMemoryExperienceHandler<>(this.maxReplayMemorySize, this.batchSize, this.random);
        }
    }

    public ReplayMemoryExperienceHandler(IExpReplay<A> iExpReplay) {
        this.expReplay = iExpReplay;
    }

    public ReplayMemoryExperienceHandler(int i, int i2, Random random) {
        this(new ExpReplay(i, i2, random));
    }

    @Override // org.deeplearning4j.rl4j.experience.ExperienceHandler
    public void addExperience(Observation observation, A a, double d, boolean z) {
        setNextObservationOnPending(observation);
        this.pendingTransition = new Transition<>(observation, a, d, z);
    }

    @Override // org.deeplearning4j.rl4j.experience.ExperienceHandler
    public void setFinalObservation(Observation observation) {
        setNextObservationOnPending(observation);
        this.pendingTransition = null;
    }

    @Override // org.deeplearning4j.rl4j.experience.ExperienceHandler
    public int getTrainingBatchSize() {
        return this.expReplay.getBatchSize();
    }

    @Override // org.deeplearning4j.rl4j.experience.ExperienceHandler
    public List<Transition<A>> generateTrainingBatch() {
        return this.expReplay.getBatch();
    }

    @Override // org.deeplearning4j.rl4j.experience.ExperienceHandler
    public void reset() {
        this.pendingTransition = null;
    }

    private void setNextObservationOnPending(Observation observation) {
        if (this.pendingTransition != null) {
            this.pendingTransition.setNextObservation(observation);
            this.expReplay.store(this.pendingTransition);
        }
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ReplayMemoryExperienceHandler)) {
            return false;
        }
        ReplayMemoryExperienceHandler replayMemoryExperienceHandler = (ReplayMemoryExperienceHandler) obj;
        if (!replayMemoryExperienceHandler.canEqual(this)) {
            return false;
        }
        IExpReplay<A> iExpReplay = this.expReplay;
        IExpReplay<A> iExpReplay2 = replayMemoryExperienceHandler.expReplay;
        if (iExpReplay == null) {
            if (iExpReplay2 != null) {
                return false;
            }
        } else if (!iExpReplay.equals(iExpReplay2)) {
            return false;
        }
        Transition<A> transition = this.pendingTransition;
        Transition<A> transition2 = replayMemoryExperienceHandler.pendingTransition;
        return transition == null ? transition2 == null : transition.equals(transition2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof ReplayMemoryExperienceHandler;
    }

    public int hashCode() {
        IExpReplay<A> iExpReplay = this.expReplay;
        int hashCode = (1 * 59) + (iExpReplay == null ? 43 : iExpReplay.hashCode());
        Transition<A> transition = this.pendingTransition;
        return (hashCode * 59) + (transition == null ? 43 : transition.hashCode());
    }
}
