From 0175ace4c3936d65a81c46f9aa984e4b4198c8ba Mon Sep 17 00:00:00 2001 From: Alex Black Date: Mon, 9 Dec 2019 23:08:00 +1100 Subject: [PATCH] Small tweaks (#119) Signed-off-by: AlexDBlack --- .../debugging/ArraySavingListener.java | 27 +++++++++++++------ .../transforms/pairwise/arithmetic/DivOp.java | 4 +-- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ArraySavingListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ArraySavingListener.java index 9137fc831..6b64c69d8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ArraySavingListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ArraySavingListener.java @@ -7,7 +7,9 @@ import org.nd4j.autodiff.listeners.Operation; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.factory.Nd4j; @@ -81,14 +83,23 @@ public class ArraySavingListener extends BaseListener { if(eq){ System.out.println("Equals: " + varName.replaceAll("__", "/")); } else { - INDArray sub = arr1.sub(arr2); - INDArray diff = Nd4j.math.abs(sub); - double maxDiff = diff.maxNumber().doubleValue(); - System.out.println("FAILS: " + varName.replaceAll("__", "/") + " - max difference = " + maxDiff); - System.out.println("\t" + f.getAbsolutePath()); - System.out.println("\t" + f2.getAbsolutePath()); - sub.close(); - diff.close();; + if(arr1.dataType() == DataType.BOOL){ + INDArray xor = Nd4j.exec(new Xor(arr1, arr2)); + int count = xor.castTo(DataType.INT).sumNumber().intValue(); + System.out.println("FAILS: " + varName.replaceAll("__", "/") + " - boolean, # differences = " + count); + System.out.println("\t" + f.getAbsolutePath()); + System.out.println("\t" + f2.getAbsolutePath()); + xor.close(); + } else { + INDArray sub = arr1.sub(arr2); + INDArray diff = Nd4j.math.abs(sub); + double maxDiff = diff.maxNumber().doubleValue(); + System.out.println("FAILS: " + varName.replaceAll("__", "/") + " - max difference = " + maxDiff); + System.out.println("\t" + f.getAbsolutePath()); + System.out.println("\t" + f2.getAbsolutePath()); + sub.close(); + diff.close(); + } } arr1.close(); arr2.close(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/DivOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/DivOp.java index 5273a2941..b76942e95 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/DivOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/DivOp.java @@ -57,8 +57,8 @@ public class DivOp extends BaseDynamicTransformOp { } @Override - public String[] tensorflowNames() { - return new String[]{"Div","RealDiv"}; + public String tensorflowName() { + return "Div"; }