From 0e8a4f77bc0365afb07f99482b1314536370bf85 Mon Sep 17 00:00:00 2001 From: Fariz Rahman Date: Thu, 5 Dec 2019 17:57:32 +0530 Subject: [PATCH] datavec python ensure host (#113) * ensure host * one more host ensure * info->debug --- .../src/main/java/org/datavec/python/NumpyArray.java | 8 ++++++-- .../main/java/org/datavec/python/PythonExecutioner.java | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java b/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java index ab49cf5ea..24a2c2e09 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java @@ -21,6 +21,7 @@ import lombok.Getter; import lombok.NoArgsConstructor; import org.bytedeco.javacpp.Pointer; import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; @@ -60,6 +61,7 @@ public class NumpyArray { setND4JArray(); if (copy){ nd4jArray = nd4jArray.dup(); + Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST); this.address = nd4jArray.data().address(); } @@ -85,6 +87,7 @@ public class NumpyArray { setND4JArray(); if (copy){ nd4jArray = nd4jArray.dup(); + Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST); this.address = nd4jArray.data().address(); } } @@ -104,11 +107,12 @@ public class NumpyArray { nd4jStrides[i] = strides[i] / elemSize; } - this.nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape,nd4jStrides,1), dtype); - + nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape,nd4jStrides,1), dtype); + Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST); } public NumpyArray(INDArray nd4jArray){ + Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST); DataBuffer buff = nd4jArray.data(); address = buff.pointer().address(); shape = nd4jArray.shape(); diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java index c6272e7ad..0f926b9ad 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java @@ -605,7 +605,7 @@ public class PythonExecutioner { private static synchronized void _exec(String code) { - log.info(code); + log.debug(code); log.info("CPython: PyRun_SimpleStringFlag()"); int result = PyRun_SimpleStringFlags(code, null);