diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/HyperRect.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/HyperRect.java index 74574ffb5..b9523f30b 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/HyperRect.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/HyperRect.java @@ -18,6 +18,9 @@ package org.deeplearning4j.clustering.kdtree; import lombok.val; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.custom.KnnMinDistance; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Pair; import java.io.Serializable; import java.util.ArrayList; @@ -28,79 +31,103 @@ import java.util.List; */ public class HyperRect implements Serializable { - private List points; + //private List points; + private float[] lowerEnds; + private float[] higherEnds; + private INDArray lowerEndsIND; + private INDArray higherEndsIND; - public HyperRect(List points) { - //this.points = points; - this.points = new ArrayList<>(points.size()); - for (int i = 0; i < points.size(); ++i) { - Interval newInterval = new Interval(points.get(i).lower, points.get(i).higher); - this.points.add(newInterval); - } + public HyperRect(float[] lowerEndsIn, float[] higherEndsIn) { + this.lowerEnds = new float[lowerEndsIn.length]; + this.higherEnds = new float[lowerEndsIn.length]; + System.arraycopy(lowerEndsIn, 0 , this.lowerEnds, 0, lowerEndsIn.length); + System.arraycopy(higherEndsIn, 0 , this.higherEnds, 0, higherEndsIn.length); + lowerEndsIND = Nd4j.createFromArray(lowerEnds); + higherEndsIND = Nd4j.createFromArray(higherEnds); + } + + public HyperRect(float[] point) { + this(point, point); + } + + public HyperRect(Pair ends) { + this(ends.getFirst(), ends.getSecond()); } public void enlargeTo(INDArray point) { - for (int i = 0; i < points.size(); i++) - points.get(i).enlarge(point.getDouble(i)); + float[] pointAsArray = point.toFloatVector(); + for (int i = 0; i < lowerEnds.length; i++) { + float p = pointAsArray[i]; + if (lowerEnds[i] > p) + lowerEnds[i] = p; + else if (higherEnds[i] < p) + higherEnds[i] = p; + } } - - public static List point(INDArray vector) { - List ret = new ArrayList<>(); + public static Pair point(INDArray vector) { + Pair ret = new Pair<>(); + float[] curr = new float[(int)vector.length()]; for (int i = 0; i < vector.length(); i++) { - double curr = vector.getDouble(i); - ret.add(new Interval(curr, curr)); + curr[i] = vector.getFloat(i); } + ret.setFirst(curr); + ret.setSecond(curr); return ret; } - public List contains(INDArray hPoint) { + /*public List contains(INDArray hPoint) { List ret = new ArrayList<>(); - for (int i = 0; i < hPoint.length(); i++) - ret.add(points.get(i).contains(hPoint.getDouble(i))); - return ret; - } - - public double minDistance(INDArray hPoint) { - double ret = 0.0; for (int i = 0; i < hPoint.length(); i++) { - double p = hPoint.getDouble(i); - Interval interval = points.get(i); - if (!interval.contains(p)) { - if (p < interval.lower) - ret += Math.pow((p - interval.lower), 2); - else - ret += Math.pow((p - interval.higher), 2); - } + ret.add(lowerEnds[i] <= hPoint.getDouble(i) && + higherEnds[i] >= hPoint.getDouble(i)); } - - ret = Math.pow(ret, 0.5); return ret; + }*/ + + public double minDistance(INDArray hPoint, INDArray output) { + Nd4j.exec(new KnnMinDistance(hPoint, lowerEndsIND, higherEndsIND, output)); + return output.getFloat(0); + + /*double ret = 0.0; + double[] pointAsArray = hPoint.toDoubleVector(); + for (int i = 0; i < pointAsArray.length; i++) { + double p = pointAsArray[i]; + if (!(lowerEnds[i] <= p || higherEnds[i] <= p)) { + if (p < lowerEnds[i]) + ret += Math.pow((p - lowerEnds[i]), 2); + else + ret += Math.pow((p - higherEnds[i]), 2); + } + } + ret = Math.pow(ret, 0.5); + return ret;*/ } public HyperRect getUpper(INDArray hPoint, int desc) { - Interval interval = points.get(desc); - double d = hPoint.getDouble(desc); - if (interval.higher < d) + //Interval interval = points.get(desc); + float higher = higherEnds[desc]; + float d = hPoint.getFloat(desc); + if (higher < d) return null; - HyperRect ret = new HyperRect(new ArrayList<>(points)); - Interval i2 = ret.points.get(desc); - if (i2.lower < d) - i2.lower = d; + HyperRect ret = new HyperRect(lowerEnds,higherEnds); + if (ret.lowerEnds[desc] < d) + ret.lowerEnds[desc] = d; return ret; } public HyperRect getLower(INDArray hPoint, int desc) { - Interval interval = points.get(desc); - double d = hPoint.getDouble(desc); - if (interval.lower > d) + //Interval interval = points.get(desc); + float lower = lowerEnds[desc]; + float d = hPoint.getFloat(desc); + if (lower > d) return null; - HyperRect ret = new HyperRect(new ArrayList<>(points)); - Interval i2 = ret.points.get(desc); - if (i2.higher > d) - i2.higher = d; + HyperRect ret = new HyperRect(lowerEnds,higherEnds); + //Interval i2 = ret.points.get(desc); + if (ret.higherEnds[desc] > d) + ret.higherEnds[desc] = d; return ret; } @@ -108,33 +135,10 @@ public class HyperRect implements Serializable { public String toString() { String retVal = ""; retVal += "["; - for (val point : points) { - retVal += "(" + point.lower + " - " + point.higher + ") "; + for (int i = 0; i < lowerEnds.length; ++i) { + retVal += "(" + lowerEnds[i] + " - " + higherEnds[i] + ") "; } retVal += "]"; return retVal; } - - public static class Interval { - private double lower, higher; - - public Interval(double lower, double higher) { - this.lower = lower; - this.higher = higher; - } - - public boolean contains(double point) { - return lower <= point || point <= higher; - - } - - public void enlarge(double p) { - if (lower > p) - lower = p; - else if (higher < p) - higher = p; - } - - } - } diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/KDTree.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/KDTree.java index c5e2452f3..3e0b90119 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/KDTree.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/KDTree.java @@ -56,7 +56,7 @@ public class KDTree implements Serializable { if (root == null) { root = new KDNode(point); - rect = new HyperRect(HyperRect.point(point)); + rect = new HyperRect(/*HyperRect.point(point)*/ point.toFloatVector()); } else { int disc = 0; KDNode node = root; @@ -125,15 +125,21 @@ public class KDTree implements Serializable { return node.getPoint(); } + // Share this data for recursive calls of "knn" + private float currentDistance; + private INDArray currentPoint; + private INDArray minDistance = Nd4j.scalar(0.f); - public List> knn(INDArray point, double distance) { - List> best = new ArrayList<>(); - knn(root, point, rect, distance, best, 0); - Collections.sort(best, new Comparator>() { + public List> knn(INDArray point, float distance) { + List> best = new ArrayList<>(); + currentDistance = distance; + currentPoint = point; + knn(root, rect, best, 0); + Collections.sort(best, new Comparator>() { @Override - public int compare(Pair o1, Pair o2) { - return Double.compare(o1.getKey(), o2.getKey()); + public int compare(Pair o1, Pair o2) { + return Float.compare(o1.getKey(), o2.getKey()); } }); @@ -141,22 +147,21 @@ public class KDTree implements Serializable { } - private void knn(KDNode node, INDArray point, HyperRect rect, double dist, List> best, - int _disc) { - if (node == null || rect == null || rect.minDistance(point) > dist) + private void knn(KDNode node, HyperRect rect, List> best, int _disc) { + if (node == null || rect == null || rect.minDistance(currentPoint, minDistance) > currentDistance) return; int _discNext = (_disc + 1) % dims; - double distance = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(point,node.point)).getFinalResult() - .doubleValue(); + float distance = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(currentPoint,node.point, minDistance)).getFinalResult() + .floatValue(); - if (distance <= dist) { + if (distance <= currentDistance) { best.add(Pair.of(distance, node.getPoint())); } HyperRect lower = rect.getLower(node.point, _disc); HyperRect upper = rect.getUpper(node.point, _disc); - knn(node.getLeft(), point, lower, dist, best, _discNext); - knn(node.getRight(), point, upper, dist, best, _discNext); + knn(node.getLeft(), lower, best, _discNext); + knn(node.getRight(), upper, best, _discNext); } /** @@ -171,7 +176,7 @@ public class KDTree implements Serializable { private Pair nn(KDNode node, INDArray point, HyperRect rect, double dist, INDArray best, int _disc) { - if (node == null || rect.minDistance(point) > dist) + if (node == null || rect.minDistance(point, minDistance) > dist) return Pair.of(Double.POSITIVE_INFINITY, null); int _discNext = (_disc + 1) % dims; diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java index 1de7a379b..618ee0c94 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java @@ -16,6 +16,8 @@ package org.deeplearning4j.clustering.kdtree; +import org.joda.time.Instant; +import org.nd4j.shade.guava.base.Stopwatch; import org.nd4j.shade.guava.primitives.Doubles; import lombok.val; import org.deeplearning4j.clustering.BaseDL4JTest; @@ -28,6 +30,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; +import org.nd4j.shade.guava.primitives.Floats; import org.opencv.ml.KNearest; import java.util.ArrayList; @@ -35,6 +38,8 @@ import java.util.Arrays; import java.util.List; import java.util.Random; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -53,17 +58,17 @@ public class KDTreeTest extends BaseDL4JTest { @Before public void setUp() { kdTree = new KDTree(2); - double[] data = new double[]{7,2}; + float[] data = new float[]{7,2}; kdTree.insert(Nd4j.createFromArray(data)); - data = new double[]{5,4}; + data = new float[]{5,4}; kdTree.insert(Nd4j.createFromArray(data)); - data = new double[]{2,3}; + data = new float[]{2,3}; kdTree.insert(Nd4j.createFromArray(data)); - data = new double[]{4,7}; + data = new float[]{4,7}; kdTree.insert(Nd4j.createFromArray(data)); - data = new double[]{9,6}; + data = new float[]{9,6}; kdTree.insert(Nd4j.createFromArray(data)); - data = new double[]{8,1}; + data = new float[]{8,1}; kdTree.insert(Nd4j.createFromArray(data)); } @@ -168,26 +173,30 @@ public class KDTreeTest extends BaseDL4JTest { @Test public void testKNN() { - int n = 10; - // make a KD-tree of dimension {#n} - KDTree kdTree = new KDTree(n); - for (int i = -1; i < n; i++) { + int dimensions = 512; + int vectorsNo = 50000; + // make a KD-tree of dimension {#dimensions} + Stopwatch stopwatch = Stopwatch.createStarted(); + KDTree kdTree = new KDTree(dimensions); + for (int i = -1; i < vectorsNo; i++) { // Insert a unit vector along each dimension - List vec = new ArrayList<>(n); - // i = -1 ensures the origin is in the Tree - for (int k = 0; k < n; k++) { - vec.add((k == i) ? 1.0 : 0.0); - } - INDArray indVec = Nd4j.create(Nd4j.createBuffer(Doubles.toArray(vec))); + INDArray indVec = Nd4j.rand(DataType.FLOAT, 1,dimensions); kdTree.insert(indVec); } + stopwatch.stop(); + System.out.println("Time elapsed for " + kdTree.size() + " nodes construction is "+ stopwatch.elapsed(SECONDS)); + Random rand = new Random(); // random point in the Hypercube - List pt = new ArrayList(n); - for (int k = 0; k < n; k++) { - pt.add(rand.nextDouble() * 10.0); + List pt = new ArrayList(dimensions); + for (int k = 0; k < dimensions; k++) { + pt.add(rand.nextFloat() * 10.0); } - List> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Doubles.toArray(pt))), 20.0); + stopwatch.reset(); + stopwatch.start(); + List> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Floats.toArray(pt))), 20.0f); + stopwatch.stop(); + System.out.println("Time elapsed for Search is "+ stopwatch.elapsed(MILLISECONDS)); } @Test @@ -195,15 +204,15 @@ public class KDTreeTest extends BaseDL4JTest { int n = 2; KDTree kdTree = new KDTree(n); - double[] data = new double[]{3,3}; + float[] data = new float[]{3,3}; kdTree.insert(Nd4j.createFromArray(data)); - data = new double[]{1,1}; + data = new float[]{1,1}; kdTree.insert(Nd4j.createFromArray(data)); - data = new double[]{2,2}; + data = new float[]{2,2}; kdTree.insert(Nd4j.createFromArray(data)); - data = new double[]{0,0}; - List> result = kdTree.knn(Nd4j.createFromArray(data), 4.5); + data = new float[]{0,0}; + List> result = kdTree.knn(Nd4j.createFromArray(data), 4.5f); assertEquals(1.0, result.get(0).getSecond().getDouble(0), 1e-5); assertEquals(1.0, result.get(0).getSecond().getDouble(1), 1e-5); @@ -220,88 +229,88 @@ public class KDTreeTest extends BaseDL4JTest { assertEquals(6, kdTree.size()); - double[] data = new double[]{8,1}; - List> result = kdTree.knn(Nd4j.createFromArray(data), 10.0); - assertEquals(8.0, result.get(0).getSecond().getDouble(0), 1e-5); - assertEquals(1.0, result.get(0).getSecond().getDouble(1), 1e-5); - assertEquals(7.0, result.get(1).getSecond().getDouble(0), 1e-5); - assertEquals(2.0, result.get(1).getSecond().getDouble(1), 1e-5); - assertEquals(5.0, result.get(2).getSecond().getDouble(0), 1e-5); - assertEquals(4.0, result.get(2).getSecond().getDouble(1), 1e-5); - assertEquals(9.0, result.get(3).getSecond().getDouble(0), 1e-5); - assertEquals(6.0, result.get(3).getSecond().getDouble(1), 1e-5); - assertEquals(2.0, result.get(4).getSecond().getDouble(0), 1e-5); - assertEquals(3.0, result.get(4).getSecond().getDouble(1), 1e-5); - assertEquals(4.0, result.get(5).getSecond().getDouble(0), 1e-5); - assertEquals(7.0, result.get(5).getSecond().getDouble(1), 1e-5); + float[] data = new float[]{8,1}; + List> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f); + assertEquals(8.0, result.get(0).getSecond().getFloat(0), 1e-5); + assertEquals(1.0, result.get(0).getSecond().getFloat(1), 1e-5); + assertEquals(7.0, result.get(1).getSecond().getFloat(0), 1e-5); + assertEquals(2.0, result.get(1).getSecond().getFloat(1), 1e-5); + assertEquals(5.0, result.get(2).getSecond().getFloat(0), 1e-5); + assertEquals(4.0, result.get(2).getSecond().getFloat(1), 1e-5); + assertEquals(9.0, result.get(3).getSecond().getFloat(0), 1e-5); + assertEquals(6.0, result.get(3).getSecond().getFloat(1), 1e-5); + assertEquals(2.0, result.get(4).getSecond().getFloat(0), 1e-5); + assertEquals(3.0, result.get(4).getSecond().getFloat(1), 1e-5); + assertEquals(4.0, result.get(5).getSecond().getFloat(0), 1e-5); + assertEquals(7.0, result.get(5).getSecond().getFloat(1), 1e-5); } @Test public void testKNN_2() { - double[] data = new double[]{8, 1}; - List> result = kdTree.knn(Nd4j.createFromArray(data), 5.0); - assertEquals(8.0, result.get(0).getSecond().getDouble(0), 1e-5); - assertEquals(1.0, result.get(0).getSecond().getDouble(1), 1e-5); - assertEquals(7.0, result.get(1).getSecond().getDouble(0), 1e-5); - assertEquals(2.0, result.get(1).getSecond().getDouble(1), 1e-5); - assertEquals(5.0, result.get(2).getSecond().getDouble(0), 1e-5); - assertEquals(4.0, result.get(2).getSecond().getDouble(1), 1e-5); + float[] data = new float[]{8, 1}; + List> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f); + assertEquals(8.0, result.get(0).getSecond().getFloat(0), 1e-5); + assertEquals(1.0, result.get(0).getSecond().getFloat(1), 1e-5); + assertEquals(7.0, result.get(1).getSecond().getFloat(0), 1e-5); + assertEquals(2.0, result.get(1).getSecond().getFloat(1), 1e-5); + assertEquals(5.0, result.get(2).getSecond().getFloat(0), 1e-5); + assertEquals(4.0, result.get(2).getSecond().getFloat(1), 1e-5); } @Test public void testKNN_3() { - double[] data = new double[]{2, 3}; - val result = kdTree.knn(Nd4j.createFromArray(data), 10.0); - assertEquals(2.0, result.get(0).getSecond().getDouble(0), 1e-5); - assertEquals(3.0, result.get(0).getSecond().getDouble(1), 1e-5); - assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); - assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5); - assertEquals(4.0, result.get(2).getSecond().getDouble(0), 1e-5); - assertEquals(7.0, result.get(2).getSecond().getDouble(1), 1e-5); - assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5); - assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5); - assertEquals(8.0, result.get(4).getSecond().getDouble(0), 1e-5); - assertEquals(1.0, result.get(4).getSecond().getDouble(1), 1e-5); - assertEquals(9.0, result.get(5).getSecond().getDouble(0), 1e-5); - assertEquals(6.0, result.get(5).getSecond().getDouble(1), 1e-5); + float[] data = new float[]{2, 3}; + List> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f); + assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5); + assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5); + assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5); + assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5); + assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5); + assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5); + assertEquals(7.0, result.get(3).getSecond().getFloat(0), 1e-5); + assertEquals(2.0, result.get(3).getSecond().getFloat(1), 1e-5); + assertEquals(8.0, result.get(4).getSecond().getFloat(0), 1e-5); + assertEquals(1.0, result.get(4).getSecond().getFloat(1), 1e-5); + assertEquals(9.0, result.get(5).getSecond().getFloat(0), 1e-5); + assertEquals(6.0, result.get(5).getSecond().getFloat(1), 1e-5); } @Test public void testKNN_4() { - double[] data = new double[]{2, 3}; - val result = kdTree.knn(Nd4j.createFromArray(data), 5.0); - assertEquals(2.0, result.get(0).getSecond().getDouble(0), 1e-5); - assertEquals(3.0, result.get(0).getSecond().getDouble(1), 1e-5); - assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); - assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5); - assertEquals(4.0, result.get(2).getSecond().getDouble(0), 1e-5); - assertEquals(7.0, result.get(2).getSecond().getDouble(1), 1e-5); + float[] data = new float[]{2, 3}; + List> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f); + assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5); + assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5); + assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5); + assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5); + assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5); + assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5); } @Test public void testKNN_5() { - double[] data = new double[]{2, 3}; - val result = kdTree.knn(Nd4j.createFromArray(data), 20.0); - assertEquals(2.0, result.get(0).getSecond().getDouble(0), 1e-5); - assertEquals(3.0, result.get(0).getSecond().getDouble(1), 1e-5); - assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); - assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5); - assertEquals(4.0, result.get(2).getSecond().getDouble(0), 1e-5); - assertEquals(7.0, result.get(2).getSecond().getDouble(1), 1e-5); - assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5); - assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5); - assertEquals(8.0, result.get(4).getSecond().getDouble(0), 1e-5); - assertEquals(1.0, result.get(4).getSecond().getDouble(1), 1e-5); - assertEquals(9.0, result.get(5).getSecond().getDouble(0), 1e-5); - assertEquals(6.0, result.get(5).getSecond().getDouble(1), 1e-5); + float[] data = new float[]{2, 3}; + List> result = kdTree.knn(Nd4j.createFromArray(data), 20.0f); + assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5); + assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5); + assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5); + assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5); + assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5); + assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5); + assertEquals(7.0, result.get(3).getSecond().getFloat(0), 1e-5); + assertEquals(2.0, result.get(3).getSecond().getFloat(1), 1e-5); + assertEquals(8.0, result.get(4).getSecond().getFloat(0), 1e-5); + assertEquals(1.0, result.get(4).getSecond().getFloat(1), 1e-5); + assertEquals(9.0, result.get(5).getSecond().getFloat(0), 1e-5); + assertEquals(6.0, result.get(5).getSecond().getFloat(1), 1e-5); } @Test public void test_KNN_6() { - double[] data = new double[]{4, 6}; - val result = kdTree.knn(Nd4j.createFromArray(data), 10.0); + float[] data = new float[]{4, 6}; + List> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f); assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5); assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5); assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); @@ -318,8 +327,8 @@ public class KDTreeTest extends BaseDL4JTest { @Test public void test_KNN_7() { - double[] data = new double[]{4, 6}; - val result = kdTree.knn(Nd4j.createFromArray(data), 5.0); + float[] data = new float[]{4, 6}; + List> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f); assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5); assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5); assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); @@ -334,8 +343,8 @@ public class KDTreeTest extends BaseDL4JTest { @Test public void test_KNN_8() { - double[] data = new double[]{4, 6}; - val result = kdTree.knn(Nd4j.createFromArray(data), 20.0); + float[] data = new float[]{4, 6}; + List> result = kdTree.knn(Nd4j.createFromArray(data), 20.0f); assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5); assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5); assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); @@ -392,12 +401,12 @@ public class KDTreeTest extends BaseDL4JTest { Duration duration = new Duration(start, end); System.out.println("Elapsed time for tree construction " + duration.getStandardSeconds() + " " + duration.getMillis()); - List pt = new ArrayList(num); + List pt = new ArrayList(num); for (int k = 0; k < n; k++) { - pt.add((double)(num / 2)); + pt.add((float)(num / 2)); } start = System.currentTimeMillis(); - List> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Doubles.toArray(pt))), 20.0); + List> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Doubles.toArray(pt))), 20.0f); end = System.currentTimeMillis(); duration = new Duration(start, end); long elapsed = end - start; diff --git a/libnd4j/include/ops/declarable/CustomOperations.h b/libnd4j/include/ops/declarable/CustomOperations.h index 9162d89bf..5aea215c1 100644 --- a/libnd4j/include/ops/declarable/CustomOperations.h +++ b/libnd4j/include/ops/declarable/CustomOperations.h @@ -39,6 +39,7 @@ #include #include #include +#include #include #include #include diff --git a/libnd4j/include/ops/declarable/generic/kernels/knn_mindistance.cpp b/libnd4j/include/ops/declarable/generic/kernels/knn_mindistance.cpp new file mode 100644 index 000000000..a7e825a9c --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/kernels/knn_mindistance.cpp @@ -0,0 +1,59 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_knn_mindistance) + +#include +#include + +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(knn_mindistance, 3, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto lowest = INPUT_VARIABLE(1); + auto highest = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(input->lengthOf() == lowest->lengthOf() && input->lengthOf() == highest->lengthOf(), 0, "knn_mindistance: all input arrays must have same length"); + REQUIRE_TRUE(input->dataType() == lowest->dataType() && input->dataType() == highest->dataType() && input->dataType() == output->dataType(), 0, "knn_mindistance: all inputs must have the same data type"); + + helpers::knn_mindistance(*input, *lowest, *highest, *output); + + return Status::OK(); + } + + DECLARE_SHAPE_FN(knn_mindistance) { + auto input = inputShape->at(0); + + // always return scalar here + return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(input))); + } + + DECLARE_TYPES(knn_mindistance) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/kernels.h b/libnd4j/include/ops/declarable/headers/kernels.h new file mode 100644 index 000000000..8fb2bab62 --- /dev/null +++ b/libnd4j/include/ops/declarable/headers/kernels.h @@ -0,0 +1,34 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef LIBND4J_KERNELS_H +#define LIBND4J_KERNELS_H + +#include + +namespace nd4j { + namespace ops { + #if NOT_EXCLUDED(OP_knn_mindistance) + DECLARE_CUSTOM_OP(knn_mindistance, 3, 1, false, 0, 0); + #endif + } +} + +#endif //LIBND4J_KERNELS_H diff --git a/libnd4j/include/ops/declarable/helpers/impl/knn_mindistance.cpp b/libnd4j/include/ops/declarable/helpers/impl/knn_mindistance.cpp new file mode 100644 index 000000000..71711832d --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/impl/knn_mindistance.cpp @@ -0,0 +1,62 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace nd4j { + namespace ops { + namespace helpers { + template + void mindistance_(const void* vinput, const void *vlow, const void *vhigh, int32_t length, void *vout) { + auto input = reinterpret_cast(vinput); + auto low = reinterpret_cast(vlow); + auto high = reinterpret_cast(vhigh); + auto output = reinterpret_cast(vout); + + T res = 0.0f; + T po = 2.f; + T o = 1.f; + +#pragma omp simd reduction(sumT:res) + for (auto e = 0; e < length; e++) { + T p = input[e]; + T l = low[e]; + T h = high[e]; + if (!(l <= p || h <= p)) { + if (p < l) + res += nd4j::math::nd4j_pow((p - o), po); + else + res += nd4j::math::nd4j_pow((p - h), po); + } + } + + output[0] = nd4j::math::nd4j_pow(res, (T) 0.5f); + } + + void knn_mindistance(const NDArray &input, const NDArray &lowest, const NDArray &highest, NDArray &output) { + NDArray::preparePrimaryUse({&output}, {&input, &lowest, &highest}); + + BUILD_SINGLE_SELECTOR(input.dataType(), mindistance_, (input.getBuffer(), lowest.getBuffer(), highest.getBuffer(), input.lengthOf(), output.buffer()), FLOAT_TYPES); + + NDArray::registerPrimaryUse({&output}, {&input, &lowest, &highest}); + } + } + } +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/knn.h b/libnd4j/include/ops/declarable/helpers/knn.h new file mode 100644 index 000000000..a2de9c71c --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/knn.h @@ -0,0 +1,34 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SAMEDIFF_KNN_H +#define SAMEDIFF_KNN_H + +#include + +namespace nd4j { + namespace ops { + namespace helpers { + void knn_mindistance(const NDArray &input, const NDArray &lowest, const NDArray &highest, NDArray &output); + } + } +} + +#endif //SAMEDIFF_KNN_H diff --git a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp index 3dea41a18..fe1574ea1 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp @@ -217,7 +217,7 @@ namespace nd4j { auto var = ctx.variable(pair); auto shape = var->getNDArray()->shapeInfo(); - if (!shape::equalsSoft(out, shape)) { + if (!shape::equalsSoft(out, shape) || shape::isEmpty(out) != shape::isEmpty(shape)) { auto eShape = ShapeUtils::shapeAsString(out); auto aShape = ShapeUtils::shapeAsString(shape); @@ -237,7 +237,7 @@ namespace nd4j { ctx.setOutputArray(idx, outArr, true); } else { auto array = fout[idx]; - if (!shape::equalsSoft(out, array->shapeInfo())) { + if (!shape::equalsSoft(out, array->shapeInfo()) || shape::isEmpty(out) != array->isEmpty()) { auto eShape = ShapeUtils::shapeAsString(out); auto aShape = ShapeUtils::shapeAsString(array->shapeInfo()); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp index adbff7f83..d95e86b1c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp @@ -133,4 +133,21 @@ TEST_F(DeclarableOpsTests16, test_hamming_distance_1) { ASSERT_EQ(e, *z); delete result; +} + +TEST_F(DeclarableOpsTests16, test_knn_mindistance_1) { + auto input = NDArrayFactory::create('c', {512}); + auto low = NDArrayFactory::create('c', {512}); + auto high = NDArrayFactory::create('c', {512}); + + auto output = NDArrayFactory::create(0.0f); + + input.linspace(1.0); + low.linspace(1.0); + high.linspace(1.0); + + nd4j::ops::knn_mindistance op; + auto result = op.execute({&input, &low, &high}, {&output}, {}, {}, {}); + ASSERT_EQ(Status::OK(), result); + } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/KnnMinDistance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/KnnMinDistance.java new file mode 100644 index 000000000..16656766f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/KnnMinDistance.java @@ -0,0 +1,23 @@ +package org.nd4j.linalg.api.ops.custom; + +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +public class KnnMinDistance extends DynamicCustomOp { + + public KnnMinDistance() { + } + + public KnnMinDistance(INDArray point, INDArray lowest, INDArray highest, INDArray distance) { + inputArguments.add(point); + inputArguments.add(lowest); + inputArguments.add(highest); + + outputArguments.add(distance); + } + + @Override + public String opName() { + return "knn_mindistance"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java index cf779f537..32f1b0a10 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java @@ -52,20 +52,26 @@ public class CudaOpContext extends BaseOpContext implements OpContext { @Override public void setIArguments(long... arguments) { - super.setIArguments(arguments); - nativeOps.setGraphContextIArguments(context, new LongPointer(arguments), arguments.length); + if (arguments.length > 0) { + super.setIArguments(arguments); + nativeOps.setGraphContextIArguments(context, new LongPointer(arguments), arguments.length); + } } @Override public void setBArguments(boolean... arguments) { - super.setBArguments(arguments); - nativeOps.setGraphContextBArguments(context, new BooleanPointer(arguments), arguments.length); + if (arguments.length > 0) { + super.setBArguments(arguments); + nativeOps.setGraphContextBArguments(context, new BooleanPointer(arguments), arguments.length); + } } @Override public void setTArguments(double... arguments) { - super.setTArguments(arguments); - nativeOps.setGraphContextTArguments(context, new DoublePointer(arguments), arguments.length); + if (arguments.length > 0) { + super.setTArguments(arguments); + nativeOps.setGraphContextTArguments(context, new DoublePointer(arguments), arguments.length); + } } @Override diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java index 8db359d01..9431a3453 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java @@ -49,20 +49,26 @@ public class CpuOpContext extends BaseOpContext implements OpContext { @Override public void setIArguments(long... arguments) { - super.setIArguments(arguments); - nativeOps.setGraphContextIArguments(context, new LongPointer(arguments), arguments.length); + if (arguments.length > 0) { + super.setIArguments(arguments); + nativeOps.setGraphContextIArguments(context, new LongPointer(arguments), arguments.length); + } } @Override public void setBArguments(boolean... arguments) { - super.setBArguments(arguments); - nativeOps.setGraphContextBArguments(context, new BooleanPointer(arguments), arguments.length); + if (arguments.length > 0) { + super.setBArguments(arguments); + nativeOps.setGraphContextBArguments(context, new BooleanPointer(arguments), arguments.length); + } } @Override public void setTArguments(double... arguments) { - super.setTArguments(arguments); - nativeOps.setGraphContextTArguments(context, new DoublePointer(arguments), arguments.length); + if (arguments.length > 0) { + super.setTArguments(arguments); + nativeOps.setGraphContextTArguments(context, new DoublePointer(arguments), arguments.length); + }; } @Override diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 8d92e09ad..047cb4021 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -1,4 +1,4 @@ -// Targeted by JavaCPP version 1.5.2-SNAPSHOT: DO NOT EDIT THIS FILE +// Targeted by JavaCPP version 1.5.1-1: DO NOT EDIT THIS FILE package org.nd4j.nativeblas; @@ -4703,6 +4703,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * k - depth * value - scalar value to assign */ + public native void p(@Cast("const Nd4jLong") long i, @Cast("const Nd4jLong") long j, @Cast("const Nd4jLong") long k, @Cast("const Nd4jLong") long l, @Const @ByRef NDArray value); /** * creates array which points on certain sub-range of this array, sub-range is defined by given indices @@ -4931,7 +4932,7 @@ NDArray NDArray::operator()(const Nd4jLong i) const { } else { Nd4jLong idx[MAX_RANK]; shape::ind2subC(rankOf(), shapeOf(), i, idx); - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), idx, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), idx); auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); @@ -4962,7 +4963,7 @@ NDArray& NDArray::operator()(const Nd4jLong i) { } else { Nd4jLong idx[MAX_RANK]; shape::ind2subC(rankOf(), shapeOf(), i, idx); - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), idx, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), idx); auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); @@ -4979,7 +4980,7 @@ NDArray NDArray::operator()(const Nd4jLong i, const Nd4jLong j) const { throw std::invalid_argument("NDArray::operator(i,j): one of input indexes is out of array length or rank!=2 !"); Nd4jLong coords[2] = {i, j}; - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), coords); // TODO: do we really want a view here? auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); @@ -4995,7 +4996,7 @@ NDArray& NDArray::operator()(const Nd4jLong i, const Nd4jLong j) { throw std::invalid_argument("NDArray::operator(i,j): one of input indexes is out of array length or rank!=2 !"); Nd4jLong coords[2] = {i, j}; - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), coords); auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); @@ -5014,7 +5015,7 @@ NDArray NDArray::operator()(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k throw std::invalid_argument("NDArray::operator(i,j,k): one of input indexes is out of array length or rank!=3 !"); Nd4jLong coords[3] = {i, j, k}; - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), coords); auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); @@ -5031,7 +5032,7 @@ NDArray& NDArray::operator()(const Nd4jLong i, const Nd4jLong j, const Nd4jLong throw std::invalid_argument("NDArray::operator(i,j,k): one of input indexes is out of array length or rank!=3 !"); Nd4jLong coords[3] = {i, j, k}; - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), coords); auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); @@ -5047,7 +5048,7 @@ NDArray NDArray::operator()(const Nd4jLong t, const Nd4jLong u, const Nd4jLong v throw std::invalid_argument("NDArray::operator(t,u,v,w): one of input indexes is out of array length or rank!=4 !"); Nd4jLong coords[4] = {t, u, v, w}; - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), coords); auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); @@ -5061,7 +5062,7 @@ NDArray& NDArray::operator()(const Nd4jLong t, const Nd4jLong u, const Nd4jLong throw std::invalid_argument("NDArray::operator(t,u,v,w): one of input indexes is out of array length or rank!=4 !"); Nd4jLong coords[4] = {t, u, v, w}; - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), coords); // FIXME auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); @@ -5077,7 +5078,7 @@ NDArray NDArray::operator()(const Nd4jLong* idx) const { if (idx[i] >= sizeAt(i)) throw std::invalid_argument("NDArray::operator(const Nd4jLong* idx): input index is out of dimension length !"); - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), idx, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), idx); auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); @@ -5092,7 +5093,7 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { if (idx[i] >= sizeAt(i)) throw std::invalid_argument("NDArray::operator(const Nd4jLong* idx): input index is out of dimension length !"); - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), idx, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), idx); auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); @@ -7958,9 +7959,7 @@ public static final int PREALLOC_SIZE = 33554432; * @param indices the indices to iterate over * @return the double at the specified index */ - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer stride, @Cast("const Nd4jLong*") LongPointer indices, int rank); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer stride, @Cast("const Nd4jLong*") LongBuffer indices, int rank); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] stride, @Cast("const Nd4jLong*") long[] indices, int rank); + @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer indices, @Cast("Nd4jLong") long baseOffset/*=0*/); @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer indices); @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer indices, @Cast("Nd4jLong") long baseOffset/*=0*/); @@ -7981,34 +7980,26 @@ public static final int PREALLOC_SIZE = 33554432; /** * Convert a linear index to the corresponding coordinates - * for example if shape is {2, 4}, then index 5 corresponds to following coordinates - * -> [1, 1] in case of c order - * -> [1, 2] in case of f order + * for example if shape is {2, 4}, then index 5 corresponds to coordinates [1, 1] */ - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") LongPointer coords, byte order/*='c'*/); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") LongPointer coords); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") LongBuffer coords, byte order/*='c'*/); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") LongBuffer coords); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") long[] coords, byte order/*='c'*/); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") long[] coords); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") LongPointer coords, byte order/*='c'*/); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") LongPointer coords); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") LongBuffer coords, byte order/*='c'*/); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") LongBuffer coords); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") long[] coords, byte order/*='c'*/); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") long[] coords); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer coords); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer coords); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] coords); + + /** * Convert coordinates to the corresponding linear index (sequence number in other words) - * for example if shape is {2, 4}, then: - * in case of c order and coordinates [1, 1] index 5 is returned - * in case of f order and coordinates [1, 2] index 5 is returned + * for example if shape is {2, 4} and coordinates [1, 1] then index 5 is returned */ - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords, byte order/*='c'*/); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords, byte order/*='c'*/); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords, byte order/*='c'*/); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords); /** @@ -8020,36 +8011,16 @@ public static final int PREALLOC_SIZE = 33554432; */ /* calculates an array buffer offset for given "index" using following formula: offset = coord_0*stride_0 + coord_1*stride_1 + ... + coord_{rank-1}*stride_{rank-1} - * arrLen - array length */ - @Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") IntPointer shapeInfo, @Cast("uint") int arrLen); - @Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") IntBuffer shapeInfo, @Cast("uint") int arrLen); - @Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") int[] shapeInfo, @Cast("uint") int arrLen); - @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong") long arrLen); - @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong") long arrLen); - @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong") long arrLen); - @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOrderOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong") long arrLen, byte order); - @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOrderOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong") long arrLen, byte order); - @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOrderOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong") long arrLen, byte order); - @Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer lShapeInfo, @Cast("const uint*") IntPointer uShapeInfo, @Cast("Nd4jLong") long arrLen, @Cast("const bool") boolean useUnsigned); - @Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer lShapeInfo, @Cast("const uint*") IntBuffer uShapeInfo, @Cast("Nd4jLong") long arrLen, @Cast("const bool") boolean useUnsigned); - @Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] lShapeInfo, @Cast("const uint*") int[] uShapeInfo, @Cast("Nd4jLong") long arrLen, @Cast("const bool") boolean useUnsigned); - - /** - * Compute the real linear indices for the given shape and stride - */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer computeIndices(int rank, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer computeIndices(int rank, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] computeIndices(int rank, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] stride); - - /** - * Compute the real linear indices for the - * given shape buffer. Shape,stride and rank are derived - * from the buffer - */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer computeIndices( @Cast("Nd4jLong*") LongPointer shapeBuffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer computeIndices( @Cast("Nd4jLong*") LongBuffer shapeBuffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] computeIndices( @Cast("Nd4jLong*") long[] shapeBuffer); + @Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") IntPointer shapeInfo); + @Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") IntBuffer shapeInfo); + @Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") int[] shapeInfo); + @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo); + @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo); + @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo); + @Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer lShapeInfo, @Cast("const uint*") IntPointer uShapeInfo, @Cast("const bool") boolean useUnsigned); + @Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer lShapeInfo, @Cast("const uint*") IntBuffer uShapeInfo, @Cast("const bool") boolean useUnsigned); + @Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] lShapeInfo, @Cast("const uint*") int[] uShapeInfo, @Cast("const bool") boolean useUnsigned); @Namespace("shape") public static native void printShapeInfo(@Cast("Nd4jLong*") LongPointer shapeInfo); @Namespace("shape") public static native void printShapeInfo(@Cast("Nd4jLong*") LongBuffer shapeInfo); @@ -8328,20 +8299,62 @@ public static final int PREALLOC_SIZE = 33554432; * for the given rank and shape. */ -/** - * Compute the real linear indices for the given shape and stride - */ - -/** -* Compute the real linear indices for the given shape and stride -*/ - +////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////////// +// INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong arrLen) { + +// const Nd4jLong ews = shapeInfo[shapeInfo[0] + shapeInfo[0] + 2]; + +// if(ews > 0 && order(shapeInfo) == 'c') +// if (ews == 1) +// return index; +// else +// return ews * index; + +// Nd4jLong offset = 0; +// Nd4jLong rank = shapeInfo[0]; +// for(int i = 1; i <= shapeInfo[0]; ++i) { +// arrLen /= shapeInfo[i]; +// if(arrLen > 0 && shapeInfo[i] > 1) { +// offset += (index / arrLen) * shapeInfo[i + rank]; +// index %= arrLen; +// } +// } +// return offset; +// } + +// INLINEDEF _CUDA_HD uint getIndexOffset(uint index, const uint *shapeInfo, uint arrLen) { + +// const uint rank = shapeInfo[0]; +// const uint ews = shapeInfo[rank + rank + 2]; + +// if(ews > 0 && shapeInfo[rank + rank + 3] == 99) +// if (ews == 1) +// return index; +// else +// return ews * index; + +// uint offset = 0; + +// for(uint i = 1; i <= rank; ++i) { +// arrLen /= shapeInfo[i]; +// if(arrLen > 0 && shapeInfo[i] > 1) { +// offset += (index / arrLen) * shapeInfo[i + rank]; +// index %= arrLen; +// } +// } +// return offset; +// } + ////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////// + + ////////////////////////////////////////////////////////////////////// /** @@ -8710,6 +8723,10 @@ public static final int PREALLOC_SIZE = 33554432; * @return the double at the specified index */ +////////////////////////////////////////////////////////////////////////// + +////////////////////////////////////////////////////////////////////////// + @@ -9038,6 +9055,8 @@ public static final int PREALLOC_SIZE = 33554432; ////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////// + @@ -13850,6 +13869,31 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif + /** + * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: + * 1) if shapes are equal that's pairwise operation, result will have the same shape. + * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. + * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. + * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. + * + * This operation returns Z = Divide(X, Y) with exception, 0 if Y = 0 + */ +// #if NOT_EXCLUDED(OP_divide_no_nan) + @Namespace("nd4j::ops") public static class divide_no_nan extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public divide_no_nan(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public divide_no_nan(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public divide_no_nan position(long position) { + return (divide_no_nan)super.position(position); + } + + public divide_no_nan() { super((Pointer)null); allocate(); } + private native void allocate(); + } +// #endif /** * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: * 1) if shapes are equal that's pairwise operation, result will have the same shape. @@ -14385,6 +14429,54 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); private native void allocate(); } // #endif + + /** + * Broadcastable igamma implementation + * + * igamma(a, x) = gamma(а, x) / Gamma(a) - Gamma distribution function P(a,x) + * Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt } + * gamma(a, x) = int from 0 to x { t ^ {a - 1} e^{-t}dt } + * \tparam T + */ +// #if NOT_EXCLUDED(OP_igamma) + @Namespace("nd4j::ops") public static class igamma extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public igamma(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public igamma(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public igamma position(long position) { + return (igamma)super.position(position); + } + + public igamma() { super((Pointer)null); allocate(); } + private native void allocate(); + } +// #endif + /** + * Broadcastable igammac implementation + * igammac(a, x) = Gamma(a,x)/Gamma(а) - Gamma distribution function Q(a,x) + * Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt } + * Gamma(a, x) = int from x to infinity { t ^ {a - 1} e^{-t}dt } + * \tparam T + */ +// #if NOT_EXCLUDED(OP_igammac) + @Namespace("nd4j::ops") public static class igammac extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public igammac(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public igammac(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public igammac position(long position) { + return (igammac)super.position(position); + } + + public igammac() { super((Pointer)null); allocate(); } + private native void allocate(); + } +// #endif @@ -15842,6 +15934,26 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif + ////////////////////////////////////////////////////////////////////////// +// #if NOT_EXCLUDED(OP_lstmLayer) + @Namespace("nd4j::ops") public static class lstmLayer extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstmLayer(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstmLayer(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstmLayer position(long position) { + return (lstmLayer)super.position(position); + } + + public lstmLayer() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + + ////////////////////////////////////////////////////////////////////////// /** * Implementation of operations for Simple Recurrent Unit cell: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi @@ -17079,16 +17191,16 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * Inserts elements provided by diagonal array into the main diagonal of innermost matrices of input array * * Input arrays: - * input: input array, considered as batch of matrices - * diagonal: array containing elements to be inserted into input array, - * following rank condition should be satisfied: diagonal_rank = input_rank - 1, - * the shapes of diagonal and input arrays must be equal except last dimension of input array, - * for example if input_shape = [A,B,C,D] then diagonal_shape = [A,B,C], - * also last dimension of diagonal array should be equal to smaller of last and last but one input dimensions - * that is: diagonal_shape[-1] = min(input_shape[-1], input_shape[-2]) + * 0: input array, considered as batch of matrices + * 1: diagonal array containing elements to be inserted into input array, + * following rank condition should be satisfied: diagonal_rank = input_rank - 1, + * the shapes of diagonal and input arrays must be equal except last dimension of input array, + * for example if input_shape = [A,B,C,D] then diagonal_shape = [A,B,C], + * also last dimension of diagonal array should be equal to smaller of last and last but one input dimensions + * that is: diagonal_shape[-1] = min(input_shape[-1], input_shape[-2]) * * Output array: - * has the same shape as input, corresponding diagonal elements are substituted + * 0: has the same shape as input, corresponding diagonal elements are substituted */ // #if NOT_EXCLUDED(OP_matrix_set_diag) @Namespace("nd4j::ops") public static class matrix_set_diag extends DeclarableOp { @@ -17109,8 +17221,16 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #endif /** - * Returns a batched matrix tensor with diagonal values given (as TF.matrix_diag). - */ + * Inserts elements provided by diagonal array into the main diagonal of innermost matrices of output array, + * rest output elements are set to zeros + * + * Input array: + * diagonal: array containing elements to be inserted into output array, + * following rank condition is present: diagonal_rank = ouput_rank - 1 + * + * Output array: + * 0: is considered as batch of matrices, if for example diagonal array has shape [A,B,C] then output array has shape [A,B,C,C] + */ @Namespace("nd4j::ops") public static class matrix_diag extends DeclarableCustomOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ @@ -17130,13 +17250,13 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /** * This op calculates regularized incomplete beta integral Ix(a, b). * Implementation is based on two algorithms depending on input values of a and b: - * - when a and b are both > maxValue (3000.), then apply Gauss-Legendre quadrature method - * - when a and b are both <= maxValue (3000.), then apply modified Lentz’s algorithm for continued fractions + * - when a and b are both > maxValue (3000.), then Gauss-Legendre quadrature method is applied + * - when a and b are both <= maxValue (3000.), then modified Lentz’s algorithm for continued fractions is applied * * Input arrays: - * a: define power t^{a-1}, must be > 0, type float. - * b: define power (1-t)^{b-1}, must be > 0, type float. - * x: define upper limit of integration, must be within (0 <= x <= 1) range, type float. + * a: defines power t^{a-1}, must be > 0, type float. + * b: defines power (1-t)^{b-1}, must be > 0, type float. + * x: defines upper limit of integration, must be within (0 <= x <= 1) range, type float. * * Output array: * 0: values of regularized incomplete beta integral that corresponds to variable upper limit x, type float @@ -18250,6 +18370,50 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif + /** + * This operation adjusts image contrast by given factor ( z = (x - mean) * factor + mean ) + * Input arrays: + * 0 - input array with rank >= 3, must have last one dimension equal 3, that is dimension containing channels. + * + * T arguments: + * 0 - contrast factor + * + */ +// #if NOT_EXCLUDED(OP_adjust_contrast) + @Namespace("nd4j::ops") public static class adjust_contrast extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public adjust_contrast(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public adjust_contrast(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public adjust_contrast position(long position) { + return (adjust_contrast)super.position(position); + } + + public adjust_contrast() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } + @Namespace("nd4j::ops") public static class adjust_contrast_v2 extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public adjust_contrast_v2(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public adjust_contrast_v2(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public adjust_contrast_v2 position(long position) { + return (adjust_contrast_v2)super.position(position); + } + + public adjust_contrast_v2() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + + + /** * This operation rearranges data from depth into blocks of spatial data. This is the reverse transformation @@ -19634,6 +19798,37 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif + /** + * draw_bounding_boxes op - modified input image with given colors exept given boxes. + * + * input params: + * 0 - images tensor (4D) with shape {batch, width, height, channels}, where channes is 1 (BW image), + * 3 (RGB) or 4 (RGBA) + * 1 - boxes tensor (3D) with shape {batch, number_of_boxes, 4} where last dimension encoded as + * (y_min, x_min, y_max, x_max), all values in between 0. and 1. + * 2 - colours tensor (2D) with shape {number_of_boxes, channels} -- bordering color set (palette) + * + * output: + * 0 - 4D tensor with same shape as images (input 0) + */ +// #if NOT_EXCLUDED(OP_draw_bounding_boxes) + @Namespace("nd4j::ops") public static class draw_bounding_boxes extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public draw_bounding_boxes(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public draw_bounding_boxes(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public draw_bounding_boxes position(long position) { + return (draw_bounding_boxes)super.position(position); + } + + public draw_bounding_boxes() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + /** * roll - op porting from numpy (https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.roll.html) * @@ -20623,6 +20818,67 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif +/** + * fake_quant_with_min_max_vals_per_channel - tf.quantization.fake_quant_with_min_max_vars_per_channel + * + * input params: + * 0 - NDArray (input) - at least 2D. + * 1 - 1D Tensor - min values (min length equals to last dim of input) + * 2 - 1D Tensor - max value (length equals to min) + * + * int params (optional): + * 0 - num_bits (allowed interval [2, 16], default 8) + * 1 - narrow_range (default False) + * + * output: + * 0 - NDArray with the same shape as input + */ +// #if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars_per_channel) + @Namespace("nd4j::ops") public static class fake_quant_with_min_max_vars_per_channel extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public fake_quant_with_min_max_vars_per_channel(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public fake_quant_with_min_max_vars_per_channel(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public fake_quant_with_min_max_vars_per_channel position(long position) { + return (fake_quant_with_min_max_vars_per_channel)super.position(position); + } + + public fake_quant_with_min_max_vars_per_channel() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + + /** + * compare_and_bitpack - compare with greater and pack result with uint8 + * + * input params: + * 0 - NDArray (input) + * 1 - 0D Tensor - threshold + * + * + * output: + * 0 - NDArray with the same shape as input and type uint8 + */ +// #if NOT_EXCLUDED(OP_compare_and_bitpack) + @Namespace("nd4j::ops") public static class compare_and_bitpack extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public compare_and_bitpack(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public compare_and_bitpack(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public compare_and_bitpack position(long position) { + return (compare_and_bitpack)super.position(position); + } + + public compare_and_bitpack() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif @@ -23102,6 +23358,28 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } // #endif + /** + * This operation change type of input and modified shape of output to conform with given data type + * + * all as above op + * */ +// #if NOT_EXCLUDED(OP_bitcast) + @Namespace("nd4j::ops") public static class bitcast extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public bitcast(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public bitcast(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public bitcast position(long position) { + return (bitcast)super.position(position); + } + + public bitcast() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index d93c934f0..ad38f39d7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -810,7 +810,7 @@ public class CustomOpsTests extends BaseNd4jTest { @Test public void testAdjustContrast() { INDArray in = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 4*4*3).reshape(4,4,3); - INDArray out = Nd4j.zeros(4,4,3); + INDArray out = Nd4j.zeros(DataType.DOUBLE,4, 4, 3); INDArray expected = Nd4j.createFromArray(new double[]{-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5, 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5, @@ -920,4 +920,15 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, output); } + + @Test + public void testKnnMinDistance() { + INDArray point = Nd4j.rand(DataType.FLOAT, 1, 20); + INDArray lowest = Nd4j.rand(DataType.FLOAT, 1, 20); + INDArray highest = Nd4j.rand(DataType.FLOAT, 1, 20); + INDArray distance = Nd4j.scalar(0.f); + + Nd4j.exec(new KnnMinDistance(point, lowest, highest, distance)); + System.out.println(distance); + } }