From 3f3b676ce5f5207bb72e6cdd4df0b2f39a11393e Mon Sep 17 00:00:00 2001 From: Alex Black Date: Fri, 30 Aug 2019 23:00:53 +1000 Subject: [PATCH] DL4J Fixes (#204) * Fix issue with recently introduced exception handling system in MultiLayerNetwork/ComputationGraph Signed-off-by: AlexDBlack * Fix for SpaceToBatch layer Signed-off-by: AlexDBlack * #8133 DL4J SpaceToBatch gradient fix Signed-off-by: AlexDBlack --- .../exceptions/TestInvalidInput.java | 55 ++++++++++++++----- .../gradientcheck/CNNGradientCheckTest.java | 7 +++ .../nn/graph/ComputationGraph.java | 19 ++++++- .../nn/layers/convolution/SpaceToBatch.java | 4 +- .../nn/layers/recurrent/LSTMHelpers.java | 10 ++-- .../nn/layers/recurrent/SimpleRnn.java | 3 + .../nn/multilayer/MultiLayerNetwork.java | 19 ++++++- 7 files changed, 92 insertions(+), 25 deletions(-) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidInput.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidInput.java index 30ef68183..096c7ac69 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidInput.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidInput.java @@ -22,12 +22,16 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.junit.Test; import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.fail; +import java.util.Map; + +import static org.junit.Assert.*; /** * A set of tests to ensure that useful exceptions are thrown on invalid input @@ -267,23 +271,44 @@ public class TestInvalidInput extends BaseDL4JTest { //Idea: Using rnnTimeStep with a different number of examples between calls //(i.e., not calling reset between time steps) - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() - .layer(0, new GravesLSTM.Builder().nIn(5).nOut(5).build()) - .layer(1, new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build()).build(); + for(String layerType : new String[]{"simple", "lstm", "graves"}) { - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); + Layer l; + switch (layerType){ + case "simple": + l = new SimpleRnn.Builder().nIn(5).nOut(5).build(); + break; + case "lstm": + l = new LSTM.Builder().nIn(5).nOut(5).build(); + break; + case "graves": + l = new GravesLSTM.Builder().nIn(5).nOut(5).build(); + break; + default: + throw new RuntimeException(); + } - net.rnnTimeStep(Nd4j.create(3, 5, 10)); + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + .layer(l) + .layer(new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build()).build(); - try { - net.rnnTimeStep(Nd4j.create(5, 5, 10)); - fail("Expected DL4JException"); - } catch (DL4JException e) { - System.out.println("testInvalidRnnTimeStep(): " + e.getMessage()); - } catch (Exception e) { - e.printStackTrace(); - fail("Expected DL4JException"); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + net.rnnTimeStep(Nd4j.create(3, 5, 10)); + + Map m = net.rnnGetPreviousState(0); + assertNotNull(m); + assertFalse(m.isEmpty()); + + try { + net.rnnTimeStep(Nd4j.create(5, 5, 10)); + fail("Expected Exception - " + layerType); + } catch (Exception e) { +// e.printStackTrace(); + String msg = e.getMessage(); + assertTrue(msg, msg != null && msg.contains("rnn") && msg.contains("batch")); + } } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java index c1a873dc8..decb81bb0 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java @@ -28,6 +28,7 @@ import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D; +import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -343,6 +344,12 @@ public class CNNGradientCheckTest extends BaseDL4JTest { assertTrue(msg, gradOK); + //Also check compgraph: + ComputationGraph cg = net.toComputationGraph(); + gradOK = GradientCheckUtil.checkGradients(cg, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, + DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[]{input}, new INDArray[]{labels}); + assertTrue(msg + " - compgraph", gradOK); + TestUtils.testModelSerialization(net); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java index 00c0cf7d6..7c292bafa 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java @@ -2460,6 +2460,13 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { } Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace); + if(t != null){ + if(t instanceof RuntimeException){ + throw ((RuntimeException)t); + } + throw new RuntimeException("Error during neural network forward pass", t); + } + if(outputWorkspace == null || outputWorkspace instanceof DummyWorkspace) { WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active at the end of outputOfLayerDetached"); } else { @@ -2780,6 +2787,13 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { } } Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace); + + if(t != null){ + if(t instanceof RuntimeException){ + throw ((RuntimeException)t); + } + throw new RuntimeException("Error during neural network backpropagation calculation", t); + } } //Now, add the gradients in the order we need them in for flattening (same as params order) @@ -3312,8 +3326,11 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { @Override public int batchSize() { + //In 99+% of cases, the input and labels dimension 0 size should be identical + //The only real exceptions: space to batch, and batch to space layers + //In those cases, we should base it on the labels size, as this impacts gradient calculation // FIXME: int cast - return (int) inputs[0].size(0); + return labels == null || labels[0] == null ? (int) inputs[0].size(0) : (int)labels[0].size(0); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToBatch.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToBatch.java index eef8fcee7..586464716 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToBatch.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToBatch.java @@ -70,12 +70,12 @@ public class SpaceToBatch extends AbstractLayer