diff --git a/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala b/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala index 890f70853..dc41b31f6 100644 --- a/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala +++ b/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala @@ -232,6 +232,15 @@ class MathTest extends FlatSpec with Matchers { val arr = Nd4j.linspace(DataType.FLOAT, 1.0, 1.0, 18).reshape(3, 3, 2) val x = sd.bind(arr) - println(x(0,0,SDIndex.all()).eval) + x.get(SDIndex.all(), SDIndex.all(), SDIndex.all()).eval shouldBe x(---, ---, ---).eval + x.get(SDIndex.point(0), SDIndex.all(), SDIndex.all()).eval shouldBe x(0, ---, ---).eval + x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.all()).eval shouldBe x(0, 0, ---).eval + x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0, 0, 0).eval + + x.get(SDIndex.interval(0, 2), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0 :: 2, 0, 0).eval + x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.interval(0, 2)).eval shouldBe x(0 :: 2, + 0 :: 1, + 0 :: 2).eval + x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.all()).eval shouldBe x(0 :: 2, 0 :: 1, ---).eval } }