diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java index 5e2e7a4de..650bdeab2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java @@ -248,16 +248,17 @@ public abstract class AbstractSession { */ Map out = new HashMap<>(); //Outputs, returned to the user + Set allExecuted = new HashSet<>(); int step = 0; //Number of execution steps //Next 3: current execution frame String currentFrame = OUTER_FRAME; int currentFrameIter = 0; FrameIter currParentFrame = null; ExecStepPredicate predicate = new ExecStepPredicate(); - while (out.size() < userRequestedUnique.size()) { + while (allExecuted.size() < allRequired.size()) { if (!dt.hasNewAllSatisfied()) { //Haven't got all of the outputs the user requested, but there's nothing left that we can execute. Should not happen. - execFailed(userRequestedUnique, out, step); + execFailed(userRequestedUnique, out, allRequired, allExecuted, step); } //Get variable in the current frame/iteration and execute it's corresponding op @@ -289,10 +290,13 @@ public abstract class AbstractSession { Preconditions.checkNotNull(arr, "Encountered null placeholder array for constant: %s", vid); nodeOutputs.put(vid, arr); outFrameIter = new FrameIter(OUTER_FRAME, 0, null); - if (allRequired.contains(es.getName())) { + if (userRequestedUnique.contains(es.getName())) { //User requested const/variable as one of the outputs out.put(es.getName(), arr); } + if(allRequired.contains(es.getName())){ + allExecuted.add(es.getName()); + } } else if (es.getType() == ExecType.PLACEHOLDER) { VarId vid = new VarId(es.getName(), OUTER_FRAME, 0, null); T phVal = placeholderValues == null ? null : placeholderValues.get(es.getName()); @@ -305,6 +309,9 @@ public abstract class AbstractSession { //User requested placeholder value as one of the outputs out.put(es.getName(), placeholderValues.get(es.getName())); } + if(allRequired.contains(es.getName())){ + allExecuted.add(es.getName()); + } } else if (es.getType() == ExecType.OP) { String opName = es.getName(); SameDiffOp op = sameDiff.getOps().get(opName); @@ -399,9 +406,12 @@ public abstract class AbstractSession { VarId vid = new VarId(n, outFrameIter.getFrame(), outFrameIter.getIteration(), outFrameIter.getParentFrame()); nodeOutputs.put(vid, opOutputValues[i]); - if (allRequired.contains(n)) { + if (userRequestedUnique.contains(n)) { out.put(n, opOutputValues[i]); } + if(allRequired.contains(n)){ + allExecuted.add(n); + } } //Post execution: update dependency tracker so we know what is available to execute next, given we now @@ -508,17 +518,19 @@ public abstract class AbstractSession { * @param out Current outputs * @param step Execution step */ - protected void execFailed(Set userRequestedUnique, Map out, int step) { + protected void execFailed(Set userRequestedUnique, Map out, Set allRequired, Set allExecuted, int step) { int missingCount = userRequestedUnique.size() - out.size(); StringBuilder sb = new StringBuilder(); sb.append("No variable are available for execution at step ") - .append(step).append(": ").append(missingCount).append(" values remaining"); + .append(step).append(": ").append(missingCount).append(" requested output values remaining, ") + .append(allExecuted.size() - allRequired.size()).append(" variables required to be executed remaining"); Set missing = new HashSet<>(); for (String s : userRequestedUnique) { if (!out.containsKey(s)) { missing.add(s); } } + if (missingCount <= 10) { sb.append(". Missing variables: "); sb.append(missing); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java index 4f105aecc..099804f09 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java @@ -16,11 +16,6 @@ package org.nd4j.autodiff.samediff.listeners; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -import java.util.*; - import org.junit.Test; import org.nd4j.autodiff.listeners.*; import org.nd4j.autodiff.listeners.impl.ScoreListener; @@ -49,6 +44,13 @@ import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.weightinit.impl.XavierInitScheme; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.*; + public class ListenerTest extends BaseNd4jTest { public ListenerTest(Nd4jBackend backend) { @@ -260,6 +262,42 @@ public class ListenerTest extends BaseNd4jTest { } } + @Test + public void testCustomListener() { + SameDiff sd = SameDiff.create(); + SDVariable in = sd.placeHolder("input", DataType.FLOAT, -1, 4); + SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3); + SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 4, 3)); + SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 3)); + SDVariable z = sd.nn().linear("z", in, w, b); + SDVariable out = sd.nn().softmax("out", z, 1); + SDVariable loss = sd.loss().softmaxCrossEntropy("loss", label, out, null); + + //Create and set the training configuration + double learningRate = 1e-3; + TrainingConfig config = new TrainingConfig.Builder() + .l2(1e-4) //L2 regularization + .updater(new Adam(learningRate)) //Adam optimizer with specified learning rate + .dataSetFeatureMapping("input") //DataSet features array should be associated with variable "input" + .dataSetLabelMapping("label") //DataSet label array should be associated with variable "label + .addEvaluations(false,"out",0,new Evaluation()) + .build(); + sd.setTrainingConfig(config); + + CustomListener listener = new CustomListener(); + Map m = sd.output() + .data(new IrisDataSetIterator(150, 150)) + .output("out") + .listeners(listener) + .exec(); + + assertEquals(1, m.size()); + assertTrue(m.containsKey("out")); + assertNotNull(listener.z); + assertNotNull(listener.out); + + } + private static class TestListener implements Listener { public TestListener(Operation operation){ @@ -356,4 +394,38 @@ public class ListenerTest extends BaseNd4jTest { preUpdateCount++; } } + + private static class CustomListener extends BaseListener { + + public INDArray z; + public INDArray out; + + // Specify that this listener is active during inference operations + @Override + public boolean isActive(Operation operation) { + return operation == Operation.INFERENCE; + } + + // Specify that this listener requires the activations of "z" and "out" + @Override + public ListenerVariables requiredVariables(SameDiff sd) { + return new ListenerVariables.Builder().inferenceVariables("z", "out").build(); + } + + // Called when the activation of a variable becomes available + @Override + public void activationAvailable(SameDiff sd, At at, + MultiDataSet batch, SameDiffOp op, + String varName, INDArray activation) { + System.out.println("activation:" + varName); + + // if the variable is z or out, store its activation + if (varName.equals("z")) { + z = activation.detach().dup(); + } else if (varName.equals("out")) { + out = activation.detach().dup(); + } + } + + } }