package org.deeplearning4j.rl4j.network;

import java.util.HashMap;
import java.util.Map;
import org.deeplearning4j.rl4j.agent.learning.update.FeaturesLabels;
import org.deeplearning4j.rl4j.agent.learning.update.Gradients;
import org.deeplearning4j.rl4j.network.BaseNetwork;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/rl4j/network/BaseNetwork.class */
public abstract class BaseNetwork<NET_TYPE extends BaseNetwork> implements ITrainableNeuralNet<NET_TYPE> {
    private final INetworkHandler networkHandler;
    private final Map<Observation, NeuralNetOutput> neuralNetOutputCache = new HashMap();

    /* loaded from: input_file:org/deeplearning4j/rl4j/network/BaseNetwork$ModelCounters.class */
    protected static final class ModelCounters {
        private final int iterationCount;
        private final int epochCount;

        public ModelCounters(int i, int i2) {
            this.iterationCount = i;
            this.epochCount = i2;
        }

        public int getIterationCount() {
            return this.iterationCount;
        }

        public int getEpochCount() {
            return this.epochCount;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof ModelCounters)) {
                return false;
            }
            ModelCounters modelCounters = (ModelCounters) obj;
            return getIterationCount() == modelCounters.getIterationCount() && getEpochCount() == modelCounters.getEpochCount();
        }

        public int hashCode() {
            return (((1 * 59) + getIterationCount()) * 59) + getEpochCount();
        }

        public String toString() {
            return "BaseNetwork.ModelCounters(iterationCount=" + getIterationCount() + ", epochCount=" + getEpochCount() + ")";
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseNetwork(INetworkHandler iNetworkHandler) {
        this.networkHandler = iNetworkHandler;
    }

    @Override // org.deeplearning4j.rl4j.network.IOutputNeuralNet
    public boolean isRecurrent() {
        return this.networkHandler.isRecurrent();
    }

    @Override // org.deeplearning4j.rl4j.network.ITrainableNeuralNet
    public void fit(FeaturesLabels featuresLabels) {
        invalidateCache();
        this.networkHandler.performFit(featuresLabels);
    }

    @Override // org.deeplearning4j.rl4j.network.ITrainableNeuralNet
    public Gradients computeGradients(FeaturesLabels featuresLabels) {
        this.networkHandler.performGradientsComputation(featuresLabels);
        this.networkHandler.notifyGradientCalculation();
        Gradients gradients = new Gradients(featuresLabels.getBatchSize());
        this.networkHandler.fillGradientsResponse(gradients);
        return gradients;
    }

    @Override // org.deeplearning4j.rl4j.network.ITrainableNeuralNet
    public void applyGradients(Gradients gradients) {
        invalidateCache();
        this.networkHandler.applyGradient(gradients, gradients.getBatchSize());
        this.networkHandler.notifyIterationDone();
    }

    @Override // org.deeplearning4j.rl4j.network.IOutputNeuralNet
    public NeuralNetOutput output(Observation observation) {
        NeuralNetOutput neuralNetOutput = this.neuralNetOutputCache.get(observation);
        if (neuralNetOutput == null) {
            neuralNetOutput = isRecurrent() ? packageResult(this.networkHandler.recurrentStepOutput(observation)) : output(observation.getData());
            this.neuralNetOutputCache.put(observation, neuralNetOutput);
        }
        return neuralNetOutput;
    }

    protected abstract NeuralNetOutput packageResult(INDArray[] iNDArrayArr);

    @Override // org.deeplearning4j.rl4j.network.IOutputNeuralNet
    public NeuralNetOutput output(INDArray iNDArray) {
        return packageResult(this.networkHandler.batchOutput(iNDArray));
    }

    @Override // org.deeplearning4j.rl4j.network.IOutputNeuralNet
    public void reset() {
        invalidateCache();
        if (isRecurrent()) {
            this.networkHandler.resetState();
        }
    }

    protected void invalidateCache() {
        this.neuralNetOutputCache.clear();
    }

    @Override // org.deeplearning4j.rl4j.network.ITrainableNeuralNet
    public void copyFrom(BaseNetwork baseNetwork) {
        reset();
        this.networkHandler.copyFrom(baseNetwork.networkHandler);
    }

    @Override // org.deeplearning4j.rl4j.network.ITrainableNeuralNet
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public abstract NET_TYPE mo26clone();

    /* JADX INFO: Access modifiers changed from: protected */
    public INetworkHandler getNetworkHandler() {
        return this.networkHandler;
    }
}
