From c4307384f35b9e21b31fc6334c86fa8e10ed70d2 Mon Sep 17 00:00:00 2001 From: Alexander Stoyakin Date: Wed, 16 Oct 2019 12:59:25 +0300 Subject: [PATCH] Fixed shape for muli --- .../models/word2vec/Word2VecTests.java | 32 +++++++++++++++++++ .../reader/impl/BasicModelUtils.java | 3 +- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java index 01b38a644..736998484 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java @@ -50,6 +50,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; +import java.io.IOException; import java.util.*; import static org.junit.Assert.*; @@ -816,6 +817,37 @@ public class Word2VecTests extends BaseDL4JTest { assertEquals(vec1.getWordVectorMatrix("money"), vec2.getWordVectorMatrix("money")); } + @Test + public void testWordsNearestSum() throws IOException { + log.info("Load & Vectorize Sentences...."); + SentenceIterator iter = new BasicLineIterator(inputFile); + TokenizerFactory t = new DefaultTokenizerFactory(); + t.setTokenPreProcessor(new CommonPreprocessor()); + + log.info("Building model...."); + Word2Vec vec = new Word2Vec.Builder() + .minWordFrequency(5) + .iterations(1) + .layerSize(100) + .seed(42) + .windowSize(5) + .iterate(iter) + .tokenizerFactory(t) + .build(); + + log.info("Fitting Word2Vec model...."); + vec.fit(); + log.info("Writing word vectors to text file...."); + log.info("Closest Words:"); + Collection lst = vec.wordsNearestSum("day", 10); + log.info("10 Words closest to 'day': {}", lst); + assertTrue(lst.contains("week")); + assertTrue(lst.contains("night")); + assertTrue(lst.contains("year")); + assertTrue(lst.contains("years")); + assertTrue(lst.contains("time")); + } + private static void printWords(String target, Collection list, Word2Vec vec) { System.out.println("Words close to [" + target + "]:"); for (String word : list) { diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java index 84fc17b7e..4912a3c47 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java @@ -351,7 +351,8 @@ public class BasicModelUtils implements ModelUtils if (lookupTable instanceof InMemoryLookupTable) { InMemoryLookupTable l = (InMemoryLookupTable) lookupTable; INDArray syn0 = l.getSyn0(); - INDArray weights = syn0.norm2(0).rdivi(1).muli(words); + INDArray temp = syn0.norm2(0).rdivi(1).reshape(words.shape()); + INDArray weights = temp.muli(words); INDArray distances = syn0.mulRowVector(weights).sum(1); INDArray[] sorted = Nd4j.sortWithIndices(distances, 0, false); INDArray sort = sorted[0];