diff --git a/pom.xml b/pom.xml
index 120159c30..9b32f25ae 100644
--- a/pom.xml
+++ b/pom.xml
@@ -303,8 +303,8 @@
1.79.0
1.10.6
0.6.1
- 0.15.4
- 1.15.0
+ 0.15.5
+ 1.15.2
${tensorflow.version}-${javacpp-presets.version}
1.18
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/EpochStepCounter.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/EpochStepCounter.java
new file mode 100644
index 000000000..746a71396
--- /dev/null
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/EpochStepCounter.java
@@ -0,0 +1,5 @@
+package org.deeplearning4j.rl4j.learning;
+
+public interface EpochStepCounter {
+ int getCurrentEpochStep();
+}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IEpochTrainer.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IEpochTrainer.java
index 72510dcaa..f113ce157 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IEpochTrainer.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IEpochTrainer.java
@@ -21,9 +21,14 @@ import org.deeplearning4j.rl4j.mdp.MDP;
/**
* The common API between Learning and AsyncThread.
*
+ * Express the ability to count the number of step of the current training.
+ * Factorisation of a feature between threads in async and learning process
+ * for the web monitoring
+ *
* @author Alexandre Boulanger
+ * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
*/
-public interface IEpochTrainer {
+public interface IEpochTrainer extends EpochStepCounter {
int getStepCounter();
int getEpochCounter();
IHistoryProcessor getHistoryProcessor();
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java
index 3c4c94c6b..d151f093b 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java
@@ -26,7 +26,7 @@ import org.deeplearning4j.rl4j.space.Encodable;
*
* A common interface that any training method should implement
*/
-public interface ILearning> extends StepCountable {
+public interface ILearning> {
IPolicy getPolicy();
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java
index 780a73752..833094929 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java
@@ -21,7 +21,6 @@ import lombok.Getter;
import lombok.Setter;
import lombok.Value;
import lombok.extern.slf4j.Slf4j;
-import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace;
@@ -53,55 +52,6 @@ public abstract class Learning, NN extends Neura
return Nd4j.argMax(vector, Integer.MAX_VALUE).getInt(0);
}
- public static > INDArray getInput(MDP mdp, O obs) {
- INDArray arr = Nd4j.create(((Encodable)obs).toArray());
- int[] shape = mdp.getObservationSpace().getShape();
- if (shape.length == 1)
- return arr.reshape(new long[] {1, arr.length()});
- else
- return arr.reshape(shape);
- }
-
- public static > InitMdp initMdp(MDP mdp,
- IHistoryProcessor hp) {
-
- O obs = mdp.reset();
-
- int step = 0;
- double reward = 0;
-
- boolean isHistoryProcessor = hp != null;
-
- int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
- int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0;
-
- INDArray input = Learning.getInput(mdp, obs);
- if (isHistoryProcessor)
- hp.record(input);
-
-
- while (step < requiredFrame && !mdp.isDone()) {
-
- A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
- if (step % skipFrame == 0 && isHistoryProcessor)
- hp.add(input);
-
- StepReply stepReply = mdp.step(action);
- reward += stepReply.getReward();
- obs = stepReply.getObservation();
-
- input = Learning.getInput(mdp, obs);
- if (isHistoryProcessor)
- hp.record(input);
-
- step++;
-
- }
-
- return new InitMdp(step, obs, reward);
-
- }
-
public static int[] makeShape(int size, int[] shape) {
int[] nshape = new int[shape.length + 1];
nshape[0] = size;
@@ -122,16 +72,16 @@ public abstract class Learning, NN extends Neura
public abstract NN getNeuralNet();
- public int incrementStep() {
- return stepCounter++;
+ public void incrementStep() {
+ stepCounter++;
}
- public int incrementEpoch() {
- return epochCounter++;
+ public void incrementEpoch() {
+ epochCounter++;
}
public void setHistoryProcessor(HistoryProcessor.Configuration conf) {
- historyProcessor = new HistoryProcessor(conf);
+ setHistoryProcessor(new HistoryProcessor(conf));
}
public void setHistoryProcessor(IHistoryProcessor historyProcessor) {
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/StepCountable.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/StepCountable.java
deleted file mode 100644
index b9538163b..000000000
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/StepCountable.java
+++ /dev/null
@@ -1,30 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.deeplearning4j.rl4j.learning;
-
-/**
- * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
- *
- * Express the ability to count the number of step of the current training.
- * Factorisation of a feature between threads in async and learning process
- * for the web monitoring
- */
-public interface StepCountable {
-
- int getStepCounter();
-
-}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java
index 2495e74ce..e95b25337 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java
@@ -30,8 +30,6 @@ import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.ActionSpace;
-import org.deeplearning4j.rl4j.space.DiscreteSpace;
-import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.nd4j.linalg.factory.Nd4j;
@@ -48,7 +46,7 @@ import org.nd4j.linalg.factory.Nd4j;
*/
@Slf4j
public abstract class AsyncThread, NN extends NeuralNet>
- extends Thread implements StepCountable, IEpochTrainer {
+ extends Thread implements IEpochTrainer {
@Getter
private int threadNumber;
@@ -61,6 +59,9 @@ public abstract class AsyncThread, NN extends Ne
@Getter @Setter
private IHistoryProcessor historyProcessor;
+ @Getter
+ private int currentEpochStep = 0;
+
private boolean isEpochStarted = false;
private final LegacyMDPWrapper mdp;
@@ -138,7 +139,7 @@ public abstract class AsyncThread, NN extends Ne
handleTraining(context);
- if (context.epochElapsedSteps >= getConf().getMaxEpochStep() || getMdp().isDone()) {
+ if (currentEpochStep >= getConf().getMaxEpochStep() || getMdp().isDone()) {
boolean canContinue = finishEpoch(context);
if (!canContinue) {
break;
@@ -154,11 +155,10 @@ public abstract class AsyncThread, NN extends Ne
}
private void handleTraining(RunContext context) {
- int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - context.epochElapsedSteps);
+ int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - currentEpochStep);
SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxSteps);
context.obs = subEpochReturn.getLastObs();
- context.epochElapsedSteps += subEpochReturn.getSteps();
context.rewards += subEpochReturn.getReward();
context.score = subEpochReturn.getScore();
}
@@ -169,7 +169,6 @@ public abstract class AsyncThread, NN extends Ne
context.obs = initMdp.getLastObs();
context.rewards = initMdp.getReward();
- context.epochElapsedSteps = initMdp.getSteps();
isEpochStarted = true;
preEpoch();
@@ -180,9 +179,9 @@ public abstract class AsyncThread, NN extends Ne
private boolean finishEpoch(RunContext context) {
isEpochStarted = false;
postEpoch();
- IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, context.epochElapsedSteps, context.score);
+ IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, currentEpochStep, context.score);
- log.info("ThreadNum-" + threadNumber + " Epoch: " + getEpochCounter() + ", reward: " + context.rewards);
+ log.info("ThreadNum-" + threadNumber + " Epoch: " + getCurrentEpochStep() + ", reward: " + context.rewards);
return listeners.notifyEpochTrainingResult(this, statEntry);
}
@@ -205,37 +204,30 @@ public abstract class AsyncThread, NN extends Ne
protected abstract SubEpochReturn trainSubEpoch(Observation obs, int nstep);
private Learning.InitMdp refacInitMdp() {
- LegacyMDPWrapper mdp = getLegacyMDPWrapper();
- IHistoryProcessor hp = getHistoryProcessor();
+ currentEpochStep = 0;
- Observation observation = mdp.reset();
-
- int step = 0;
double reward = 0;
- boolean isHistoryProcessor = hp != null;
-
- int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
- int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0;
-
- while (step < requiredFrame && !mdp.isDone()) {
-
- A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
+ LegacyMDPWrapper mdp = getLegacyMDPWrapper();
+ Observation observation = mdp.reset();
+ A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
+ while (observation.isSkipped() && !mdp.isDone()) {
StepReply stepReply = mdp.step(action);
+
reward += stepReply.getReward();
observation = stepReply.getObservation();
- step++;
-
+ incrementStep();
}
- return new Learning.InitMdp(step, observation, reward);
+ return new Learning.InitMdp(0, observation, reward);
}
public void incrementStep() {
++stepCounter;
+ ++currentEpochStep;
}
@AllArgsConstructor
@@ -260,7 +252,6 @@ public abstract class AsyncThread, NN extends Ne
private static class RunContext {
private Observation obs;
private double rewards;
- private int epochElapsedSteps;
private double score;
}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java
index 6b0078883..0b3eb1c72 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java
@@ -20,19 +20,14 @@ import lombok.Getter;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
-import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
-import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
-import org.deeplearning4j.rl4j.space.Encodable;
-import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
-import org.nd4j.linalg.util.ArrayUtil;
import java.util.Stack;
@@ -74,17 +69,18 @@ public abstract class AsyncThreadDiscrete
IPolicy policy = getPolicy(current);
Integer action;
- Integer lastAction = null;
+ Integer lastAction = getMdp().getActionSpace().noOp();
IHistoryProcessor hp = getHistoryProcessor();
int skipFrame = hp != null ? hp.getConf().getSkipFrame() : 1;
double reward = 0;
double accuReward = 0;
- int i = 0;
- while (!getMdp().isDone() && i < nstep * skipFrame) {
+ int stepAtStart = getCurrentEpochStep();
+ int lastStep = nstep * skipFrame + stepAtStart;
+ while (!getMdp().isDone() && getCurrentEpochStep() < lastStep) {
//if step of training, just repeat lastAction
- if (i % skipFrame != 0 && lastAction != null) {
+ if (obs.isSkipped()) {
action = lastAction;
} else {
action = policy.nextAction(obs);
@@ -94,7 +90,7 @@ public abstract class AsyncThreadDiscrete
accuReward += stepReply.getReward() * getConf().getRewardFactor();
//if it's not a skipped frame, you can do a step of training
- if (i % skipFrame == 0 || lastAction == null || stepReply.isDone()) {
+ if (!obs.isSkipped() || stepReply.isDone()) {
INDArray[] output = current.outputAll(obs.getData());
rewards.add(new MiniTrans(obs.getData(), action, output, accuReward));
@@ -106,7 +102,6 @@ public abstract class AsyncThreadDiscrete
reward += stepReply.getReward();
- i++;
incrementStep();
lastAction = action;
}
@@ -114,7 +109,7 @@ public abstract class AsyncThreadDiscrete
//a bit of a trick usable because of how the stack is treated to init R
// FIXME: The last element of minitrans is only used to seed the reward in calcGradient; observation, action and output are ignored.
- if (getMdp().isDone() && i < nstep * skipFrame)
+ if (getMdp().isDone() && getCurrentEpochStep() < lastStep)
rewards.add(new MiniTrans(obs.getData(), null, null, 0));
else {
INDArray[] output = null;
@@ -127,9 +122,9 @@ public abstract class AsyncThreadDiscrete
rewards.add(new MiniTrans(obs.getData(), null, output, maxQ));
}
- getAsyncGlobal().enqueue(calcGradient(current, rewards), i);
+ getAsyncGlobal().enqueue(calcGradient(current, rewards), getCurrentEpochStep());
- return new SubEpochReturn(i, obs, reward, current.getLatestScore());
+ return new SubEpochReturn(getCurrentEpochStep() - stepAtStart, obs, reward, current.getLatestScore());
}
public abstract Gradient[] calcGradient(NN nn, Stack> rewards);
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java
index ed5d73a75..22d936fcf 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java
@@ -17,7 +17,6 @@
package org.deeplearning4j.rl4j.learning.sync;
import lombok.Getter;
-import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.ILearning;
@@ -25,7 +24,6 @@ import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.listener.*;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace;
-import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager;
/**
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/Transition.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/Transition.java
index 509b28a88..3c35c47c7 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/Transition.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/Transition.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.rl4j.learning.sync;
+import lombok.Data;
import lombok.Value;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -34,7 +35,7 @@ import java.util.List;
* @author Alexandre Boulanger
*
*/
-@Value
+@Data
public class Transition {
Observation observation;
@@ -43,12 +44,15 @@ public class Transition {
boolean isTerminal;
INDArray nextObservation;
- public Transition(Observation observation, A action, double reward, boolean isTerminal, Observation nextObservation) {
+ public Transition(Observation observation, A action, double reward, boolean isTerminal) {
this.observation = observation;
this.action = action;
this.reward = reward;
this.isTerminal = isTerminal;
+ this.nextObservation = null;
+ }
+ public void setNextObservation(Observation nextObservation) {
// To conserve memory, only the most recent frame of the next observation is kept (if history is used).
// The full nextObservation will be re-build from observation when needed.
long[] nextObservationShape = nextObservation.getData().shape().clone();
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java
index bfd23ef5b..0757043f0 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java
@@ -21,8 +21,7 @@ import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.gym.StepReply;
-import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
-import org.deeplearning4j.rl4j.learning.Learning;
+import org.deeplearning4j.rl4j.learning.EpochStepCounter;
import org.deeplearning4j.rl4j.learning.sync.ExpReplay;
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
import org.deeplearning4j.rl4j.learning.sync.SyncLearning;
@@ -34,9 +33,8 @@ import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager.StatEntry;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.api.rng.Random;
+import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
@@ -49,7 +47,8 @@ import java.util.List;
*/
@Slf4j
public abstract class QLearning>
- extends SyncLearning implements TargetQNetworkSource {
+ extends SyncLearning
+ implements TargetQNetworkSource, EpochStepCounter {
// FIXME Changed for refac
// @Getter
@@ -104,18 +103,22 @@ public abstract class QLearning trainStep(Observation obs);
+ @Getter
+ private int currentEpochStep = 0;
+
protected StatEntry trainEpoch() {
+ resetNetworks();
+
InitMdp initMdp = refacInitMdp();
Observation obs = initMdp.getLastObs();
double reward = initMdp.getReward();
- int step = initMdp.getSteps();
Double startQ = Double.NaN;
double meanQ = 0;
int numQ = 0;
List scores = new ArrayList<>();
- while (step < getConfiguration().getMaxEpochStep() && !getMdp().isDone()) {
+ while (currentEpochStep < getConfiguration().getMaxEpochStep() && !getMdp().isDone()) {
if (getStepCounter() % getConfiguration().getTargetDqnUpdateFreq() == 0) {
updateTargetNetwork();
@@ -136,49 +139,53 @@ public abstract class QLearning refacInitMdp() {
- getQNetwork().reset();
- getTargetQNetwork().reset();
+ currentEpochStep = 0;
- LegacyMDPWrapper mdp = getLegacyMDPWrapper();
- IHistoryProcessor hp = getHistoryProcessor();
-
- Observation observation = mdp.reset();
-
- int step = 0;
double reward = 0;
- boolean isHistoryProcessor = hp != null;
-
- int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
- int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0;
-
- while (step < requiredFrame && !mdp.isDone()) {
-
- A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
+ LegacyMDPWrapper mdp = getLegacyMDPWrapper();
+ Observation observation = mdp.reset();
+ A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
+ while (observation.isSkipped() && !mdp.isDone()) {
StepReply stepReply = mdp.step(action);
+
reward += stepReply.getReward();
observation = stepReply.getObservation();
- step++;
-
+ incrementStep();
}
- return new InitMdp(step, observation, reward);
+ return new InitMdp(0, observation, reward);
}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java
index 9da30ccef..0fa0f33e6 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java
@@ -20,10 +20,13 @@ import lombok.AccessLevel;
import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.gym.StepReply;
+import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
-import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.*;
+import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
+import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm;
+import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.observation.Observation;
@@ -36,7 +39,6 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
-import org.nd4j.linalg.util.ArrayUtil;
import java.util.ArrayList;
@@ -68,6 +70,8 @@ public abstract class QLearningDiscrete extends QLearning getLegacyMDPWrapper() {
@@ -83,7 +87,7 @@ public abstract class QLearningDiscrete extends QLearning(mdp, this);
+ this.mdp = new LegacyMDPWrapper(mdp, null, this);
qNetwork = dqn;
targetQNetwork = dqn.clone();
policy = new DQNPolicy(getQNetwork());
@@ -108,8 +112,15 @@ public abstract class QLearningDiscrete extends QLearning extends QLearning trainStep(Observation obs) {
Integer action;
+
boolean isHistoryProcessor = getHistoryProcessor() != null;
-
-
int skipFrame = isHistoryProcessor ? getHistoryProcessor().getConf().getSkipFrame() : 1;
int historyLength = isHistoryProcessor ? getHistoryProcessor().getConf().getHistoryLength() : 1;
int updateStart = getConfiguration().getUpdateStart()
@@ -131,7 +141,7 @@ public abstract class QLearningDiscrete extends QLearning extends QLearning stepReply = mdp.step(action);
- Observation nextObservation = stepReply.getObservation();
-
accuReward += stepReply.getReward() * configuration.getRewardFactor();
//if it's not a skipped frame, you can do a step of training
- if (getStepCounter() % skipFrame == 0 || stepReply.isDone()) {
+ if (!obs.isSkipped() || stepReply.isDone()) {
- Transition trans = new Transition(obs, action, accuReward, stepReply.isDone(), nextObservation);
- getExpReplay().store(trans);
+ // Add experience
+ if(pendingTransition != null) {
+ pendingTransition.setNextObservation(obs);
+ getExpReplay().store(pendingTransition);
+ }
+ pendingTransition = new Transition(obs, action, accuReward, stepReply.isDone());
+ accuReward = 0;
+ // Update NN
+ // FIXME: maybe start updating when experience replay has reached a certain size instead of using "updateStart"?
if (getStepCounter() > updateStart) {
DataSet targets = setTarget(getExpReplay().getBatch());
getQNetwork().fit(targets.getFeatures(), targets.getLabels());
}
-
- accuReward = 0;
}
return new QLStepReturn(maxQ, getQNetwork().getLatestScore(), stepReply);
@@ -172,4 +185,12 @@ public abstract class QLearningDiscrete extends QLearning xThreshold
+ done |= x < -xThreshold || x > xThreshold
|| theta < -thetaThresholdRadians || theta > thetaThresholdRadians;
double reward;
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java
index 197a8f744..af1fe18ea 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java
@@ -16,6 +16,8 @@
package org.deeplearning4j.rl4j.observation;
+import lombok.Getter;
+import lombok.Setter;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
@@ -28,6 +30,9 @@ public class Observation {
private final DataSet data;
+ @Getter @Setter
+ private boolean skipped;
+
public Observation(INDArray[] data) {
this(data, false);
}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java
index 732ce1a0e..3ed375084 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java
@@ -18,7 +18,8 @@ package org.deeplearning4j.rl4j.policy;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
-import org.deeplearning4j.rl4j.learning.StepCountable;
+import org.deeplearning4j.rl4j.learning.IEpochTrainer;
+import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
@@ -46,7 +47,7 @@ public class EpsGreedy> extends Policy {
final private int epsilonNbStep;
final private Random rnd;
final private float minEpsilon;
- final private StepCountable learning;
+ final private IEpochTrainer learning;
public NeuralNet getNeuralNet() {
return policy.getNeuralNet();
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java
index 97a11b99c..fb20a60ac 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java
@@ -19,20 +19,15 @@ package org.deeplearning4j.rl4j.policy;
import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.gym.StepReply;
+import org.deeplearning4j.rl4j.learning.EpochStepCounter;
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
-import org.deeplearning4j.rl4j.learning.StepCountable;
-import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.ActionSpace;
-import org.deeplearning4j.rl4j.space.DiscreteSpace;
-import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.util.ArrayUtil;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16.
@@ -57,24 +52,22 @@ public abstract class Policy implements IPolicy {
@Override
public > double play(MDP mdp, IHistoryProcessor hp) {
- RefacStepCountable stepCountable = new RefacStepCountable();
- LegacyMDPWrapper mdpWrapper = new LegacyMDPWrapper(mdp, hp, stepCountable);
+ resetNetworks();
- boolean isHistoryProcessor = hp != null;
- int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
+ RefacEpochStepCounter epochStepCounter = new RefacEpochStepCounter();
+ LegacyMDPWrapper mdpWrapper = new LegacyMDPWrapper(mdp, hp, epochStepCounter);
- Learning.InitMdp initMdp = refacInitMdp(mdpWrapper, hp);
+ Learning.InitMdp initMdp = refacInitMdp(mdpWrapper, hp, epochStepCounter);
Observation obs = initMdp.getLastObs();
double reward = initMdp.getReward();
A lastAction = mdpWrapper.getActionSpace().noOp();
A action;
- stepCountable.setStepCounter(initMdp.getSteps());
while (!mdpWrapper.isDone()) {
- if (stepCountable.getStepCounter() % skipFrame != 0) {
+ if (obs.isSkipped()) {
action = lastAction;
} else {
action = nextAction(obs);
@@ -86,52 +79,46 @@ public abstract class Policy implements IPolicy {
reward += stepReply.getReward();
obs = stepReply.getObservation();
- stepCountable.increment();
+ epochStepCounter.incrementEpochStep();
}
return reward;
}
- private > Learning.InitMdp refacInitMdp(LegacyMDPWrapper mdpWrapper, IHistoryProcessor hp) {
+ protected void resetNetworks() {
getNeuralNet().reset();
- Observation observation = mdpWrapper.reset();
+ }
+
+ private > Learning.InitMdp refacInitMdp(LegacyMDPWrapper mdpWrapper, IHistoryProcessor hp, RefacEpochStepCounter epochStepCounter) {
+ epochStepCounter.setCurrentEpochStep(0);
- int step = 0;
double reward = 0;
- boolean isHistoryProcessor = hp != null;
+ Observation observation = mdpWrapper.reset();
- int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
- int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0;
-
- while (step < requiredFrame && !mdpWrapper.isDone()) {
-
- A action = mdpWrapper.getActionSpace().noOp(); //by convention should be the NO_OP
+ A action = mdpWrapper.getActionSpace().noOp(); //by convention should be the NO_OP
+ while (observation.isSkipped() && !mdpWrapper.isDone()) {
StepReply stepReply = mdpWrapper.step(action);
+
reward += stepReply.getReward();
observation = stepReply.getObservation();
- step++;
-
+ epochStepCounter.incrementEpochStep();
}
- return new Learning.InitMdp(step, observation, reward);
+ return new Learning.InitMdp(0, observation, reward);
}
- private class RefacStepCountable implements StepCountable {
+ public class RefacEpochStepCounter implements EpochStepCounter {
@Getter
@Setter
- private int stepCounter = 0;
+ private int currentEpochStep = 0;
- public void increment() {
- ++stepCounter;
+ public void incrementEpochStep() {
+ ++currentEpochStep;
}
- @Override
- public int getStepCounter() {
- return 0;
- }
}
}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java
index dbcd38ddc..26546a923 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java
@@ -1,10 +1,11 @@
package org.deeplearning4j.rl4j.util;
+import lombok.AccessLevel;
import lombok.Getter;
+import lombok.Setter;
import org.deeplearning4j.gym.StepReply;
+import org.deeplearning4j.rl4j.learning.EpochStepCounter;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
-import org.deeplearning4j.rl4j.learning.ILearning;
-import org.deeplearning4j.rl4j.learning.StepCountable;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.ActionSpace;
@@ -19,50 +20,20 @@ public class LegacyMDPWrapper> implements MDP wrappedMDP;
@Getter
private final WrapperObservationSpace observationSpace;
- private final ILearning learning;
+
+ @Getter(AccessLevel.PRIVATE) @Setter(AccessLevel.PUBLIC)
private IHistoryProcessor historyProcessor;
- private final StepCountable stepCountable;
- private int skipFrame;
- private int step = 0;
+ private final EpochStepCounter epochStepCounter;
- public LegacyMDPWrapper(MDP wrappedMDP, ILearning learning) {
- this(wrappedMDP, learning, null, null);
- }
+ private int skipFrame = 1;
+ private int requiredFrame = 0;
- public LegacyMDPWrapper(MDP wrappedMDP, IHistoryProcessor historyProcessor, StepCountable stepCountable) {
- this(wrappedMDP, null, historyProcessor, stepCountable);
- }
-
- private LegacyMDPWrapper(MDP wrappedMDP, ILearning learning, IHistoryProcessor historyProcessor, StepCountable stepCountable) {
+ public LegacyMDPWrapper(MDP wrappedMDP, IHistoryProcessor historyProcessor, EpochStepCounter epochStepCounter) {
this.wrappedMDP = wrappedMDP;
this.observationSpace = new WrapperObservationSpace(wrappedMDP.getObservationSpace().getShape());
- this.learning = learning;
this.historyProcessor = historyProcessor;
- this.stepCountable = stepCountable;
- }
-
- private IHistoryProcessor getHistoryProcessor() {
- if(historyProcessor != null) {
- return historyProcessor;
- }
-
- if (learning != null) {
- return learning.getHistoryProcessor();
- }
- return null;
- }
-
- public void setHistoryProcessor(IHistoryProcessor historyProcessor) {
- this.historyProcessor = historyProcessor;
- }
-
- private int getStep() {
- if(stepCountable != null) {
- return stepCountable.getStepCounter();
- }
-
- return learning.getStepCounter();
+ this.epochStepCounter = epochStepCounter;
}
@Override
@@ -83,9 +54,12 @@ public class LegacyMDPWrapper> implements MDP> implements MDP rawStepReply = wrappedMDP.step(a);
INDArray rawObservation = getInput(rawStepReply.getObservation());
- ++step;
+ int stepOfObservation = epochStepCounter.getCurrentEpochStep() + 1;
- int requiredFrame = 0;
if(historyProcessor != null) {
historyProcessor.record(rawObservation);
- requiredFrame = skipFrame * (historyProcessor.getConf().getHistoryLength() - 1);
- if ((getStep() % skipFrame == 0 && step >= requiredFrame)
- || (step % skipFrame == 0 && step < requiredFrame )){
+ if (stepOfObservation % skipFrame == 0) {
historyProcessor.add(rawObservation);
}
}
Observation observation;
- if(historyProcessor != null && step >= requiredFrame) {
+ if(historyProcessor != null && stepOfObservation >= requiredFrame) {
observation = new Observation(historyProcessor.getHistory(), true);
observation.getData().muli(1.0 / historyProcessor.getScale());
}
@@ -119,6 +90,10 @@ public class LegacyMDPWrapper> implements MDP(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo());
}
@@ -134,7 +109,7 @@ public class LegacyMDPWrapper> implements MDP newInstance() {
- return new LegacyMDPWrapper(wrappedMDP.newInstance(), learning);
+ return new LegacyMDPWrapper(wrappedMDP.newInstance(), historyProcessor, epochStepCounter);
}
private INDArray getInput(O obs) {
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java
index b94541159..8a0090b62 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java
@@ -32,7 +32,7 @@ public class AsyncThreadDiscreteTest {
MockMDP mdpMock = new MockMDP(observationSpace);
TrainingListenerList listeners = new TrainingListenerList();
MockPolicy policyMock = new MockPolicy();
- MockAsyncConfiguration config = new MockAsyncConfiguration(5, 16, 0, 0, 2, 5,0, 0, 0, 0);
+ MockAsyncConfiguration config = new MockAsyncConfiguration(5, 100, 0, 0, 2, 5,0, 0, 0, 0);
TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock);
// Act
@@ -41,8 +41,8 @@ public class AsyncThreadDiscreteTest {
// Assert
assertEquals(2, sut.trainSubEpochResults.size());
double[][] expectedLastObservations = new double[][] {
- new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
- new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 },
+ new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
+ new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
};
double[] expectedSubEpochReturnRewards = new double[] { 42.0, 58.0 };
for(int i = 0; i < 2; ++i) {
@@ -60,7 +60,7 @@ public class AsyncThreadDiscreteTest {
assertEquals(2, asyncGlobalMock.enqueueCallCount);
// HistoryProcessor
- double[] expectedAddValues = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0 };
+ double[] expectedAddValues = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 };
assertEquals(expectedAddValues.length, hpMock.addCalls.size());
for(int i = 0; i < expectedAddValues.length; ++i) {
assertEquals(expectedAddValues[i], hpMock.addCalls.get(i).getDouble(0), 0.00001);
@@ -75,9 +75,9 @@ public class AsyncThreadDiscreteTest {
// Policy
double[][] expectedPolicyInputs = new double[][] {
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
- new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
- new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
- new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
+ new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
+ new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
+ new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
};
assertEquals(expectedPolicyInputs.length, policyMock.actionInputs.size());
for(int i = 0; i < expectedPolicyInputs.length; ++i) {
@@ -93,11 +93,11 @@ public class AsyncThreadDiscreteTest {
assertEquals(2, nnMock.copyCallCount);
double[][] expectedNNInputs = new double[][] {
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
- new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
- new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, // FIXME: This one comes from the computation of output of the last minitrans
- new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
- new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
- new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 }, // FIXME: This one comes from the computation of output of the last minitrans
+ new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
+ new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, // FIXME: This one comes from the computation of output of the last minitrans
+ new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
+ new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
+ new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, // FIXME: This one comes from the computation of output of the last minitrans
};
assertEquals(expectedNNInputs.length, nnMock.outputAllInputs.size());
for(int i = 0; i < expectedNNInputs.length; ++i) {
@@ -113,13 +113,13 @@ public class AsyncThreadDiscreteTest {
double[][][] expectedMinitransObs = new double[][][] {
new double[][] {
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
- new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
- new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, // FIXME: The last minitrans contains the next observation
+ new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
+ new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, // FIXME: The last minitrans contains the next observation
},
new double[][] {
- new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
- new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
- new double[] { 8.0, 9.0, 11.0, 13.0, 15 }, // FIXME: The last minitrans contains the next observation
+ new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
+ new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
+ new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, // FIXME: The last minitrans contains the next observation
}
};
double[] expectedOutputs = new double[] { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 };
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java
index 377b32175..b01105294 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java
@@ -5,15 +5,12 @@ import lombok.Getter;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
-import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
-import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.support.*;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Test;
-import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.ArrayList;
import java.util.List;
@@ -91,7 +88,7 @@ public class AsyncThreadTest {
// Assert
assertEquals(numberOfEpochs, context.listener.statEntries.size());
- int[] expectedStepCounter = new int[] { 2, 4, 6, 8, 10 };
+ int[] expectedStepCounter = new int[] { 10, 20, 30, 40, 50 };
double expectedReward = (1.0 + 2.0 + 3.0 + 4.0 + 5.0 + 6.0 + 7.0 + 8.0) // reward from init
+ 1.0; // Reward from trainSubEpoch()
for(int i = 0; i < numberOfEpochs; ++i) {
@@ -114,7 +111,7 @@ public class AsyncThreadTest {
// Assert
assertEquals(numberOfEpochs, context.sut.trainSubEpochParams.size());
double[] expectedObservation = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 };
- for(int i = 0; i < context.sut.getEpochCounter(); ++i) {
+ for(int i = 0; i < context.sut.trainSubEpochParams.size(); ++i) {
MockAsyncThread.TrainSubEpochParams params = context.sut.trainSubEpochParams.get(i);
assertEquals(2, params.nstep);
assertEquals(expectedObservation.length, params.obs.getData().shape()[1]);
@@ -199,7 +196,9 @@ public class AsyncThreadTest {
protected SubEpochReturn trainSubEpoch(Observation obs, int nstep) {
asyncGlobal.increaseCurrentLoop();
trainSubEpochParams.add(new TrainSubEpochParams(obs, nstep));
- setStepCounter(getStepCounter() + nstep);
+ for(int i = 0; i < nstep; ++i) {
+ incrementStep();
+ }
return new SubEpochReturn(nstep, null, 1.0, 1.0);
}
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java
index 373c4b189..73b27776a 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java
@@ -18,8 +18,8 @@ public class ExpReplayTest {
ExpReplay sut = new ExpReplay(2, 1, randomMock);
// Act
- Transition transition = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }),
- 123, 234, false, new Observation(Nd4j.create(1)));
+ Transition transition = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
+ 123, 234, new Observation(Nd4j.create(1)));
sut.store(transition);
List> results = sut.getBatch(1);
@@ -36,12 +36,12 @@ public class ExpReplayTest {
ExpReplay sut = new ExpReplay(2, 1, randomMock);
// Act
- Transition transition1 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }),
- 1, 2, false, new Observation(Nd4j.create(1)));
- Transition transition2 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }),
- 3, 4, false, new Observation(Nd4j.create(1)));
- Transition transition3 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }),
- 5, 6, false, new Observation(Nd4j.create(1)));
+ Transition transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
+ 1, 2, new Observation(Nd4j.create(1)));
+ Transition transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
+ 3, 4, new Observation(Nd4j.create(1)));
+ Transition transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
+ 5, 6, new Observation(Nd4j.create(1)));
sut.store(transition1);
sut.store(transition2);
sut.store(transition3);
@@ -78,12 +78,12 @@ public class ExpReplayTest {
ExpReplay sut = new ExpReplay(5, 1, randomMock);
// Act
- Transition transition1 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }),
- 1, 2, false, new Observation(Nd4j.create(1)));
- Transition transition2 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }),
- 3, 4, false, new Observation(Nd4j.create(1)));
- Transition transition3 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }),
- 5, 6, false, new Observation(Nd4j.create(1)));
+ Transition transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
+ 1, 2, new Observation(Nd4j.create(1)));
+ Transition transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
+ 3, 4, new Observation(Nd4j.create(1)));
+ Transition transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
+ 5, 6, new Observation(Nd4j.create(1)));
sut.store(transition1);
sut.store(transition2);
sut.store(transition3);
@@ -100,12 +100,12 @@ public class ExpReplayTest {
ExpReplay sut = new ExpReplay(5, 1, randomMock);
// Act
- Transition transition1 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }),
- 1, 2, false, new Observation(Nd4j.create(1)));
- Transition transition2 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }),
- 3, 4, false, new Observation(Nd4j.create(1)));
- Transition transition3 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }),
- 5, 6, false, new Observation(Nd4j.create(1)));
+ Transition transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
+ 1, 2, new Observation(Nd4j.create(1)));
+ Transition transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
+ 3, 4, new Observation(Nd4j.create(1)));
+ Transition transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
+ 5, 6, new Observation(Nd4j.create(1)));
sut.store(transition1);
sut.store(transition2);
sut.store(transition3);
@@ -131,16 +131,16 @@ public class ExpReplayTest {
ExpReplay sut = new ExpReplay(5, 1, randomMock);
// Act
- Transition transition1 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }),
- 1, 2, false, new Observation(Nd4j.create(1)));
- Transition transition2 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }),
- 3, 4, false, new Observation(Nd4j.create(1)));
- Transition transition3 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }),
- 5, 6, false, new Observation(Nd4j.create(1)));
- Transition transition4 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }),
- 7, 8, false, new Observation(Nd4j.create(1)));
- Transition transition5 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }),
- 9, 10, false, new Observation(Nd4j.create(1)));
+ Transition transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
+ 1, 2, new Observation(Nd4j.create(1)));
+ Transition transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
+ 3, 4, new Observation(Nd4j.create(1)));
+ Transition transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
+ 5, 6, new Observation(Nd4j.create(1)));
+ Transition transition4 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
+ 7, 8, new Observation(Nd4j.create(1)));
+ Transition transition5 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
+ 9, 10, new Observation(Nd4j.create(1)));
sut.store(transition1);
sut.store(transition2);
sut.store(transition3);
@@ -168,16 +168,16 @@ public class ExpReplayTest {
ExpReplay sut = new ExpReplay(5, 1, randomMock);
// Act
- Transition transition1 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }),
- 1, 2, false, new Observation(Nd4j.create(1)));
- Transition transition2 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }),
- 3, 4, false, new Observation(Nd4j.create(1)));
- Transition transition3 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }),
- 5, 6, false, new Observation(Nd4j.create(1)));
- Transition transition4 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }),
- 7, 8, false, new Observation(Nd4j.create(1)));
- Transition transition5 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }),
- 9, 10, false, new Observation(Nd4j.create(1)));
+ Transition transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
+ 1, 2, new Observation(Nd4j.create(1)));
+ Transition transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
+ 3, 4, new Observation(Nd4j.create(1)));
+ Transition transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
+ 5, 6, new Observation(Nd4j.create(1)));
+ Transition transition4 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
+ 7, 8, new Observation(Nd4j.create(1)));
+ Transition transition5 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
+ 9, 10, new Observation(Nd4j.create(1)));
sut.store(transition1);
sut.store(transition2);
sut.store(transition3);
@@ -196,6 +196,12 @@ public class ExpReplayTest {
assertEquals(5, (int)results.get(2).getAction());
assertEquals(6, (int)results.get(2).getReward());
+ }
+ private Transition buildTransition(Observation observation, Integer action, double reward, Observation nextObservation) {
+ Transition result = new Transition(observation, action, reward, false);
+ result.setNextObservation(nextObservation);
+
+ return result;
}
}
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java
index 9b89390d8..79be025b5 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java
@@ -1,5 +1,6 @@
package org.deeplearning4j.rl4j.learning.sync;
+import lombok.Getter;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
import org.deeplearning4j.rl4j.mdp.MDP;
@@ -88,12 +89,15 @@ public class SyncLearningTest {
private final LConfiguration conf;
+ @Getter
+ private int currentEpochStep = 0;
+
public MockSyncLearning(LConfiguration conf) {
this.conf = conf;
}
@Override
- protected void preEpoch() { }
+ protected void preEpoch() { currentEpochStep = 0; }
@Override
protected void postEpoch() { }
@@ -101,7 +105,7 @@ public class SyncLearningTest {
@Override
protected IDataManager.StatEntry trainEpoch() {
setStepCounter(getStepCounter() + 1);
- return new MockStatEntry(getEpochCounter(), getStepCounter(), 1.0);
+ return new MockStatEntry(getCurrentEpochStep(), getStepCounter(), 1.0);
}
@Override
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/TransitionTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/TransitionTest.java
index b74ebac11..944e41a31 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/TransitionTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/TransitionTest.java
@@ -21,7 +21,7 @@ public class TransitionTest {
Observation nextObservation = buildObservation(nextObs);
// Act
- Transition transition = new Transition(observation, 123, 234.0, false, nextObservation);
+ Transition transition = buildTransition(observation, 123, 234.0, nextObservation);
// Assert
double[][] expectedObservation = new double[][] { obs };
@@ -52,7 +52,7 @@ public class TransitionTest {
Observation nextObservation = buildObservation(nextObs);
// Act
- Transition transition = new Transition(observation, 123, 234.0, false, nextObservation);
+ Transition transition = buildTransition(observation, 123, 234.0, nextObservation);
// Assert
assertExpected(obs, transition.getObservation().getData());
@@ -71,12 +71,12 @@ public class TransitionTest {
double[] obs1 = new double[] { 0.0, 1.0, 2.0 };
Observation observation1 = buildObservation(obs1);
Observation nextObservation1 = buildObservation(new double[] { 100.0, 101.0, 102.0 });
- transitions.add(new Transition(observation1,0, 0.0, false, nextObservation1));
+ transitions.add(buildTransition(observation1,0, 0.0, nextObservation1));
double[] obs2 = new double[] { 10.0, 11.0, 12.0 };
Observation observation2 = buildObservation(obs2);
Observation nextObservation2 = buildObservation(new double[] { 110.0, 111.0, 112.0 });
- transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2));
+ transitions.add(buildTransition(observation2, 0, 0.0, nextObservation2));
// Act
INDArray result = Transition.buildStackedObservations(transitions);
@@ -101,7 +101,7 @@ public class TransitionTest {
double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 };
Observation nextObservation1 = buildNextObservation(obs1, nextObs1);
- transitions.add(new Transition(observation1, 0, 0.0, false, nextObservation1));
+ transitions.add(buildTransition(observation1, 0, 0.0, nextObservation1));
double[][] obs2 = new double[][] {
{ 10.0, 11.0, 12.0 },
@@ -112,7 +112,7 @@ public class TransitionTest {
double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 };
Observation nextObservation2 = buildNextObservation(obs2, nextObs2);
- transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2));
+ transitions.add(buildTransition(observation2, 0, 0.0, nextObservation2));
// Act
INDArray result = Transition.buildStackedObservations(transitions);
@@ -131,13 +131,13 @@ public class TransitionTest {
double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 };
Observation observation1 = buildObservation(obs1);
Observation nextObservation1 = buildObservation(nextObs1);
- transitions.add(new Transition(observation1, 0, 0.0, false, nextObservation1));
+ transitions.add(buildTransition(observation1, 0, 0.0, nextObservation1));
double[] obs2 = new double[] { 10.0, 11.0, 12.0 };
double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 };
Observation observation2 = buildObservation(obs2);
Observation nextObservation2 = buildObservation(nextObs2);
- transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2));
+ transitions.add(buildTransition(observation2, 0, 0.0, nextObservation2));
// Act
INDArray result = Transition.buildStackedNextObservations(transitions);
@@ -162,7 +162,7 @@ public class TransitionTest {
double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 };
Observation nextObservation1 = buildNextObservation(obs1, nextObs1);
- transitions.add(new Transition(observation1, 0, 0.0, false, nextObservation1));
+ transitions.add(buildTransition(observation1, 0, 0.0, nextObservation1));
double[][] obs2 = new double[][] {
{ 10.0, 11.0, 12.0 },
@@ -174,7 +174,7 @@ public class TransitionTest {
double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 };
Observation nextObservation2 = buildNextObservation(obs2, nextObs2);
- transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2));
+ transitions.add(buildTransition(observation2, 0, 0.0, nextObservation2));
// Act
INDArray result = Transition.buildStackedNextObservations(transitions);
@@ -207,7 +207,13 @@ public class TransitionTest {
Nd4j.create(obs[1]).reshape(1, 3),
};
return new Observation(nextHistory);
+ }
+ private Transition buildTransition(Observation observation, int action, double reward, Observation nextObservation) {
+ Transition result = new Transition(observation, action, reward, false);
+ result.setNextObservation(nextObservation);
+
+ return result;
}
private void assertExpected(double[] expected, INDArray actual) {
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java
index 22e1ba49d..ee8b365f0 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java
@@ -40,8 +40,10 @@ public class QLearningDiscreteTest {
new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 });
MockMDP mdp = new MockMDP(observationSpace, random);
+ int initStepCount = 8;
+
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 24, 0, 5, 1, 1000,
- 0, 1.0, 0, 0, 0, 0, true);
+ initStepCount, 1.0, 0, 0, 0, 0, true);
MockDataManager dataManager = new MockDataManager(false);
MockExpReplay expReplay = new MockExpReplay();
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10, random);
@@ -60,7 +62,7 @@ public class QLearningDiscreteTest {
for(int i = 0; i < expectedRecords.length; ++i) {
assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001);
}
- double[] expectedAdds = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0 };
+ double[] expectedAdds = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0, 22.0, 24.0 };
assertEquals(expectedAdds.length, hp.addCalls.size());
for(int i = 0; i < expectedAdds.length; ++i) {
assertEquals(expectedAdds[i], hp.addCalls.get(i).getDouble(0), 0.0001);
@@ -75,19 +77,19 @@ public class QLearningDiscreteTest {
assertEquals(14, dqn.outputParams.size());
double[][] expectedDQNOutput = new double[][] {
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
- new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
- new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
- new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
- new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
- new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
- new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 },
- new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 },
- new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
- new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
- new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
- new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
- new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
- new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
+ new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
+ new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
+ new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
+ new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
+ new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
+ new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
+ new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
+ new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 },
+ new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 },
+ new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 },
+ new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 },
+ new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 },
+ new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 },
};
for(int i = 0; i < expectedDQNOutput.length; ++i) {
INDArray outputParam = dqn.outputParams.get(i);
@@ -105,19 +107,20 @@ public class QLearningDiscreteTest {
assertArrayEquals(new Integer[] {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 }, mdp.actions.toArray());
// ExpReplay calls
- double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0, 45.0 };
+ double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0 };
int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 };
- double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0 };
+ double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 };
double[][] expectedTrObservations = new double[][] {
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
- new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
- new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
- new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
- new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 },
- new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
- new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
- new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
+ new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
+ new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
+ new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
+ new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
+ new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 },
+ new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 },
+ new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 },
};
+ assertEquals(expectedTrObservations.length, expReplay.transitions.size());
for(int i = 0; i < expectedTrRewards.length; ++i) {
Transition tr = expReplay.transitions.get(i);
assertEquals(expectedTrRewards[i], tr.getReward(), 0.0001);
@@ -129,7 +132,7 @@ public class QLearningDiscreteTest {
}
// trainEpoch result
- assertEquals(16, result.getStepCounter());
+ assertEquals(initStepCount + 16, result.getStepCounter());
assertEquals(300.0, result.getReward(), 0.00001);
assertTrue(dqn.hasBeenReset);
assertTrue(((MockDQN)sut.getTargetQNetwork()).hasBeenReset);
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java
index bb8af1950..760666f33 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java
@@ -26,7 +26,7 @@ public class DoubleDQNTest {
List> transitions = new ArrayList>() {
{
- add(new Transition(buildObservation(new double[]{1.1, 2.2}),
+ add(builtTransition(buildObservation(new double[]{1.1, 2.2}),
0, 1.0, true, buildObservation(new double[]{11.0, 22.0})));
}
};
@@ -52,7 +52,7 @@ public class DoubleDQNTest {
List> transitions = new ArrayList>() {
{
- add(new Transition(buildObservation(new double[]{1.1, 2.2}),
+ add(builtTransition(buildObservation(new double[]{1.1, 2.2}),
0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
}
};
@@ -78,11 +78,11 @@ public class DoubleDQNTest {
List> transitions = new ArrayList>() {
{
- add(new Transition(buildObservation(new double[]{1.1, 2.2}),
+ add(builtTransition(buildObservation(new double[]{1.1, 2.2}),
0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
- add(new Transition(buildObservation(new double[]{3.3, 4.4}),
+ add(builtTransition(buildObservation(new double[]{3.3, 4.4}),
1, 2.0, false, buildObservation(new double[]{33.0, 44.0})));
- add(new Transition(buildObservation(new double[]{5.5, 6.6}),
+ add(builtTransition(buildObservation(new double[]{5.5, 6.6}),
0, 3.0, true, buildObservation(new double[]{55.0, 66.0})));
}
};
@@ -108,4 +108,11 @@ public class DoubleDQNTest {
private Observation buildObservation(double[] data) {
return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)});
}
+
+ private Transition builtTransition(Observation observation, Integer action, double reward, boolean isTerminal, Observation nextObservation) {
+ Transition result = new Transition(observation, action, reward, isTerminal);
+ result.setNextObservation(nextObservation);
+
+ return result;
+ }
}
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java
index d2608437d..c540646a8 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java
@@ -25,7 +25,7 @@ public class StandardDQNTest {
List> transitions = new ArrayList>() {
{
- add(new Transition(buildObservation(new double[]{1.1, 2.2}),
+ add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
0, 1.0, true, buildObservation(new double[]{11.0, 22.0})));
}
};
@@ -51,7 +51,7 @@ public class StandardDQNTest {
List> transitions = new ArrayList>() {
{
- add(new Transition(buildObservation(new double[]{1.1, 2.2}),
+ add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
}
};
@@ -77,11 +77,11 @@ public class StandardDQNTest {
List> transitions = new ArrayList>() {
{
- add(new Transition(buildObservation(new double[]{1.1, 2.2}),
+ add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
- add(new Transition(buildObservation(new double[]{3.3, 4.4}),
+ add(buildTransition(buildObservation(new double[]{3.3, 4.4}),
1, 2.0, false, buildObservation(new double[]{33.0, 44.0})));
- add(new Transition(buildObservation(new double[]{5.5, 6.6}),
+ add(buildTransition(buildObservation(new double[]{5.5, 6.6}),
0, 3.0, true, buildObservation(new double[]{55.0, 66.0})));
}
};
@@ -108,4 +108,10 @@ public class StandardDQNTest {
return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)});
}
+ private Transition buildTransition(Observation observation, Integer action, double reward, boolean isTerminal, Observation nextObservation) {
+ Transition result = new Transition(observation, action, reward, isTerminal);
+ result.setNextObservation(nextObservation);
+
+ return result;
+ }
}
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java
index 6fa2d06d4..f97457a52 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java
@@ -198,7 +198,7 @@ public class PolicyTest {
assertEquals(465.0, totalReward, 0.0001);
// HistoryProcessor
- assertEquals(27, hp.addCalls.size());
+ assertEquals(16, hp.addCalls.size());
assertEquals(31, hp.recordCalls.size());
for(int i=0; i <= 30; ++i) {
assertEquals((double)i, hp.recordCalls.get(i).getDouble(0), 0.0001);
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java
index f6da6d378..a3a5598d4 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java
@@ -142,6 +142,11 @@ public class DataManagerTrainingListenerTest {
return 0;
}
+ @Override
+ public int getCurrentEpochStep() {
+ return 0;
+ }
+
@Getter
@Setter
private IHistoryProcessor historyProcessor;
diff --git a/rl4j/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/GymEnv.java b/rl4j/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/GymEnv.java
index 66be7698f..afeea5f7c 100644
--- a/rl4j/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/GymEnv.java
+++ b/rl4j/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/GymEnv.java
@@ -93,6 +93,9 @@ public class GymEnv> implements MDP {
private boolean done = false;
public GymEnv(String envId, boolean render, boolean monitor) {
+ this(envId, render, monitor, (Integer)null);
+ }
+ public GymEnv(String envId, boolean render, boolean monitor, Integer seed) {
this.envId = envId;
this.render = render;
this.monitor = monitor;
@@ -107,6 +110,10 @@ public class GymEnv> implements MDP {
Py_DecRef(PyRun_StringFlags("env = gym.wrappers.Monitor(env, '" + GYM_MONITOR_DIR + "')", Py_single_input, globals, locals, null));
checkPythonError();
}
+ if (seed != null) {
+ Py_DecRef(PyRun_StringFlags("env.seed(" + seed + ")", Py_single_input, globals, locals, null));
+ checkPythonError();
+ }
PyObject shapeTuple = PyRun_StringFlags("env.observation_space.shape", Py_eval_input, globals, locals, null);
int[] shape = new int[(int)PyTuple_Size(shapeTuple)];
for (int i = 0; i < shape.length; i++) {
@@ -125,7 +132,10 @@ public class GymEnv> implements MDP {
}
public GymEnv(String envId, boolean render, boolean monitor, int[] actions) {
- this(envId, render, monitor);
+ this(envId, render, monitor, null, actions);
+ }
+ public GymEnv(String envId, boolean render, boolean monitor, Integer seed, int[] actions) {
+ this(envId, render, monitor, seed);
actionTransformer = new ActionTransformer((HighLowDiscrete) getActionSpace(), actions);
}