diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml
index db110703b..076c22ab9 100644
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml
+++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml
@@ -38,7 +38,7 @@
org.datavec
datavec-spark-inference-server_2.11
- 1.0.0_spark_2-SNAPSHOT
+ 1.0.0-SNAPSHOT
test
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml
index 605b13b70..8bef216a7 100644
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml
+++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml
@@ -25,7 +25,7 @@
datavec-spark-inference-server_2.11
jar
- 1.0.0_spark_2-SNAPSHOT
+ 1.0.0-SNAPSHOT
datavec-spark-inference-server
diff --git a/datavec/datavec-spark/pom.xml b/datavec/datavec-spark/pom.xml
index 05c505cac..f7143c6ea 100644
--- a/datavec/datavec-spark/pom.xml
+++ b/datavec/datavec-spark/pom.xml
@@ -24,7 +24,7 @@
4.0.0
- 1.0.0_spark_2-SNAPSHOT
+ 1.0.0-SNAPSHOT
datavec-spark_2.11
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/RandomTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/RandomTests.java
new file mode 100644
index 000000000..8f727fdf9
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/RandomTests.java
@@ -0,0 +1,63 @@
+package org.deeplearning4j;
+
+import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
+import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
+import org.deeplearning4j.nn.api.OptimizationAlgorithm;
+import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.nd4j.linalg.activations.Activation;
+import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.learning.config.RmsProp;
+import org.nd4j.linalg.lossfunctions.LossFunctions;
+
+import java.util.concurrent.CountDownLatch;
+
+@Ignore
+public class RandomTests {
+
+ @Test
+ public void testReproduce() throws Exception {
+
+ final MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp())
+ .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
+ .layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(28 * 28).nOut(10)
+ .activation(Activation.TANH).build())
+ .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
+ LossFunctions.LossFunction.MCXENT).nIn(10).nOut(10)
+ .activation(Activation.SOFTMAX).build())
+ .build();
+
+ for (int e = 0; e < 3; e++) {
+
+ int nThreads = 10;
+ final CountDownLatch l = new CountDownLatch(nThreads);
+ for (int i = 0; i < nThreads; i++) {
+ final int j = i;
+ Thread t = new Thread(new Runnable() {
+ @Override
+ public void run() {
+ try {
+ MultiLayerNetwork net = new MultiLayerNetwork(conf.clone());
+ net.init();
+ DataSetIterator iter = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(10, false, 12345), 100);
+ net.fit(iter);
+ } catch (Throwable t) {
+ System.out.println("Thread failed: " + j);
+ t.printStackTrace();
+ } finally {
+ l.countDown();
+ }
+ }
+ });
+ t.start();
+ }
+
+ l.await();
+ System.out.println("DONE " + e + "\n");
+ }
+ }
+}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java
index f4dd1a6c5..69eae7307 100755
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java
@@ -833,14 +833,14 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
public void testB64_1() throws Exception {
String wordA = "night";
String wordB = "night day";
- String encA = WordVectorSerializer.encodeB64(wordA);
- String encB = WordVectorSerializer.encodeB64(wordB);
+ String encA = WordVectorSerializer.ReadHelper.encodeB64(wordA);
+ String encB = WordVectorSerializer.ReadHelper.encodeB64(wordB);
- assertEquals(wordA, WordVectorSerializer.decodeB64(encA));
- assertEquals(wordB, WordVectorSerializer.decodeB64(encB));
+ assertEquals(wordA, WordVectorSerializer.ReadHelper.decodeB64(encA));
+ assertEquals(wordB, WordVectorSerializer.ReadHelper.decodeB64(encB));
- assertEquals(wordA, WordVectorSerializer.decodeB64(wordA));
- assertEquals(wordB, WordVectorSerializer.decodeB64(wordB));
+ assertEquals(wordA, WordVectorSerializer.ReadHelper.decodeB64(wordA));
+ assertEquals(wordB, WordVectorSerializer.ReadHelper.decodeB64(wordB));
}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java
index cce6a740a..80ce0bf34 100755
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java
@@ -24,7 +24,6 @@ import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.apache.commons.io.output.CloseShieldOutputStream;
-import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
@@ -52,7 +51,6 @@ import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.deeplearning4j.util.DL4JFileUtils;
-import org.nd4j.base.Preconditions;
import org.nd4j.compression.impl.NoOp;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
@@ -68,8 +66,6 @@ import org.nd4j.util.OneTimeLogger;
import java.io.*;
import java.nio.charset.StandardCharsets;
-import java.nio.file.Files;
-import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
@@ -78,6 +74,80 @@ import java.util.zip.*;
/**
* This is utility class, providing various methods for WordVectors serialization
*
+ * List of available serialization methods (please keep this list consistent with source code):
+ *
+ *
+ * - Serializers for Word2Vec:
+ * {@link #writeWordVectors(WeightLookupTable, File)}
+ * {@link #writeWordVectors(WeightLookupTable, OutputStream)}
+ * {@link #writeWord2VecModel(Word2Vec, File)}
+ * {@link #writeWord2VecModel(Word2Vec, String)}
+ * {@link #writeWord2VecModel(Word2Vec, OutputStream)}
+ *
+ * - Deserializers for Word2Vec:
+ * {@link #readWord2VecModel(File)}
+ * {@link #readWord2VecModel(String)}
+ * {@link #readWord2VecModel(File, boolean)}
+ * {@link #readWord2VecModel(String, boolean)}
+ * {@link #readAsBinaryNoLineBreaks(File)}
+ * {@link #readAsBinary(File)}
+ * {@link #readAsCsv(File)}
+ * {@link #readBinaryModel(File, boolean, boolean)}
+ * {@link #readWord2VecFromText(File, File, File, File, VectorsConfiguration)}
+ * {@link #readWord2Vec(String, boolean)}
+ * {@link #readWord2Vec(File, boolean)}
+ * {@link #readWord2Vec(InputStream, boolean)}
+ *
+ * - Serializers for ParaVec:
+ * {@link #writeParagraphVectors(ParagraphVectors, File)}
+ * {@link #writeParagraphVectors(ParagraphVectors, String)}
+ * {@link #writeParagraphVectors(ParagraphVectors, OutputStream)}
+ *
+ * - Deserializers for ParaVec:
+ * {@link #readParagraphVectors(File)}
+ * {@link #readParagraphVectors(String)}
+ * {@link #readParagraphVectors(InputStream)}
+ *
+ * - Serializers for GloVe:
+ * {@link #writeWordVectors(Glove, File)}
+ * {@link #writeWordVectors(Glove, String)}
+ * {@link #writeWordVectors(Glove, OutputStream)}
+ *
+ * - Adapters
+ * {@link #fromTableAndVocab(WeightLookupTable, VocabCache)}
+ * {@link #fromPair(Pair)}
+ * {@link #loadTxt(File)}
+ *
+ * - Serializers to tSNE format
+ * {@link #writeTsneFormat(Glove, INDArray, File)}
+ * {@link #writeTsneFormat(Word2Vec, INDArray, File)}
+ *
+ * - FastText serializer:
+ * {@link #writeWordVectors(FastText, File)}
+ *
+ * - FastText deserializer:
+ * {@link #readWordVectors(File)}
+ *
+ * - SequenceVectors serializers:
+ * {@link #writeSequenceVectors(SequenceVectors, OutputStream)}
+ * {@link #writeSequenceVectors(SequenceVectors, SequenceElementFactory, File)}
+ * {@link #writeSequenceVectors(SequenceVectors, SequenceElementFactory, String)}
+ * {@link #writeSequenceVectors(SequenceVectors, SequenceElementFactory, OutputStream)}
+ * {@link #writeLookupTable(WeightLookupTable, File)}
+ * {@link #writeVocabCache(VocabCache, File)}
+ * {@link #writeVocabCache(VocabCache, OutputStream)}
+ *
+ * - SequenceVectors deserializers:
+ * {@link #readSequenceVectors(File, boolean)}
+ * {@link #readSequenceVectors(String, boolean)}
+ * {@link #readSequenceVectors(SequenceElementFactory, File)}
+ * {@link #readSequenceVectors(InputStream, boolean)}
+ * {@link #readSequenceVectors(SequenceElementFactory, InputStream)}
+ * {@link #readLookupTable(File)}
+ * {@link #readLookupTable(InputStream)}
+ *
+ *
+ *
* @author Adam Gibson
* @author raver119
* @author alexander@skymind.io
@@ -97,7 +167,7 @@ public class WordVectorSerializer {
* @throws IOException
* @throws NumberFormatException
*/
- private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException {
+ /*private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException {
InMemoryLookupTable lookupTable;
VocabCache cache;
INDArray syn0;
@@ -142,7 +212,7 @@ public class WordVectorSerializer {
ret.setLookupTable(lookupTable);
}
return ret;
- }
+ }*/
/**
* Read a binary word2vec file.
@@ -173,8 +243,8 @@ public class WordVectorSerializer {
try (BufferedInputStream bis = new BufferedInputStream(GzipUtils.isCompressedFilename(modelFile.getName())
? new GZIPInputStream(new FileInputStream(modelFile)) : new FileInputStream(modelFile));
DataInputStream dis = new DataInputStream(bis)) {
- words = Integer.parseInt(readString(dis));
- size = Integer.parseInt(readString(dis));
+ words = Integer.parseInt(ReadHelper.readString(dis));
+ size = Integer.parseInt(ReadHelper.readString(dis));
syn0 = Nd4j.create(words, size);
cache = new AbstractCache<>();
@@ -188,11 +258,11 @@ public class WordVectorSerializer {
float[] vector = new float[size];
for (int i = 0; i < words; i++) {
- word = readString(dis);
+ word = ReadHelper.readString(dis);
log.trace("Loading " + word + " with word " + i);
for (int j = 0; j < size; j++) {
- vector[j] = readFloat(dis);
+ vector[j] = ReadHelper.readFloat(dis);
}
if (cache.containsWord(word))
@@ -236,64 +306,6 @@ public class WordVectorSerializer {
}
- /**
- * Read a float from a data input stream Credit to:
- * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
- *
- * @param is
- * @return
- * @throws IOException
- */
- public static float readFloat(InputStream is) throws IOException {
- byte[] bytes = new byte[4];
- is.read(bytes);
- return getFloat(bytes);
- }
-
- /**
- * Read a string from a data input stream Credit to:
- * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
- *
- * @param b
- * @return
- * @throws IOException
- */
- public static float getFloat(byte[] b) {
- int accum = 0;
- accum = accum | (b[0] & 0xff) << 0;
- accum = accum | (b[1] & 0xff) << 8;
- accum = accum | (b[2] & 0xff) << 16;
- accum = accum | (b[3] & 0xff) << 24;
- return Float.intBitsToFloat(accum);
- }
-
- /**
- * Read a string from a data input stream Credit to:
- * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
- *
- * @param dis
- * @return
- * @throws IOException
- */
- public static String readString(DataInputStream dis) throws IOException {
- byte[] bytes = new byte[MAX_SIZE];
- byte b = dis.readByte();
- int i = -1;
- StringBuilder sb = new StringBuilder();
- while (b != 32 && b != 10) {
- i++;
- bytes[i] = b;
- b = dis.readByte();
- if (i == 49) {
- sb.append(new String(bytes, "UTF-8"));
- i = -1;
- bytes = new byte[MAX_SIZE];
- }
- }
- sb.append(new String(bytes, 0, i + 1, "UTF-8"));
- return sb.toString();
- }
-
/**
* This method writes word vectors to the given path.
* Please note: this method doesn't load whole vocab/lookupTable into memory, so it's able to process large vocabularies served over network.
@@ -355,7 +367,7 @@ public class WordVectorSerializer {
val builder = new StringBuilder();
val l = element.getLabel();
- builder.append(encodeB64(l)).append(" ");
+ builder.append(ReadHelper.encodeB64(l)).append(" ");
val vec = lookupTable.vector(element.getLabel());
for (int i = 0; i < vec.length(); i++) {
builder.append(vec.getDouble(i));
@@ -518,7 +530,7 @@ public class WordVectorSerializer {
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileCodes))) {
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
VocabWord word = vectors.getVocab().elementAtIndex(i);
- StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ");
+ StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ");
for (int code : word.getCodes()) {
builder.append(code).append(" ");
}
@@ -536,7 +548,7 @@ public class WordVectorSerializer {
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileHuffman))) {
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
VocabWord word = vectors.getVocab().elementAtIndex(i);
- StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ");
+ StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ");
for (int point : word.getPoints()) {
builder.append(point).append(" ");
}
@@ -554,7 +566,7 @@ public class WordVectorSerializer {
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileFreqs))) {
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
VocabWord word = vectors.getVocab().elementAtIndex(i);
- StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ")
+ StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ")
.append(word.getElementFrequency()).append(" ")
.append(vectors.getVocab().docAppearedIn(word.getLabel()));
@@ -638,7 +650,7 @@ public class WordVectorSerializer {
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileCodes))) {
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
VocabWord word = vectors.getVocab().elementAtIndex(i);
- StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ");
+ StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ");
for (int code : word.getCodes()) {
builder.append(code).append(" ");
}
@@ -656,7 +668,7 @@ public class WordVectorSerializer {
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileHuffman))) {
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
VocabWord word = vectors.getVocab().elementAtIndex(i);
- StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ");
+ StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ");
for (int point : word.getPoints()) {
builder.append(point).append(" ");
}
@@ -677,7 +689,7 @@ public class WordVectorSerializer {
StringBuilder builder = new StringBuilder();
for (VocabWord word : vectors.getVocab().tokens()) {
if (word.isLabel())
- builder.append(encodeB64(word.getLabel())).append("\n");
+ builder.append(ReadHelper.encodeB64(word.getLabel())).append("\n");
}
IOUtils.write(builder.toString().trim(), zipfile, StandardCharsets.UTF_8);
@@ -688,7 +700,7 @@ public class WordVectorSerializer {
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileFreqs))) {
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
VocabWord word = vectors.getVocab().elementAtIndex(i);
- builder = new StringBuilder(encodeB64(word.getLabel())).append(" ").append(word.getElementFrequency())
+ builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ").append(word.getElementFrequency())
.append(" ").append(vectors.getVocab().docAppearedIn(word.getLabel()));
writer.println(builder.toString().trim());
@@ -744,7 +756,7 @@ public class WordVectorSerializer {
try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
- VocabWord word = vectors.getVocab().tokenFor(decodeB64(line.trim()));
+ VocabWord word = vectors.getVocab().tokenFor(ReadHelper.decodeB64(line.trim()));
if (word != null) {
word.markAsLabel(true);
}
@@ -836,7 +848,7 @@ public class WordVectorSerializer {
String line;
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
- VocabWord word = w2v.getVocab().tokenFor(decodeB64(split[0]));
+ VocabWord word = w2v.getVocab().tokenFor(ReadHelper.decodeB64(split[0]));
word.setElementFrequency((long) Double.parseDouble(split[1]));
word.setSequencesCount((long) Double.parseDouble(split[2]));
}
@@ -946,7 +958,7 @@ public class WordVectorSerializer {
reader = new BufferedReader(new FileReader(h_points));
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
- VocabWord word = vocab.wordFor(decodeB64(split[0]));
+ VocabWord word = vocab.wordFor(ReadHelper.decodeB64(split[0]));
List points = new ArrayList<>();
for (int i = 1; i < split.length; i++) {
points.add(Integer.parseInt(split[i]));
@@ -960,7 +972,7 @@ public class WordVectorSerializer {
reader = new BufferedReader(new FileReader(h_codes));
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
- VocabWord word = vocab.wordFor(decodeB64(split[0]));
+ VocabWord word = vocab.wordFor(ReadHelper.decodeB64(split[0]));
List codes = new ArrayList<>();
for (int i = 1; i < split.length; i++) {
codes.add(Byte.parseByte(split[i]));
@@ -1704,7 +1716,7 @@ public class WordVectorSerializer {
if (line.isEmpty())
line = iter.nextLine();
String[] split = line.split(" ");
- String word = decodeB64(split[0]); //split[0].replaceAll(whitespaceReplacement, " ");
+ String word = ReadHelper.decodeB64(split[0]); //split[0].replaceAll(whitespaceReplacement, " ");
VocabWord word1 = new VocabWord(1.0, word);
word1.setIndex(cache.numWords());
@@ -1994,7 +2006,13 @@ public class WordVectorSerializer {
private static final String SYN1_ENTRY = "syn1.bin";
private static final String SYN1_NEG_ENTRY = "syn1neg.bin";
-
+ /**
+ * This method saves specified SequenceVectors model to target OutputStream
+ *
+ * @param vectors SequenceVectors model
+ * @param stream Target output stream
+ * @param
+ */
public static void writeSequenceVectors(@NonNull SequenceVectors vectors,
@NonNull OutputStream stream)
throws IOException {
@@ -2040,7 +2058,13 @@ public class WordVectorSerializer {
}
}
-
+ /**
+ * This method loads SequenceVectors from specified file path
+ *
+ * @param path String
+ * @param readExtendedTables boolean
+ * @param
+ */
public static SequenceVectors readSequenceVectors(@NonNull String path,
boolean readExtendedTables)
throws IOException {
@@ -2050,6 +2074,14 @@ public class WordVectorSerializer {
return vectors;
}
+ /**
+ * This method loads SequenceVectors from specified file path
+ *
+ * @param file File
+ * @param readExtendedTables boolean
+ * @param
+ */
+
public static SequenceVectors readSequenceVectors(@NonNull File file,
boolean readExtendedTables)
throws IOException {
@@ -2058,6 +2090,13 @@ public class WordVectorSerializer {
return vectors;
}
+ /**
+ * This method loads SequenceVectors from specified input stream
+ *
+ * @param stream InputStream
+ * @param readExtendedTables boolean
+ * @param
+ */
public static SequenceVectors readSequenceVectors(@NonNull InputStream stream,
boolean readExtendedTables)
throws IOException {
@@ -2381,6 +2420,12 @@ public class WordVectorSerializer {
}
}
+ /**
+ * This method loads Word2Vec model from binary file
+ *
+ * @param file File
+ * @return Word2Vec
+ */
public static Word2Vec readAsBinary(@NonNull File file) {
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
@@ -2403,6 +2448,12 @@ public class WordVectorSerializer {
}
}
+ /**
+ * This method loads Word2Vec model from csv file
+ *
+ * @param file File
+ * @return Word2Vec
+ */
public static Word2Vec readAsCsv(@NonNull File file) {
Word2Vec vec;
@@ -2491,7 +2542,7 @@ public class WordVectorSerializer {
String line;
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
- VocabWord word = new VocabWord(Double.valueOf(split[1]), decodeB64(split[0]));
+ VocabWord word = new VocabWord(Double.valueOf(split[1]), ReadHelper.decodeB64(split[0]));
word.setIndex(cnt.getAndIncrement());
word.incrementSequencesCount(Long.valueOf(split[2]));
@@ -2669,7 +2720,7 @@ public class WordVectorSerializer {
*
* In return you get StaticWord2Vec model, which might be used as lookup table only in multi-gpu environment.
*
- * @param file File should point to previously saved w2v model
+ * @param inputStream InputStream should point to previously saved w2v model
* @return
*/
public static WordVectors loadStaticModel(InputStream inputStream) throws IOException {
@@ -2685,6 +2736,17 @@ public class WordVectorSerializer {
}
// TODO: this method needs better name :)
+ /**
+ * This method restores previously saved w2v model. File can be in one of the following formats:
+ * 1) Binary model, either compressed or not. Like well-known Google Model
+ * 2) Popular CSV word2vec text format
+ * 3) DL4j compressed format
+ *
+ * In return you get StaticWord2Vec model, which might be used as lookup table only in multi-gpu environment.
+ *
+ * @param file File
+ * @return
+ */
public static WordVectors loadStaticModel(@NonNull File file) {
if (!file.exists() || file.isDirectory())
throw new RuntimeException(
@@ -2843,8 +2905,8 @@ public class WordVectorSerializer {
throw new RuntimeException(e);
}
try {
- numWords = Integer.parseInt(readString(stream));
- vectorLength = Integer.parseInt(readString(stream));
+ numWords = Integer.parseInt(ReadHelper.readString(stream));
+ vectorLength = Integer.parseInt(ReadHelper.readString(stream));
} catch (IOException e) {
throw new RuntimeException(e);
}
@@ -2858,13 +2920,13 @@ public class WordVectorSerializer {
@Override
public Pair next() {
try {
- String word = readString(stream);
+ String word = ReadHelper.readString(stream);
VocabWord element = new VocabWord(1.0, word);
element.setIndex(idxCounter.getAndIncrement());
float[] vector = new float[vectorLength];
for (int i = 0; i < vectorLength; i++) {
- vector[i] = readFloat(stream);
+ vector[i] = ReadHelper.readFloat(stream);
}
return Pair.makePair(element, vector);
@@ -2913,7 +2975,7 @@ public class WordVectorSerializer {
String[] split = nextLine.split(" ");
- VocabWord word = new VocabWord(1.0, decodeB64(split[0]));
+ VocabWord word = new VocabWord(1.0, ReadHelper.decodeB64(split[0]));
word.setIndex(idxCounter.getAndIncrement());
float[] vector = new float[split.length - 1];
@@ -2937,26 +2999,12 @@ public class WordVectorSerializer {
}
}
- public static String encodeB64(String word) {
- try {
- return "B64:" + Base64.encodeBase64String(word.getBytes("UTF-8")).replaceAll("(\r|\n)", "");
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
- }
-
- public static String decodeB64(String word) {
- if (word.startsWith("B64:")) {
- String arp = word.replaceFirst("B64:", "");
- try {
- return new String(Base64.decodeBase64(arp), "UTF-8");
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
- } else
- return word;
- }
-
+ /**
+ * This method saves Word2Vec model to output stream
+ *
+ * @param word2Vec Word2Vec
+ * @param stream OutputStream
+ */
public static void writeWord2Vec(@NonNull Word2Vec word2Vec, @NonNull OutputStream stream)
throws IOException {
@@ -2968,6 +3016,13 @@ public class WordVectorSerializer {
writeSequenceVectors(vectors, stream);
}
+ /**
+ * This method restores Word2Vec model from file
+ *
+ * @param path String
+ * @param readExtendedTables booleab
+ * @return Word2Vec
+ */
public static Word2Vec readWord2Vec(@NonNull String path, boolean readExtendedTables)
throws IOException {
@@ -2976,6 +3031,12 @@ public class WordVectorSerializer {
return word2Vec;
}
+ /**
+ * This method saves table of weights to file
+ *
+ * @param weightLookupTable WeightLookupTable
+ * @param file File
+ */
public static void writeLookupTable(WeightLookupTable weightLookupTable,
@NonNull File file) throws IOException {
try (BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file),
@@ -3038,7 +3099,7 @@ public class WordVectorSerializer {
headerRead = true;
weightLookupTable = new InMemoryLookupTable.Builder().cache(vocabCache).vectorLength(layerSize).build();
} else {
- String label = decodeB64(tokens[0]);
+ String label = ReadHelper.decodeB64(tokens[0]);
int freq = Integer.parseInt(tokens[1]);
int rows = Integer.parseInt(tokens[2]);
int cols = Integer.parseInt(tokens[3]);
@@ -3071,6 +3132,13 @@ public class WordVectorSerializer {
return weightLookupTable;
}
+ /**
+ * This method loads Word2Vec model from file
+ *
+ * @param file File
+ * @param readExtendedTables boolean
+ * @return Word2Vec
+ */
public static Word2Vec readWord2Vec(@NonNull File file, boolean readExtendedTables)
throws IOException {
@@ -3078,6 +3146,13 @@ public class WordVectorSerializer {
return word2Vec;
}
+ /**
+ * This method loads Word2Vec model from input stream
+ *
+ * @param stream InputStream
+ * @param readExtendedTable boolean
+ * @return Word2Vec
+ */
public static Word2Vec readWord2Vec(@NonNull InputStream stream,
boolean readExtendedTable) throws IOException {
SequenceVectors vectors = readSequenceVectors(stream, readExtendedTable);
@@ -3087,7 +3162,13 @@ public class WordVectorSerializer {
word2Vec.setModelUtils(vectors.getModelUtils());
return word2Vec;
}
-
+
+ /**
+ * This method loads FastText model to file
+ *
+ * @param vectors FastText
+ * @param path File
+ */
public static void writeWordVectors(@NonNull FastText vectors, @NonNull File path) throws IOException {
ObjectOutputStream outputStream = null;
try {
@@ -3106,6 +3187,11 @@ public class WordVectorSerializer {
}
}
+ /**
+ * This method unloads FastText model from file
+ *
+ * @param path File
+ */
public static FastText readWordVectors(File path) {
FastText result = null;
try {
@@ -3124,6 +3210,13 @@ public class WordVectorSerializer {
return result;
}
+ /**
+ * This method prints memory usage to log
+ *
+ * @param numWords
+ * @param vectorLength
+ * @param numTables
+ */
public static void printOutProjectedMemoryUse(long numWords, int vectorLength, int numTables) {
double memSize = numWords * vectorLength * Nd4j.sizeOfDataType() * numTables;
@@ -3144,4 +3237,102 @@ public class WordVectorSerializer {
OneTimeLogger.info(log, "Projected memory use for model: [{} {}]", String.format("%.2f", value), sfx);
}
+
+ /**
+ * Helper static methods to read data from input stream.
+ */
+ public static class ReadHelper {
+ /**
+ * Read a float from a data input stream Credit to:
+ * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
+ *
+ * @param is
+ * @return
+ * @throws IOException
+ */
+ private static float readFloat(InputStream is) throws IOException {
+ byte[] bytes = new byte[4];
+ is.read(bytes);
+ return getFloat(bytes);
+ }
+
+ /**
+ * Read a string from a data input stream Credit to:
+ * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
+ *
+ * @param b
+ * @return
+ * @throws IOException
+ */
+ private static float getFloat(byte[] b) {
+ int accum = 0;
+ accum = accum | (b[0] & 0xff) << 0;
+ accum = accum | (b[1] & 0xff) << 8;
+ accum = accum | (b[2] & 0xff) << 16;
+ accum = accum | (b[3] & 0xff) << 24;
+ return Float.intBitsToFloat(accum);
+ }
+
+ /**
+ * Read a string from a data input stream Credit to:
+ * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
+ *
+ * @param dis
+ * @return
+ * @throws IOException
+ */
+ private static String readString(DataInputStream dis) throws IOException {
+ byte[] bytes = new byte[MAX_SIZE];
+ byte b = dis.readByte();
+ int i = -1;
+ StringBuilder sb = new StringBuilder();
+ while (b != 32 && b != 10) {
+ i++;
+ bytes[i] = b;
+ b = dis.readByte();
+ if (i == 49) {
+ sb.append(new String(bytes, "UTF-8"));
+ i = -1;
+ bytes = new byte[MAX_SIZE];
+ }
+ }
+ sb.append(new String(bytes, 0, i + 1, "UTF-8"));
+ return sb.toString();
+ }
+
+ private static final String B64 = "B64:";
+
+ /**
+ * Encode input string
+ *
+ * @param word String
+ * @return String
+ */
+ public static String encodeB64(String word) {
+ try {
+ return B64 + Base64.encodeBase64String(word.getBytes("UTF-8")).replaceAll("(\r|\n)", "");
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ /**
+ * Encode input string
+ *
+ * @param word String
+ * @return String
+ */
+
+ public static String decodeB64(String word) {
+ if (word.startsWith(B64)) {
+ String arp = word.replaceFirst(B64, "");
+ try {
+ return new String(Base64.decodeBase64(arp), "UTF-8");
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ } else
+ return word;
+ }
+ }
}
diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/pom.xml b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/pom.xml
index 0b6b05c26..7c9967ef8 100644
--- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/pom.xml
+++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/pom.xml
@@ -24,7 +24,7 @@
deeplearning4j-aws_2.11
DeepLearning4j-AWS
- 1.0.0_spark_2-SNAPSHOT
+ 1.0.0-SNAPSHOT
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml
index 3fded3e4a..8a19b3b68 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml
@@ -18,7 +18,7 @@
spark_2.11
org.deeplearning4j
- 1.0.0_spark_2-SNAPSHOT
+ 1.0.0-SNAPSHOT
4.0.0
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml
index a5aff014e..16c4ac298 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml
@@ -18,7 +18,7 @@
spark_2.11
org.deeplearning4j
- 1.0.0_spark_2-SNAPSHOT
+ 1.0.0-SNAPSHOT
4.0.0
dl4j-spark-nlp_2.11
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml
index d8f425286..9192bb877 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml
@@ -19,7 +19,7 @@
spark_2.11
org.deeplearning4j
- 1.0.0_spark_2-SNAPSHOT
+ 1.0.0-SNAPSHOT
4.0.0
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml
index 8b31872c5..d84947f1e 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml
@@ -18,7 +18,7 @@
spark_2.11
org.deeplearning4j
- 1.0.0_spark_2-SNAPSHOT
+ 1.0.0-SNAPSHOT
4.0.0
dl4j-spark_2.11
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestKryo.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestKryo.java
index e6688a215..8c5188b70 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestKryo.java
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestKryo.java
@@ -17,7 +17,6 @@
package org.deeplearning4j.spark;
import org.apache.spark.serializer.SerializerInstance;
-import org.deeplearning4j.eval.*;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@@ -28,6 +27,9 @@ import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.junit.Test;
+import org.nd4j.evaluation.IEvaluation;
+import org.nd4j.evaluation.classification.*;
+import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java
index f8fe1f4f0..ecf9b937b 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java
@@ -19,7 +19,6 @@ package org.deeplearning4j.spark.impl.multilayer;
import lombok.extern.slf4j.Slf4j;
import org.apache.spark.api.java.JavaRDD;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
-import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@@ -30,6 +29,7 @@ import org.deeplearning4j.spark.BaseSparkTest;
import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
import org.junit.Test;
+import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java
index ed56af9ee..abfd39060 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java
@@ -29,15 +29,13 @@ import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.MLUtils;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
-import org.deeplearning4j.eval.Evaluation;
-import org.deeplearning4j.eval.ROC;
-import org.deeplearning4j.eval.ROCMultiClass;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
+import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution;
@@ -56,6 +54,9 @@ import org.junit.Ignore;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
+import org.nd4j.evaluation.classification.Evaluation;
+import org.nd4j.evaluation.classification.ROC;
+import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
@@ -63,6 +64,7 @@ import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
+import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.learning.config.RmsProp;
@@ -70,7 +72,6 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
import scala.Tuple2;
import java.io.File;
-import java.nio.file.Files;
import java.nio.file.Path;
import java.util.*;
@@ -121,11 +122,6 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0));
MultiLayerNetwork network2 = master.fitLabeledPoint(data);
- Evaluation evaluation = new Evaluation();
- evaluation.eval(d.getLabels(), network2.output(d.getFeatures()));
- System.out.println(evaluation.stats());
-
-
}
@@ -137,20 +133,15 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
.getAbsolutePath())
.toJavaRDD().map(new TestFn());
- DataSet d = new IrisDataSetIterator(150, 150).next();
MultiLayerConfiguration conf =
new NeuralNetConfiguration.Builder().seed(123)
- .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
- .miniBatch(true).maxNumLineSearchIterations(10)
- .list().layer(0,
- new DenseLayer.Builder().nIn(4).nOut(100)
- .weightInit(WeightInit.XAVIER)
- .activation(Activation.RELU)
- .build())
- .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
- LossFunctions.LossFunction.MCXENT).nIn(100).nOut(3)
- .activation(Activation.SOFTMAX)
- .weightInit(WeightInit.XAVIER).build())
+ .updater(new Adam(1e-6))
+ .weightInit(WeightInit.XAVIER)
+ .list()
+ .layer(new BatchNormalization.Builder().nIn(4).nOut(4).build())
+ .layer(new DenseLayer.Builder().nIn(4).nOut(32).activation(Activation.RELU).build())
+ .layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(32).nOut(3)
+ .activation(Activation.SOFTMAX).build())
.build();
@@ -161,10 +152,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
SparkDl4jMultiLayer master = new SparkDl4jMultiLayer(sc, getBasicConf(),
new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0));
- MultiLayerNetwork network2 = master.fitLabeledPoint(data);
- Evaluation evaluation = new Evaluation();
- evaluation.eval(d.getLabels(), network2.output(d.getFeatures()));
- System.out.println(evaluation.stats());
+ master.fitLabeledPoint(data);
}
@Test(timeout = 120000L)
@@ -465,8 +453,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
tempDirF.deleteOnExit();
int dataSetObjSize = 1;
- int batchSizePerExecutor = 25;
- int numSplits = 10;
+ int batchSizePerExecutor = 16;
+ int numSplits = 5;
int averagingFrequency = 3;
int totalExamples = numExecutors() * batchSizePerExecutor * numSplits * averagingFrequency;
DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, totalExamples, false);
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml
index f753fefae..bd7226b0e 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml
@@ -22,7 +22,7 @@
4.0.0
spark_2.11
- 1.0.0_spark_2-SNAPSHOT
+ 1.0.0-SNAPSHOT
pom
Spark parent
@@ -36,7 +36,7 @@
UTF-8
UTF-8
- 1.0.0_spark_2-SNAPSHOT
+ 1.0.0-SNAPSHOT
2.1.0
diff --git a/libnd4j/include/array/ConstantHolder.h b/libnd4j/include/array/ConstantHolder.h
index 89be279e4..137d26f29 100644
--- a/libnd4j/include/array/ConstantHolder.h
+++ b/libnd4j/include/array/ConstantHolder.h
@@ -24,11 +24,13 @@
#include