From fb578fdecd18a183ae2914a2990fe7f674501f4e Mon Sep 17 00:00:00 2001 From: Alexandre Boulanger <44292157+aboulang2002@users.noreply.github.com> Date: Thu, 25 Jun 2020 22:23:47 -0400 Subject: [PATCH] RL4J: Use directly NeuralNet instances in DoubleDQN and StandardDQN (#499) Signed-off-by: Alexandre Boulanger --- .../agent/update/DQNNeuralNetUpdateRule.java | 14 ++-- .../sync/qlearning/TargetQNetworkSource.java | 28 -------- .../TDTargetAlgorithm/BaseDQNAlgorithm.java | 21 +++--- .../BaseTDTargetAlgorithm.java | 19 +++--- .../discrete/TDTargetAlgorithm/DoubleDQN.java | 10 +-- .../TDTargetAlgorithm/StandardDQN.java | 10 +-- .../IOutputNeuralNet.java} | 66 +++++++++++-------- .../deeplearning4j/rl4j/network/dqn/IDQN.java | 6 +- .../TDTargetAlgorithm/DoubleDQNTest.java | 40 +++++++---- .../TDTargetAlgorithm/StandardDQNTest.java | 45 ++++++++----- .../support/MockTargetQNetworkSource.java | 26 -------- 11 files changed, 126 insertions(+), 159 deletions(-) delete mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/TargetQNetworkSource.java rename rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/{learning/sync/qlearning/QNetworkSource.java => network/IOutputNeuralNet.java} (51%) delete mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockTargetQNetworkSource.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/DQNNeuralNetUpdateRule.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/DQNNeuralNetUpdateRule.java index 46123d645..98873b827 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/DQNNeuralNetUpdateRule.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/DQNNeuralNetUpdateRule.java @@ -17,7 +17,6 @@ package org.deeplearning4j.rl4j.agent.update; import lombok.Getter; import org.deeplearning4j.rl4j.learning.sync.Transition; -import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource; import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN; import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm; import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN; @@ -28,13 +27,10 @@ import java.util.List; // Temporary class that will be replaced with a more generic class that delegates gradient computation // and network update to sub components. -public class DQNNeuralNetUpdateRule implements IUpdateRule>, TargetQNetworkSource { +public class DQNNeuralNetUpdateRule implements IUpdateRule> { - @Getter private final IDQN qNetwork; - - @Getter - private IDQN targetQNetwork; + private final IDQN targetQNetwork; private final int targetUpdateFrequency; private final ITDTargetAlgorithm tdTargetAlgorithm; @@ -47,8 +43,8 @@ public class DQNNeuralNetUpdateRule implements IUpdateRule>, this.targetQNetwork = qNetwork.clone(); this.targetUpdateFrequency = targetUpdateFrequency; tdTargetAlgorithm = isDoubleDQN - ? new DoubleDQN(this, gamma, errorClamp) - : new StandardDQN(this, gamma, errorClamp); + ? new DoubleDQN(qNetwork, targetQNetwork, gamma, errorClamp) + : new StandardDQN(qNetwork, targetQNetwork, gamma, errorClamp); } @Override @@ -56,7 +52,7 @@ public class DQNNeuralNetUpdateRule implements IUpdateRule>, DataSet targets = tdTargetAlgorithm.computeTDTargets(trainingBatch); qNetwork.fit(targets.getFeatures(), targets.getLabels()); if(++updateCount % targetUpdateFrequency == 0) { - targetQNetwork = qNetwork.clone(); + targetQNetwork.copy(qNetwork); } } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/TargetQNetworkSource.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/TargetQNetworkSource.java deleted file mode 100644 index 34fd9c06e..000000000 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/TargetQNetworkSource.java +++ /dev/null @@ -1,28 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.rl4j.learning.sync.qlearning; - -import org.deeplearning4j.rl4j.network.dqn.IDQN; - -/** - * An interface that is an extension of {@link QNetworkSource} for all implementations capable of supplying a target Q-Network - * - * @author Alexandre Boulanger - */ -public interface TargetQNetworkSource extends QNetworkSource { - IDQN getTargetQNetwork(); -} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseDQNAlgorithm.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseDQNAlgorithm.java index 3f27f954c..6cae384d5 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseDQNAlgorithm.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseDQNAlgorithm.java @@ -16,8 +16,7 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; -import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource; -import org.deeplearning4j.rl4j.network.dqn.IDQN; +import org.deeplearning4j.rl4j.network.IOutputNeuralNet; import org.nd4j.linalg.api.ndarray.INDArray; /** @@ -28,7 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; */ public abstract class BaseDQNAlgorithm extends BaseTDTargetAlgorithm { - private final TargetQNetworkSource qTargetNetworkSource; + private final IOutputNeuralNet targetQNetwork; /** * In litterature, this corresponds to Q{net}(s(t+1), a) @@ -40,23 +39,21 @@ public abstract class BaseDQNAlgorithm extends BaseTDTargetAlgorithm { */ protected INDArray targetQNetworkNextObservation; - protected BaseDQNAlgorithm(TargetQNetworkSource qTargetNetworkSource, double gamma) { - super(qTargetNetworkSource, gamma); - this.qTargetNetworkSource = qTargetNetworkSource; + protected BaseDQNAlgorithm(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) { + super(qNetwork, gamma); + this.targetQNetwork = targetQNetwork; } - protected BaseDQNAlgorithm(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) { - super(qTargetNetworkSource, gamma, errorClamp); - this.qTargetNetworkSource = qTargetNetworkSource; + protected BaseDQNAlgorithm(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) { + super(qNetwork, gamma, errorClamp); + this.targetQNetwork = targetQNetwork; } @Override protected void initComputation(INDArray observations, INDArray nextObservations) { super.initComputation(observations, nextObservations); - qNetworkNextObservation = qNetworkSource.getQNetwork().output(nextObservations); - - IDQN targetQNetwork = qTargetNetworkSource.getTargetQNetwork(); + qNetworkNextObservation = qNetwork.output(nextObservations); targetQNetworkNextObservation = targetQNetwork.output(nextObservations); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseTDTargetAlgorithm.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseTDTargetAlgorithm.java index ca4beb47e..e0ede18d7 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseTDTargetAlgorithm.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseTDTargetAlgorithm.java @@ -17,7 +17,7 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; import org.deeplearning4j.rl4j.learning.sync.Transition; -import org.deeplearning4j.rl4j.learning.sync.qlearning.QNetworkSource; +import org.deeplearning4j.rl4j.network.IOutputNeuralNet; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; @@ -30,7 +30,7 @@ import java.util.List; */ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm { - protected final QNetworkSource qNetworkSource; + protected final IOutputNeuralNet qNetwork; protected final double gamma; private final double errorClamp; @@ -38,12 +38,12 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithmerrorClamp away from the previous value. Double.NaN will disable the clamping. */ - protected BaseTDTargetAlgorithm(QNetworkSource qNetworkSource, double gamma, double errorClamp) { - this.qNetworkSource = qNetworkSource; + protected BaseTDTargetAlgorithm(IOutputNeuralNet qNetwork, double gamma, double errorClamp) { + this.qNetwork = qNetwork; this.gamma = gamma; this.errorClamp = errorClamp; @@ -52,12 +52,12 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm transition = transitions.get(i); double yTarget = computeTarget(i, transition.getReward(), transition.isTerminal()); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQN.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQN.java index 3203af1b8..caeb85fb6 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQN.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQN.java @@ -16,7 +16,7 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; -import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource; +import org.deeplearning4j.rl4j.network.IOutputNeuralNet; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -32,12 +32,12 @@ public class DoubleDQN extends BaseDQNAlgorithm { // In litterature, this corresponds to: max_{a}Q(s_{t+1}, a) private INDArray maxActionsFromQNetworkNextObservation; - public DoubleDQN(TargetQNetworkSource qTargetNetworkSource, double gamma) { - super(qTargetNetworkSource, gamma); + public DoubleDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) { + super(qNetwork, targetQNetwork, gamma); } - public DoubleDQN(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) { - super(qTargetNetworkSource, gamma, errorClamp); + public DoubleDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) { + super(qNetwork, targetQNetwork, gamma, errorClamp); } @Override diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQN.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQN.java index 8c03c8de9..6cd047c74 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQN.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQN.java @@ -16,7 +16,7 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; -import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource; +import org.deeplearning4j.rl4j.network.IOutputNeuralNet; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -32,12 +32,12 @@ public class StandardDQN extends BaseDQNAlgorithm { // In litterature, this corresponds to: max_{a}Q_{tar}(s_{t+1}, a) private INDArray maxActionsFromQTargetNextObservation; - public StandardDQN(TargetQNetworkSource qTargetNetworkSource, double gamma) { - super(qTargetNetworkSource, gamma); + public StandardDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) { + super(qNetwork, targetQNetwork, gamma); } - public StandardDQN(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) { - super(qTargetNetworkSource, gamma, errorClamp); + public StandardDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) { + super(qNetwork, targetQNetwork, gamma, errorClamp); } @Override diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QNetworkSource.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/IOutputNeuralNet.java similarity index 51% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QNetworkSource.java rename to rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/IOutputNeuralNet.java index e22d368e4..58e219ea0 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QNetworkSource.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/IOutputNeuralNet.java @@ -1,28 +1,38 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.rl4j.learning.sync.qlearning; - -import org.deeplearning4j.rl4j.network.dqn.IDQN; - -/** - * An interface for all implementations capable of supplying a Q-Network - * - * @author Alexandre Boulanger - */ -public interface QNetworkSource { - IDQN getQNetwork(); -} +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.network; + +import org.deeplearning4j.rl4j.observation.Observation; +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * An interface defining the output aspect of a {@link NeuralNet}. + */ +public interface IOutputNeuralNet { + /** + * Compute the output for the supplied observation. + * @param observation An {@link Observation} + * @return The ouptut of the network + */ + INDArray output(Observation observation); + + /** + * Compute the output for the supplied batch. + * @param batch + * @return The ouptut of the network + */ + INDArray output(INDArray batch); +} \ No newline at end of file diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java index af295d202..daed646c5 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java @@ -17,6 +17,7 @@ package org.deeplearning4j.rl4j.network.dqn; import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.network.IOutputNeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.observation.Observation; import org.nd4j.linalg.api.ndarray.INDArray; @@ -27,7 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; * This neural net quantify the value of each action given a state * */ -public interface IDQN extends NeuralNet { +public interface IDQN extends NeuralNet, IOutputNeuralNet { boolean isRecurrent(); @@ -37,9 +38,6 @@ public interface IDQN extends NeuralNet { void fit(INDArray input, INDArray[] labels); - INDArray output(INDArray batch); - INDArray output(Observation observation); - INDArray[] outputAll(INDArray batch); NN clone(); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java index 798bddf0d..0f03a5370 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java @@ -1,10 +1,13 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; import org.deeplearning4j.rl4j.learning.sync.Transition; -import org.deeplearning4j.rl4j.learning.sync.support.MockDQN; -import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource; +import org.deeplearning4j.rl4j.network.IOutputNeuralNet; import org.deeplearning4j.rl4j.observation.Observation; +import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.factory.Nd4j; @@ -13,16 +16,29 @@ import java.util.ArrayList; import java.util.List; import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; +@RunWith(MockitoJUnitRunner.class) public class DoubleDQNTest { + @Mock + IOutputNeuralNet qNetworkMock; + + @Mock + IOutputNeuralNet targetQNetworkMock; + + + @Before + public void setup() { + when(qNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]); + } + @Test public void when_isTerminal_expect_rewardValueAtIdx0() { // Assemble - MockDQN qNetwork = new MockDQN(); - MockDQN targetQNetwork = new MockDQN(); - MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork); + when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]); List> transitions = new ArrayList>() { { @@ -31,7 +47,7 @@ public class DoubleDQNTest { } }; - DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5); + DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act DataSet result = sut.computeTDTargets(transitions); @@ -46,9 +62,7 @@ public class DoubleDQNTest { public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() { // Assemble - MockDQN qNetwork = new MockDQN(); - MockDQN targetQNetwork = new MockDQN(-1.0); - MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork); + when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> ((INDArray)i.getArguments()[0]).mul(-1.0)); List> transitions = new ArrayList>() { { @@ -57,7 +71,7 @@ public class DoubleDQNTest { } }; - DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5); + DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act DataSet result = sut.computeTDTargets(transitions); @@ -72,9 +86,7 @@ public class DoubleDQNTest { public void when_batchHasMoreThanOne_expect_everySampleEvaluated() { // Assemble - MockDQN qNetwork = new MockDQN(); - MockDQN targetQNetwork = new MockDQN(-1.0); - MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork); + when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> ((INDArray)i.getArguments()[0]).mul(-1.0)); List> transitions = new ArrayList>() { { @@ -87,7 +99,7 @@ public class DoubleDQNTest { } }; - DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5); + DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act DataSet result = sut.computeTDTargets(transitions); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java index 3e3701669..6aead9e76 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java @@ -1,10 +1,13 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; import org.deeplearning4j.rl4j.learning.sync.Transition; -import org.deeplearning4j.rl4j.learning.sync.support.MockDQN; -import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource; +import org.deeplearning4j.rl4j.network.IOutputNeuralNet; import org.deeplearning4j.rl4j.observation.Observation; +import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.factory.Nd4j; @@ -12,17 +15,31 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; +@RunWith(MockitoJUnitRunner.class) public class StandardDQNTest { + + @Mock + IOutputNeuralNet qNetworkMock; + + @Mock + IOutputNeuralNet targetQNetworkMock; + + + @Before + public void setup() { + when(qNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]); + when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]); + } + + @Test public void when_isTerminal_expect_rewardValueAtIdx0() { // Assemble - MockDQN qNetwork = new MockDQN(); - MockDQN targetQNetwork = new MockDQN(); - MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork); - List> transitions = new ArrayList>() { { add(buildTransition(buildObservation(new double[]{1.1, 2.2}), @@ -30,7 +47,7 @@ public class StandardDQNTest { } }; - StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5); + StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act DataSet result = sut.computeTDTargets(transitions); @@ -45,10 +62,6 @@ public class StandardDQNTest { public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() { // Assemble - MockDQN qNetwork = new MockDQN(); - MockDQN targetQNetwork = new MockDQN(); - MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork); - List> transitions = new ArrayList>() { { add(buildTransition(buildObservation(new double[]{1.1, 2.2}), @@ -56,7 +69,7 @@ public class StandardDQNTest { } }; - StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5); + StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act DataSet result = sut.computeTDTargets(transitions); @@ -71,10 +84,6 @@ public class StandardDQNTest { public void when_batchHasMoreThanOne_expect_everySampleEvaluated() { // Assemble - MockDQN qNetwork = new MockDQN(); - MockDQN targetQNetwork = new MockDQN(); - MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork); - List> transitions = new ArrayList>() { { add(buildTransition(buildObservation(new double[]{1.1, 2.2}), @@ -86,7 +95,7 @@ public class StandardDQNTest { } }; - StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5); + StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act DataSet result = sut.computeTDTargets(transitions); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockTargetQNetworkSource.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockTargetQNetworkSource.java deleted file mode 100644 index ce756aa88..000000000 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockTargetQNetworkSource.java +++ /dev/null @@ -1,26 +0,0 @@ -package org.deeplearning4j.rl4j.learning.sync.support; - -import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource; -import org.deeplearning4j.rl4j.network.dqn.IDQN; - -public class MockTargetQNetworkSource implements TargetQNetworkSource { - - - private final IDQN qNetwork; - private final IDQN targetQNetwork; - - public MockTargetQNetworkSource(IDQN qNetwork, IDQN targetQNetwork) { - this.qNetwork = qNetwork; - this.targetQNetwork = targetQNetwork; - } - - @Override - public IDQN getTargetQNetwork() { - return targetQNetwork; - } - - @Override - public IDQN getQNetwork() { - return qNetwork; - } -}