diff --git a/pom.xml b/pom.xml index 120159c30..9b32f25ae 100644 --- a/pom.xml +++ b/pom.xml @@ -303,8 +303,8 @@ 1.79.0 1.10.6 0.6.1 - 0.15.4 - 1.15.0 + 0.15.5 + 1.15.2 ${tensorflow.version}-${javacpp-presets.version} 1.18 diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/EpochStepCounter.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/EpochStepCounter.java new file mode 100644 index 000000000..746a71396 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/EpochStepCounter.java @@ -0,0 +1,5 @@ +package org.deeplearning4j.rl4j.learning; + +public interface EpochStepCounter { + int getCurrentEpochStep(); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IEpochTrainer.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IEpochTrainer.java index 72510dcaa..f113ce157 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IEpochTrainer.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IEpochTrainer.java @@ -21,9 +21,14 @@ import org.deeplearning4j.rl4j.mdp.MDP; /** * The common API between Learning and AsyncThread. * + * Express the ability to count the number of step of the current training. + * Factorisation of a feature between threads in async and learning process + * for the web monitoring + * * @author Alexandre Boulanger + * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. */ -public interface IEpochTrainer { +public interface IEpochTrainer extends EpochStepCounter { int getStepCounter(); int getEpochCounter(); IHistoryProcessor getHistoryProcessor(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java index 3c4c94c6b..d151f093b 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java @@ -26,7 +26,7 @@ import org.deeplearning4j.rl4j.space.Encodable; * * A common interface that any training method should implement */ -public interface ILearning> extends StepCountable { +public interface ILearning> { IPolicy getPolicy(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java index 780a73752..833094929 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java @@ -21,7 +21,6 @@ import lombok.Getter; import lombok.Setter; import lombok.Value; import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.space.ActionSpace; @@ -53,55 +52,6 @@ public abstract class Learning, NN extends Neura return Nd4j.argMax(vector, Integer.MAX_VALUE).getInt(0); } - public static > INDArray getInput(MDP mdp, O obs) { - INDArray arr = Nd4j.create(((Encodable)obs).toArray()); - int[] shape = mdp.getObservationSpace().getShape(); - if (shape.length == 1) - return arr.reshape(new long[] {1, arr.length()}); - else - return arr.reshape(shape); - } - - public static > InitMdp initMdp(MDP mdp, - IHistoryProcessor hp) { - - O obs = mdp.reset(); - - int step = 0; - double reward = 0; - - boolean isHistoryProcessor = hp != null; - - int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1; - int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0; - - INDArray input = Learning.getInput(mdp, obs); - if (isHistoryProcessor) - hp.record(input); - - - while (step < requiredFrame && !mdp.isDone()) { - - A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP - if (step % skipFrame == 0 && isHistoryProcessor) - hp.add(input); - - StepReply stepReply = mdp.step(action); - reward += stepReply.getReward(); - obs = stepReply.getObservation(); - - input = Learning.getInput(mdp, obs); - if (isHistoryProcessor) - hp.record(input); - - step++; - - } - - return new InitMdp(step, obs, reward); - - } - public static int[] makeShape(int size, int[] shape) { int[] nshape = new int[shape.length + 1]; nshape[0] = size; @@ -122,16 +72,16 @@ public abstract class Learning, NN extends Neura public abstract NN getNeuralNet(); - public int incrementStep() { - return stepCounter++; + public void incrementStep() { + stepCounter++; } - public int incrementEpoch() { - return epochCounter++; + public void incrementEpoch() { + epochCounter++; } public void setHistoryProcessor(HistoryProcessor.Configuration conf) { - historyProcessor = new HistoryProcessor(conf); + setHistoryProcessor(new HistoryProcessor(conf)); } public void setHistoryProcessor(IHistoryProcessor historyProcessor) { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/StepCountable.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/StepCountable.java deleted file mode 100644 index b9538163b..000000000 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/StepCountable.java +++ /dev/null @@ -1,30 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.rl4j.learning; - -/** - * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. - * - * Express the ability to count the number of step of the current training. - * Factorisation of a feature between threads in async and learning process - * for the web monitoring - */ -public interface StepCountable { - - int getStepCounter(); - -} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java index 2495e74ce..e95b25337 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java @@ -30,8 +30,6 @@ import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.ActionSpace; -import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.nd4j.linalg.factory.Nd4j; @@ -48,7 +46,7 @@ import org.nd4j.linalg.factory.Nd4j; */ @Slf4j public abstract class AsyncThread, NN extends NeuralNet> - extends Thread implements StepCountable, IEpochTrainer { + extends Thread implements IEpochTrainer { @Getter private int threadNumber; @@ -61,6 +59,9 @@ public abstract class AsyncThread, NN extends Ne @Getter @Setter private IHistoryProcessor historyProcessor; + @Getter + private int currentEpochStep = 0; + private boolean isEpochStarted = false; private final LegacyMDPWrapper mdp; @@ -138,7 +139,7 @@ public abstract class AsyncThread, NN extends Ne handleTraining(context); - if (context.epochElapsedSteps >= getConf().getMaxEpochStep() || getMdp().isDone()) { + if (currentEpochStep >= getConf().getMaxEpochStep() || getMdp().isDone()) { boolean canContinue = finishEpoch(context); if (!canContinue) { break; @@ -154,11 +155,10 @@ public abstract class AsyncThread, NN extends Ne } private void handleTraining(RunContext context) { - int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - context.epochElapsedSteps); + int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - currentEpochStep); SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxSteps); context.obs = subEpochReturn.getLastObs(); - context.epochElapsedSteps += subEpochReturn.getSteps(); context.rewards += subEpochReturn.getReward(); context.score = subEpochReturn.getScore(); } @@ -169,7 +169,6 @@ public abstract class AsyncThread, NN extends Ne context.obs = initMdp.getLastObs(); context.rewards = initMdp.getReward(); - context.epochElapsedSteps = initMdp.getSteps(); isEpochStarted = true; preEpoch(); @@ -180,9 +179,9 @@ public abstract class AsyncThread, NN extends Ne private boolean finishEpoch(RunContext context) { isEpochStarted = false; postEpoch(); - IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, context.epochElapsedSteps, context.score); + IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, currentEpochStep, context.score); - log.info("ThreadNum-" + threadNumber + " Epoch: " + getEpochCounter() + ", reward: " + context.rewards); + log.info("ThreadNum-" + threadNumber + " Epoch: " + getCurrentEpochStep() + ", reward: " + context.rewards); return listeners.notifyEpochTrainingResult(this, statEntry); } @@ -205,37 +204,30 @@ public abstract class AsyncThread, NN extends Ne protected abstract SubEpochReturn trainSubEpoch(Observation obs, int nstep); private Learning.InitMdp refacInitMdp() { - LegacyMDPWrapper mdp = getLegacyMDPWrapper(); - IHistoryProcessor hp = getHistoryProcessor(); + currentEpochStep = 0; - Observation observation = mdp.reset(); - - int step = 0; double reward = 0; - boolean isHistoryProcessor = hp != null; - - int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1; - int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0; - - while (step < requiredFrame && !mdp.isDone()) { - - A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP + LegacyMDPWrapper mdp = getLegacyMDPWrapper(); + Observation observation = mdp.reset(); + A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP + while (observation.isSkipped() && !mdp.isDone()) { StepReply stepReply = mdp.step(action); + reward += stepReply.getReward(); observation = stepReply.getObservation(); - step++; - + incrementStep(); } - return new Learning.InitMdp(step, observation, reward); + return new Learning.InitMdp(0, observation, reward); } public void incrementStep() { ++stepCounter; + ++currentEpochStep; } @AllArgsConstructor @@ -260,7 +252,6 @@ public abstract class AsyncThread, NN extends Ne private static class RunContext { private Observation obs; private double rewards; - private int epochElapsedSteps; private double score; } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java index 6b0078883..0b3eb1c72 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java @@ -20,19 +20,14 @@ import lombok.Getter; import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; -import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; -import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.deeplearning4j.rl4j.space.Encodable; -import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.util.ArrayUtil; import java.util.Stack; @@ -74,17 +69,18 @@ public abstract class AsyncThreadDiscrete IPolicy policy = getPolicy(current); Integer action; - Integer lastAction = null; + Integer lastAction = getMdp().getActionSpace().noOp(); IHistoryProcessor hp = getHistoryProcessor(); int skipFrame = hp != null ? hp.getConf().getSkipFrame() : 1; double reward = 0; double accuReward = 0; - int i = 0; - while (!getMdp().isDone() && i < nstep * skipFrame) { + int stepAtStart = getCurrentEpochStep(); + int lastStep = nstep * skipFrame + stepAtStart; + while (!getMdp().isDone() && getCurrentEpochStep() < lastStep) { //if step of training, just repeat lastAction - if (i % skipFrame != 0 && lastAction != null) { + if (obs.isSkipped()) { action = lastAction; } else { action = policy.nextAction(obs); @@ -94,7 +90,7 @@ public abstract class AsyncThreadDiscrete accuReward += stepReply.getReward() * getConf().getRewardFactor(); //if it's not a skipped frame, you can do a step of training - if (i % skipFrame == 0 || lastAction == null || stepReply.isDone()) { + if (!obs.isSkipped() || stepReply.isDone()) { INDArray[] output = current.outputAll(obs.getData()); rewards.add(new MiniTrans(obs.getData(), action, output, accuReward)); @@ -106,7 +102,6 @@ public abstract class AsyncThreadDiscrete reward += stepReply.getReward(); - i++; incrementStep(); lastAction = action; } @@ -114,7 +109,7 @@ public abstract class AsyncThreadDiscrete //a bit of a trick usable because of how the stack is treated to init R // FIXME: The last element of minitrans is only used to seed the reward in calcGradient; observation, action and output are ignored. - if (getMdp().isDone() && i < nstep * skipFrame) + if (getMdp().isDone() && getCurrentEpochStep() < lastStep) rewards.add(new MiniTrans(obs.getData(), null, null, 0)); else { INDArray[] output = null; @@ -127,9 +122,9 @@ public abstract class AsyncThreadDiscrete rewards.add(new MiniTrans(obs.getData(), null, output, maxQ)); } - getAsyncGlobal().enqueue(calcGradient(current, rewards), i); + getAsyncGlobal().enqueue(calcGradient(current, rewards), getCurrentEpochStep()); - return new SubEpochReturn(i, obs, reward, current.getLatestScore()); + return new SubEpochReturn(getCurrentEpochStep() - stepAtStart, obs, reward, current.getLatestScore()); } public abstract Gradient[] calcGradient(NN nn, Stack> rewards); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java index ed5d73a75..22d936fcf 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java @@ -17,7 +17,6 @@ package org.deeplearning4j.rl4j.learning.sync; import lombok.Getter; -import lombok.Setter; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.rl4j.learning.IEpochTrainer; import org.deeplearning4j.rl4j.learning.ILearning; @@ -25,7 +24,6 @@ import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.listener.*; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.space.ActionSpace; -import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.util.IDataManager; /** diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/Transition.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/Transition.java index 509b28a88..3c35c47c7 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/Transition.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/Transition.java @@ -16,6 +16,7 @@ package org.deeplearning4j.rl4j.learning.sync; +import lombok.Data; import lombok.Value; import org.deeplearning4j.rl4j.observation.Observation; import org.nd4j.linalg.api.ndarray.INDArray; @@ -34,7 +35,7 @@ import java.util.List; * @author Alexandre Boulanger * */ -@Value +@Data public class Transition { Observation observation; @@ -43,12 +44,15 @@ public class Transition { boolean isTerminal; INDArray nextObservation; - public Transition(Observation observation, A action, double reward, boolean isTerminal, Observation nextObservation) { + public Transition(Observation observation, A action, double reward, boolean isTerminal) { this.observation = observation; this.action = action; this.reward = reward; this.isTerminal = isTerminal; + this.nextObservation = null; + } + public void setNextObservation(Observation nextObservation) { // To conserve memory, only the most recent frame of the next observation is kept (if history is used). // The full nextObservation will be re-build from observation when needed. long[] nextObservationShape = nextObservation.getData().shape().clone(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java index bfd23ef5b..0757043f0 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java @@ -21,8 +21,7 @@ import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder; import lombok.*; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.gym.StepReply; -import org.deeplearning4j.rl4j.learning.IHistoryProcessor; -import org.deeplearning4j.rl4j.learning.Learning; +import org.deeplearning4j.rl4j.learning.EpochStepCounter; import org.deeplearning4j.rl4j.learning.sync.ExpReplay; import org.deeplearning4j.rl4j.learning.sync.IExpReplay; import org.deeplearning4j.rl4j.learning.sync.SyncLearning; @@ -34,9 +33,8 @@ import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.util.IDataManager.StatEntry; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.api.rng.Random; +import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; import java.util.List; @@ -49,7 +47,8 @@ import java.util.List; */ @Slf4j public abstract class QLearning> - extends SyncLearning implements TargetQNetworkSource { + extends SyncLearning + implements TargetQNetworkSource, EpochStepCounter { // FIXME Changed for refac // @Getter @@ -104,18 +103,22 @@ public abstract class QLearning trainStep(Observation obs); + @Getter + private int currentEpochStep = 0; + protected StatEntry trainEpoch() { + resetNetworks(); + InitMdp initMdp = refacInitMdp(); Observation obs = initMdp.getLastObs(); double reward = initMdp.getReward(); - int step = initMdp.getSteps(); Double startQ = Double.NaN; double meanQ = 0; int numQ = 0; List scores = new ArrayList<>(); - while (step < getConfiguration().getMaxEpochStep() && !getMdp().isDone()) { + while (currentEpochStep < getConfiguration().getMaxEpochStep() && !getMdp().isDone()) { if (getStepCounter() % getConfiguration().getTargetDqnUpdateFreq() == 0) { updateTargetNetwork(); @@ -136,49 +139,53 @@ public abstract class QLearning refacInitMdp() { - getQNetwork().reset(); - getTargetQNetwork().reset(); + currentEpochStep = 0; - LegacyMDPWrapper mdp = getLegacyMDPWrapper(); - IHistoryProcessor hp = getHistoryProcessor(); - - Observation observation = mdp.reset(); - - int step = 0; double reward = 0; - boolean isHistoryProcessor = hp != null; - - int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1; - int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0; - - while (step < requiredFrame && !mdp.isDone()) { - - A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP + LegacyMDPWrapper mdp = getLegacyMDPWrapper(); + Observation observation = mdp.reset(); + A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP + while (observation.isSkipped() && !mdp.isDone()) { StepReply stepReply = mdp.step(action); + reward += stepReply.getReward(); observation = stepReply.getObservation(); - step++; - + incrementStep(); } - return new InitMdp(step, observation, reward); + return new InitMdp(0, observation, reward); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java index 9da30ccef..0fa0f33e6 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java @@ -20,10 +20,13 @@ import lombok.AccessLevel; import lombok.Getter; import lombok.Setter; import org.deeplearning4j.gym.StepReply; +import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; -import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.*; +import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN; +import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm; +import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.observation.Observation; @@ -36,7 +39,6 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.util.ArrayUtil; import java.util.ArrayList; @@ -68,6 +70,8 @@ public abstract class QLearningDiscrete extends QLearning getLegacyMDPWrapper() { @@ -83,7 +87,7 @@ public abstract class QLearningDiscrete extends QLearning(mdp, this); + this.mdp = new LegacyMDPWrapper(mdp, null, this); qNetwork = dqn; targetQNetwork = dqn.clone(); policy = new DQNPolicy(getQNetwork()); @@ -108,8 +112,15 @@ public abstract class QLearningDiscrete extends QLearning extends QLearning trainStep(Observation obs) { Integer action; + boolean isHistoryProcessor = getHistoryProcessor() != null; - - int skipFrame = isHistoryProcessor ? getHistoryProcessor().getConf().getSkipFrame() : 1; int historyLength = isHistoryProcessor ? getHistoryProcessor().getConf().getHistoryLength() : 1; int updateStart = getConfiguration().getUpdateStart() @@ -131,7 +141,7 @@ public abstract class QLearningDiscrete extends QLearning extends QLearning stepReply = mdp.step(action); - Observation nextObservation = stepReply.getObservation(); - accuReward += stepReply.getReward() * configuration.getRewardFactor(); //if it's not a skipped frame, you can do a step of training - if (getStepCounter() % skipFrame == 0 || stepReply.isDone()) { + if (!obs.isSkipped() || stepReply.isDone()) { - Transition trans = new Transition(obs, action, accuReward, stepReply.isDone(), nextObservation); - getExpReplay().store(trans); + // Add experience + if(pendingTransition != null) { + pendingTransition.setNextObservation(obs); + getExpReplay().store(pendingTransition); + } + pendingTransition = new Transition(obs, action, accuReward, stepReply.isDone()); + accuReward = 0; + // Update NN + // FIXME: maybe start updating when experience replay has reached a certain size instead of using "updateStart"? if (getStepCounter() > updateStart) { DataSet targets = setTarget(getExpReplay().getBatch()); getQNetwork().fit(targets.getFeatures(), targets.getLabels()); } - - accuReward = 0; } return new QLStepReturn(maxQ, getQNetwork().getLatestScore(), stepReply); @@ -172,4 +185,12 @@ public abstract class QLearningDiscrete extends QLearning xThreshold + done |= x < -xThreshold || x > xThreshold || theta < -thetaThresholdRadians || theta > thetaThresholdRadians; double reward; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java index 197a8f744..af1fe18ea 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java @@ -16,6 +16,8 @@ package org.deeplearning4j.rl4j.observation; +import lombok.Getter; +import lombok.Setter; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.factory.Nd4j; @@ -28,6 +30,9 @@ public class Observation { private final DataSet data; + @Getter @Setter + private boolean skipped; + public Observation(INDArray[] data) { this(data, false); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java index 732ce1a0e..3ed375084 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java @@ -18,7 +18,8 @@ package org.deeplearning4j.rl4j.policy; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.rl4j.learning.StepCountable; +import org.deeplearning4j.rl4j.learning.IEpochTrainer; +import org.deeplearning4j.rl4j.learning.ILearning; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.observation.Observation; @@ -46,7 +47,7 @@ public class EpsGreedy> extends Policy { final private int epsilonNbStep; final private Random rnd; final private float minEpsilon; - final private StepCountable learning; + final private IEpochTrainer learning; public NeuralNet getNeuralNet() { return policy.getNeuralNet(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java index 97a11b99c..fb20a60ac 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java @@ -19,20 +19,15 @@ package org.deeplearning4j.rl4j.policy; import lombok.Getter; import lombok.Setter; import org.deeplearning4j.gym.StepReply; +import org.deeplearning4j.rl4j.learning.EpochStepCounter; import org.deeplearning4j.rl4j.learning.HistoryProcessor; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.Learning; -import org.deeplearning4j.rl4j.learning.StepCountable; -import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.ActionSpace; -import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.util.ArrayUtil; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16. @@ -57,24 +52,22 @@ public abstract class Policy implements IPolicy { @Override public > double play(MDP mdp, IHistoryProcessor hp) { - RefacStepCountable stepCountable = new RefacStepCountable(); - LegacyMDPWrapper mdpWrapper = new LegacyMDPWrapper(mdp, hp, stepCountable); + resetNetworks(); - boolean isHistoryProcessor = hp != null; - int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1; + RefacEpochStepCounter epochStepCounter = new RefacEpochStepCounter(); + LegacyMDPWrapper mdpWrapper = new LegacyMDPWrapper(mdp, hp, epochStepCounter); - Learning.InitMdp initMdp = refacInitMdp(mdpWrapper, hp); + Learning.InitMdp initMdp = refacInitMdp(mdpWrapper, hp, epochStepCounter); Observation obs = initMdp.getLastObs(); double reward = initMdp.getReward(); A lastAction = mdpWrapper.getActionSpace().noOp(); A action; - stepCountable.setStepCounter(initMdp.getSteps()); while (!mdpWrapper.isDone()) { - if (stepCountable.getStepCounter() % skipFrame != 0) { + if (obs.isSkipped()) { action = lastAction; } else { action = nextAction(obs); @@ -86,52 +79,46 @@ public abstract class Policy implements IPolicy { reward += stepReply.getReward(); obs = stepReply.getObservation(); - stepCountable.increment(); + epochStepCounter.incrementEpochStep(); } return reward; } - private > Learning.InitMdp refacInitMdp(LegacyMDPWrapper mdpWrapper, IHistoryProcessor hp) { + protected void resetNetworks() { getNeuralNet().reset(); - Observation observation = mdpWrapper.reset(); + } + + private > Learning.InitMdp refacInitMdp(LegacyMDPWrapper mdpWrapper, IHistoryProcessor hp, RefacEpochStepCounter epochStepCounter) { + epochStepCounter.setCurrentEpochStep(0); - int step = 0; double reward = 0; - boolean isHistoryProcessor = hp != null; + Observation observation = mdpWrapper.reset(); - int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1; - int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0; - - while (step < requiredFrame && !mdpWrapper.isDone()) { - - A action = mdpWrapper.getActionSpace().noOp(); //by convention should be the NO_OP + A action = mdpWrapper.getActionSpace().noOp(); //by convention should be the NO_OP + while (observation.isSkipped() && !mdpWrapper.isDone()) { StepReply stepReply = mdpWrapper.step(action); + reward += stepReply.getReward(); observation = stepReply.getObservation(); - step++; - + epochStepCounter.incrementEpochStep(); } - return new Learning.InitMdp(step, observation, reward); + return new Learning.InitMdp(0, observation, reward); } - private class RefacStepCountable implements StepCountable { + public class RefacEpochStepCounter implements EpochStepCounter { @Getter @Setter - private int stepCounter = 0; + private int currentEpochStep = 0; - public void increment() { - ++stepCounter; + public void incrementEpochStep() { + ++currentEpochStep; } - @Override - public int getStepCounter() { - return 0; - } } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java index dbcd38ddc..26546a923 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java @@ -1,10 +1,11 @@ package org.deeplearning4j.rl4j.util; +import lombok.AccessLevel; import lombok.Getter; +import lombok.Setter; import org.deeplearning4j.gym.StepReply; +import org.deeplearning4j.rl4j.learning.EpochStepCounter; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; -import org.deeplearning4j.rl4j.learning.ILearning; -import org.deeplearning4j.rl4j.learning.StepCountable; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.ActionSpace; @@ -19,50 +20,20 @@ public class LegacyMDPWrapper> implements MDP wrappedMDP; @Getter private final WrapperObservationSpace observationSpace; - private final ILearning learning; + + @Getter(AccessLevel.PRIVATE) @Setter(AccessLevel.PUBLIC) private IHistoryProcessor historyProcessor; - private final StepCountable stepCountable; - private int skipFrame; - private int step = 0; + private final EpochStepCounter epochStepCounter; - public LegacyMDPWrapper(MDP wrappedMDP, ILearning learning) { - this(wrappedMDP, learning, null, null); - } + private int skipFrame = 1; + private int requiredFrame = 0; - public LegacyMDPWrapper(MDP wrappedMDP, IHistoryProcessor historyProcessor, StepCountable stepCountable) { - this(wrappedMDP, null, historyProcessor, stepCountable); - } - - private LegacyMDPWrapper(MDP wrappedMDP, ILearning learning, IHistoryProcessor historyProcessor, StepCountable stepCountable) { + public LegacyMDPWrapper(MDP wrappedMDP, IHistoryProcessor historyProcessor, EpochStepCounter epochStepCounter) { this.wrappedMDP = wrappedMDP; this.observationSpace = new WrapperObservationSpace(wrappedMDP.getObservationSpace().getShape()); - this.learning = learning; this.historyProcessor = historyProcessor; - this.stepCountable = stepCountable; - } - - private IHistoryProcessor getHistoryProcessor() { - if(historyProcessor != null) { - return historyProcessor; - } - - if (learning != null) { - return learning.getHistoryProcessor(); - } - return null; - } - - public void setHistoryProcessor(IHistoryProcessor historyProcessor) { - this.historyProcessor = historyProcessor; - } - - private int getStep() { - if(stepCountable != null) { - return stepCountable.getStepCounter(); - } - - return learning.getStepCounter(); + this.epochStepCounter = epochStepCounter; } @Override @@ -83,9 +54,12 @@ public class LegacyMDPWrapper> implements MDP> implements MDP rawStepReply = wrappedMDP.step(a); INDArray rawObservation = getInput(rawStepReply.getObservation()); - ++step; + int stepOfObservation = epochStepCounter.getCurrentEpochStep() + 1; - int requiredFrame = 0; if(historyProcessor != null) { historyProcessor.record(rawObservation); - requiredFrame = skipFrame * (historyProcessor.getConf().getHistoryLength() - 1); - if ((getStep() % skipFrame == 0 && step >= requiredFrame) - || (step % skipFrame == 0 && step < requiredFrame )){ + if (stepOfObservation % skipFrame == 0) { historyProcessor.add(rawObservation); } } Observation observation; - if(historyProcessor != null && step >= requiredFrame) { + if(historyProcessor != null && stepOfObservation >= requiredFrame) { observation = new Observation(historyProcessor.getHistory(), true); observation.getData().muli(1.0 / historyProcessor.getScale()); } @@ -119,6 +90,10 @@ public class LegacyMDPWrapper> implements MDP(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo()); } @@ -134,7 +109,7 @@ public class LegacyMDPWrapper> implements MDP newInstance() { - return new LegacyMDPWrapper(wrappedMDP.newInstance(), learning); + return new LegacyMDPWrapper(wrappedMDP.newInstance(), historyProcessor, epochStepCounter); } private INDArray getInput(O obs) { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java index b94541159..8a0090b62 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java @@ -32,7 +32,7 @@ public class AsyncThreadDiscreteTest { MockMDP mdpMock = new MockMDP(observationSpace); TrainingListenerList listeners = new TrainingListenerList(); MockPolicy policyMock = new MockPolicy(); - MockAsyncConfiguration config = new MockAsyncConfiguration(5, 16, 0, 0, 2, 5,0, 0, 0, 0); + MockAsyncConfiguration config = new MockAsyncConfiguration(5, 100, 0, 0, 2, 5,0, 0, 0, 0); TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock); // Act @@ -41,8 +41,8 @@ public class AsyncThreadDiscreteTest { // Assert assertEquals(2, sut.trainSubEpochResults.size()); double[][] expectedLastObservations = new double[][] { - new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, - new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 }, + new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, + new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, }; double[] expectedSubEpochReturnRewards = new double[] { 42.0, 58.0 }; for(int i = 0; i < 2; ++i) { @@ -60,7 +60,7 @@ public class AsyncThreadDiscreteTest { assertEquals(2, asyncGlobalMock.enqueueCallCount); // HistoryProcessor - double[] expectedAddValues = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0 }; + double[] expectedAddValues = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 }; assertEquals(expectedAddValues.length, hpMock.addCalls.size()); for(int i = 0; i < expectedAddValues.length; ++i) { assertEquals(expectedAddValues[i], hpMock.addCalls.get(i).getDouble(0), 0.00001); @@ -75,9 +75,9 @@ public class AsyncThreadDiscreteTest { // Policy double[][] expectedPolicyInputs = new double[][] { new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, - new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 }, - new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, - new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 }, + new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, + new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, + new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, }; assertEquals(expectedPolicyInputs.length, policyMock.actionInputs.size()); for(int i = 0; i < expectedPolicyInputs.length; ++i) { @@ -93,11 +93,11 @@ public class AsyncThreadDiscreteTest { assertEquals(2, nnMock.copyCallCount); double[][] expectedNNInputs = new double[][] { new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, - new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 }, - new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, // FIXME: This one comes from the computation of output of the last minitrans - new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, - new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 }, - new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 }, // FIXME: This one comes from the computation of output of the last minitrans + new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, + new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, // FIXME: This one comes from the computation of output of the last minitrans + new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, + new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, + new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, // FIXME: This one comes from the computation of output of the last minitrans }; assertEquals(expectedNNInputs.length, nnMock.outputAllInputs.size()); for(int i = 0; i < expectedNNInputs.length; ++i) { @@ -113,13 +113,13 @@ public class AsyncThreadDiscreteTest { double[][][] expectedMinitransObs = new double[][][] { new double[][] { new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, - new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 }, - new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, // FIXME: The last minitrans contains the next observation + new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, + new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, // FIXME: The last minitrans contains the next observation }, new double[][] { - new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, - new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 }, - new double[] { 8.0, 9.0, 11.0, 13.0, 15 }, // FIXME: The last minitrans contains the next observation + new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, + new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, + new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, // FIXME: The last minitrans contains the next observation } }; double[] expectedOutputs = new double[] { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 }; diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java index 377b32175..b01105294 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java @@ -5,15 +5,12 @@ import lombok.Getter; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; -import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.support.*; import org.deeplearning4j.rl4j.util.IDataManager; import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; import java.util.ArrayList; import java.util.List; @@ -91,7 +88,7 @@ public class AsyncThreadTest { // Assert assertEquals(numberOfEpochs, context.listener.statEntries.size()); - int[] expectedStepCounter = new int[] { 2, 4, 6, 8, 10 }; + int[] expectedStepCounter = new int[] { 10, 20, 30, 40, 50 }; double expectedReward = (1.0 + 2.0 + 3.0 + 4.0 + 5.0 + 6.0 + 7.0 + 8.0) // reward from init + 1.0; // Reward from trainSubEpoch() for(int i = 0; i < numberOfEpochs; ++i) { @@ -114,7 +111,7 @@ public class AsyncThreadTest { // Assert assertEquals(numberOfEpochs, context.sut.trainSubEpochParams.size()); double[] expectedObservation = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }; - for(int i = 0; i < context.sut.getEpochCounter(); ++i) { + for(int i = 0; i < context.sut.trainSubEpochParams.size(); ++i) { MockAsyncThread.TrainSubEpochParams params = context.sut.trainSubEpochParams.get(i); assertEquals(2, params.nstep); assertEquals(expectedObservation.length, params.obs.getData().shape()[1]); @@ -199,7 +196,9 @@ public class AsyncThreadTest { protected SubEpochReturn trainSubEpoch(Observation obs, int nstep) { asyncGlobal.increaseCurrentLoop(); trainSubEpochParams.add(new TrainSubEpochParams(obs, nstep)); - setStepCounter(getStepCounter() + nstep); + for(int i = 0; i < nstep; ++i) { + incrementStep(); + } return new SubEpochReturn(nstep, null, 1.0, 1.0); } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java index 373c4b189..73b27776a 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java @@ -18,8 +18,8 @@ public class ExpReplayTest { ExpReplay sut = new ExpReplay(2, 1, randomMock); // Act - Transition transition = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), - 123, 234, false, new Observation(Nd4j.create(1))); + Transition transition = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + 123, 234, new Observation(Nd4j.create(1))); sut.store(transition); List> results = sut.getBatch(1); @@ -36,12 +36,12 @@ public class ExpReplayTest { ExpReplay sut = new ExpReplay(2, 1, randomMock); // Act - Transition transition1 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), - 1, 2, false, new Observation(Nd4j.create(1))); - Transition transition2 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), - 3, 4, false, new Observation(Nd4j.create(1))); - Transition transition3 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), - 5, 6, false, new Observation(Nd4j.create(1))); + Transition transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + 1, 2, new Observation(Nd4j.create(1))); + Transition transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + 3, 4, new Observation(Nd4j.create(1))); + Transition transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + 5, 6, new Observation(Nd4j.create(1))); sut.store(transition1); sut.store(transition2); sut.store(transition3); @@ -78,12 +78,12 @@ public class ExpReplayTest { ExpReplay sut = new ExpReplay(5, 1, randomMock); // Act - Transition transition1 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), - 1, 2, false, new Observation(Nd4j.create(1))); - Transition transition2 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), - 3, 4, false, new Observation(Nd4j.create(1))); - Transition transition3 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), - 5, 6, false, new Observation(Nd4j.create(1))); + Transition transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + 1, 2, new Observation(Nd4j.create(1))); + Transition transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + 3, 4, new Observation(Nd4j.create(1))); + Transition transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + 5, 6, new Observation(Nd4j.create(1))); sut.store(transition1); sut.store(transition2); sut.store(transition3); @@ -100,12 +100,12 @@ public class ExpReplayTest { ExpReplay sut = new ExpReplay(5, 1, randomMock); // Act - Transition transition1 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), - 1, 2, false, new Observation(Nd4j.create(1))); - Transition transition2 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), - 3, 4, false, new Observation(Nd4j.create(1))); - Transition transition3 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), - 5, 6, false, new Observation(Nd4j.create(1))); + Transition transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + 1, 2, new Observation(Nd4j.create(1))); + Transition transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + 3, 4, new Observation(Nd4j.create(1))); + Transition transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + 5, 6, new Observation(Nd4j.create(1))); sut.store(transition1); sut.store(transition2); sut.store(transition3); @@ -131,16 +131,16 @@ public class ExpReplayTest { ExpReplay sut = new ExpReplay(5, 1, randomMock); // Act - Transition transition1 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), - 1, 2, false, new Observation(Nd4j.create(1))); - Transition transition2 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), - 3, 4, false, new Observation(Nd4j.create(1))); - Transition transition3 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), - 5, 6, false, new Observation(Nd4j.create(1))); - Transition transition4 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), - 7, 8, false, new Observation(Nd4j.create(1))); - Transition transition5 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), - 9, 10, false, new Observation(Nd4j.create(1))); + Transition transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + 1, 2, new Observation(Nd4j.create(1))); + Transition transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + 3, 4, new Observation(Nd4j.create(1))); + Transition transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + 5, 6, new Observation(Nd4j.create(1))); + Transition transition4 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + 7, 8, new Observation(Nd4j.create(1))); + Transition transition5 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + 9, 10, new Observation(Nd4j.create(1))); sut.store(transition1); sut.store(transition2); sut.store(transition3); @@ -168,16 +168,16 @@ public class ExpReplayTest { ExpReplay sut = new ExpReplay(5, 1, randomMock); // Act - Transition transition1 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), - 1, 2, false, new Observation(Nd4j.create(1))); - Transition transition2 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), - 3, 4, false, new Observation(Nd4j.create(1))); - Transition transition3 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), - 5, 6, false, new Observation(Nd4j.create(1))); - Transition transition4 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), - 7, 8, false, new Observation(Nd4j.create(1))); - Transition transition5 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), - 9, 10, false, new Observation(Nd4j.create(1))); + Transition transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + 1, 2, new Observation(Nd4j.create(1))); + Transition transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + 3, 4, new Observation(Nd4j.create(1))); + Transition transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + 5, 6, new Observation(Nd4j.create(1))); + Transition transition4 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + 7, 8, new Observation(Nd4j.create(1))); + Transition transition5 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + 9, 10, new Observation(Nd4j.create(1))); sut.store(transition1); sut.store(transition2); sut.store(transition3); @@ -196,6 +196,12 @@ public class ExpReplayTest { assertEquals(5, (int)results.get(2).getAction()); assertEquals(6, (int)results.get(2).getReward()); + } + private Transition buildTransition(Observation observation, Integer action, double reward, Observation nextObservation) { + Transition result = new Transition(observation, action, reward, false); + result.setNextObservation(nextObservation); + + return result; } } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java index 9b89390d8..79be025b5 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java @@ -1,5 +1,6 @@ package org.deeplearning4j.rl4j.learning.sync; +import lombok.Getter; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry; import org.deeplearning4j.rl4j.mdp.MDP; @@ -88,12 +89,15 @@ public class SyncLearningTest { private final LConfiguration conf; + @Getter + private int currentEpochStep = 0; + public MockSyncLearning(LConfiguration conf) { this.conf = conf; } @Override - protected void preEpoch() { } + protected void preEpoch() { currentEpochStep = 0; } @Override protected void postEpoch() { } @@ -101,7 +105,7 @@ public class SyncLearningTest { @Override protected IDataManager.StatEntry trainEpoch() { setStepCounter(getStepCounter() + 1); - return new MockStatEntry(getEpochCounter(), getStepCounter(), 1.0); + return new MockStatEntry(getCurrentEpochStep(), getStepCounter(), 1.0); } @Override diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/TransitionTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/TransitionTest.java index b74ebac11..944e41a31 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/TransitionTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/TransitionTest.java @@ -21,7 +21,7 @@ public class TransitionTest { Observation nextObservation = buildObservation(nextObs); // Act - Transition transition = new Transition(observation, 123, 234.0, false, nextObservation); + Transition transition = buildTransition(observation, 123, 234.0, nextObservation); // Assert double[][] expectedObservation = new double[][] { obs }; @@ -52,7 +52,7 @@ public class TransitionTest { Observation nextObservation = buildObservation(nextObs); // Act - Transition transition = new Transition(observation, 123, 234.0, false, nextObservation); + Transition transition = buildTransition(observation, 123, 234.0, nextObservation); // Assert assertExpected(obs, transition.getObservation().getData()); @@ -71,12 +71,12 @@ public class TransitionTest { double[] obs1 = new double[] { 0.0, 1.0, 2.0 }; Observation observation1 = buildObservation(obs1); Observation nextObservation1 = buildObservation(new double[] { 100.0, 101.0, 102.0 }); - transitions.add(new Transition(observation1,0, 0.0, false, nextObservation1)); + transitions.add(buildTransition(observation1,0, 0.0, nextObservation1)); double[] obs2 = new double[] { 10.0, 11.0, 12.0 }; Observation observation2 = buildObservation(obs2); Observation nextObservation2 = buildObservation(new double[] { 110.0, 111.0, 112.0 }); - transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2)); + transitions.add(buildTransition(observation2, 0, 0.0, nextObservation2)); // Act INDArray result = Transition.buildStackedObservations(transitions); @@ -101,7 +101,7 @@ public class TransitionTest { double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 }; Observation nextObservation1 = buildNextObservation(obs1, nextObs1); - transitions.add(new Transition(observation1, 0, 0.0, false, nextObservation1)); + transitions.add(buildTransition(observation1, 0, 0.0, nextObservation1)); double[][] obs2 = new double[][] { { 10.0, 11.0, 12.0 }, @@ -112,7 +112,7 @@ public class TransitionTest { double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 }; Observation nextObservation2 = buildNextObservation(obs2, nextObs2); - transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2)); + transitions.add(buildTransition(observation2, 0, 0.0, nextObservation2)); // Act INDArray result = Transition.buildStackedObservations(transitions); @@ -131,13 +131,13 @@ public class TransitionTest { double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 }; Observation observation1 = buildObservation(obs1); Observation nextObservation1 = buildObservation(nextObs1); - transitions.add(new Transition(observation1, 0, 0.0, false, nextObservation1)); + transitions.add(buildTransition(observation1, 0, 0.0, nextObservation1)); double[] obs2 = new double[] { 10.0, 11.0, 12.0 }; double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 }; Observation observation2 = buildObservation(obs2); Observation nextObservation2 = buildObservation(nextObs2); - transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2)); + transitions.add(buildTransition(observation2, 0, 0.0, nextObservation2)); // Act INDArray result = Transition.buildStackedNextObservations(transitions); @@ -162,7 +162,7 @@ public class TransitionTest { double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 }; Observation nextObservation1 = buildNextObservation(obs1, nextObs1); - transitions.add(new Transition(observation1, 0, 0.0, false, nextObservation1)); + transitions.add(buildTransition(observation1, 0, 0.0, nextObservation1)); double[][] obs2 = new double[][] { { 10.0, 11.0, 12.0 }, @@ -174,7 +174,7 @@ public class TransitionTest { double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 }; Observation nextObservation2 = buildNextObservation(obs2, nextObs2); - transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2)); + transitions.add(buildTransition(observation2, 0, 0.0, nextObservation2)); // Act INDArray result = Transition.buildStackedNextObservations(transitions); @@ -207,7 +207,13 @@ public class TransitionTest { Nd4j.create(obs[1]).reshape(1, 3), }; return new Observation(nextHistory); + } + private Transition buildTransition(Observation observation, int action, double reward, Observation nextObservation) { + Transition result = new Transition(observation, action, reward, false); + result.setNextObservation(nextObservation); + + return result; } private void assertExpected(double[] expected, INDArray actual) { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java index 22e1ba49d..ee8b365f0 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java @@ -40,8 +40,10 @@ public class QLearningDiscreteTest { new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 }); MockMDP mdp = new MockMDP(observationSpace, random); + int initStepCount = 8; + QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 24, 0, 5, 1, 1000, - 0, 1.0, 0, 0, 0, 0, true); + initStepCount, 1.0, 0, 0, 0, 0, true); MockDataManager dataManager = new MockDataManager(false); MockExpReplay expReplay = new MockExpReplay(); TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10, random); @@ -60,7 +62,7 @@ public class QLearningDiscreteTest { for(int i = 0; i < expectedRecords.length; ++i) { assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001); } - double[] expectedAdds = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0 }; + double[] expectedAdds = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0, 22.0, 24.0 }; assertEquals(expectedAdds.length, hp.addCalls.size()); for(int i = 0; i < expectedAdds.length; ++i) { assertEquals(expectedAdds[i], hp.addCalls.get(i).getDouble(0), 0.0001); @@ -75,19 +77,19 @@ public class QLearningDiscreteTest { assertEquals(14, dqn.outputParams.size()); double[][] expectedDQNOutput = new double[][] { new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, - new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 }, - new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 }, - new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, - new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 }, - new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 }, - new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 }, - new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 }, - new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 }, - new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 }, - new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 }, - new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 }, - new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 }, - new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 }, + new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, + new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, + new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, + new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, + new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, + new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, + new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, + new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 }, + new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 }, + new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 }, + new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 }, + new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 }, + new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 }, }; for(int i = 0; i < expectedDQNOutput.length; ++i) { INDArray outputParam = dqn.outputParams.get(i); @@ -105,19 +107,20 @@ public class QLearningDiscreteTest { assertArrayEquals(new Integer[] {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 }, mdp.actions.toArray()); // ExpReplay calls - double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0, 45.0 }; + double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0 }; int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 }; - double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0 }; + double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 }; double[][] expectedTrObservations = new double[][] { new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, - new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 }, - new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, - new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 }, - new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 }, - new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 }, - new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 }, - new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 }, + new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, + new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, + new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, + new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, + new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 }, + new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 }, + new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 }, }; + assertEquals(expectedTrObservations.length, expReplay.transitions.size()); for(int i = 0; i < expectedTrRewards.length; ++i) { Transition tr = expReplay.transitions.get(i); assertEquals(expectedTrRewards[i], tr.getReward(), 0.0001); @@ -129,7 +132,7 @@ public class QLearningDiscreteTest { } // trainEpoch result - assertEquals(16, result.getStepCounter()); + assertEquals(initStepCount + 16, result.getStepCounter()); assertEquals(300.0, result.getReward(), 0.00001); assertTrue(dqn.hasBeenReset); assertTrue(((MockDQN)sut.getTargetQNetwork()).hasBeenReset); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java index bb8af1950..760666f33 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java @@ -26,7 +26,7 @@ public class DoubleDQNTest { List> transitions = new ArrayList>() { { - add(new Transition(buildObservation(new double[]{1.1, 2.2}), + add(builtTransition(buildObservation(new double[]{1.1, 2.2}), 0, 1.0, true, buildObservation(new double[]{11.0, 22.0}))); } }; @@ -52,7 +52,7 @@ public class DoubleDQNTest { List> transitions = new ArrayList>() { { - add(new Transition(buildObservation(new double[]{1.1, 2.2}), + add(builtTransition(buildObservation(new double[]{1.1, 2.2}), 0, 1.0, false, buildObservation(new double[]{11.0, 22.0}))); } }; @@ -78,11 +78,11 @@ public class DoubleDQNTest { List> transitions = new ArrayList>() { { - add(new Transition(buildObservation(new double[]{1.1, 2.2}), + add(builtTransition(buildObservation(new double[]{1.1, 2.2}), 0, 1.0, false, buildObservation(new double[]{11.0, 22.0}))); - add(new Transition(buildObservation(new double[]{3.3, 4.4}), + add(builtTransition(buildObservation(new double[]{3.3, 4.4}), 1, 2.0, false, buildObservation(new double[]{33.0, 44.0}))); - add(new Transition(buildObservation(new double[]{5.5, 6.6}), + add(builtTransition(buildObservation(new double[]{5.5, 6.6}), 0, 3.0, true, buildObservation(new double[]{55.0, 66.0}))); } }; @@ -108,4 +108,11 @@ public class DoubleDQNTest { private Observation buildObservation(double[] data) { return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)}); } + + private Transition builtTransition(Observation observation, Integer action, double reward, boolean isTerminal, Observation nextObservation) { + Transition result = new Transition(observation, action, reward, isTerminal); + result.setNextObservation(nextObservation); + + return result; + } } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java index d2608437d..c540646a8 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java @@ -25,7 +25,7 @@ public class StandardDQNTest { List> transitions = new ArrayList>() { { - add(new Transition(buildObservation(new double[]{1.1, 2.2}), + add(buildTransition(buildObservation(new double[]{1.1, 2.2}), 0, 1.0, true, buildObservation(new double[]{11.0, 22.0}))); } }; @@ -51,7 +51,7 @@ public class StandardDQNTest { List> transitions = new ArrayList>() { { - add(new Transition(buildObservation(new double[]{1.1, 2.2}), + add(buildTransition(buildObservation(new double[]{1.1, 2.2}), 0, 1.0, false, buildObservation(new double[]{11.0, 22.0}))); } }; @@ -77,11 +77,11 @@ public class StandardDQNTest { List> transitions = new ArrayList>() { { - add(new Transition(buildObservation(new double[]{1.1, 2.2}), + add(buildTransition(buildObservation(new double[]{1.1, 2.2}), 0, 1.0, false, buildObservation(new double[]{11.0, 22.0}))); - add(new Transition(buildObservation(new double[]{3.3, 4.4}), + add(buildTransition(buildObservation(new double[]{3.3, 4.4}), 1, 2.0, false, buildObservation(new double[]{33.0, 44.0}))); - add(new Transition(buildObservation(new double[]{5.5, 6.6}), + add(buildTransition(buildObservation(new double[]{5.5, 6.6}), 0, 3.0, true, buildObservation(new double[]{55.0, 66.0}))); } }; @@ -108,4 +108,10 @@ public class StandardDQNTest { return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)}); } + private Transition buildTransition(Observation observation, Integer action, double reward, boolean isTerminal, Observation nextObservation) { + Transition result = new Transition(observation, action, reward, isTerminal); + result.setNextObservation(nextObservation); + + return result; + } } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java index 6fa2d06d4..f97457a52 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java @@ -198,7 +198,7 @@ public class PolicyTest { assertEquals(465.0, totalReward, 0.0001); // HistoryProcessor - assertEquals(27, hp.addCalls.size()); + assertEquals(16, hp.addCalls.size()); assertEquals(31, hp.recordCalls.size()); for(int i=0; i <= 30; ++i) { assertEquals((double)i, hp.recordCalls.get(i).getDouble(0), 0.0001); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java index f6da6d378..a3a5598d4 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java @@ -142,6 +142,11 @@ public class DataManagerTrainingListenerTest { return 0; } + @Override + public int getCurrentEpochStep() { + return 0; + } + @Getter @Setter private IHistoryProcessor historyProcessor; diff --git a/rl4j/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/GymEnv.java b/rl4j/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/GymEnv.java index 66be7698f..afeea5f7c 100644 --- a/rl4j/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/GymEnv.java +++ b/rl4j/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/GymEnv.java @@ -93,6 +93,9 @@ public class GymEnv> implements MDP { private boolean done = false; public GymEnv(String envId, boolean render, boolean monitor) { + this(envId, render, monitor, (Integer)null); + } + public GymEnv(String envId, boolean render, boolean monitor, Integer seed) { this.envId = envId; this.render = render; this.monitor = monitor; @@ -107,6 +110,10 @@ public class GymEnv> implements MDP { Py_DecRef(PyRun_StringFlags("env = gym.wrappers.Monitor(env, '" + GYM_MONITOR_DIR + "')", Py_single_input, globals, locals, null)); checkPythonError(); } + if (seed != null) { + Py_DecRef(PyRun_StringFlags("env.seed(" + seed + ")", Py_single_input, globals, locals, null)); + checkPythonError(); + } PyObject shapeTuple = PyRun_StringFlags("env.observation_space.shape", Py_eval_input, globals, locals, null); int[] shape = new int[(int)PyTuple_Size(shapeTuple)]; for (int i = 0; i < shape.length; i++) { @@ -125,7 +132,10 @@ public class GymEnv> implements MDP { } public GymEnv(String envId, boolean render, boolean monitor, int[] actions) { - this(envId, render, monitor); + this(envId, render, monitor, null, actions); + } + public GymEnv(String envId, boolean render, boolean monitor, Integer seed, int[] actions) { + this(envId, render, monitor, seed); actionTransformer = new ActionTransformer((HighLowDiscrete) getActionSpace(), actions); }