diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java
index da2671817..a5408f100 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java
@@ -56,8 +56,10 @@ import static org.junit.Assert.*;
*/
public class RegressionTest050 extends BaseDL4JTest {
- @Rule
- public Timeout timeout = Timeout.seconds(300);
+ @Override
+ public long getTimeoutMilliseconds() {
+ return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
+ }
@Override
public DataType getDataType(){
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java
index 0f6005884..10a15919d 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java
@@ -64,6 +64,11 @@ public class RegressionTest060 extends BaseDL4JTest {
return DataType.FLOAT;
}
+ @Override
+ public long getTimeoutMilliseconds() {
+ return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
+ }
+
@Test
public void regressionTestMLP1() throws Exception {
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java
index 76027ca59..3d34b3fd2 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java
@@ -64,6 +64,12 @@ public class RegressionTest071 extends BaseDL4JTest {
public DataType getDataType(){
return DataType.FLOAT;
}
+
+ @Override
+ public long getTimeoutMilliseconds() {
+ return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
+ }
+
@Test
public void regressionTestMLP1() throws Exception {
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java
index b4aa72712..f963c2f59 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java
@@ -64,6 +64,11 @@ public class RegressionTest080 extends BaseDL4JTest {
return DataType.FLOAT;
}
+ @Override
+ public long getTimeoutMilliseconds() {
+ return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
+ }
+
@Test
public void regressionTestMLP1() throws Exception {
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java
index dda910b0e..3214a80ef 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java
@@ -56,7 +56,7 @@ public class RegressionTest100a extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
- return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
+ return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java
index a96d4cc30..3a1bafd95 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java
@@ -52,7 +52,7 @@ public class RegressionTest100b3 extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
- return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
+ return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java
index 2be2970a7..49dd8f34a 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java
@@ -71,7 +71,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
- return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
+ return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java
index 34bca6cc2..a87a522c7 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java
@@ -58,7 +58,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
- return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
+ return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/TestDistributionDeserializer.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/TestDistributionDeserializer.java
index 530d8f3c6..9bee792b3 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/TestDistributionDeserializer.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/TestDistributionDeserializer.java
@@ -30,6 +30,11 @@ import static org.junit.Assert.assertTrue;
*/
public class TestDistributionDeserializer extends BaseDL4JTest {
+ @Override
+ public long getTimeoutMilliseconds() {
+ return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
+ }
+
@Test
public void testDistributionDeserializer() throws Exception {
//Test current format:
diff --git a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestDeepWalk.java b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestDeepWalk.java
index 91939430f..d589c9a8a 100644
--- a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestDeepWalk.java
+++ b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestDeepWalk.java
@@ -46,6 +46,11 @@ public class TestDeepWalk extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
+ @Override
+ public long getTimeoutMilliseconds() {
+ return 120_000L; //Increase timeout due to intermittently slow CI machines
+ }
+
@Test(timeout = 60000L)
public void testBasic() throws IOException {
//Very basic test. Load graph, build tree, call fit, make sure it doesn't throw any exceptions
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml
index 8a7eacada..668c728ae 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml
@@ -84,12 +84,6 @@
${project.version}
test
-
- org.awaitility
- awaitility
- 4.0.2
- test
-
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/ElementsLearningAlgorithm.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/ElementsLearningAlgorithm.java
index b684a40a3..f856fc6ea 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/ElementsLearningAlgorithm.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/ElementsLearningAlgorithm.java
@@ -27,7 +27,7 @@ import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import java.util.concurrent.atomic.AtomicLong;
/**
- * Implementations of this interface should contain element-related learning algorithms. Like skip-gram, cbow or glove
+ * Implementations of this interface should contain element-related learning algorithms. Like skip-gram or cbow
*
* @author raver119@gmail.com
*/
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/GloVe.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/GloVe.java
deleted file mode 100644
index a655cfbde..000000000
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/GloVe.java
+++ /dev/null
@@ -1,427 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.deeplearning4j.models.embeddings.learning.impl.elements;
-
-import lombok.NonNull;
-import org.deeplearning4j.models.embeddings.WeightLookupTable;
-import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
-import org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm;
-import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
-import org.deeplearning4j.models.glove.AbstractCoOccurrences;
-import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
-import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
-import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
-import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.factory.Nd4j;
-import org.nd4j.linalg.learning.legacy.AdaGrad;
-import org.nd4j.common.primitives.Counter;
-import org.nd4j.common.primitives.Pair;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.Iterator;
-import java.util.List;
-import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.concurrent.atomic.AtomicLong;
-
-/**
- * GloVe LearningAlgorithm implementation for SequenceVectors
- *
- *
- * @author raver119@gmail.com
- */
-public class GloVe implements ElementsLearningAlgorithm {
-
- private VocabCache vocabCache;
- private AbstractCoOccurrences coOccurrences;
- private WeightLookupTable lookupTable;
- private VectorsConfiguration configuration;
-
- private AtomicBoolean isTerminate = new AtomicBoolean(false);
-
- private INDArray syn0;
-
- private double xMax;
- private boolean shuffle;
- private boolean symmetric;
- protected double alpha = 0.75d;
- protected double learningRate = 0.0d;
- protected int maxmemory = 0;
- protected int batchSize = 1000;
-
- private AdaGrad weightAdaGrad;
- private AdaGrad biasAdaGrad;
- private INDArray bias;
-
- private int workers = Runtime.getRuntime().availableProcessors();
-
- private int vectorLength;
-
- private static final Logger log = LoggerFactory.getLogger(GloVe.class);
-
- @Override
- public String getCodeName() {
- return "GloVe";
- }
-
- @Override
- public void finish() {
- log.info("GloVe finalizer...");
- }
-
- @Override
- public void configure(@NonNull VocabCache vocabCache, @NonNull WeightLookupTable lookupTable,
- @NonNull VectorsConfiguration configuration) {
- this.vocabCache = vocabCache;
- this.lookupTable = lookupTable;
- this.configuration = configuration;
-
- this.syn0 = ((InMemoryLookupTable) lookupTable).getSyn0();
-
-
- this.vectorLength = configuration.getLayersSize();
-
- if (this.learningRate == 0.0d)
- this.learningRate = configuration.getLearningRate();
-
-
-
- weightAdaGrad = new AdaGrad(new long[] {this.vocabCache.numWords() + 1, vectorLength}, learningRate);
- bias = Nd4j.create(syn0.rows());
-
- biasAdaGrad = new AdaGrad(bias.shape(), this.learningRate);
-
- // maxmemory = Runtime.getRuntime().maxMemory() - (vocabCache.numWords() * vectorLength * 2 * 8);
-
- log.info("GloVe params: {Max Memory: [" + maxmemory + "], Learning rate: [" + this.learningRate + "], Alpha: ["
- + alpha + "], xMax: [" + xMax + "], Symmetric: [" + symmetric + "], Shuffle: [" + shuffle
- + "]}");
- }
-
- /**
- * pretrain is used to build CoOccurrence matrix for GloVe algorithm
- * @param iterator
- */
- @Override
- public void pretrain(@NonNull SequenceIterator iterator) {
- // CoOccurence table should be built here
- coOccurrences = new AbstractCoOccurrences.Builder()
- // TODO: symmetric should be handled via VectorsConfiguration
- .symmetric(this.symmetric).windowSize(configuration.getWindow()).iterate(iterator)
- .workers(workers).vocabCache(vocabCache).maxMemory(maxmemory).build();
-
- coOccurrences.fit();
- }
-
- public double learnSequence(Sequence sequence, AtomicLong nextRandom, double learningRate,
- BatchSequences batchSequences) {
- throw new UnsupportedOperationException();
- }
- /**
- * Learns sequence using GloVe algorithm
- *
- * @param sequence
- * @param nextRandom
- * @param learningRate
- */
- @Override
- public synchronized double learnSequence(@NonNull Sequence sequence, @NonNull AtomicLong nextRandom,
- double learningRate) {
- /*
- GloVe learning algorithm is implemented like a hack over settled ElementsLearningAlgorithm mechanics. It's called in SequenceVectors context, but actually only for the first call.
- All subsequent calls will met early termination condition, and will be successfully ignored. But since elements vectors will be updated within first call,
- this will allow compatibility with everything beyond this implementaton
- */
- if (isTerminate.get())
- return 0;
-
- final AtomicLong pairsCount = new AtomicLong(0);
- final Counter errorCounter = new Counter<>();
-
- //List> coList = coOccurrences.coOccurrenceList();
-
- for (int i = 0; i < configuration.getEpochs(); i++) {
-
- // TODO: shuffle should be built in another way.
- //if (shuffle)
- //Collections.shuffle(coList);
-
- Iterator, Double>> pairs = coOccurrences.iterator();
-
- List threads = new ArrayList<>();
- for (int x = 0; x < workers; x++) {
- threads.add(x, new GloveCalculationsThread(i, x, pairs, pairsCount, errorCounter));
- threads.get(x).start();
- }
-
-
-
- for (int x = 0; x < workers; x++) {
- try {
- threads.get(x).join();
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
- }
-
- log.info("Processed [" + pairsCount.get() + "] pairs, Error was [" + errorCounter.getCount(i) + "]");
- }
-
- isTerminate.set(true);
- return 0;
- }
-
- /**
- * Since GloVe is learning representations using elements CoOccurences, all training is done in GloVe class internally, so only first thread will execute learning process,
- * and the rest of parent threads will just exit learning process
- *
- * @return True, if training should stop, False otherwise.
- */
- @Override
- public synchronized boolean isEarlyTerminationHit() {
- return isTerminate.get();
- }
-
- private double iterateSample(T element1, T element2, double score) {
- //prediction: input + bias
- if (element1.getIndex() < 0 || element1.getIndex() >= syn0.rows())
- throw new IllegalArgumentException("Illegal index for word " + element1.getLabel());
- if (element2.getIndex() < 0 || element2.getIndex() >= syn0.rows())
- throw new IllegalArgumentException("Illegal index for word " + element2.getLabel());
-
- INDArray w1Vector = syn0.slice(element1.getIndex());
- INDArray w2Vector = syn0.slice(element2.getIndex());
-
-
- //w1 * w2 + bias
- double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector);
- prediction += bias.getDouble(element1.getIndex()) + bias.getDouble(element2.getIndex()) - Math.log(score);
-
- double fDiff = (score > xMax) ? prediction : Math.pow(score / xMax, alpha) * prediction; // Math.pow(Math.min(1.0,(score / maxCount)),xMax);
-
- // double fDiff = score > xMax ? prediction : weight * (prediction - Math.log(score));
-
- if (Double.isNaN(fDiff))
- fDiff = Nd4j.EPS_THRESHOLD;
- //amount of change
- double gradient = fDiff * learningRate;
-
- //note the update step here: the gradient is
- //the gradient of the OPPOSITE word
- //for adagrad we will use the index of the word passed in
- //for the gradient calculation we will use the context vector
- update(element1, w1Vector, w2Vector, gradient);
- update(element2, w2Vector, w1Vector, gradient);
- return 0.5 * fDiff * prediction;
- }
-
- private void update(T element1, INDArray wordVector, INDArray contextVector, double gradient) {
- //gradient for word vectors
- INDArray grad1 = contextVector.mul(gradient);
- INDArray update = weightAdaGrad.getGradient(grad1, element1.getIndex(), syn0.shape());
-
- //update vector
- wordVector.subi(update);
-
- double w1Bias = bias.getDouble(element1.getIndex());
- double biasGradient = biasAdaGrad.getGradient(gradient, element1.getIndex(), bias.shape());
- double update2 = w1Bias - biasGradient;
- bias.putScalar(element1.getIndex(), update2);
- }
-
- private class GloveCalculationsThread extends Thread implements Runnable {
- private final int threadId;
- private final int epochId;
- // private final AbstractCoOccurrences coOccurrences;
- private final Iterator, Double>> coList;
-
- private final AtomicLong pairsCounter;
- private final Counter errorCounter;
-
- public GloveCalculationsThread(int epochId, int threadId, @NonNull Iterator, Double>> pairs,
- @NonNull AtomicLong pairsCounter, @NonNull Counter errorCounter) {
- this.epochId = epochId;
- this.threadId = threadId;
- // this.coOccurrences = coOccurrences;
-
- this.pairsCounter = pairsCounter;
- this.errorCounter = errorCounter;
-
- coList = pairs;
-
- this.setName("GloVe ELA t." + this.threadId);
- }
-
- @Override
- public void run() {
- // int startPosition = threadId * (coList.size() / workers);
- // int stopPosition = (threadId + 1) * (coList.size() / workers);
- // log.info("Total size: [" + coList.size() + "], thread start: [" + startPosition + "], thread stop: [" + stopPosition + "]");
- while (coList.hasNext()) {
-
- // now we fetch pairs into batch
- List, Double>> pairs = new ArrayList<>();
- int cnt = 0;
- while (coList.hasNext() && cnt < batchSize) {
- pairs.add(coList.next());
- cnt++;
- }
-
- if (shuffle)
- Collections.shuffle(pairs);
-
- Iterator, Double>> iterator = pairs.iterator();
-
- while (iterator.hasNext()) {
- // now for each pair do appropriate training
- Pair, Double> pairDoublePair = iterator.next();
-
- // That's probably ugly and probably should be improved somehow
-
- T element1 = pairDoublePair.getFirst().getFirst();
- T element2 = pairDoublePair.getFirst().getSecond();
- double weight = pairDoublePair.getSecond(); //coOccurrences.getCoOccurrenceCount(element1, element2);
- if (weight <= 0) {
- // log.warn("Skipping pair ("+ element1.getLabel()+", " + element2.getLabel()+")");
- pairsCounter.incrementAndGet();
- continue;
- }
-
- errorCounter.incrementCount(epochId, iterateSample(element1, element2, weight));
- if (pairsCounter.incrementAndGet() % 1000000 == 0) {
- log.info("Processed [" + pairsCounter.get() + "] word pairs so far...");
- }
- }
-
- }
- }
- }
-
- public static class Builder {
-
- protected double xMax = 100.0d;
- protected double alpha = 0.75d;
- protected double learningRate = 0.0d;
-
- protected boolean shuffle = false;
- protected boolean symmetric = false;
- protected int maxmemory = 0;
-
- protected int batchSize = 1000;
-
- public Builder() {
-
- }
-
- /**
- * This parameter specifies batch size for each thread. Also, if shuffle == TRUE, this batch will be shuffled before processing. Default value: 1000;
- *
- * @param batchSize
- * @return
- */
- public Builder batchSize(int batchSize) {
- this.batchSize = batchSize;
- return this;
- }
-
-
- /**
- * Initial learning rate; default 0.05
- *
- * @param eta
- * @return
- */
- public Builder learningRate(double eta) {
- this.learningRate = eta;
- return this;
- }
-
- /**
- * Parameter in exponent of weighting function; default 0.75
- *
- * @param alpha
- * @return
- */
- public Builder alpha(double alpha) {
- this.alpha = alpha;
- return this;
- }
-
- /**
- * This method allows you to specify maximum memory available for CoOccurrence map builder.
- *
- * Please note: this option can be considered a debugging method. In most cases setting proper -Xmx argument set to JVM is enough to limit this algorithm.
- * Please note: this option won't override -Xmx JVM value.
- *
- * @param gbytes memory limit, in gigabytes
- * @return
- */
- public Builder maxMemory(int gbytes) {
- this.maxmemory = gbytes;
- return this;
- }
-
- /**
- * Parameter specifying cutoff in weighting function; default 100.0
- *
- * @param xMax
- * @return
- */
- public Builder xMax(double xMax) {
- this.xMax = xMax;
- return this;
- }
-
- /**
- * Parameter specifying, if cooccurrences list should be shuffled between training epochs
- *
- * @param reallyShuffle
- * @return
- */
- public Builder shuffle(boolean reallyShuffle) {
- this.shuffle = reallyShuffle;
- return this;
- }
-
- /**
- * Parameters specifying, if cooccurrences list should be build into both directions from any current word.
- *
- * @param reallySymmetric
- * @return
- */
- public Builder symmetric(boolean reallySymmetric) {
- this.symmetric = reallySymmetric;
- return this;
- }
-
- public GloVe build() {
- GloVe ret = new GloVe<>();
- ret.symmetric = this.symmetric;
- ret.shuffle = this.shuffle;
- ret.xMax = this.xMax;
- ret.alpha = this.alpha;
- ret.learningRate = this.learningRate;
- ret.maxmemory = this.maxmemory;
- ret.batchSize = this.batchSize;
-
- return ret;
- }
- }
-}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java
index a77bdf0de..3598bb62a 100755
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java
@@ -64,7 +64,6 @@ import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl;
import org.deeplearning4j.models.fasttext.FastText;
-import org.deeplearning4j.models.glove.Glove;
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceElementFactory;
@@ -142,10 +141,6 @@ import lombok.val;
* {@link #readParagraphVectors(String)}
* {@link #readParagraphVectors(InputStream)}
*
- * Serializers for GloVe:
- * {@link #writeWordVectors(Glove, File)}
- * {@link #writeWordVectors(Glove, String)}
- * {@link #writeWordVectors(Glove, OutputStream)}
*
* Adapters
* {@link #fromTableAndVocab(WeightLookupTable, VocabCache)}
@@ -154,7 +149,6 @@ import lombok.val;
* {@link #loadTxt(InputStream)}
*
* Serializers to tSNE format
- * {@link #writeTsneFormat(Glove, INDArray, File)}
* {@link #writeTsneFormat(Word2Vec, INDArray, File)}
*
* FastText serializer:
@@ -974,7 +968,7 @@ public class WordVectorSerializer {
public static Word2Vec readWord2VecFromText(@NonNull File vectors, @NonNull File hs, @NonNull File h_codes,
@NonNull File h_points, @NonNull VectorsConfiguration configuration) throws IOException {
// first we load syn0
- Pair pair = loadTxt(new FileInputStream(vectors));
+ Pair pair = loadTxt(new FileInputStream(vectors)); //Note stream is closed in loadTxt
InMemoryLookupTable lookupTable = pair.getFirst();
lookupTable.setNegative(configuration.getNegative());
if (configuration.getNegative() > 0)
@@ -1161,48 +1155,6 @@ public class WordVectorSerializer {
}
}
- /**
- * This method saves GloVe model to the given output stream.
- *
- * @param vectors GloVe model to be saved
- * @param file path where model should be saved to
- */
- public static void writeWordVectors(@NonNull Glove vectors, @NonNull File file) {
- try (BufferedOutputStream fos = new BufferedOutputStream(new FileOutputStream(file))) {
- writeWordVectors(vectors, fos);
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
- }
-
- /**
- * This method saves GloVe model to the given output stream.
- *
- * @param vectors GloVe model to be saved
- * @param path path where model should be saved to
- */
- public static void writeWordVectors(@NonNull Glove vectors, @NonNull String path) {
- try (BufferedOutputStream fos = new BufferedOutputStream(new FileOutputStream(path))) {
- writeWordVectors(vectors, fos);
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
- }
-
- /**
- * This method saves GloVe model to the given OutputStream
- *
- * @param vectors GloVe model to be saved
- * @param stream OutputStream where model should be saved to
- */
- public static void writeWordVectors(@NonNull Glove vectors, @NonNull OutputStream stream) {
- try {
- writeWordVectors(vectors.lookupTable(), stream);
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
- }
-
/**
* This method saves paragraph vectors to the given output stream.
*
@@ -1655,7 +1607,7 @@ public class WordVectorSerializer {
*/
@Deprecated
public static WordVectors loadTxtVectors(File vectorsFile) throws IOException {
- FileInputStream fileInputStream = new FileInputStream(vectorsFile);
+ FileInputStream fileInputStream = new FileInputStream(vectorsFile); //Note stream is closed in loadTxt
Pair pair = loadTxt(fileInputStream);
return fromPair(pair);
}
@@ -1877,43 +1829,6 @@ public class WordVectorSerializer {
return fromPair(Pair.makePair((InMemoryLookupTable) lookupTable, (VocabCache) cache));
}
- /**
- * Write the tsne format
- *
- * @param vec the word vectors to use for labeling
- * @param tsne the tsne array to write
- * @param csv the file to use
- * @throws Exception
- */
- public static void writeTsneFormat(Glove vec, INDArray tsne, File csv) throws Exception {
- try (BufferedWriter write = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(csv), StandardCharsets.UTF_8))) {
- int words = 0;
- InMemoryLookupCache l = (InMemoryLookupCache) vec.vocab();
- for (String word : vec.vocab().words()) {
- if (word == null) {
- continue;
- }
- StringBuilder sb = new StringBuilder();
- INDArray wordVector = tsne.getRow(l.wordFor(word).getIndex());
- for (int j = 0; j < wordVector.length(); j++) {
- sb.append(wordVector.getDouble(j));
- if (j < wordVector.length() - 1) {
- sb.append(",");
- }
- }
- sb.append(",");
- sb.append(word.replaceAll(" ", WHITESPACE_REPLACEMENT));
- sb.append(" ");
-
- sb.append("\n");
- write.write(sb.toString());
-
- }
-
- log.info("Wrote " + words + " with size of " + vec.lookupTable().layerSize());
- }
- }
-
/**
* Write the tsne format
*
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/AbstractCoOccurrences.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/AbstractCoOccurrences.java
deleted file mode 100644
index 969dbaeb9..000000000
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/AbstractCoOccurrences.java
+++ /dev/null
@@ -1,652 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.deeplearning4j.models.glove;
-
-import lombok.NonNull;
-import org.deeplearning4j.models.glove.count.*;
-import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
-import org.deeplearning4j.models.sequencevectors.iterators.FilteredSequenceIterator;
-import org.deeplearning4j.models.sequencevectors.iterators.SynchronizedSequenceIterator;
-import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
-import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
-import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
-import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
-import org.deeplearning4j.text.sentenceiterator.PrefetchingSentenceIterator;
-import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
-import org.deeplearning4j.text.sentenceiterator.SynchronizedSentenceIterator;
-import org.deeplearning4j.common.util.DL4JFileUtils;
-import org.nd4j.common.util.ThreadUtils;
-import org.nd4j.linalg.factory.Nd4j;
-import org.nd4j.common.primitives.Pair;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.io.File;
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.Iterator;
-import java.util.List;
-import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.concurrent.atomic.AtomicLong;
-import java.util.concurrent.locks.ReentrantReadWriteLock;
-
-/**
- * This class implements building cooccurrence map for abstract training corpus.
- * However it's performance rather low, due to exsessive IO that happens in ShadowCopyThread
- *
- * PLEASE NOTE: Current implementation involves massive IO, and it should be rewritter as soon as ND4j gets sparse arrays support
- *
- * @author raver119@gmail.com
- */
-public class AbstractCoOccurrences implements Serializable {
-
- protected boolean symmetric;
- protected int windowSize;
- protected VocabCache vocabCache;
- protected SequenceIterator sequenceIterator;
-
- // please note, we need enough room for ShadowCopy thread, that's why -1 there
- protected int workers = Math.max(Runtime.getRuntime().availableProcessors() - 1, 1);
-
- // target file, where text with cooccurrencies should be saved
- protected File targetFile;
-
- protected ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
-
- protected long memory_threshold = 0;
-
- private ShadowCopyThread shadowThread;
-
- // private Counter sentenceOccurrences = Util.parallelCounter();
- //private CounterMap coOccurrenceCounts = Util.parallelCounterMap();
- private volatile CountMap coOccurrenceCounts = new CountMap<>();
- //private Counter occurrenceAllocations = Util.parallelCounter();
- //private List> coOccurrences;
- private AtomicLong processedSequences = new AtomicLong(0);
-
-
- protected static final Logger logger = LoggerFactory.getLogger(AbstractCoOccurrences.class);
-
- // this method should be private, to avoid non-configured instantiation
- private AbstractCoOccurrences() {}
-
- /**
- * This method returns cooccurrence distance weights for two SequenceElements
- *
- * @param element1
- * @param element2
- * @return distance weight
- */
- public double getCoOccurrenceCount(@NonNull T element1, @NonNull T element2) {
- return coOccurrenceCounts.getCount(element1, element2);
- }
-
- /**
- * This method returns estimated memory footrpint, based on current CountMap content
- * @return
- */
- protected long getMemoryFootprint() {
- // TODO: implement this method. It should return approx. memory used by appropriate CountMap
- try {
- lock.readLock().lock();
- return ((long) coOccurrenceCounts.size()) * 24L * 5L;
- } finally {
- lock.readLock().unlock();
- }
- }
-
- /**
- * This memory returns memory threshold, defined as 1/2 of memory allowed for allocation
- * @return
- */
- protected long getMemoryThreshold() {
- return memory_threshold / 2L;
- }
-
- public void fit() {
- shadowThread = new ShadowCopyThread();
- shadowThread.start();
-
- // we should reset iterator before counting cooccurrences
- sequenceIterator.reset();
-
- List threads = new ArrayList<>();
- for (int x = 0; x < workers; x++) {
- threads.add(x, new CoOccurrencesCalculatorThread(x, new FilteredSequenceIterator<>(
- new SynchronizedSequenceIterator<>(sequenceIterator), vocabCache), processedSequences));
- threads.get(x).start();
- }
-
- for (int x = 0; x < workers; x++) {
- try {
- threads.get(x).join();
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
- }
-
- shadowThread.finish();
- logger.info("CoOccurrences map was built.");
- }
-
- /**
- *
- * This method returns iterator with elements pairs and their weights. Resulting iterator is safe to use in multi-threaded environment.
- *
- * Developer's note: thread safety on received iterator is delegated to PrefetchedSentenceIterator
- * @return
- */
- public Iterator, Double>> iterator() {
- final SentenceIterator iterator;
-
- try {
- iterator = new SynchronizedSentenceIterator(
- new PrefetchingSentenceIterator.Builder(new BasicLineIterator(targetFile))
- .setFetchSize(500000).build());
-
- } catch (Exception e) {
- logger.error("Target file was not found on last stage!");
- throw new RuntimeException(e);
- }
- return new Iterator, Double>>() {
- /*
- iterator should be built on top of current text file with all pairs
- */
-
- @Override
- public boolean hasNext() {
- return iterator.hasNext();
- }
-
- @Override
- public Pair, Double> next() {
- String line = iterator.nextSentence();
- String[] strings = line.split(" ");
-
- T element1 = vocabCache.elementAtIndex(Integer.valueOf(strings[0]));
- T element2 = vocabCache.elementAtIndex(Integer.valueOf(strings[1]));
- Double weight = Double.valueOf(strings[2]);
-
- return new Pair<>(new Pair<>(element1, element2), weight);
- }
-
- @Override
- public void remove() {
- throw new UnsupportedOperationException("remove() method can't be supported on read-only interface");
- }
- };
- }
-
- public static class Builder {
-
- protected boolean symmetric;
- protected int windowSize = 5;
- protected VocabCache vocabCache;
- protected SequenceIterator sequenceIterator;
- protected int workers = Runtime.getRuntime().availableProcessors();
- protected File target;
- protected long maxmemory = Runtime.getRuntime().maxMemory();
-
- public Builder() {
-
- }
-
- public Builder symmetric(boolean reallySymmetric) {
- this.symmetric = reallySymmetric;
- return this;
- }
-
- public Builder windowSize(int windowSize) {
- this.windowSize = windowSize;
- return this;
- }
-
- public Builder vocabCache(@NonNull VocabCache cache) {
- this.vocabCache = cache;
- return this;
- }
-
- public Builder iterate(@NonNull SequenceIterator iterator) {
- this.sequenceIterator = new SynchronizedSequenceIterator<>(iterator);
- return this;
- }
-
- public Builder workers(int numWorkers) {
- this.workers = numWorkers;
- return this;
- }
-
- /**
- * This method allows you to specify maximum memory available for CoOccurrence map builder.
- *
- * Please note: this option can be considered a debugging method. In most cases setting proper -Xmx argument set to JVM is enough to limit this algorithm.
- * Please note: this option won't override -Xmx JVM value.
- *
- * @param gbytes memory available, in GigaBytes
- * @return
- */
- public Builder maxMemory(int gbytes) {
- if (gbytes > 0) {
- this.maxmemory = Math.max(gbytes - 1, 1) * 1024 * 1024 * 1024L;
- }
-
- return this;
- }
-
- /**
- * Path to save cooccurrence map after construction.
- * If targetFile is not specified, temporary file will be used.
- *
- * @param path
- * @return
- */
- public Builder targetFile(@NonNull String path) {
- this.targetFile(new File(path));
- return this;
- }
-
- /**
- * Path to save cooccurrence map after construction.
- * If targetFile is not specified, temporary file will be used.
- *
- * @param file
- * @return
- */
- public Builder targetFile(@NonNull File file) {
- this.target = file;
- return this;
- }
-
- public AbstractCoOccurrences build() {
- AbstractCoOccurrences ret = new AbstractCoOccurrences<>();
- ret.sequenceIterator = this.sequenceIterator;
- ret.windowSize = this.windowSize;
- ret.vocabCache = this.vocabCache;
- ret.symmetric = this.symmetric;
- ret.workers = this.workers;
-
- if (this.maxmemory < 1) {
- this.maxmemory = Runtime.getRuntime().maxMemory();
- }
- ret.memory_threshold = this.maxmemory;
-
-
- logger.info("Actual memory limit: [" + this.maxmemory + "]");
-
- // use temp file, if no target file was specified
- try {
- if (this.target == null) {
- this.target = DL4JFileUtils.createTempFile("cooccurrence", "map");
- }
- this.target.deleteOnExit();
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
-
- ret.targetFile = this.target;
-
- return ret;
- }
- }
-
- private class CoOccurrencesCalculatorThread extends Thread implements Runnable {
-
- private final SequenceIterator iterator;
- private final AtomicLong sequenceCounter;
- private int threadId;
-
- public CoOccurrencesCalculatorThread(int threadId, @NonNull SequenceIterator iterator,
- @NonNull AtomicLong sequenceCounter) {
- this.iterator = iterator;
- this.sequenceCounter = sequenceCounter;
- this.threadId = threadId;
-
- this.setName("CoOccurrencesCalculatorThread " + threadId);
- }
-
- @Override
- public void run() {
- while (iterator.hasMoreSequences()) {
- Sequence sequence = iterator.nextSequence();
-
- List tokens = new ArrayList<>(sequence.asLabels());
- // logger.info("Tokens size: " + tokens.size());
- for (int x = 0; x < sequence.getElements().size(); x++) {
- int wordIdx = vocabCache.indexOf(tokens.get(x));
- if (wordIdx < 0) {
- continue;
- }
- String w1 = vocabCache.wordFor(tokens.get(x)).getLabel();
-
- // THIS iS SAFE TO REMOVE, NO CHANCE WE'll HAVE UNK WORD INSIDE SEQUENCE
- /*if(w1.equals(Glove.UNK))
- continue;
- */
-
- int windowStop = Math.min(x + windowSize + 1, tokens.size());
- for (int j = x; j < windowStop; j++) {
- int otherWord = vocabCache.indexOf(tokens.get(j));
- if (otherWord < 0) {
- continue;
- }
- String w2 = vocabCache.wordFor(tokens.get(j)).getLabel();
-
- if (w2.equals(Glove.DEFAULT_UNK) || otherWord == wordIdx) {
- continue;
- }
-
-
- T tokenX = vocabCache.wordFor(tokens.get(x));
- T tokenJ = vocabCache.wordFor(tokens.get(j));
- double nWeight = 1.0 / (j - x + Nd4j.EPS_THRESHOLD);
-
- while (getMemoryFootprint() >= getMemoryThreshold()) {
- shadowThread.invoke();
- /*lock.readLock().lock();
- int size = coOccurrenceCounts.size();
- lock.readLock().unlock();
- */
- if (threadId == 0) {
- logger.debug("Memory consuimption > threshold: {footrpint: [" + getMemoryFootprint()
- + "], threshold: [" + getMemoryThreshold() + "] }");
- }
- ThreadUtils.uncheckedSleep(10000);
- }
- /*
- if (getMemoryFootprint() == 0) {
- logger.info("Zero size!");
- }
- */
-
- try {
- lock.readLock().lock();
- if (wordIdx < otherWord) {
- coOccurrenceCounts.incrementCount(tokenX, tokenJ, nWeight);
- if (symmetric) {
- coOccurrenceCounts.incrementCount(tokenJ, tokenX, nWeight);
- }
- } else {
- coOccurrenceCounts.incrementCount(tokenJ, tokenX, nWeight);
-
- if (symmetric) {
- coOccurrenceCounts.incrementCount(tokenX, tokenJ, nWeight);
- }
- }
- } finally {
- lock.readLock().unlock();
- }
- }
- }
-
- sequenceCounter.incrementAndGet();
- }
- }
- }
-
- /**
- * This class is designed to provide shadow copy functionality for CoOccurence maps, since with proper corpus size you can't fit such a map into memory
- *
- */
- private class ShadowCopyThread extends Thread implements Runnable {
-
- private AtomicBoolean isFinished = new AtomicBoolean(false);
- private AtomicBoolean isTerminate = new AtomicBoolean(false);
- private AtomicBoolean isInvoked = new AtomicBoolean(false);
- private AtomicBoolean shouldInvoke = new AtomicBoolean(false);
-
- // file that contains resuts from previous runs
- private File[] tempFiles;
- private RoundCount counter;
-
- public ShadowCopyThread() {
- try {
-
- counter = new RoundCount(1);
- tempFiles = new File[2];
-
- tempFiles[0] = DL4JFileUtils.createTempFile("aco", "tmp");
- tempFiles[1] = DL4JFileUtils.createTempFile("aco", "tmp");
-
- tempFiles[0].deleteOnExit();
- tempFiles[1].deleteOnExit();
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
-
- this.setName("ACO ShadowCopy thread");
- }
-
- @Override
- public void run() {
- /*
- Basic idea is pretty simple: run quetly, untill memory gets filled up to some high volume.
- As soon as this happens - execute shadow copy.
- */
- while (!isFinished.get() && !isTerminate.get()) {
- // check used memory. if memory use below threshold - sleep for a while. if above threshold - invoke copier
-
- if (getMemoryFootprint() > getMemoryThreshold() || (shouldInvoke.get() && !isInvoked.get())) {
- // we'll just invoke copier, nothing else
- shouldInvoke.compareAndSet(true, false);
- invokeBlocking();
- } else {
- /*
- commented and left here for future debugging purposes, if needed
-
- //lock.readLock().lock();
- //int size = coOccurrenceCounts.size();
- //lock.readLock().unlock();
- //logger.info("Current memory situation: {size: [" +size+ "], footprint: [" + getMemoryFootprint()+"], threshold: ["+ getMemoryThreshold() +"]}");
- */
- ThreadUtils.uncheckedSleep(1000);
- }
- }
- }
-
- /**
- * This methods advises shadow copy process to start
- */
- public void invoke() {
- shouldInvoke.compareAndSet(false, true);
- }
-
- /**
- * This methods dumps cooccurrence map into save file.
- * Please note: this method is synchronized and will block, until complete
- */
- public synchronized void invokeBlocking() {
- if (getMemoryFootprint() < getMemoryThreshold() && !isFinished.get()) {
- return;
- }
-
- int numberOfLinesSaved = 0;
-
- isInvoked.set(true);
-
- logger.debug("Memory purge started.");
-
- /*
- Basic plan:
- 1. Open temp file
- 2. Read that file line by line
- 3. For each read line do synchronization in memory > new file direction
- */
-
- counter.tick();
-
- CountMap localMap;
- try {
- // in any given moment there's going to be only 1 WriteLock, due to invokeBlocking() being synchronized call
- lock.writeLock().lock();
-
-
-
- // obtain local copy of CountMap
- localMap = coOccurrenceCounts;
-
- // set new CountMap, and release write lock
- coOccurrenceCounts = new CountMap<>();
- } finally {
- lock.writeLock().unlock();
- }
-
- try {
-
- File file = null;
- if (!isFinished.get()) {
- file = tempFiles[counter.previous()];
- } else
- file = targetFile;
-
-
- // PrintWriter pw = new PrintWriter(file);
-
- int linesRead = 0;
-
- logger.debug("Saving to: [" + counter.get() + "], Reading from: [" + counter.previous() + "]");
- CoOccurenceReader reader =
- new BinaryCoOccurrenceReader<>(tempFiles[counter.previous()], vocabCache, localMap);
- CoOccurrenceWriter writer = (isFinished.get()) ? new ASCIICoOccurrenceWriter(targetFile)
- : new BinaryCoOccurrenceWriter(tempFiles[counter.get()]);
- while (reader.hasMoreObjects()) {
- CoOccurrenceWeight line = reader.nextObject();
-
- if (line != null) {
- writer.writeObject(line);
- numberOfLinesSaved++;
- linesRead++;
- }
- }
- reader.finish();
-
- logger.debug("Lines read: [" + linesRead + "]");
-
- //now, we can dump the rest of elements, which were not presented in existing dump
- Iterator> iterator = localMap.getPairIterator();
- while (iterator.hasNext()) {
- Pair pair = iterator.next();
- double mWeight = localMap.getCount(pair);
- CoOccurrenceWeight object = new CoOccurrenceWeight<>();
- object.setElement1(pair.getFirst());
- object.setElement2(pair.getSecond());
- object.setWeight(mWeight);
-
- writer.writeObject(object);
-
- numberOfLinesSaved++;
- // if (numberOfLinesSaved % 100000 == 0) logger.info("Lines saved: [" + numberOfLinesSaved +"]");
- }
-
- writer.finish();
-
- /*
- SentenceIterator sIterator = new PrefetchingSentenceIterator.Builder(new BasicLineIterator(tempFiles[counter.get()]))
- .setFetchSize(500000)
- .build();
-
-
- int linesRead = 0;
- while (sIterator.hasNext()) {
- //List list = new ArrayList<>(reader.next());
- String sentence = sIterator.nextSentence();
- if (sentence == null || sentence.isEmpty()) continue;
- String[] strings = sentence.split(" ");
-
-
- // first two elements are integers - vocab indexes
- //T element1 = vocabCache.wordFor(vocabCache.wordAtIndex(list.get(0).toInt()));
- //T element2 = vocabCache.wordFor(vocabCache.wordAtIndex(list.get(1).toInt()));
- T element1 = vocabCache.elementAtIndex(Integer.valueOf(strings[0]));
- T element2 = vocabCache.elementAtIndex(Integer.valueOf(strings[1]));
-
- // getting third element, previously stored weight
- double sWeight = Double.valueOf(strings[2]); // list.get(2).toDouble();
-
- // now, since we have both elements ready, we can check this pair against inmemory map
- double mWeight = localMap.getCount(element1, element2);
- if (mWeight <= 0) {
- // this means we have no such pair in memory, so we'll do nothing to sWeight
- } else {
- // since we have new weight value in memory, we should update sWeight value before moving it off memory
- sWeight += mWeight;
-
- // original pair can be safely removed from CountMap
- localMap.removePair(element1,element2);
- }
-
- StringBuilder builder = new StringBuilder().append(element1.getIndex()).append(" ").append(element2.getIndex()).append(" ").append(sWeight);
- pw.println(builder.toString());
- numberOfLinesSaved++;
- linesRead++;
-
- // if (numberOfLinesSaved % 100000 == 0) logger.info("Lines saved: [" + numberOfLinesSaved +"]");
- // if (linesRead % 100000 == 0) logger.info("Lines read: [" + linesRead +"]");
- }
- */
- /*
- logger.info("Lines read: [" + linesRead + "]");
-
- //now, we can dump the rest of elements, which were not presented in existing dump
- Iterator> iterator = localMap.getPairIterator();
- while (iterator.hasNext()) {
- Pair pair = iterator.next();
- double mWeight = localMap.getCount(pair);
-
- StringBuilder builder = new StringBuilder().append(pair.getFirst().getIndex()).append(" ").append(pair.getFirst().getIndex()).append(" ").append(mWeight);
- pw.println(builder.toString());
- numberOfLinesSaved++;
-
- // if (numberOfLinesSaved % 100000 == 0) logger.info("Lines saved: [" + numberOfLinesSaved +"]");
- }
-
- pw.flush();
- pw.close();
-
- */
-
- // just a hint for gc
- localMap = null;
- //sIterator.finish();
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
-
- logger.info("Number of word pairs saved so far: [" + numberOfLinesSaved + "]");
- isInvoked.set(false);
- }
-
- /**
- * This method provides soft finish ability for shadow copy process.
- * Please note: it's blocking call, since it requires for final merge.
- */
- public void finish() {
- if (this.isFinished.get()) {
- return;
- }
-
- this.isFinished.set(true);
- invokeBlocking();
- }
-
- /**
- * This method provides hard fiinish ability for shadow copy process
- */
- public void terminate() {
- this.isTerminate.set(true);
- }
- }
-}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/Glove.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/Glove.java
deleted file mode 100644
index 44887e8fe..000000000
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/Glove.java
+++ /dev/null
@@ -1,444 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.deeplearning4j.models.glove;
-
-import lombok.NonNull;
-import org.deeplearning4j.models.embeddings.WeightLookupTable;
-import org.deeplearning4j.models.embeddings.learning.impl.elements.GloVe;
-import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
-import org.deeplearning4j.models.embeddings.reader.ModelUtils;
-import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
-import org.deeplearning4j.models.sequencevectors.SequenceVectors;
-import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
-import org.deeplearning4j.models.sequencevectors.interfaces.VectorsListener;
-import org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator;
-import org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer;
-import org.deeplearning4j.models.word2vec.VocabWord;
-import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
-import org.deeplearning4j.text.documentiterator.DocumentIterator;
-import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
-import org.deeplearning4j.text.sentenceiterator.StreamLineIterator;
-import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
-
-import java.util.Collection;
-import java.util.List;
-
-/**
- * GlobalVectors standalone implementation for DL4j.
- * Based on original Stanford GloVe http://www-nlp.stanford.edu/pubs/glove.pdf
- *
- * @author raver119@gmail.com
- */
-public class Glove extends SequenceVectors {
-
- protected Glove() {
-
- }
-
- public static class Builder extends SequenceVectors.Builder {
- private double xMax;
- private boolean shuffle;
- private boolean symmetric;
- protected double alpha = 0.75d;
- private int maxmemory = (int) (Runtime.getRuntime().totalMemory() / 1024 / 1024 / 1024);
-
- protected TokenizerFactory tokenFactory;
- protected SentenceIterator sentenceIterator;
- protected DocumentIterator documentIterator;
-
- public Builder() {
- super();
- }
-
-
- public Builder(@NonNull VectorsConfiguration configuration) {
- super(configuration);
- }
-
-
- /**
- * This method has no effect for GloVe
- *
- * @param vec existing WordVectors model
- * @return
- */
- @Override
- public Builder useExistingWordVectors(@NonNull WordVectors vec) {
- return this;
- }
-
- @Override
- public Builder iterate(@NonNull SequenceIterator iterator) {
- super.iterate(iterator);
- return this;
- }
-
- /**
- * Specifies minibatch size for training process.
- *
- * @param batchSize
- * @return
- */
- @Override
- public Builder batchSize(int batchSize) {
- super.batchSize(batchSize);
- return this;
- }
-
- /**
- * Ierations and epochs are the same in GloVe implementation.
- *
- * @param iterations
- * @return
- */
- @Override
- public Builder iterations(int iterations) {
- super.epochs(iterations);
- return this;
- }
-
- /**
- * Sets the number of iteration over training corpus during training
- *
- * @param numEpochs
- * @return
- */
- @Override
- public Builder epochs(int numEpochs) {
- super.epochs(numEpochs);
- return this;
- }
-
- @Override
- public Builder useAdaGrad(boolean reallyUse) {
- super.useAdaGrad(true);
- return this;
- }
-
- @Override
- public Builder layerSize(int layerSize) {
- super.layerSize(layerSize);
- return this;
- }
-
- @Override
- public Builder learningRate(double learningRate) {
- super.learningRate(learningRate);
- return this;
- }
-
- /**
- * Sets minimum word frequency during vocabulary mastering.
- * Please note: this option is ignored, if vocabulary is built outside of GloVe
- *
- * @param minWordFrequency
- * @return
- */
- @Override
- public Builder minWordFrequency(int minWordFrequency) {
- super.minWordFrequency(minWordFrequency);
- return this;
- }
-
- @Override
- public Builder minLearningRate(double minLearningRate) {
- super.minLearningRate(minLearningRate);
- return this;
- }
-
- @Override
- public Builder resetModel(boolean reallyReset) {
- super.resetModel(reallyReset);
- return this;
- }
-
- @Override
- public Builder vocabCache(@NonNull VocabCache vocabCache) {
- super.vocabCache(vocabCache);
- return this;
- }
-
- @Override
- public Builder lookupTable(@NonNull WeightLookupTable lookupTable) {
- super.lookupTable(lookupTable);
- return this;
- }
-
- @Override
- @Deprecated
- public Builder sampling(double sampling) {
- super.sampling(sampling);
- return this;
- }
-
- @Override
- @Deprecated
- public Builder negativeSample(double negative) {
- super.negativeSample(negative);
- return this;
- }
-
- @Override
- public Builder stopWords(@NonNull List stopList) {
- super.stopWords(stopList);
- return this;
- }
-
- @Override
- public Builder trainElementsRepresentation(boolean trainElements) {
- super.trainElementsRepresentation(true);
- return this;
- }
-
- @Override
- @Deprecated
- public Builder trainSequencesRepresentation(boolean trainSequences) {
- super.trainSequencesRepresentation(false);
- return this;
- }
-
- @Override
- public Builder stopWords(@NonNull Collection stopList) {
- super.stopWords(stopList);
- return this;
- }
-
- @Override
- public Builder windowSize(int windowSize) {
- super.windowSize(windowSize);
- return this;
- }
-
- @Override
- public Builder seed(long randomSeed) {
- super.seed(randomSeed);
- return this;
- }
-
- @Override
- public Builder workers(int numWorkers) {
- super.workers(numWorkers);
- return this;
- }
-
- /**
- * Sets TokenizerFactory to be used for training
- *
- * @param tokenizerFactory
- * @return
- */
- public Builder tokenizerFactory(@NonNull TokenizerFactory tokenizerFactory) {
- this.tokenFactory = tokenizerFactory;
- return this;
- }
-
- /**
- * Parameter specifying cutoff in weighting function; default 100.0
- *
- * @param xMax
- * @return
- */
- public Builder xMax(double xMax) {
- this.xMax = xMax;
- return this;
- }
-
- /**
- * Parameters specifying, if cooccurrences list should be build into both directions from any current word.
- *
- * @param reallySymmetric
- * @return
- */
- public Builder symmetric(boolean reallySymmetric) {
- this.symmetric = reallySymmetric;
- return this;
- }
-
- /**
- * Parameter specifying, if cooccurrences list should be shuffled between training epochs
- *
- * @param reallyShuffle
- * @return
- */
- public Builder shuffle(boolean reallyShuffle) {
- this.shuffle = reallyShuffle;
- return this;
- }
-
- /**
- * This method has no effect for ParagraphVectors
- *
- * @param windows
- * @return
- */
- @Override
- public Builder useVariableWindow(int... windows) {
- // no-op
- return this;
- }
-
- /**
- * Parameter in exponent of weighting function; default 0.75
- *
- * @param alpha
- * @return
- */
- public Builder alpha(double alpha) {
- this.alpha = alpha;
- return this;
- }
-
- public Builder iterate(@NonNull SentenceIterator iterator) {
- this.sentenceIterator = iterator;
- return this;
- }
-
- public Builder iterate(@NonNull DocumentIterator iterator) {
- this.sentenceIterator = new StreamLineIterator.Builder(iterator).setFetchSize(100).build();
- return this;
- }
-
- /**
- * Sets ModelUtils that gonna be used as provider for utility methods: similarity(), wordsNearest(), accuracy(), etc
- *
- * @param modelUtils model utils to be used
- * @return
- */
- @Override
- public Builder modelUtils(@NonNull ModelUtils modelUtils) {
- super.modelUtils(modelUtils);
- return this;
- }
-
- /**
- * This method sets VectorsListeners for this SequenceVectors model
- *
- * @param vectorsListeners
- * @return
- */
- @Override
- public Builder setVectorsListeners(@NonNull Collection> vectorsListeners) {
- super.setVectorsListeners(vectorsListeners);
- return this;
- }
-
- /**
- * This method allows you to specify maximum memory available for CoOccurrence map builder.
- *
- * Please note: this option can be considered a debugging method. In most cases setting proper -Xmx argument set to JVM is enough to limit this algorithm.
- * Please note: this option won't override -Xmx JVM value.
- *
- * @param gbytes memory limit, in gigabytes
- * @return
- */
- public Builder maxMemory(int gbytes) {
- this.maxmemory = gbytes;
- return this;
- }
-
- /**
- * This method allows you to specify SequenceElement that will be used as UNK element, if UNK is used
- *
- * @param element
- * @return
- */
- @Override
- public Builder unknownElement(VocabWord element) {
- super.unknownElement(element);
- return this;
- }
-
- /**
- * This method allows you to specify, if UNK word should be used internally
- *
- * @param reallyUse
- * @return
- */
- @Override
- public Builder useUnknown(boolean reallyUse) {
- super.useUnknown(reallyUse);
- if (this.unknownElement == null) {
- this.unknownElement(new VocabWord(1.0, Glove.DEFAULT_UNK));
- }
- return this;
- }
-
- public Glove build() {
- presetTables();
-
- Glove ret = new Glove();
-
-
- // hardcoded value for glove
-
- if (sentenceIterator != null) {
- SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(sentenceIterator)
- .tokenizerFactory(tokenFactory).build();
- this.iterator = new AbstractSequenceIterator.Builder<>(transformer).build();
- }
-
-
- ret.trainElementsVectors = true;
- ret.trainSequenceVectors = false;
- ret.useAdeGrad = true;
- this.useAdaGrad = true;
-
- ret.learningRate.set(this.learningRate);
- ret.resetModel = this.resetModel;
- ret.batchSize = this.batchSize;
- ret.iterator = this.iterator;
- ret.numEpochs = this.numEpochs;
- ret.numIterations = this.iterations;
- ret.layerSize = this.layerSize;
-
- ret.useUnknown = this.useUnknown;
- ret.unknownElement = this.unknownElement;
-
-
-
- this.configuration.setLearningRate(this.learningRate);
- this.configuration.setLayersSize(layerSize);
- this.configuration.setHugeModelExpected(hugeModelExpected);
- this.configuration.setWindow(window);
- this.configuration.setMinWordFrequency(minWordFrequency);
- this.configuration.setIterations(iterations);
- this.configuration.setSeed(seed);
- this.configuration.setBatchSize(batchSize);
- this.configuration.setLearningRateDecayWords(learningRateDecayWords);
- this.configuration.setMinLearningRate(minLearningRate);
- this.configuration.setSampling(this.sampling);
- this.configuration.setUseAdaGrad(useAdaGrad);
- this.configuration.setNegative(negative);
- this.configuration.setEpochs(this.numEpochs);
-
-
- ret.configuration = this.configuration;
-
- ret.lookupTable = this.lookupTable;
- ret.vocab = this.vocabCache;
- ret.modelUtils = this.modelUtils;
- ret.eventListeners = this.vectorsListeners;
-
-
- ret.elementsLearningAlgorithm = new GloVe.Builder().learningRate(this.learningRate)
- .shuffle(this.shuffle).symmetric(this.symmetric).xMax(this.xMax).alpha(this.alpha)
- .maxMemory(maxmemory).build();
-
- return ret;
- }
- }
-}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/GloveWeightLookupTable.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/GloveWeightLookupTable.java
deleted file mode 100644
index bc52a1422..000000000
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/GloveWeightLookupTable.java
+++ /dev/null
@@ -1,334 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.deeplearning4j.models.glove;
-
-
-import org.apache.commons.io.IOUtils;
-import org.apache.commons.io.LineIterator;
-import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
-import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
-import org.deeplearning4j.models.word2vec.Word2Vec;
-import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.api.rng.Random;
-import org.nd4j.linalg.factory.Nd4j;
-import org.nd4j.linalg.learning.legacy.AdaGrad;
-
-import java.io.IOException;
-import java.io.InputStream;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.concurrent.atomic.AtomicLong;
-
-/**
- * Glove lookup table
- *
- * @author Adam Gibson
- */
-// Deprecated due to logic being pulled off WeightLookupTable classes into LearningAlgorithm interfaces for better code.
-@Deprecated
-public class GloveWeightLookupTable extends InMemoryLookupTable {
-
-
- private AdaGrad weightAdaGrad;
- private AdaGrad biasAdaGrad;
- private INDArray bias;
- //also known as alpha
- private double xMax = 0.75;
- private double maxCount = 100;
-
-
- public GloveWeightLookupTable(VocabCache vocab, int vectorLength, boolean useAdaGrad, double lr, Random gen,
- double negative, double xMax, double maxCount) {
- super(vocab, vectorLength, useAdaGrad, lr, gen, negative);
- this.xMax = xMax;
- this.maxCount = maxCount;
- }
-
- @Override
- public void resetWeights(boolean reset) {
- if (rng == null)
- this.rng = Nd4j.getRandom();
-
- //note the +2 which is the unk vocab word and the bias
- if (syn0 == null || reset) {
- syn0 = Nd4j.rand(new int[] {vocab.numWords() + 1, vectorLength}, rng).subi(0.5).divi((double) vectorLength);
- INDArray randUnk = Nd4j.rand(1, vectorLength, rng).subi(0.5).divi(vectorLength);
- putVector(Word2Vec.DEFAULT_UNK, randUnk);
- }
- if (weightAdaGrad == null || reset) {
- weightAdaGrad = new AdaGrad(new long[]{vocab.numWords() + 1, vectorLength}, lr.get());
- }
-
-
- //right after unknown
- if (bias == null || reset)
- bias = Nd4j.create(syn0.rows());
-
- if (biasAdaGrad == null || reset) {
- biasAdaGrad = new AdaGrad(bias.shape(), lr.get());
- }
-
-
- }
-
- /**
- * Reset the weights of the cache
- */
- @Override
- public void resetWeights() {
- resetWeights(true);
-
- }
-
- /**
- * glove iteration
- * @param w1 the first word
- * @param w2 the second word
- * @param score the weight learned for the particular co occurrences
- */
- public double iterateSample(T w1, T w2, double score) {
- INDArray w1Vector = syn0.slice(w1.getIndex());
- INDArray w2Vector = syn0.slice(w2.getIndex());
- //prediction: input + bias
- if (w1.getIndex() < 0 || w1.getIndex() >= syn0.rows())
- throw new IllegalArgumentException("Illegal index for word " + w1.getLabel());
- if (w2.getIndex() < 0 || w2.getIndex() >= syn0.rows())
- throw new IllegalArgumentException("Illegal index for word " + w2.getLabel());
-
-
- //w1 * w2 + bias
- double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector);
- prediction += bias.getDouble(w1.getIndex()) + bias.getDouble(w2.getIndex());
-
- double weight = Math.pow(Math.min(1.0, (score / maxCount)), xMax);
-
- double fDiff = score > xMax ? prediction : weight * (prediction - Math.log(score));
- if (Double.isNaN(fDiff))
- fDiff = Nd4j.EPS_THRESHOLD;
- //amount of change
- double gradient = fDiff;
-
- //note the update step here: the gradient is
- //the gradient of the OPPOSITE word
- //for adagrad we will use the index of the word passed in
- //for the gradient calculation we will use the context vector
- update(w1, w1Vector, w2Vector, gradient);
- update(w2, w2Vector, w1Vector, gradient);
- return fDiff;
-
-
-
- }
-
-
- private void update(T w1, INDArray wordVector, INDArray contextVector, double gradient) {
- //gradient for word vectors
- INDArray grad1 = contextVector.mul(gradient);
- INDArray update = weightAdaGrad.getGradient(grad1, w1.getIndex(), syn0.shape());
-
- //update vector
- wordVector.subi(update);
-
- double w1Bias = bias.getDouble(w1.getIndex());
- double biasGradient = biasAdaGrad.getGradient(gradient, w1.getIndex(), bias.shape());
- double update2 = w1Bias - biasGradient;
- bias.putScalar(w1.getIndex(), update2);
- }
-
- public AdaGrad getWeightAdaGrad() {
- return weightAdaGrad;
- }
-
-
- public AdaGrad getBiasAdaGrad() {
- return biasAdaGrad;
- }
-
-
-
- /**
- * Load a glove model from an input stream.
- * The format is:
- * word num1 num2....
- * @param is the input stream to read from for the weights
- * @param vocab the vocab for the lookuptable
- * @return the loaded model
- * @throws java.io.IOException if one occurs
- */
- public static GloveWeightLookupTable load(InputStream is, VocabCache extends SequenceElement> vocab)
- throws IOException {
- LineIterator iter = IOUtils.lineIterator(is, "UTF-8");
- GloveWeightLookupTable glove = null;
- Map wordVectors = new HashMap<>();
- while (iter.hasNext()) {
- String line = iter.nextLine().trim();
- if (line.isEmpty())
- continue;
- String[] split = line.split(" ");
- String word = split[0];
- if (glove == null)
- glove = new GloveWeightLookupTable.Builder().cache(vocab).vectorLength(split.length - 1).build();
-
-
-
- if (word.isEmpty())
- continue;
- float[] read = read(split, glove.layerSize());
- if (read.length < 1)
- continue;
-
-
- wordVectors.put(word, read);
-
-
-
- }
-
- glove.setSyn0(weights(glove, wordVectors, vocab));
- glove.resetWeights(false);
-
-
- iter.close();
-
-
- return glove;
-
- }
-
- private static INDArray weights(GloveWeightLookupTable glove, Map data, VocabCache vocab) {
- INDArray ret = Nd4j.create(data.size(), glove.layerSize());
-
- for (Map.Entry entry : data.entrySet()) {
- String key = entry.getKey();
- INDArray row = Nd4j.create(Nd4j.createBuffer(entry.getValue()));
- if (row.length() != glove.layerSize())
- continue;
- if (vocab.indexOf(key) >= data.size())
- continue;
- if (vocab.indexOf(key) < 0)
- continue;
- ret.putRow(vocab.indexOf(key), row);
- }
- return ret;
- }
-
-
- private static float[] read(String[] split, int length) {
- float[] ret = new float[length];
- for (int i = 1; i < split.length; i++) {
- ret[i - 1] = Float.parseFloat(split[i]);
- }
- return ret;
- }
-
-
- @Override
- public void iterateSample(T w1, T w2, AtomicLong nextRandom, double alpha) {
- throw new UnsupportedOperationException();
-
- }
-
- public double getxMax() {
- return xMax;
- }
-
- public void setxMax(double xMax) {
- this.xMax = xMax;
- }
-
- public double getMaxCount() {
- return maxCount;
- }
-
- public void setMaxCount(double maxCount) {
- this.maxCount = maxCount;
- }
-
- public INDArray getBias() {
- return bias;
- }
-
- public void setBias(INDArray bias) {
- this.bias = bias;
- }
-
- public static class Builder extends InMemoryLookupTable.Builder {
- private double xMax = 0.75;
- private double maxCount = 100;
-
-
- public Builder maxCount(double maxCount) {
- this.maxCount = maxCount;
- return this;
- }
-
-
- public Builder xMax(double xMax) {
- this.xMax = xMax;
- return this;
- }
-
- @Override
- public Builder cache(VocabCache vocab) {
- super.cache(vocab);
- return this;
- }
-
- @Override
- public Builder negative(double negative) {
- super.negative(negative);
- return this;
- }
-
- @Override
- public Builder vectorLength(int vectorLength) {
- super.vectorLength(vectorLength);
- return this;
- }
-
- @Override
- public Builder useAdaGrad(boolean useAdaGrad) {
- super.useAdaGrad(useAdaGrad);
- return this;
- }
-
- @Override
- public Builder lr(double lr) {
- super.lr(lr);
- return this;
- }
-
- @Override
- public Builder gen(Random gen) {
- super.gen(gen);
- return this;
- }
-
- @Override
- public Builder seed(long seed) {
- super.seed(seed);
- return this;
- }
-
- public GloveWeightLookupTable build() {
- return new GloveWeightLookupTable<>(vocabCache, vectorLength, useAdaGrad, lr, gen, negative, xMax,
- maxCount);
- }
- }
-
-}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/ASCIICoOccurrenceReader.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/ASCIICoOccurrenceReader.java
deleted file mode 100644
index 8dd2fe85f..000000000
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/ASCIICoOccurrenceReader.java
+++ /dev/null
@@ -1,91 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.deeplearning4j.models.glove.count;
-
-import lombok.NonNull;
-import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
-import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
-import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
-import org.deeplearning4j.text.sentenceiterator.PrefetchingSentenceIterator;
-import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
-
-import java.io.File;
-import java.io.PrintWriter;
-
-/**
- * @author raver119@gmail.com
- */
-public class ASCIICoOccurrenceReader implements CoOccurenceReader {
- private File file;
- private PrintWriter writer;
- private SentenceIterator iterator;
- private VocabCache vocabCache;
-
- public ASCIICoOccurrenceReader(@NonNull File file, @NonNull VocabCache vocabCache) {
- this.vocabCache = vocabCache;
- this.file = file;
- try {
- iterator = new PrefetchingSentenceIterator.Builder(new BasicLineIterator(file)).build();
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
- }
-
-
- @Override
- public boolean hasMoreObjects() {
- return iterator.hasNext();
- }
-
-
- /**
- * Returns next CoOccurrenceWeight object
- *
- * PLEASE NOTE: This method can return null value.
- * @return
- */
- @Override
- public CoOccurrenceWeight nextObject() {
- String line = iterator.nextSentence();
- if (line == null || line.isEmpty()) {
- return null;
- }
- String[] strings = line.split(" ");
-
- CoOccurrenceWeight object = new CoOccurrenceWeight<>();
- object.setElement1(vocabCache.elementAtIndex(Integer.valueOf(strings[0])));
- object.setElement2(vocabCache.elementAtIndex(Integer.valueOf(strings[1])));
- object.setWeight(Double.parseDouble(strings[2]));
-
- return object;
- }
-
-
-
- @Override
- public void finish() {
- try {
- if (writer != null) {
- writer.flush();
- writer.close();
- }
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
-
- }
-}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/ASCIICoOccurrenceWriter.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/ASCIICoOccurrenceWriter.java
deleted file mode 100644
index 4ef4aada5..000000000
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/ASCIICoOccurrenceWriter.java
+++ /dev/null
@@ -1,69 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.deeplearning4j.models.glove.count;
-
-import lombok.NonNull;
-import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
-
-import java.io.BufferedOutputStream;
-import java.io.File;
-import java.io.FileOutputStream;
-import java.io.PrintWriter;
-
-/**
- * @author raver119@gmail.com
- */
-public class ASCIICoOccurrenceWriter implements CoOccurrenceWriter {
-
- private File file;
- private PrintWriter writer;
-
- public ASCIICoOccurrenceWriter(@NonNull File file) {
- this.file = file;
- try {
- this.writer = new PrintWriter(new BufferedOutputStream(new FileOutputStream(file), 10 * 1024 * 1024));
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
- }
-
- @Override
- public void writeObject(CoOccurrenceWeight object) {
- StringBuilder builder = new StringBuilder(String.valueOf(object.getElement1().getIndex())).append(" ")
- .append(String.valueOf(object.getElement2().getIndex())).append(" ")
- .append(String.valueOf(object.getWeight()));
- writer.println(builder.toString());
- }
-
- @Override
- public void queueObject(CoOccurrenceWeight object) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public void finish() {
- try {
- writer.flush();
- } catch (Exception e) {
- }
-
- try {
- writer.close();
- } catch (Exception e) {
- }
- }
-}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/BinaryCoOccurrenceReader.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/BinaryCoOccurrenceReader.java
deleted file mode 100644
index 549cb9aca..000000000
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/BinaryCoOccurrenceReader.java
+++ /dev/null
@@ -1,245 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.deeplearning4j.models.glove.count;
-
-import lombok.NonNull;
-import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
-import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.io.BufferedInputStream;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.InputStream;
-import java.nio.ByteBuffer;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.concurrent.ArrayBlockingQueue;
-import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.concurrent.atomic.AtomicInteger;
-
-/**
- * Binary implementation of CoOccurenceReader interface, used to provide off-memory storage for cooccurrence maps generated for GloVe
- *
- * @author raver119@gmail.com
- */
-public class BinaryCoOccurrenceReader implements CoOccurenceReader {
- private VocabCache vocabCache;
- private InputStream inputStream;
- private File file;
- private ArrayBlockingQueue> buffer;
- int workers = Math.max(Runtime.getRuntime().availableProcessors() - 1, 1);
- private StreamReaderThread readerThread;
- private CountMap countMap;
-
-
- protected static final Logger logger = LoggerFactory.getLogger(BinaryCoOccurrenceReader.class);
-
- public BinaryCoOccurrenceReader(@NonNull File file, @NonNull VocabCache vocabCache, CountMap map) {
- this.vocabCache = vocabCache;
- this.file = file;
- this.countMap = map;
- buffer = new ArrayBlockingQueue<>(200000);
-
- try {
- inputStream = new BufferedInputStream(new FileInputStream(this.file), 100 * 1024 * 1024);
- //inputStream = new BufferedInputStream(new FileInputStream(file), 1024 * 1024);
- readerThread = new StreamReaderThread(inputStream);
- readerThread.start();
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
- }
-
- @Override
- public boolean hasMoreObjects() {
-
- if (!buffer.isEmpty())
- return true;
-
- try {
- return readerThread.hasMoreObjects() || !buffer.isEmpty();
- } catch (Exception e) {
- throw new RuntimeException(e);
- //return false;
- }
- }
-
- @Override
- public CoOccurrenceWeight nextObject() {
- if (!buffer.isEmpty()) {
- return buffer.poll();
- } else {
- // buffer can be starved, or we're already at the end of file.
- if (readerThread.hasMoreObjects()) {
- try {
- return buffer.poll(3, TimeUnit.SECONDS);
- } catch (Exception e) {
- return null;
- }
- }
- }
-
-
- return null;
- /*
- try {
- CoOccurrenceWeight ret = new CoOccurrenceWeight<>();
- ret.setElement1(vocabCache.elementAtIndex(inputStream.readInt()));
- ret.setElement2(vocabCache.elementAtIndex(inputStream.readInt()));
- ret.setWeight(inputStream.readDouble());
-
- return ret;
- } catch (Exception e) {
- return null;
- }
- */
- }
-
- @Override
- public void finish() {
- try {
- if (inputStream != null)
- inputStream.close();
- } catch (Exception e) {
- //
- }
- }
-
- private class StreamReaderThread extends Thread implements Runnable {
- private InputStream stream;
- private AtomicBoolean isReading = new AtomicBoolean(false);
-
- public StreamReaderThread(@NonNull InputStream stream) {
- this.stream = stream;
- isReading.set(false);
- }
-
- @Override
- public void run() {
- try {
- // we read pre-defined number of objects as byte array
- byte[] array = new byte[16 * 500000];
- while (true) {
- int count = stream.read(array);
-
- isReading.set(true);
- if (count == 0)
- break;
-
- // now we deserialize them in separate threads to gain some speedup, if possible
- List threads = new ArrayList<>();
- AtomicInteger internalPosition = new AtomicInteger(0);
-
- for (int t = 0; t < workers; t++) {
- threads.add(t, new AsyncDeserializationThread(t, array, buffer, internalPosition, count));
- threads.get(t).start();
- }
-
- // we'll block this cycle untill all objects are fit into queue
- for (int t = 0; t < workers; t++) {
- try {
- threads.get(t).join();
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
- }
-
- isReading.set(false);
- if (count < array.length)
- break;
- }
-
- } catch (Exception e) {
- isReading.set(false);
- throw new RuntimeException(e);
- }
- }
-
- public boolean hasMoreObjects() {
- try {
- return stream.available() > 0 || isReading.get();
- } catch (Exception e) {
- return false;
- } finally {
- }
- }
- }
-
- /**
- * Utility class that accepts byte array as input, and deserialize it into set of CoOccurrenceWeight objects
- */
- private class AsyncDeserializationThread extends Thread implements Runnable {
- private int threadId;
- private byte[] arrayReference;
- private ArrayBlockingQueue> targetBuffer;
- private AtomicInteger pointer;
- private int limit;
-
- public AsyncDeserializationThread(int threadId, @NonNull byte[] array,
- @NonNull ArrayBlockingQueue> targetBuffer,
- @NonNull AtomicInteger sharedPointer, int limit) {
- this.threadId = threadId;
- this.arrayReference = array;
- this.targetBuffer = targetBuffer;
- this.pointer = sharedPointer;
- this.limit = limit;
-
-
- setName("AsynDeserialization thread " + this.threadId);
- }
-
- @Override
- public void run() {
- ByteBuffer bB = ByteBuffer.wrap(arrayReference);
- int position = 0;
- while ((position = pointer.getAndAdd(16)) < this.limit) {
- if (position >= limit) {
- continue;
- }
-
-
- int e1idx = bB.getInt(position);
- int e2idx = bB.getInt(position + 4);
- double eW = bB.getDouble(position + 8);
-
-
- CoOccurrenceWeight object = new CoOccurrenceWeight<>();
- object.setElement1(vocabCache.elementAtIndex(e1idx));
- object.setElement2(vocabCache.elementAtIndex(e2idx));
-
- if (countMap != null) {
- double mW = countMap.getCount(object.getElement1(), object.getElement2());
-
- if (mW > 0) {
- eW += mW;
- countMap.removePair(object.getElement1(), object.getElement2());
- }
- }
- object.setWeight(eW);
-
- try {
- targetBuffer.put(object);
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
- }
- }
- }
-}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/BinaryCoOccurrenceWriter.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/BinaryCoOccurrenceWriter.java
deleted file mode 100644
index 81230802e..000000000
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/BinaryCoOccurrenceWriter.java
+++ /dev/null
@@ -1,78 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.deeplearning4j.models.glove.count;
-
-import lombok.NonNull;
-import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.io.BufferedOutputStream;
-import java.io.DataOutputStream;
-import java.io.File;
-import java.io.FileOutputStream;
-
-/**
- * @author raver119@gmail.com
- */
-public class BinaryCoOccurrenceWriter implements CoOccurrenceWriter {
- private File file;
- private DataOutputStream outputStream;
-
- private static final Logger log = LoggerFactory.getLogger(BinaryCoOccurrenceWriter.class);
-
- public BinaryCoOccurrenceWriter(@NonNull File file) {
- this.file = file;
-
- try {
- outputStream = new DataOutputStream(
- new BufferedOutputStream(new FileOutputStream(file), 100 * 1024 * 1024));
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
- }
-
- @Override
- public void writeObject(@NonNull CoOccurrenceWeight object) {
- try {
- // log.info("Saving objects: { [" +object.getElement1().getIndex() +"], [" + object.getElement2().getIndex() + "] }");
- outputStream.writeInt(object.getElement1().getIndex());
- outputStream.writeInt(object.getElement2().getIndex());
- outputStream.writeDouble(object.getWeight());
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
- }
-
- @Override
- public void queueObject(CoOccurrenceWeight object) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public void finish() {
- try {
- outputStream.flush();
- } catch (Exception e) {
- }
-
- try {
- outputStream.close();
- } catch (Exception e) {
- }
- }
-}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CoOccurenceReader.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CoOccurenceReader.java
deleted file mode 100644
index 0eaecc00b..000000000
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CoOccurenceReader.java
+++ /dev/null
@@ -1,34 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.deeplearning4j.models.glove.count;
-
-import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
-
-/**
- * Created by raver on 24.12.2015.
- */
-public interface CoOccurenceReader {
- /*
- Storage->Memory merging part
- */
- boolean hasMoreObjects();
-
-
- CoOccurrenceWeight nextObject();
-
- void finish();
-}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CoOccurrenceWeight.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CoOccurrenceWeight.java
deleted file mode 100644
index 251163e0a..000000000
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CoOccurrenceWeight.java
+++ /dev/null
@@ -1,54 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.deeplearning4j.models.glove.count;
-
-import lombok.Data;
-import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
-
-/**
- * Simple POJO holding pairs of elements and their respective weights, used in GloVe -> CoOccurrence
- *
- * @author raver119@gmail.com
- */
-@Data
-public class CoOccurrenceWeight {
- private T element1;
- private T element2;
- private double weight;
-
- @Override
- public boolean equals(Object o) {
- if (this == o)
- return true;
- if (o == null || getClass() != o.getClass())
- return false;
-
- CoOccurrenceWeight> that = (CoOccurrenceWeight>) o;
-
- if (element1 != null ? !element1.equals(that.element1) : that.element1 != null)
- return false;
- return element2 != null ? element2.equals(that.element2) : that.element2 == null;
-
- }
-
- @Override
- public int hashCode() {
- int result = element1 != null ? element1.hashCode() : 0;
- result = 31 * result + (element2 != null ? element2.hashCode() : 0);
- return result;
- }
-}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CoOccurrenceWriter.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CoOccurrenceWriter.java
deleted file mode 100644
index b7f7a21ea..000000000
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CoOccurrenceWriter.java
+++ /dev/null
@@ -1,43 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.deeplearning4j.models.glove.count;
-
-import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
-
-/**
- * Created by fartovii on 25.12.15.
- */
-public interface CoOccurrenceWriter {
-
- /**
- * This method implementations should write out objects immediately
- * @param object
- */
- void writeObject(CoOccurrenceWeight object);
-
- /**
- * This method implementations should queue objects for writing out.
- *
- * @param object
- */
- void queueObject(CoOccurrenceWeight object);
-
- /**
- * Implementations of this method should close everything they use, before eradication
- */
- void finish();
-}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CountMap.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CountMap.java
deleted file mode 100644
index 274551ebf..000000000
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CountMap.java
+++ /dev/null
@@ -1,99 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.deeplearning4j.models.glove.count;
-
-import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
-import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
-import org.nd4j.common.primitives.Pair;
-
-import java.util.Iterator;
-import java.util.Map;
-import java.util.concurrent.ConcurrentHashMap;
-
-/**
- * Drop-in replacement for CounterMap
- *
- * WORK IN PROGRESS, PLEASE DO NOT USE
- *
- * @author raver119@gmail.com
- */
-public class CountMap {
- private volatile Map, AtomicDouble> backingMap = new ConcurrentHashMap<>();
-
- public CountMap() {
- // placeholder
- }
-
- public void incrementCount(T element1, T element2, double weight) {
- Pair tempEntry = new Pair<>(element1, element2);
- if (backingMap.containsKey(tempEntry)) {
- backingMap.get(tempEntry).addAndGet(weight);
- } else {
- backingMap.put(tempEntry, new AtomicDouble(weight));
- }
- }
-
- public void removePair(T element1, T element2) {
- Pair tempEntry = new Pair<>(element1, element2);
- backingMap.remove(tempEntry);
- }
-
- public void removePair(Pair pair) {
- backingMap.remove(pair);
- }
-
- public double getCount(T element1, T element2) {
- Pair tempEntry = new Pair<>(element1, element2);
- if (backingMap.containsKey(tempEntry)) {
- return backingMap.get(tempEntry).get();
- } else
- return 0;
- }
-
- public double getCount(Pair pair) {
- if (backingMap.containsKey(pair)) {
- return backingMap.get(pair).get();
- } else
- return 0;
- }
-
- public Iterator> getPairIterator() {
- return new Iterator>() {
- private Iterator> iterator = backingMap.keySet().iterator();
-
- @Override
- public boolean hasNext() {
- return iterator.hasNext();
- }
-
- @Override
- public Pair next() {
- //MapEntry entry = iterator.next();
- return iterator.next(); //new Pair<>(entry.getElement1(), entry.getElement2());
- }
-
- @Override
- public void remove() {
- throw new UnsupportedOperationException("remove() isn't supported here");
- }
- };
- }
-
- public int size() {
- return backingMap.size();
- }
-}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/RoundCount.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/RoundCount.java
deleted file mode 100644
index d9f729dba..000000000
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/RoundCount.java
+++ /dev/null
@@ -1,86 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.deeplearning4j.models.glove.count;
-
-import java.util.concurrent.locks.ReentrantReadWriteLock;
-
-/**
- * Simple circular counter, that circulates within 0...Limit, both inclusive
- *
- * @author raver119@gmail.com
- */
-public class RoundCount {
-
- private int limit = 0;
- private int lower = 0;
- private int value = 0;
-
- private ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
-
- /**
- * Creates new RoundCount instance.
- *
- * @param limit Maximum top value for this counter. Inclusive.
- */
- public RoundCount(int limit) {
- this.limit = limit;
- }
-
- /**
- * Creates new RoundCount instance.
- *
- * @param lower - Minimum value for this counter. Inclusive
- * @param top - Maximum value for this counter. Inclusive.
- */
- public RoundCount(int lower, int top) {
- this.limit = top;
- this.lower = lower;
- }
-
- public int previous() {
- try {
- lock.readLock().lock();
- if (value == lower)
- return limit;
- else
- return value - 1;
- } finally {
- lock.readLock().unlock();
- }
- }
-
- public int get() {
- try {
- lock.readLock().lock();
- return value;
- } finally {
- lock.readLock().unlock();
- }
- }
-
- public void tick() {
- try {
- lock.writeLock().lock();
- if (value == limit)
- value = lower;
- else
- value++;
- } finally {
- lock.writeLock().unlock();
- }
- }
-}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java
index c007d4b96..e27debd50 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java
@@ -763,7 +763,7 @@ public class ParagraphVectors extends Word2Vec {
/**
- * This method allows you to use pre-built WordVectors model (Word2Vec or GloVe) for ParagraphVectors.
+ * This method allows you to use pre-built WordVectors model (e.g. Word2Vec) for ParagraphVectors.
* Existing model will be transferred into new model before training starts.
*
* PLEASE NOTE: Non-normalized model is recommended to use here.
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java
index d31cc51b0..0e104bb20 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java
@@ -520,7 +520,7 @@ public class SequenceVectors extends WordVectorsImpl<
}
/**
- * This method allows you to use pre-built WordVectors model (SkipGram or GloVe) for DBOW sequence learning.
+ * This method allows you to use pre-built WordVectors model (e.g. SkipGram) for DBOW sequence learning.
* Existing model will be transferred into new model before training starts.
*
* PLEASE NOTE: This model has no effect for elements learning algorithms. Only sequence learning is affected.
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/glove/AbstractCoOccurrencesTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/glove/AbstractCoOccurrencesTest.java
deleted file mode 100644
index 8d59b2a5a..000000000
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/glove/AbstractCoOccurrencesTest.java
+++ /dev/null
@@ -1,101 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.deeplearning4j.models.glove;
-
-import org.deeplearning4j.BaseDL4JTest;
-import org.nd4j.common.io.ClassPathResource;
-import org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator;
-import org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer;
-import org.deeplearning4j.models.word2vec.VocabWord;
-import org.deeplearning4j.models.word2vec.wordstore.VocabConstructor;
-import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
-import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
-import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
-import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
-import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
-import org.junit.Before;
-import org.junit.Test;
-import org.nd4j.common.primitives.Pair;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.io.File;
-import java.util.ArrayList;
-import java.util.Iterator;
-import java.util.List;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotEquals;
-
-/**
- * @author raver119@gmail.com
- */
-public class AbstractCoOccurrencesTest extends BaseDL4JTest {
-
- private static final Logger log = LoggerFactory.getLogger(AbstractCoOccurrencesTest.class);
-
- @Before
- public void setUp() throws Exception {
-
- }
-
- @Test
- public void testFit1() throws Exception {
- ClassPathResource resource = new ClassPathResource("other/oneline.txt");
- File file = resource.getFile();
-
- AbstractCache vocabCache = new AbstractCache.Builder().build();
- BasicLineIterator underlyingIterator = new BasicLineIterator(file);
-
- TokenizerFactory t = new DefaultTokenizerFactory();
- t.setTokenPreProcessor(new CommonPreprocessor());
-
- SentenceTransformer transformer =
- new SentenceTransformer.Builder().iterator(underlyingIterator).tokenizerFactory(t).build();
-
- AbstractSequenceIterator sequenceIterator =
- new AbstractSequenceIterator.Builder<>(transformer).build();
-
- VocabConstructor constructor = new VocabConstructor.Builder()
- .addSource(sequenceIterator, 1).setTargetVocabCache(vocabCache).build();
-
- constructor.buildJointVocabulary(false, true);
-
- AbstractCoOccurrences coOccurrences = new AbstractCoOccurrences.Builder()
- .iterate(sequenceIterator).vocabCache(vocabCache).symmetric(false).windowSize(15).build();
-
- coOccurrences.fit();
-
- //List> list = coOccurrences.i();
- Iterator, Double>> iterator = coOccurrences.iterator();
- assertNotEquals(null, iterator);
- int cnt = 0;
-
- List> list = new ArrayList<>();
- while (iterator.hasNext()) {
- Pair, Double> pair = iterator.next();
- list.add(pair.getFirst());
- cnt++;
- }
-
-
- log.info("CoOccurrences: " + list);
-
- assertEquals(16, list.size());
- assertEquals(16, cnt);
- }
-}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/glove/GloveTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/glove/GloveTest.java
deleted file mode 100644
index 39aa40d10..000000000
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/glove/GloveTest.java
+++ /dev/null
@@ -1,137 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.deeplearning4j.models.glove;
-
-import org.deeplearning4j.BaseDL4JTest;
-import org.nd4j.common.io.ClassPathResource;
-import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
-import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
-import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
-import org.deeplearning4j.text.sentenceiterator.LineSentenceIterator;
-import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
-import org.deeplearning4j.text.sentenceiterator.SentencePreProcessor;
-import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
-import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
-import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
-import org.junit.Before;
-import org.junit.Ignore;
-import org.junit.Test;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.common.resources.Resources;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.io.File;
-import java.util.Collection;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
-
-/**
- * Created by agibsonccc on 12/3/14.
- */
-public class GloveTest extends BaseDL4JTest {
- private static final Logger log = LoggerFactory.getLogger(GloveTest.class);
- private Glove glove;
- private SentenceIterator iter;
-
- @Before
- public void before() throws Exception {
-
- ClassPathResource resource = new ClassPathResource("/raw_sentences.txt");
- File file = resource.getFile();
- iter = new LineSentenceIterator(file);
- iter.setPreProcessor(new SentencePreProcessor() {
- @Override
- public String preProcess(String sentence) {
- return sentence.toLowerCase();
- }
- });
-
- }
-
-
- @Ignore
- @Test
- public void testGlove() throws Exception {
- /*
- glove = new Glove.Builder().iterate(iter).symmetric(true).shuffle(true)
- .minWordFrequency(1).iterations(10).learningRate(0.1)
- .layerSize(300)
- .build();
-
- glove.fit();
- Collection words = glove.wordsNearest("day", 20);
- log.info("Nearest words to 'day': " + words);
- assertTrue(words.contains("week"));
-
- */
-
- }
-
- @Ignore
- @Test
- public void testGloVe1() throws Exception {
- File inputFile = Resources.asFile("big/raw_sentences.txt");
-
- SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
- // Split on white spaces in the line to get words
- TokenizerFactory t = new DefaultTokenizerFactory();
- t.setTokenPreProcessor(new CommonPreprocessor());
-
- Glove glove = new Glove.Builder().iterate(iter).tokenizerFactory(t).alpha(0.75).learningRate(0.1).epochs(45)
- .xMax(100).shuffle(true).symmetric(true).build();
-
- glove.fit();
-
- double simD = glove.similarity("day", "night");
- double simP = glove.similarity("best", "police");
-
-
-
- log.info("Day/night similarity: " + simD);
- log.info("Best/police similarity: " + simP);
-
- Collection words = glove.wordsNearest("day", 10);
- log.info("Nearest words to 'day': " + words);
-
-
- assertTrue(simD > 0.7);
-
- // actually simP should be somewhere at 0
- assertTrue(simP < 0.5);
-
- assertTrue(words.contains("night"));
- assertTrue(words.contains("year"));
- assertTrue(words.contains("week"));
-
- File tempFile = File.createTempFile("glove", "temp");
- tempFile.deleteOnExit();
-
- INDArray day1 = glove.getWordVectorMatrix("day").dup();
-
- WordVectorSerializer.writeWordVectors(glove, tempFile);
-
- WordVectors vectors = WordVectorSerializer.loadTxtVectors(tempFile);
-
- INDArray day2 = vectors.getWordVectorMatrix("day").dup();
-
- assertEquals(day1, day2);
-
- tempFile.delete();
- }
-}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/glove/count/BinaryCoOccurrenceReaderTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/glove/count/BinaryCoOccurrenceReaderTest.java
deleted file mode 100644
index 7f357a901..000000000
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/glove/count/BinaryCoOccurrenceReaderTest.java
+++ /dev/null
@@ -1,156 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.deeplearning4j.models.glove.count;
-
-import org.deeplearning4j.BaseDL4JTest;
-import org.deeplearning4j.models.word2vec.Huffman;
-import org.deeplearning4j.models.word2vec.VocabWord;
-import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
-import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
-import org.junit.Before;
-import org.junit.Test;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.io.File;
-
-import static org.junit.Assert.assertNotEquals;
-
-/**
- * Created by fartovii on 25.12.15.
- */
-public class BinaryCoOccurrenceReaderTest extends BaseDL4JTest {
-
- private static final Logger log = LoggerFactory.getLogger(BinaryCoOccurrenceReaderTest.class);
-
- @Before
- public void setUp() throws Exception {
-
- }
-
- @Test
- public void testHasMoreObjects1() throws Exception {
- File tempFile = File.createTempFile("tmp", "tmp");
- tempFile.deleteOnExit();
-
- VocabCache vocabCache = new AbstractCache.Builder().build();
-
- VocabWord word1 = new VocabWord(1.0, "human");
- VocabWord word2 = new VocabWord(2.0, "animal");
- VocabWord word3 = new VocabWord(3.0, "unknown");
-
- vocabCache.addToken(word1);
- vocabCache.addToken(word2);
- vocabCache.addToken(word3);
-
- Huffman huffman = new Huffman(vocabCache.vocabWords());
- huffman.build();
- huffman.applyIndexes(vocabCache);
-
-
- BinaryCoOccurrenceWriter writer = new BinaryCoOccurrenceWriter<>(tempFile);
-
- CoOccurrenceWeight object1 = new CoOccurrenceWeight<>();
- object1.setElement1(word1);
- object1.setElement2(word2);
- object1.setWeight(3.14159265);
-
- writer.writeObject(object1);
-
- CoOccurrenceWeight object2 = new CoOccurrenceWeight<>();
- object2.setElement1(word2);
- object2.setElement2(word3);
- object2.setWeight(0.197);
-
- writer.writeObject(object2);
-
- writer.finish();
-
- BinaryCoOccurrenceReader reader = new BinaryCoOccurrenceReader<>(tempFile, vocabCache, null);
-
-
- CoOccurrenceWeight r1 = reader.nextObject();
- log.info("Object received: " + r1);
- assertNotEquals(null, r1);
-
- r1 = reader.nextObject();
- log.info("Object received: " + r1);
- assertNotEquals(null, r1);
- }
-
- @Test
- public void testHasMoreObjects2() throws Exception {
- File tempFile = File.createTempFile("tmp", "tmp");
- tempFile.deleteOnExit();
-
- VocabCache vocabCache = new AbstractCache.Builder().build();
-
- VocabWord word1 = new VocabWord(1.0, "human");
- VocabWord word2 = new VocabWord(2.0, "animal");
- VocabWord word3 = new VocabWord(3.0, "unknown");
-
- vocabCache.addToken(word1);
- vocabCache.addToken(word2);
- vocabCache.addToken(word3);
-
- Huffman huffman = new Huffman(vocabCache.vocabWords());
- huffman.build();
- huffman.applyIndexes(vocabCache);
-
-
- BinaryCoOccurrenceWriter writer = new BinaryCoOccurrenceWriter<>(tempFile);
-
- CoOccurrenceWeight object1 = new CoOccurrenceWeight<>();
- object1.setElement1(word1);
- object1.setElement2(word2);
- object1.setWeight(3.14159265);
-
- writer.writeObject(object1);
-
- CoOccurrenceWeight object2 = new CoOccurrenceWeight<>();
- object2.setElement1(word2);
- object2.setElement2(word3);
- object2.setWeight(0.197);
-
- writer.writeObject(object2);
-
- CoOccurrenceWeight object3 = new CoOccurrenceWeight<>();
- object3.setElement1(word1);
- object3.setElement2(word3);
- object3.setWeight(0.001);
-
- writer.writeObject(object3);
-
- writer.finish();
-
- BinaryCoOccurrenceReader reader = new BinaryCoOccurrenceReader<>(tempFile, vocabCache, null);
-
-
- CoOccurrenceWeight r1 = reader.nextObject();
- log.info("Object received: " + r1);
- assertNotEquals(null, r1);
-
- r1 = reader.nextObject();
- log.info("Object received: " + r1);
- assertNotEquals(null, r1);
-
- r1 = reader.nextObject();
- log.info("Object received: " + r1);
- assertNotEquals(null, r1);
-
- }
-}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/glove/count/RoundCountTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/glove/count/RoundCountTest.java
deleted file mode 100644
index 737533648..000000000
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/glove/count/RoundCountTest.java
+++ /dev/null
@@ -1,90 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.deeplearning4j.models.glove.count;
-
-import org.deeplearning4j.BaseDL4JTest;
-import org.junit.Before;
-import org.junit.Test;
-
-import static org.junit.Assert.assertEquals;
-
-/**
- * Created by fartovii on 23.12.15.
- */
-public class RoundCountTest extends BaseDL4JTest {
-
- @Before
- public void setUp() throws Exception {
-
- }
-
- @Test
- public void testGet1() throws Exception {
- RoundCount count = new RoundCount(1);
-
- assertEquals(0, count.get());
-
- count.tick();
- assertEquals(1, count.get());
-
- count.tick();
- assertEquals(0, count.get());
- }
-
- @Test
- public void testGet2() throws Exception {
- RoundCount count = new RoundCount(3);
-
- assertEquals(0, count.get());
-
- count.tick();
- assertEquals(1, count.get());
-
- count.tick();
- assertEquals(2, count.get());
-
- count.tick();
- assertEquals(3, count.get());
-
- count.tick();
- assertEquals(0, count.get());
- }
-
- @Test
- public void testPrevious1() throws Exception {
- RoundCount count = new RoundCount(3);
-
- assertEquals(0, count.get());
- assertEquals(3, count.previous());
-
- count.tick();
- assertEquals(1, count.get());
- assertEquals(0, count.previous());
-
- count.tick();
- assertEquals(2, count.get());
- assertEquals(1, count.previous());
-
- count.tick();
- assertEquals(3, count.get());
- assertEquals(2, count.previous());
-
- count.tick();
- assertEquals(0, count.get());
- assertEquals(3, count.previous());
- }
-}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/SequenceVectorsTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/SequenceVectorsTest.java
index b3ba6e198..6ec46bb7d 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/SequenceVectorsTest.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/SequenceVectorsTest.java
@@ -21,12 +21,10 @@ import lombok.Getter;
import lombok.Setter;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
-import org.deeplearning4j.BaseDL4JTest;
-import org.nd4j.common.io.ClassPathResource;
import org.datavec.api.writable.Writable;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
-import org.deeplearning4j.models.embeddings.learning.impl.elements.GloVe;
import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
@@ -55,6 +53,7 @@ import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
+import org.nd4j.common.io.ClassPathResource;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.heartbeat.Heartbeat;
import org.slf4j.Logger;
@@ -270,65 +269,6 @@ public class SequenceVectorsTest extends BaseDL4JTest {
.epochs(1).resetModel(false).trainElementsRepresentation(false).build();
}
- @Ignore
- @Test
- public void testGlove1() throws Exception {
- logger.info("Max available memory: " + Runtime.getRuntime().maxMemory());
- ClassPathResource resource = new ClassPathResource("big/raw_sentences.txt");
- File file = resource.getFile();
-
- BasicLineIterator underlyingIterator = new BasicLineIterator(file);
-
- TokenizerFactory t = new DefaultTokenizerFactory();
- t.setTokenPreProcessor(new CommonPreprocessor());
-
- SentenceTransformer transformer =
- new SentenceTransformer.Builder().iterator(underlyingIterator).tokenizerFactory(t).build();
-
- AbstractSequenceIterator sequenceIterator =
- new AbstractSequenceIterator.Builder<>(transformer).build();
-
- VectorsConfiguration configuration = new VectorsConfiguration();
- configuration.setWindow(5);
- configuration.setLearningRate(0.06);
- configuration.setLayersSize(100);
-
-
- SequenceVectors vectors = new SequenceVectors.Builder(configuration)
- .iterate(sequenceIterator).iterations(1).epochs(45)
- .elementsLearningAlgorithm(new GloVe.Builder().shuffle(true).symmetric(true)
- .learningRate(0.05).alpha(0.75).xMax(100.0).build())
- .resetModel(true).trainElementsRepresentation(true).trainSequencesRepresentation(false).build();
-
- vectors.fit();
-
- double sim = vectors.similarity("day", "night");
- logger.info("Day/night similarity: " + sim);
-
-
- sim = vectors.similarity("day", "another");
- logger.info("Day/another similarity: " + sim);
-
- sim = vectors.similarity("night", "year");
- logger.info("Night/year similarity: " + sim);
-
- sim = vectors.similarity("night", "me");
- logger.info("Night/me similarity: " + sim);
-
- sim = vectors.similarity("day", "know");
- logger.info("Day/know similarity: " + sim);
-
- sim = vectors.similarity("best", "police");
- logger.info("Best/police similarity: " + sim);
-
- Collection labels = vectors.wordsNearest("day", 10);
- logger.info("Nearest labels to 'day': " + labels);
-
-
- sim = vectors.similarity("day", "night");
- assertTrue(sim > 0.6d);
- }
-
@Test
@Ignore
public void testDeepWalk() throws Exception {
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java
index 38b44d1ff..b8b30c6c9 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java
@@ -49,8 +49,8 @@ import java.io.File;
import java.util.Collection;
import java.util.concurrent.Callable;
-import static org.awaitility.Awaitility.await;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
@Slf4j
@@ -206,12 +206,6 @@ public class Word2VecTestsSmall extends BaseDL4JTest {
final MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
- await()
- .until(new Callable() {
- @Override
- public Boolean call() {
- return net.params().equalsWithEps(restored.params(), 2e-3);
- }
- });
+ assertTrue(net.params().equalsWithEps(restored.params(), 2e-3));
}
}
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseOutputLayer.java
index b5c68c884..a10bb33f3 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseOutputLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseOutputLayer.java
@@ -25,6 +25,7 @@ import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.Solver;
+import org.nd4j.common.base.Preconditions;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -247,10 +248,8 @@ public abstract class BaseOutputLayer Integer.MAX_VALUE)
throw new ND4JArraySizeException();
- int[] ret = new int[(int) d.size(0)];
- if (d.isRowVectorOrScalar())
- ret[0] = Nd4j.getBlasWrapper().iamax(output);
- else {
- for (int i = 0; i < ret.length; i++)
- ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
- }
- return ret;
+ Preconditions.checkState(output.rank() == 2, "predict(INDArray) method can only be used on rank 2 output - got array with rank %s", output.rank());
+ return output.argMax(1).toIntVector();
}
/**
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/glove/Glove.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/glove/Glove.java
deleted file mode 100644
index 81de8effb..000000000
--- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/glove/Glove.java
+++ /dev/null
@@ -1,280 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.deeplearning4j.spark.models.embeddings.glove;
-
-import org.apache.commons.math3.util.FastMath;
-import org.apache.spark.SparkConf;
-import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.api.java.function.PairFunction;
-import org.apache.spark.broadcast.Broadcast;
-import org.deeplearning4j.models.glove.GloveWeightLookupTable;
-import org.deeplearning4j.models.word2vec.VocabWord;
-import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
-import org.deeplearning4j.spark.models.embeddings.glove.cooccurrences.CoOccurrenceCalculator;
-import org.deeplearning4j.spark.models.embeddings.glove.cooccurrences.CoOccurrenceCounts;
-import org.deeplearning4j.spark.text.functions.TextPipeline;
-import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.factory.Nd4j;
-import org.nd4j.linalg.learning.legacy.AdaGrad;
-import org.nd4j.common.primitives.CounterMap;
-import org.nd4j.common.primitives.Pair;
-import org.nd4j.common.primitives.Triple;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import scala.Tuple2;
-
-import java.io.Serializable;
-import java.util.*;
-import java.util.concurrent.atomic.AtomicLong;
-
-import static org.deeplearning4j.spark.models.embeddings.word2vec.Word2VecVariables.*;
-
-/**
- * Spark glove
- *
- * @author Adam Gibson
- */
-public class Glove implements Serializable {
-
- private Broadcast> vocabCacheBroadcast;
- private String tokenizerFactoryClazz = DefaultTokenizerFactory.class.getName();
- private boolean symmetric = true;
- private int windowSize = 15;
- private int iterations = 300;
- private static Logger log = LoggerFactory.getLogger(Glove.class);
-
- /**
- *
- * @param tokenizerFactoryClazz the fully qualified class name of the tokenizer
- * @param symmetric whether the co occurrence counts should be symmetric
- * @param windowSize the window size for co occurrence
- * @param iterations the number of iterations
- */
- public Glove(String tokenizerFactoryClazz, boolean symmetric, int windowSize, int iterations) {
- this.tokenizerFactoryClazz = tokenizerFactoryClazz;
- this.symmetric = symmetric;
- this.windowSize = windowSize;
- this.iterations = iterations;
- }
-
- /**
- *
- * @param symmetric whether the co occurrence counts should be symmetric
- * @param windowSize the window size for co occurrence
- * @param iterations the number of iterations
- */
- public Glove(boolean symmetric, int windowSize, int iterations) {
- this.symmetric = symmetric;
- this.windowSize = windowSize;
- this.iterations = iterations;
- }
-
-
- private Pair update(AdaGrad weightAdaGrad, AdaGrad biasAdaGrad, INDArray syn0, INDArray bias,
- VocabWord w1, INDArray wordVector, INDArray contextVector, double gradient) {
- //gradient for word vectors
- INDArray grad1 = contextVector.mul(gradient);
- INDArray update = weightAdaGrad.getGradient(grad1, w1.getIndex(), syn0.shape());
- wordVector.subi(update);
-
- double w1Bias = bias.getDouble(w1.getIndex());
- double biasGradient = biasAdaGrad.getGradient(gradient, w1.getIndex(), bias.shape());
- double update2 = w1Bias - biasGradient;
- bias.putScalar(w1.getIndex(), bias.getDouble(w1.getIndex()) - update2);
- return new Pair<>(update, (float) update2);
- }
-
- /**
- * Train on the corpus
- * @param rdd the rdd to train
- * @return the vocab and weights
- */
- public Pair, GloveWeightLookupTable> train(JavaRDD rdd) throws Exception {
- // Each `train()` can use different parameters
- final JavaSparkContext sc = new JavaSparkContext(rdd.context());
- final SparkConf conf = sc.getConf();
- final int vectorLength = assignVar(VECTOR_LENGTH, conf, Integer.class);
- final boolean useAdaGrad = assignVar(ADAGRAD, conf, Boolean.class);
- final double negative = assignVar(NEGATIVE, conf, Double.class);
- final int numWords = assignVar(NUM_WORDS, conf, Integer.class);
- final int window = assignVar(WINDOW, conf, Integer.class);
- final double alpha = assignVar(ALPHA, conf, Double.class);
- final double minAlpha = assignVar(MIN_ALPHA, conf, Double.class);
- final int iterations = assignVar(ITERATIONS, conf, Integer.class);
- final int nGrams = assignVar(N_GRAMS, conf, Integer.class);
- final String tokenizer = assignVar(TOKENIZER, conf, String.class);
- final String tokenPreprocessor = assignVar(TOKEN_PREPROCESSOR, conf, String.class);
- final boolean removeStop = assignVar(REMOVE_STOPWORDS, conf, Boolean.class);
-
- Map tokenizerVarMap = new HashMap() {
- {
- put("numWords", numWords);
- put("nGrams", nGrams);
- put("tokenizer", tokenizer);
- put("tokenPreprocessor", tokenPreprocessor);
- put("removeStop", removeStop);
- }
- };
- Broadcast