package org.deeplearning4j.rl4j.agent;

import lombok.NonNull;
import org.deeplearning4j.rl4j.agent.listener.AgentListener;
import org.deeplearning4j.rl4j.agent.listener.AgentListenerList;
import org.deeplearning4j.rl4j.environment.Environment;
import org.deeplearning4j.rl4j.environment.StepResult;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.nd4j.common.base.Preconditions;

/* loaded from: input_file:org/deeplearning4j/rl4j/agent/Agent.class */
public class Agent<ACTION> {
    private final String id;
    private final Environment<ACTION> environment;
    private final IPolicy<ACTION> policy;
    private final TransformProcess transformProcess;
    protected final AgentListenerList<ACTION> listeners;
    private final Integer maxEpisodeSteps;
    private Observation observation;
    private ACTION lastAction;
    private int episodeStepNumber;
    private double reward;
    protected boolean canContinue;

    /* loaded from: input_file:org/deeplearning4j/rl4j/agent/Agent$Builder.class */
    public static class Builder<ACTION> {
        private final Environment<ACTION> environment;
        private final TransformProcess transformProcess;
        private final IPolicy<ACTION> policy;
        private Integer maxEpisodeSteps = null;
        private String id;

        public Builder(@NonNull Environment<ACTION> environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy<ACTION> iPolicy) {
            if (environment == null) {
                throw new NullPointerException("environment is marked non-null but is null");
            }
            if (transformProcess == null) {
                throw new NullPointerException("transformProcess is marked non-null but is null");
            }
            if (iPolicy == null) {
                throw new NullPointerException("policy is marked non-null but is null");
            }
            this.environment = environment;
            this.transformProcess = transformProcess;
            this.policy = iPolicy;
        }

        public Builder<ACTION> maxEpisodeSteps(int i) {
            Preconditions.checkArgument(i > 0, "maxEpisodeSteps must be greater than 0, got", i);
            this.maxEpisodeSteps = Integer.valueOf(i);
            return this;
        }

        public Builder<ACTION> id(String str) {
            this.id = str;
            return this;
        }

        public Agent build() {
            return new Agent(this);
        }
    }

    private Agent(Builder<ACTION> builder) {
        this.environment = ((Builder) builder).environment;
        this.transformProcess = ((Builder) builder).transformProcess;
        this.policy = ((Builder) builder).policy;
        this.maxEpisodeSteps = ((Builder) builder).maxEpisodeSteps;
        this.id = ((Builder) builder).id;
        this.listeners = buildListenerList();
    }

    protected AgentListenerList<ACTION> buildListenerList() {
        return new AgentListenerList<>();
    }

    public void addListener(AgentListener agentListener) {
        this.listeners.add(agentListener);
    }

    public void run() {
        runEpisode();
    }

    protected void onBeforeEpisode() {
    }

    protected void onAfterEpisode() {
    }

    protected void runEpisode() {
        reset();
        onBeforeEpisode();
        this.canContinue = this.listeners.notifyBeforeEpisode(this);
        while (this.canContinue && !this.environment.isEpisodeFinished() && (this.maxEpisodeSteps == null || this.episodeStepNumber < this.maxEpisodeSteps.intValue())) {
            performStep();
        }
        if (this.canContinue) {
            onAfterEpisode();
        }
    }

    protected void reset() {
        resetEnvironment();
        resetPolicy();
        this.reward = 0.0d;
        this.lastAction = getInitialAction();
        this.canContinue = true;
    }

    protected void resetEnvironment() {
        this.episodeStepNumber = 0;
        this.observation = this.transformProcess.transform(this.environment.reset(), this.episodeStepNumber, false);
    }

    protected void resetPolicy() {
        this.policy.reset();
    }

    protected ACTION getInitialAction() {
        return this.environment.getSchema().getActionSchema().getNoOp();
    }

    protected void performStep() {
        onBeforeStep();
        ACTION decideAction = decideAction(this.observation);
        this.canContinue = this.listeners.notifyBeforeStep(this, this.observation, decideAction);
        if (this.canContinue) {
            StepResult act = act(decideAction);
            handleStepResult(act);
            onAfterStep(act);
            this.canContinue = this.listeners.notifyAfterStep(this, act);
            if (this.canContinue) {
                incrementEpisodeStepNumber();
            }
        }
    }

    protected void incrementEpisodeStepNumber() {
        this.episodeStepNumber++;
    }

    protected ACTION decideAction(Observation observation) {
        if (!observation.isSkipped()) {
            this.lastAction = this.policy.nextAction(observation);
        }
        return this.lastAction;
    }

    protected StepResult act(ACTION action) {
        return this.environment.step(action);
    }

    protected void handleStepResult(StepResult stepResult) {
        this.observation = convertChannelDataToObservation(stepResult, this.episodeStepNumber + 1);
        this.reward += computeReward(stepResult);
    }

    protected Observation convertChannelDataToObservation(StepResult stepResult, int i) {
        return this.transformProcess.transform(stepResult.getChannelsData(), i, stepResult.isTerminal());
    }

    protected double computeReward(StepResult stepResult) {
        return stepResult.getReward();
    }

    protected void onAfterStep(StepResult stepResult) {
    }

    protected void onBeforeStep() {
    }

    public static <ACTION> Builder<ACTION> builder(@NonNull Environment<ACTION> environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy<ACTION> iPolicy) {
        if (environment == null) {
            throw new NullPointerException("environment is marked non-null but is null");
        }
        if (transformProcess == null) {
            throw new NullPointerException("transformProcess is marked non-null but is null");
        }
        if (iPolicy == null) {
            throw new NullPointerException("policy is marked non-null but is null");
        }
        return new Builder<>(environment, transformProcess, iPolicy);
    }

    public String getId() {
        return this.id;
    }

    public Environment<ACTION> getEnvironment() {
        return this.environment;
    }

    public IPolicy<ACTION> getPolicy() {
        return this.policy;
    }

    protected Observation getObservation() {
        return this.observation;
    }

    protected ACTION getLastAction() {
        return this.lastAction;
    }

    public int getEpisodeStepNumber() {
        return this.episodeStepNumber;
    }

    public double getReward() {
        return this.reward;
    }
}
