diff --git a/deeplearning4j/deeplearning4j-common/pom.xml b/deeplearning4j/deeplearning4j-common/pom.xml
index 7e2e0ebb6..4f44582f6 100644
--- a/deeplearning4j/deeplearning4j-common/pom.xml
+++ b/deeplearning4j/deeplearning4j-common/pom.xml
@@ -33,6 +33,12 @@
nd4j-common
${nd4j.version}
+
+
+ junit
+ junit
+ test
+
@@ -43,5 +49,4 @@
test-nd4j-cuda-11.0
-
diff --git a/deeplearning4j/deeplearning4j-common/src/main/java/org/deeplearning4j/common/config/DL4JClassLoading.java b/deeplearning4j/deeplearning4j-common/src/main/java/org/deeplearning4j/common/config/DL4JClassLoading.java
new file mode 100644
index 000000000..ae8e6da3f
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-common/src/main/java/org/deeplearning4j/common/config/DL4JClassLoading.java
@@ -0,0 +1,133 @@
+/*******************************************************************************
+ * Copyright (c) Eclipse Deeplearning4j Contributors 2020
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.deeplearning4j.common.config;
+
+import lombok.extern.slf4j.Slf4j;
+import org.nd4j.common.config.ND4JClassLoading;
+
+import java.lang.reflect.InvocationTargetException;
+import java.util.Objects;
+import java.util.ServiceLoader;
+
+/**
+ * Global context for class-loading in DL4J.
+ *
Use {@code DL4JClassLoading} to define classloader for Deeplearning4j only! To define classloader used by
+ * {@code ND4J} use class {@link org.nd4j.common.config.ND4JClassLoading}.
+ *
+ *
Usage:
+ *
{@code
+ * public class Application {
+ * static {
+ * DL4JClassLoading.setDl4jClassloaderFromClass(Application.class);
+ * }
+ *
+ * public static void main(String[] args) {
+ * }
+ * }
+ * }
+ *
+ * @see org.nd4j.common.config.ND4JClassLoading
+ *
+ * @author Alexei KLENIN
+ */
+@Slf4j
+public class DL4JClassLoading {
+ private static ClassLoader dl4jClassloader = ND4JClassLoading.getNd4jClassloader();
+
+ private DL4JClassLoading() {
+ }
+
+ public static ClassLoader getDl4jClassloader() {
+ return DL4JClassLoading.dl4jClassloader;
+ }
+
+ public static void setDl4jClassloaderFromClass(Class> clazz) {
+ setDl4jClassloader(clazz.getClassLoader());
+ }
+
+ public static void setDl4jClassloader(ClassLoader dl4jClassloader) {
+ DL4JClassLoading.dl4jClassloader = dl4jClassloader;
+ log.debug("Global class-loader for DL4J was changed.");
+ }
+
+ public static boolean classPresentOnClasspath(String className) {
+ return classPresentOnClasspath(className, dl4jClassloader);
+ }
+
+ public static boolean classPresentOnClasspath(String className, ClassLoader classLoader) {
+ return loadClassByName(className, false, classLoader) != null;
+ }
+
+ public static Class loadClassByName(String className) {
+ return loadClassByName(className, true, dl4jClassloader);
+ }
+
+ @SuppressWarnings("unchecked")
+ public static Class loadClassByName(String className, boolean initialize, ClassLoader classLoader) {
+ try {
+ return (Class) Class.forName(className, initialize, classLoader);
+ } catch (ClassNotFoundException classNotFoundException) {
+ log.error(String.format("Cannot find class [%s] of provided class-loader.", className));
+ return null;
+ }
+ }
+
+ public static T createNewInstance(String className, Object... args) {
+ return createNewInstance(className, Object.class, args);
+ }
+
+ public static T createNewInstance(String className, Class super T> superclass) {
+ return createNewInstance(className, superclass, new Class>[]{}, new Object[]{});
+ }
+
+ public static T createNewInstance(String className, Class super T> superclass, Object... args) {
+ Class>[] parameterTypes = new Class>[args.length];
+ for (int i = 0; i < args.length; i++) {
+ Object arg = args[i];
+ Objects.requireNonNull(arg);
+ parameterTypes[i] = arg.getClass();
+ }
+
+ return createNewInstance(className, superclass, parameterTypes, args);
+ }
+
+ public static T createNewInstance(
+ String className,
+ Class super T> superclass,
+ Class>[] parameterTypes,
+ Object... args) {
+ try {
+ return (T) DL4JClassLoading
+ .loadClassByName(className)
+ .asSubclass(superclass)
+ .getDeclaredConstructor(parameterTypes)
+ .newInstance(args);
+ } catch (InstantiationException | IllegalAccessException | InvocationTargetException
+ | NoSuchMethodException instantiationException) {
+ log.error(String.format("Cannot create instance of class '%s'.", className), instantiationException);
+ throw new RuntimeException(instantiationException);
+ }
+ }
+
+ public static ServiceLoader loadService(Class serviceClass) {
+ return loadService(serviceClass, dl4jClassloader);
+ }
+
+ public static ServiceLoader loadService(Class serviceClass, ClassLoader classLoader) {
+ return ServiceLoader.load(serviceClass, classLoader);
+ }
+}
diff --git a/deeplearning4j/deeplearning4j-common/src/test/java/org/deeplearning4j/common/config/DL4JClassLoadingTest.java b/deeplearning4j/deeplearning4j-common/src/test/java/org/deeplearning4j/common/config/DL4JClassLoadingTest.java
new file mode 100644
index 000000000..9f0ff3bda
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-common/src/test/java/org/deeplearning4j/common/config/DL4JClassLoadingTest.java
@@ -0,0 +1,67 @@
+package org.deeplearning4j.common.config;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+
+import org.deeplearning4j.common.config.dummies.TestAbstract;
+import org.junit.Test;
+
+public class DL4JClassLoadingTest {
+ private static final String PACKAGE_PREFIX = "org.deeplearning4j.common.config.dummies.";
+
+ @Test
+ public void testCreateNewInstance_constructorWithoutArguments() {
+
+ /* Given */
+ String className = PACKAGE_PREFIX + "TestDummy";
+
+ /* When */
+ Object instance = DL4JClassLoading.createNewInstance(className);
+
+ /* Then */
+ assertNotNull(instance);
+ assertEquals(className, instance.getClass().getName());
+ }
+
+ @Test
+ public void testCreateNewInstance_constructorWithArgument_implicitArgumentTypes() {
+
+ /* Given */
+ String className = PACKAGE_PREFIX + "TestColor";
+
+ /* When */
+ TestAbstract instance = DL4JClassLoading.createNewInstance(className, TestAbstract.class, "white");
+
+ /* Then */
+ assertNotNull(instance);
+ assertEquals(className, instance.getClass().getName());
+ }
+
+ @Test
+ public void testCreateNewInstance_constructorWithArgument_explicitArgumentTypes() {
+
+ /* Given */
+ String colorClassName = PACKAGE_PREFIX + "TestColor";
+ String rectangleClassName = PACKAGE_PREFIX + "TestRectangle";
+
+ /* When */
+ TestAbstract color = DL4JClassLoading.createNewInstance(
+ colorClassName,
+ Object.class,
+ new Class>[]{ int.class, int.class, int.class },
+ 45, 175, 200);
+
+ TestAbstract rectangle = DL4JClassLoading.createNewInstance(
+ rectangleClassName,
+ Object.class,
+ new Class>[]{ int.class, int.class, TestAbstract.class },
+ 10, 15, color);
+
+ /* Then */
+ assertNotNull(color);
+ assertEquals(colorClassName, color.getClass().getName());
+
+ assertNotNull(rectangle);
+ assertEquals(rectangleClassName, rectangle.getClass().getName());
+ }
+}
diff --git a/deeplearning4j/deeplearning4j-common/src/test/java/org/deeplearning4j/common/config/dummies/TestAbstract.java b/deeplearning4j/deeplearning4j-common/src/test/java/org/deeplearning4j/common/config/dummies/TestAbstract.java
new file mode 100644
index 000000000..a833a4aa0
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-common/src/test/java/org/deeplearning4j/common/config/dummies/TestAbstract.java
@@ -0,0 +1,4 @@
+package org.deeplearning4j.common.config.dummies;
+
+public abstract class TestAbstract {
+}
diff --git a/deeplearning4j/deeplearning4j-common/src/test/java/org/deeplearning4j/common/config/dummies/TestColor.java b/deeplearning4j/deeplearning4j-common/src/test/java/org/deeplearning4j/common/config/dummies/TestColor.java
new file mode 100644
index 000000000..02b1e09ca
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-common/src/test/java/org/deeplearning4j/common/config/dummies/TestColor.java
@@ -0,0 +1,9 @@
+package org.deeplearning4j.common.config.dummies;
+
+public class TestColor extends TestAbstract {
+ public TestColor(String color) {
+ }
+
+ public TestColor(int r, int g, int b) {
+ }
+}
diff --git a/deeplearning4j/deeplearning4j-common/src/test/java/org/deeplearning4j/common/config/dummies/TestDummy.java b/deeplearning4j/deeplearning4j-common/src/test/java/org/deeplearning4j/common/config/dummies/TestDummy.java
new file mode 100644
index 000000000..682044d95
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-common/src/test/java/org/deeplearning4j/common/config/dummies/TestDummy.java
@@ -0,0 +1,6 @@
+package org.deeplearning4j.common.config.dummies;
+
+public class TestDummy {
+ public TestDummy() {
+ }
+}
diff --git a/deeplearning4j/deeplearning4j-common/src/test/java/org/deeplearning4j/common/config/dummies/TestRectangle.java b/deeplearning4j/deeplearning4j-common/src/test/java/org/deeplearning4j/common/config/dummies/TestRectangle.java
new file mode 100644
index 000000000..3b2b9cc9f
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-common/src/test/java/org/deeplearning4j/common/config/dummies/TestRectangle.java
@@ -0,0 +1,6 @@
+package org.deeplearning4j.common.config.dummies;
+
+public class TestRectangle extends TestAbstract {
+ public TestRectangle(int width, int height, TestAbstract color) {
+ }
+}
diff --git a/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/core/parallelism/AsyncIterator.java b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/core/parallelism/AsyncIterator.java
index b0e31d787..f45cc4036 100644
--- a/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/core/parallelism/AsyncIterator.java
+++ b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/core/parallelism/AsyncIterator.java
@@ -102,7 +102,6 @@ public class AsyncIterator implements Iterator {
}
}
-
private class ReaderThread extends Thread implements Runnable {
private BlockingQueue buffer;
private Iterator iterator;
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java
index b06db266d..1279be530 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java
@@ -19,6 +19,7 @@ package org.deeplearning4j;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
+import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
@@ -66,23 +67,23 @@ public class LayerHelperValidationUtil {
public static void disableCppHelpers(){
try {
- Class> c = Class.forName("org.nd4j.nativeblas.Nd4jCpu$Environment");
- Method m = c.getMethod("getInstance");
- Object instance = m.invoke(null);
- Method m2 = c.getMethod("allowHelpers", boolean.class);
- m2.invoke(instance, false);
+ Class> clazz = DL4JClassLoading.loadClassByName("org.nd4j.nativeblas.Nd4jCpu$Environment");
+ Method getInstance = clazz.getMethod("getInstance");
+ Object instance = getInstance.invoke(null);
+ Method allowHelpers = clazz.getMethod("allowHelpers", boolean.class);
+ allowHelpers.invoke(instance, false);
} catch (Throwable t){
throw new RuntimeException(t);
}
}
public static void enableCppHelpers(){
- try{
- Class> c = Class.forName("org.nd4j.nativeblas.Nd4jCpu$Environment");
- Method m = c.getMethod("getInstance");
- Object instance = m.invoke(null);
- Method m2 = c.getMethod("allowHelpers", boolean.class);
- m2.invoke(instance, true);
+ try {
+ Class> clazz = DL4JClassLoading.loadClassByName("org.nd4j.nativeblas.Nd4jCpu$Environment");
+ Method getInstance = clazz.getMethod("getInstance");
+ Object instance = getInstance.invoke(null);
+ Method allowHelpers = clazz.getMethod("allowHelpers", boolean.class);
+ allowHelpers.invoke(instance, true);
} catch (Throwable t){
throw new RuntimeException(t);
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java
index 96f641f47..da4754e0b 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java
@@ -16,28 +16,96 @@
package org.deeplearning4j.nn.dtypes;
-import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
-import org.deeplearning4j.nn.conf.preprocessor.*;
-import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayer;
-import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
-import org.nd4j.linalg.profiler.ProfilerConfig;
-import org.nd4j.shade.guava.collect.ImmutableSet;
-import org.nd4j.shade.guava.reflect.ClassPath;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
-import org.deeplearning4j.nn.conf.*;
+import org.deeplearning4j.common.config.DL4JClassLoading;
+import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
+import org.deeplearning4j.nn.conf.ConvolutionMode;
+import org.deeplearning4j.nn.conf.InputPreProcessor;
+import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.RNNFormat;
+import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.deeplearning4j.nn.conf.dropout.AlphaDropout;
import org.deeplearning4j.nn.conf.dropout.GaussianDropout;
import org.deeplearning4j.nn.conf.dropout.GaussianNoise;
import org.deeplearning4j.nn.conf.dropout.SpatialDropout;
-import org.deeplearning4j.nn.conf.graph.*;
+import org.deeplearning4j.nn.conf.graph.AttentionVertex;
+import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
+import org.deeplearning4j.nn.conf.graph.FrozenVertex;
+import org.deeplearning4j.nn.conf.graph.GraphVertex;
+import org.deeplearning4j.nn.conf.graph.L2NormalizeVertex;
+import org.deeplearning4j.nn.conf.graph.L2Vertex;
+import org.deeplearning4j.nn.conf.graph.LayerVertex;
+import org.deeplearning4j.nn.conf.graph.MergeVertex;
+import org.deeplearning4j.nn.conf.graph.PoolHelperVertex;
+import org.deeplearning4j.nn.conf.graph.PreprocessorVertex;
+import org.deeplearning4j.nn.conf.graph.ReshapeVertex;
+import org.deeplearning4j.nn.conf.graph.ScaleVertex;
+import org.deeplearning4j.nn.conf.graph.ShiftVertex;
+import org.deeplearning4j.nn.conf.graph.StackVertex;
+import org.deeplearning4j.nn.conf.graph.SubsetVertex;
+import org.deeplearning4j.nn.conf.graph.UnstackVertex;
import org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex;
import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex;
import org.deeplearning4j.nn.conf.graph.rnn.ReverseTimeSeriesVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
-import org.deeplearning4j.nn.conf.layers.*;
+import org.deeplearning4j.nn.conf.layers.ActivationLayer;
+import org.deeplearning4j.nn.conf.layers.AutoEncoder;
+import org.deeplearning4j.nn.conf.layers.BatchNormalization;
+import org.deeplearning4j.nn.conf.layers.CapsuleLayer;
+import org.deeplearning4j.nn.conf.layers.CapsuleStrengthLayer;
+import org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer;
+import org.deeplearning4j.nn.conf.layers.Cnn3DLossLayer;
+import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
+import org.deeplearning4j.nn.conf.layers.Convolution1D;
+import org.deeplearning4j.nn.conf.layers.Convolution2D;
+import org.deeplearning4j.nn.conf.layers.Convolution3D;
+import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
+import org.deeplearning4j.nn.conf.layers.Deconvolution2D;
+import org.deeplearning4j.nn.conf.layers.Deconvolution3D;
+import org.deeplearning4j.nn.conf.layers.DenseLayer;
+import org.deeplearning4j.nn.conf.layers.DepthwiseConvolution2D;
+import org.deeplearning4j.nn.conf.layers.DropoutLayer;
+import org.deeplearning4j.nn.conf.layers.EmbeddingLayer;
+import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer;
+import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
+import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM;
+import org.deeplearning4j.nn.conf.layers.GravesLSTM;
+import org.deeplearning4j.nn.conf.layers.LSTM;
+import org.deeplearning4j.nn.conf.layers.Layer;
+import org.deeplearning4j.nn.conf.layers.LearnedSelfAttentionLayer;
+import org.deeplearning4j.nn.conf.layers.LocalResponseNormalization;
+import org.deeplearning4j.nn.conf.layers.LocallyConnected1D;
+import org.deeplearning4j.nn.conf.layers.LocallyConnected2D;
+import org.deeplearning4j.nn.conf.layers.LossLayer;
+import org.deeplearning4j.nn.conf.layers.OutputLayer;
+import org.deeplearning4j.nn.conf.layers.PReLULayer;
+import org.deeplearning4j.nn.conf.layers.Pooling1D;
+import org.deeplearning4j.nn.conf.layers.Pooling2D;
+import org.deeplearning4j.nn.conf.layers.PoolingType;
+import org.deeplearning4j.nn.conf.layers.PrimaryCapsules;
+import org.deeplearning4j.nn.conf.layers.RecurrentAttentionLayer;
+import org.deeplearning4j.nn.conf.layers.RnnLossLayer;
+import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
+import org.deeplearning4j.nn.conf.layers.SelfAttentionLayer;
+import org.deeplearning4j.nn.conf.layers.SeparableConvolution2D;
+import org.deeplearning4j.nn.conf.layers.SpaceToBatchLayer;
+import org.deeplearning4j.nn.conf.layers.SpaceToDepthLayer;
+import org.deeplearning4j.nn.conf.layers.Subsampling1DLayer;
+import org.deeplearning4j.nn.conf.layers.Subsampling3DLayer;
+import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
+import org.deeplearning4j.nn.conf.layers.Upsampling1D;
+import org.deeplearning4j.nn.conf.layers.Upsampling2D;
+import org.deeplearning4j.nn.conf.layers.Upsampling3D;
+import org.deeplearning4j.nn.conf.layers.ZeroPadding1DLayer;
+import org.deeplearning4j.nn.conf.layers.ZeroPadding3DLayer;
+import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping3D;
@@ -49,16 +117,24 @@ import org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
+import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.deeplearning4j.nn.conf.layers.util.MaskLayer;
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer;
+import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
+import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor;
+import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor;
+import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnn3DPreProcessor;
+import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.util.IdentityLayer;
+import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayer;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
+import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitDistribution;
@@ -77,12 +153,17 @@ import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
+import org.nd4j.linalg.profiler.ProfilerConfig;
+import org.nd4j.shade.guava.collect.ImmutableSet;
+import org.nd4j.shade.guava.reflect.ClassPath;
import java.io.IOException;
import java.lang.reflect.Modifier;
-import java.util.*;
-
-import static org.junit.Assert.*;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
@Slf4j
public class DTypeTests extends BaseDL4JTest {
@@ -120,20 +201,17 @@ public class DTypeTests extends BaseDL4JTest {
Set> preprocClasses = new HashSet<>();
Set> vertexClasses = new HashSet<>();
for (ClassPath.ClassInfo ci : info) {
- Class> clazz;
- try {
- clazz = Class.forName(ci.getName());
- } catch (ClassNotFoundException e) {
- //Should never happen as this was found on the classpath
- throw new RuntimeException(e);
- }
+ Class> clazz = DL4JClassLoading.loadClassByName(ci.getName());
- if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface() || TFOpLayer.class == clazz) { //Skip TFOpLayer here - dtype depends on imported model dtype
+ if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface() || TFOpLayer.class == clazz) {
+ // Skip TFOpLayer here - dtype depends on imported model dtype
continue;
}
- if (clazz.getName().toLowerCase().contains("custom") || clazz.getName().contains("samediff.testlayers")
- || clazz.getName().toLowerCase().contains("test") || ignoreClasses.contains(clazz)) {
+ if (clazz.getName().toLowerCase().contains("custom")
+ || clazz.getName().contains("samediff.testlayers")
+ || clazz.getName().toLowerCase().contains("test")
+ || ignoreClasses.contains(clazz)) {
continue;
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTMTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTMTest.java
index 9c0f16f5e..f9e17cc50 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTMTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTMTest.java
@@ -18,6 +18,7 @@ package org.deeplearning4j.nn.layers.recurrent;
import lombok.val;
import org.deeplearning4j.BaseDL4JTest;
+import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@@ -167,7 +168,7 @@ public class GravesLSTMTest extends BaseDL4JTest {
actHelper.setAccessible(true);
//Call activateHelper with both forBackprop == true, and forBackprop == false and compare
- Class> innerClass = Class.forName("org.deeplearning4j.nn.layers.recurrent.FwdPassReturn");
+ Class> innerClass = DL4JClassLoading.loadClassByName("org.deeplearning4j.nn.layers.recurrent.FwdPassReturn");
Object oFalse = actHelper.invoke(lstm, false, null, null, false, LayerWorkspaceMgr.noWorkspacesImmutable()); //GravesLSTM.FwdPassReturn object; want fwdPassOutput INDArray
Object oTrue = actHelper.invoke(lstm, false, null, null, true, LayerWorkspaceMgr.noWorkspacesImmutable()); //want fwdPassOutputAsArrays object
diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java
index 47d063826..26c369d8b 100644
--- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java
+++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java
@@ -19,6 +19,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
+import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.AbstractLayer;
@@ -110,7 +111,7 @@ public class TFOpLayerImpl extends AbstractLayer {
org.nd4j.shade.protobuf.ByteString serialized = graphDef.toByteString();
byte[] graphBytes = serialized.toByteArray();
- ServiceLoader sl = ServiceLoader.load(TFGraphRunnerService.class);
+ ServiceLoader sl = DL4JClassLoading.loadService(TFGraphRunnerService.class);
Iterator iter = sl.iterator();
if (!iter.hasNext()){
throw new RuntimeException("The model contains a Tensorflow Op, which requires the nd4j-tensorflow dependency to execute.");
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/main/java/org/ansj/app/crf/Model.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/main/java/org/ansj/app/crf/Model.java
index 410b623cc..aedacee25 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/main/java/org/ansj/app/crf/Model.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/main/java/org/ansj/app/crf/Model.java
@@ -60,15 +60,13 @@ public abstract class Model {
/**
* 模型读取
- *
- * @param path
- * @return
- * @return
- * @throws Exception
+ *
*/
- public static Model load(Class extends Model> c, InputStream is) throws Exception {
- Model model = c.newInstance();
- return model.loadModel(is);
+ public static Model load(Class extends Model> modelClass, InputStream inputStream) throws Exception {
+ return modelClass
+ .getDeclaredConstructor()
+ .newInstance()
+ .loadModel(inputStream);
}
/**
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/main/java/org/ansj/dic/PathToStream.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/main/java/org/ansj/dic/PathToStream.java
index c1267be13..3b1d88546 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/main/java/org/ansj/dic/PathToStream.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/main/java/org/ansj/dic/PathToStream.java
@@ -5,6 +5,7 @@ import org.ansj.dic.impl.Jar2Stream;
import org.ansj.dic.impl.Jdbc2Stream;
import org.ansj.dic.impl.Url2Stream;
import org.ansj.exception.LibraryException;
+import org.deeplearning4j.common.config.DL4JClassLoading;
import java.io.InputStream;
@@ -25,7 +26,8 @@ public abstract class PathToStream {
} else if (path.startsWith("jar://")) {
return new Jar2Stream().toStream(path);
} else if (path.startsWith("class://")) {
- ((PathToStream) Class.forName(path.substring(8).split("\\|")[0]).newInstance()).toStream(path);
+ // Probably unused
+ return loadClass(path);
} else if (path.startsWith("http://") || path.startsWith("https://")) {
return new Url2Stream().toStream(path);
} else {
@@ -34,9 +36,17 @@ public abstract class PathToStream {
} catch (Exception e) {
throw new LibraryException(e);
}
- throw new LibraryException("not find method type in path " + path);
}
public abstract InputStream toStream(String path);
+ static InputStream loadClass(String path) {
+ String className = path
+ .substring("class://".length())
+ .split("\\|")[0];
+
+ return DL4JClassLoading
+ .createNewInstance(className, PathToStream.class)
+ .toStream(path);
+ }
}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/main/java/org/ansj/dic/impl/Jar2Stream.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/main/java/org/ansj/dic/impl/Jar2Stream.java
index 021e56aae..6d423acd5 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/main/java/org/ansj/dic/impl/Jar2Stream.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/main/java/org/ansj/dic/impl/Jar2Stream.java
@@ -3,6 +3,7 @@ package org.ansj.dic.impl;
import org.ansj.dic.DicReader;
import org.ansj.dic.PathToStream;
import org.ansj.exception.LibraryException;
+import org.deeplearning4j.common.config.DL4JClassLoading;
import java.io.InputStream;
@@ -17,12 +18,16 @@ public class Jar2Stream extends PathToStream {
@Override
public InputStream toStream(String path) {
if (path.contains("|")) {
- String[] split = path.split("\\|");
- try {
- return Class.forName(split[0].substring(6)).getResourceAsStream(split[1].trim());
- } catch (ClassNotFoundException e) {
- throw new LibraryException(e);
+ String[] tokens = path.split("\\|");
+ String className = tokens[0].substring(6);
+ String resourceName = tokens[1].trim();
+
+ Class resourceClass = DL4JClassLoading.loadClassByName(className);
+ if (resourceClass == null) {
+ throw new LibraryException(String.format("Class '%s' was not found.", className));
}
+
+ return resourceClass.getResourceAsStream(resourceName);
} else {
return DicReader.getInputStream(path.substring(6));
}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/main/java/org/ansj/dic/impl/Jdbc2Stream.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/main/java/org/ansj/dic/impl/Jdbc2Stream.java
index 52cc85899..f8dba2159 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/main/java/org/ansj/dic/impl/Jdbc2Stream.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/main/java/org/ansj/dic/impl/Jdbc2Stream.java
@@ -2,6 +2,7 @@ package org.ansj.dic.impl;
import org.ansj.dic.PathToStream;
import org.ansj.exception.LibraryException;
+import org.deeplearning4j.common.config.DL4JClassLoading;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
@@ -22,20 +23,26 @@ public class Jdbc2Stream extends PathToStream {
private static final byte[] LINE = "\n".getBytes();
+ private static final String[] JDBC_DRIVERS = {
+ "org.h2.Driver",
+ "com.ibm.db2.jcc.DB2Driver",
+ "org.hsqldb.jdbcDriver",
+ "org.gjt.mm.mysql.Driver",
+ "oracle.jdbc.OracleDriver",
+ "org.postgresql.Driver",
+ "net.sourceforge.jtds.jdbc.Driver",
+ "com.microsoft.sqlserver.jdbc.SQLServerDriver",
+ "org.sqlite.JDBC",
+ "com.mysql.jdbc.Driver"
+ };
+
static {
- String[] drivers = {"org.h2.Driver", "com.ibm.db2.jcc.DB2Driver", "org.hsqldb.jdbcDriver",
- "org.gjt.mm.mysql.Driver", "oracle.jdbc.OracleDriver", "org.postgresql.Driver",
- "net.sourceforge.jtds.jdbc.Driver", "com.microsoft.sqlserver.jdbc.SQLServerDriver",
- "org.sqlite.JDBC", "com.mysql.jdbc.Driver"};
- for (String driverClassName : drivers) {
- try {
- try {
- Thread.currentThread().getContextClassLoader().loadClass(driverClassName);
- } catch (ClassNotFoundException e) {
- Class.forName(driverClassName);
- }
- } catch (Throwable e) {
- }
+ loadJdbcDrivers();
+ }
+
+ static void loadJdbcDrivers() {
+ for (String driverClassName : JDBC_DRIVERS) {
+ DL4JClassLoading.loadClassByName(driverClassName);
}
}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java
index 3598bb62a..4a946256e 100755
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java
@@ -17,44 +17,20 @@
package org.deeplearning4j.models.embeddings.loader;
-import java.io.BufferedInputStream;
-import java.io.BufferedOutputStream;
-import java.io.BufferedReader;
-import java.io.BufferedWriter;
-import java.io.ByteArrayInputStream;
-import java.io.DataInputStream;
-import java.io.DataOutputStream;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.FileNotFoundException;
-import java.io.FileOutputStream;
-import java.io.FileReader;
-import java.io.FileWriter;
-import java.io.IOException;
-import java.io.InputStream;
-import java.io.InputStreamReader;
-import java.io.ObjectInputStream;
-import java.io.ObjectOutputStream;
-import java.io.OutputStream;
-import java.io.OutputStreamWriter;
-import java.io.PrintWriter;
-import java.io.UnsupportedEncodingException;
-import java.nio.charset.StandardCharsets;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.concurrent.atomic.AtomicInteger;
-import java.util.zip.GZIPInputStream;
-import java.util.zip.ZipEntry;
-import java.util.zip.ZipFile;
-import java.util.zip.ZipInputStream;
-import java.util.zip.ZipOutputStream;
-
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+import lombok.NonNull;
+import lombok.extern.slf4j.Slf4j;
+import lombok.val;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.compress.compressors.gzip.GzipUtils;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.apache.commons.io.output.CloseShieldOutputStream;
+import org.apache.commons.lang3.StringUtils;
+import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.common.util.DL4JFileUtils;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
@@ -94,12 +70,37 @@ import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.SerializationFeature;
import org.nd4j.storage.CompressedRamStorage;
-import lombok.AllArgsConstructor;
-import lombok.Data;
-import lombok.NoArgsConstructor;
-import lombok.NonNull;
-import lombok.extern.slf4j.Slf4j;
-import lombok.val;
+import java.io.BufferedInputStream;
+import java.io.BufferedOutputStream;
+import java.io.BufferedReader;
+import java.io.BufferedWriter;
+import java.io.ByteArrayInputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
+import java.io.FileOutputStream;
+import java.io.FileReader;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.io.OutputStream;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+import java.io.UnsupportedEncodingException;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.zip.GZIPInputStream;
+import java.util.zip.ZipEntry;
+import java.util.zip.ZipFile;
+import java.util.zip.ZipInputStream;
+import java.util.zip.ZipOutputStream;
/**
* This is utility class, providing various methods for WordVectors serialization
@@ -2676,26 +2677,23 @@ public class WordVectorSerializer {
}
protected static TokenizerFactory getTokenizerFactory(VectorsConfiguration configuration) {
- if (configuration == null)
+ if (configuration == null) {
return null;
-
- if (configuration.getTokenizerFactory() != null && !configuration.getTokenizerFactory().isEmpty()) {
- try {
- TokenizerFactory factory =
- (TokenizerFactory) Class.forName(configuration.getTokenizerFactory()).newInstance();
-
- if (configuration.getTokenPreProcessor() != null && !configuration.getTokenPreProcessor().isEmpty()) {
- TokenPreProcess preProcessor =
- (TokenPreProcess) Class.forName(configuration.getTokenPreProcessor()).newInstance();
- factory.setTokenPreProcessor(preProcessor);
- }
-
- return factory;
-
- } catch (Exception e) {
- log.error("Can't instantiate saved TokenizerFactory: {}", configuration.getTokenizerFactory());
- }
}
+
+ String tokenizerFactoryClassName = configuration.getTokenizerFactory();
+ if (StringUtils.isNotEmpty(tokenizerFactoryClassName)) {
+ TokenizerFactory factory = DL4JClassLoading.createNewInstance(tokenizerFactoryClassName);
+
+ String tokenPreProcessorClassName = configuration.getTokenPreProcessor();
+ if (StringUtils.isNotEmpty(tokenPreProcessorClassName)) {
+ TokenPreProcess preProcessor = DL4JClassLoading.createNewInstance(tokenizerFactoryClassName);
+ factory.setTokenPreProcessor(preProcessor);
+ }
+
+ return factory;
+ }
+
return null;
}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java
index 0e104bb20..33b77f658 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java
@@ -16,6 +16,8 @@
package org.deeplearning4j.models.sequencevectors;
+import org.apache.commons.lang3.StringUtils;
+import org.deeplearning4j.common.config.DL4JClassLoading;
import org.nd4j.shade.guava.primitives.Ints;
import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
import lombok.Getter;
@@ -494,15 +496,17 @@ public class SequenceVectors extends WordVectorsImpl<
this.useHierarchicSoftmax = configuration.isUseHierarchicSoftmax();
this.preciseMode = configuration.isPreciseMode();
- if (configuration.getModelUtils() != null && !configuration.getModelUtils().isEmpty()) {
-
+ String modelUtilsClassName = configuration.getModelUtils();
+ if (StringUtils.isNotEmpty(modelUtilsClassName)) {
try {
- this.modelUtils = (ModelUtils) Class.forName(configuration.getModelUtils()).newInstance();
- } catch (Exception e) {
- log.error("Got {} trying to instantiate ModelUtils, falling back to BasicModelUtils instead");
+ this.modelUtils = DL4JClassLoading.createNewInstance(modelUtilsClassName);
+ } catch (Exception instantiationException) {
+ log.error(
+ "Got '{}' trying to instantiate ModelUtils, falling back to BasicModelUtils instead",
+ instantiationException.getMessage(),
+ instantiationException);
this.modelUtils = new BasicModelUtils<>();
}
-
}
if (configuration.getElementsLearningAlgorithm() != null
@@ -551,12 +555,7 @@ public class SequenceVectors extends WordVectorsImpl<
* @return
*/
public Builder sequenceLearningAlgorithm(@NonNull String algoName) {
- try {
- Class clazz = Class.forName(algoName);
- sequenceLearningAlgorithm = (SequenceLearningAlgorithm) clazz.newInstance();
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
+ this.sequenceLearningAlgorithm = DL4JClassLoading.createNewInstance(algoName);
return this;
}
@@ -578,13 +577,9 @@ public class SequenceVectors extends WordVectorsImpl<
* @return
*/
public Builder elementsLearningAlgorithm(@NonNull String algoName) {
- try {
- Class clazz = Class.forName(algoName);
- elementsLearningAlgorithm = (ElementsLearningAlgorithm) clazz.newInstance();
- this.configuration.setElementsLearningAlgorithm(elementsLearningAlgorithm.getClass().getCanonicalName());
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
+ this.elementsLearningAlgorithm = DL4JClassLoading.createNewInstance(algoName);
+ this.configuration.setElementsLearningAlgorithm(elementsLearningAlgorithm.getClass().getCanonicalName());
+
return this;
}
@@ -943,31 +938,23 @@ public class SequenceVectors extends WordVectorsImpl<
.lr(learningRate).seed(seed).build();
}
- if (this.configuration.getElementsLearningAlgorithm() != null) {
- try {
- elementsLearningAlgorithm = (ElementsLearningAlgorithm) Class
- .forName(this.configuration.getElementsLearningAlgorithm()).newInstance();
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
+ String elementsLearningAlgorithm = this.configuration.getElementsLearningAlgorithm();
+ if (StringUtils.isNotEmpty(elementsLearningAlgorithm)) {
+ this.elementsLearningAlgorithm = DL4JClassLoading.createNewInstance(elementsLearningAlgorithm);
}
- if (this.configuration.getSequenceLearningAlgorithm() != null) {
- try {
- sequenceLearningAlgorithm = (SequenceLearningAlgorithm) Class
- .forName(this.configuration.getSequenceLearningAlgorithm()).newInstance();
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
+ String sequenceLearningAlgorithm = this.configuration.getSequenceLearningAlgorithm();
+ if (StringUtils.isNotEmpty(sequenceLearningAlgorithm)) {
+ this.sequenceLearningAlgorithm = DL4JClassLoading.createNewInstance(sequenceLearningAlgorithm);
}
- if (trainElementsVectors && elementsLearningAlgorithm == null) {
+ if (trainElementsVectors && this.elementsLearningAlgorithm == null) {
// create default implementation of ElementsLearningAlgorithm
- elementsLearningAlgorithm = new SkipGram<>();
+ this.elementsLearningAlgorithm = new SkipGram<>();
}
- if (trainSequenceVectors && sequenceLearningAlgorithm == null) {
- sequenceLearningAlgorithm = new DBOW<>();
+ if (trainSequenceVectors && this.sequenceLearningAlgorithm == null) {
+ this.sequenceLearningAlgorithm = new DBOW<>();
}
this.modelUtils.init(lookupTable);
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java
index acb6afa2c..faded5e58 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java
@@ -21,6 +21,7 @@ import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
+import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.base.Preconditions;
@@ -132,26 +133,21 @@ public class Dropout implements IDropout {
*/
protected void initializeHelper(DataType dataType){
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
+
if("CUDA".equalsIgnoreCase(backend)) {
- try {
- helper = Class.forName("org.deeplearning4j.cuda.dropout.CudnnDropoutHelper")
- .asSubclass(DropoutHelper.class).getConstructor(DataType.class).newInstance(dataType);
- log.debug("CudnnDropoutHelper successfully initialized");
- if (!helper.checkSupported()) {
- helper = null;
- }
- } catch (Throwable t) {
- if (!(t instanceof ClassNotFoundException)) {
- log.warn("Could not initialize CudnnDropoutHelper", t);
- }
- //Unlike other layers, don't warn here about CuDNN not found - if the user has any other layers that can
- // benefit from them cudnn, they will get a warning from those
+ helper = DL4JClassLoading.createNewInstance(
+ "org.deeplearning4j.cuda.dropout.CudnnDropoutHelper",
+ DropoutHelper.class,
+ dataType);
+ log.debug("CudnnDropoutHelper successfully initialized");
+ if (!helper.checkSupported()) {
+ helper = null;
}
}
+
initializedHelper = true;
}
-
@Override
public INDArray applyDropout(INDArray inputActivations, INDArray output, int iteration, int epoch, LayerWorkspaceMgr workspaceMgr) {
Preconditions.checkState(output.dataType().isFPType(), "Output array must be a floating point type, got %s for array of shape %ndShape",
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java
index a90218946..c4f594ece 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java
@@ -43,6 +43,7 @@ import org.nd4j.shade.jackson.databind.deser.std.StdDeserializer;
import org.nd4j.shade.jackson.databind.node.ObjectNode;
import java.io.IOException;
+import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
@@ -268,14 +269,17 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im
//Changed after 0.7.1 from "activationFunction" : "softmax" to "activationFn" :
protected void handleActivationBackwardCompatibility(BaseLayer baseLayer, ObjectNode on){
-
if(baseLayer.getActivationFn() == null && on.has("activationFunction")){
String afn = on.get("activationFunction").asText();
IActivation a = null;
try {
- a = getMap().get(afn.toLowerCase()).newInstance();
- } catch (InstantiationException | IllegalAccessException e){
- //Ignore
+ a = getMap()
+ .get(afn.toLowerCase())
+ .getDeclaredConstructor()
+ .newInstance();
+ } catch (InstantiationException | IllegalAccessException | NoSuchMethodException
+ | InvocationTargetException instantiationException){
+ log.error(instantiationException.getMessage());
}
baseLayer.setActivationFn(a);
}
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java
index 6d9f8b534..826df70da 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java
@@ -18,6 +18,7 @@ package org.deeplearning4j.nn.layers.convolution;
import lombok.extern.slf4j.Slf4j;
+import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.CNN2DFormat;
@@ -73,26 +74,19 @@ public class ConvolutionLayer extends BaseLayer c = Class.forName("org.nd4j.nativeblas.Nd4jCpu$Environment");
- Method m = c.getMethod("getInstance");
- Object instance = m.invoke(null);
- Method m2 = c.getMethod("isUseMKLDNN");
- boolean b = (Boolean)m2.invoke(instance);
- return b;
+ Class> clazz = DL4JClassLoading.loadClassByName("org.nd4j.nativeblas.Nd4jCpu$Environment");
+ Method getInstance = clazz.getMethod("getInstance");
+ Object instance = getInstance.invoke(null);
+ Method isUseMKLDNNMethod = clazz.getMethod("isUseMKLDNN");
+ return (boolean) isUseMKLDNNMethod.invoke(instance);
} catch (Throwable t ){
FAILED_CHECK = new AtomicBoolean(true);
return false;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java
index 8e4260845..c2a94a75b 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java
@@ -18,6 +18,7 @@ package org.deeplearning4j.nn.layers.normalization;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
+import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
@@ -75,24 +76,18 @@ public class BatchNormalization extends BaseLayer
* Created by nyghtowl on 10/29/15.
*/
+@Slf4j
public class LocalResponseNormalization
- extends AbstractLayer {
- protected static final Logger log =
- LoggerFactory.getLogger(org.deeplearning4j.nn.conf.layers.LocalResponseNormalization.class);
+ extends AbstractLayer {
protected LocalResponseNormalizationHelper helper = null;
protected int helperCountFail = 0;
@@ -86,19 +84,11 @@ public class LocalResponseNormalization
void initializeHelper() {
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
if("CUDA".equalsIgnoreCase(backend)) {
- try {
- helper = Class.forName("org.deeplearning4j.cuda.normalization.CudnnLocalResponseNormalizationHelper")
- .asSubclass(LocalResponseNormalizationHelper.class).getConstructor(DataType.class).newInstance(dataType);
- log.debug("CudnnLocalResponseNormalizationHelper successfully initialized");
- } catch (Throwable t) {
- if (!(t instanceof ClassNotFoundException)) {
- log.warn("Could not initialize CudnnLocalResponseNormalizationHelper", t);
- } else {
- OneTimeLogger.info(log, "cuDNN not found: "
- + "use cuDNN for better GPU performance by including the deeplearning4j-cuda module. "
- + "For more information, please refer to: https://deeplearning4j.konduit.ai/config/backends/config-cudnn", t);
- }
- }
+ helper = DL4JClassLoading.createNewInstance(
+ "org.deeplearning4j.cuda.normalization.CudnnLocalResponseNormalizationHelper",
+ LocalResponseNormalizationHelper.class,
+ dataType);
+ log.debug("CudnnLocalResponseNormalizationHelper successfully initialized");
}
//2019-03-09 AB - MKL-DNN helper disabled: https://github.com/deeplearning4j/deeplearning4j/issues/7272
// else if("CPU".equalsIgnoreCase(backend)){
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LSTM.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LSTM.java
index 2a2d125d5..b284e6bcb 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LSTM.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LSTM.java
@@ -17,6 +17,7 @@
package org.deeplearning4j.nn.layers.recurrent;
import lombok.extern.slf4j.Slf4j;
+import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@@ -58,21 +59,13 @@ public class LSTM extends BaseRecurrentLayer extends SequenceVec
validateConfiguration();
if (ela == null) {
- try {
- ela = (SparkElementsLearningAlgorithm) Class.forName(configuration.getElementsLearningAlgorithm())
- .newInstance();
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
+ String className = configuration.getElementsLearningAlgorithm();
+ ela = DL4JClassLoading.createNewInstance(className);
}
-
if (workers > 1) {
log.info("Repartitioning corpus to {} parts...", workers);
corpus.repartition(workers);
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/BaseTokenizerFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/BaseTokenizerFunction.java
index bf6625164..e0e100c88 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/BaseTokenizerFunction.java
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/BaseTokenizerFunction.java
@@ -18,6 +18,7 @@ package org.deeplearning4j.spark.models.sequencevectors.functions;
import lombok.NonNull;
import org.apache.spark.broadcast.Broadcast;
+import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
@@ -42,24 +43,17 @@ public abstract class BaseTokenizerFunction implements Serializable {
String tpClassName = this.configurationBroadcast.getValue().getTokenPreProcessor();
if (tfClassName != null && !tfClassName.isEmpty()) {
- try {
- tokenizerFactory = (TokenizerFactory) Class.forName(tfClassName).newInstance();
+ tokenizerFactory = DL4JClassLoading.createNewInstance(tfClassName);
- if (tpClassName != null && !tpClassName.isEmpty()) {
- try {
- tokenPreprocessor = (TokenPreProcess) Class.forName(tpClassName).newInstance();
- } catch (Exception e) {
- throw new RuntimeException("Unable to instantiate TokenPreProcessor.", e);
- }
- }
-
- if (tokenPreprocessor != null) {
- tokenizerFactory.setTokenPreProcessor(tokenPreprocessor);
- }
- } catch (Exception e) {
- throw new RuntimeException("Unable to instantiate TokenizerFactory.", e);
+ if (tpClassName != null && !tpClassName.isEmpty()) {
+ tokenPreprocessor = DL4JClassLoading.createNewInstance(tpClassName);
}
- } else
+
+ if (tokenPreprocessor != null) {
+ tokenizerFactory.setTokenPreProcessor(tokenPreprocessor);
+ }
+ } else {
throw new RuntimeException("TokenizerFactory wasn't defined.");
+ }
}
}
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/CountFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/CountFunction.java
index 15eaa7162..75fc057f4 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/CountFunction.java
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/CountFunction.java
@@ -21,6 +21,7 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.spark.Accumulator;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.broadcast.Broadcast;
+import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
@@ -65,13 +66,8 @@ public class CountFunction implements Function implements Voi
if (vectorsConfiguration == null)
vectorsConfiguration = configurationBroadcast.getValue();
+ String elementsLearningAlgorithm = vectorsConfiguration.getElementsLearningAlgorithm();
if (paramServer == null) {
paramServer = VoidParameterServer.getInstance();
- if (elementsLearningAlgorithm == null) {
- try {
- elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class
- .forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
+ if (this.elementsLearningAlgorithm == null) {
+ this.elementsLearningAlgorithm = DL4JClassLoading.createNewInstance(elementsLearningAlgorithm);
}
- driver = elementsLearningAlgorithm.getTrainingDriver();
+ driver = this.elementsLearningAlgorithm.getTrainingDriver();
// FIXME: init line should probably be removed, basically init happens in VocabRddFunction
paramServer.init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
@@ -95,33 +92,24 @@ public class PartitionTrainingFunction implements Voi
if (shallowVocabCache == null)
shallowVocabCache = vocabCacheBroadcast.getValue();
- if (elementsLearningAlgorithm == null && vectorsConfiguration.getElementsLearningAlgorithm() != null) {
+ if (this.elementsLearningAlgorithm == null && elementsLearningAlgorithm != null) {
// TODO: do ELA initialization
- try {
- elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class
- .forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
+ this.elementsLearningAlgorithm = DL4JClassLoading.createNewInstance(elementsLearningAlgorithm);
}
- if (elementsLearningAlgorithm != null)
- elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
+ if (this.elementsLearningAlgorithm != null)
+ this.elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
- if (sequenceLearningAlgorithm == null && vectorsConfiguration.getSequenceLearningAlgorithm() != null) {
+ String sequenceLearningAlgorithm = vectorsConfiguration.getSequenceLearningAlgorithm();
+ if (this.sequenceLearningAlgorithm == null && sequenceLearningAlgorithm != null) {
// TODO: do SLA initialization
- try {
- sequenceLearningAlgorithm = (SparkSequenceLearningAlgorithm) Class
- .forName(vectorsConfiguration.getSequenceLearningAlgorithm()).newInstance();
- sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
+ this.sequenceLearningAlgorithm = DL4JClassLoading.createNewInstance(sequenceLearningAlgorithm);
+ this.sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
}
- if (sequenceLearningAlgorithm != null)
- sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
+ if (this.sequenceLearningAlgorithm != null)
+ this.sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
- if (elementsLearningAlgorithm == null && sequenceLearningAlgorithm == null) {
+ if (this.elementsLearningAlgorithm == null && this.sequenceLearningAlgorithm == null) {
throw new ND4JIllegalStateException("No LearningAlgorithms specified!");
}
@@ -142,7 +130,7 @@ public class PartitionTrainingFunction implements Voi
}
// do the same with labels, transfer them, if any
- if (sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
+ if (this.sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
for (T label : sequence.getSequenceLabels()) {
ShallowSequenceElement reduced = shallowVocabCache.tokenFor(label.getStorageId());
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/TrainingFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/TrainingFunction.java
index 8be785c6a..bb80d51c2 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/TrainingFunction.java
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/TrainingFunction.java
@@ -20,6 +20,7 @@ import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.broadcast.Broadcast;
+import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
@@ -73,19 +74,15 @@ public class TrainingFunction implements VoidFunction
if (vectorsConfiguration == null)
vectorsConfiguration = configurationBroadcast.getValue();
+ String elementsLearningAlgorithm = vectorsConfiguration.getElementsLearningAlgorithm();
if (paramServer == null) {
paramServer = VoidParameterServer.getInstance();
- if (elementsLearningAlgorithm == null) {
- try {
- elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class
- .forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
+ if (this.elementsLearningAlgorithm == null) {
+ this.elementsLearningAlgorithm = DL4JClassLoading.createNewInstance(elementsLearningAlgorithm);
}
- driver = elementsLearningAlgorithm.getTrainingDriver();
+ driver = this.elementsLearningAlgorithm.getTrainingDriver();
// FIXME: init line should probably be removed, basically init happens in VocabRddFunction
paramServer.init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
@@ -98,33 +95,23 @@ public class TrainingFunction implements VoidFunction
shallowVocabCache = vocabCacheBroadcast.getValue();
- if (elementsLearningAlgorithm == null && vectorsConfiguration.getElementsLearningAlgorithm() != null) {
+ if (this.elementsLearningAlgorithm == null && elementsLearningAlgorithm != null) {
// TODO: do ELA initialization
- try {
- elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class
- .forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
- elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
+ this.elementsLearningAlgorithm = DL4JClassLoading.createNewInstance(elementsLearningAlgorithm);
+ this.elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
}
- if (sequenceLearningAlgorithm == null && vectorsConfiguration.getSequenceLearningAlgorithm() != null) {
+ String sequenceLearningAlgorithm = vectorsConfiguration.getSequenceLearningAlgorithm();
+ if (this.sequenceLearningAlgorithm == null && sequenceLearningAlgorithm != null) {
// TODO: do SLA initialization
- try {
- sequenceLearningAlgorithm = (SparkSequenceLearningAlgorithm) Class
- .forName(vectorsConfiguration.getSequenceLearningAlgorithm()).newInstance();
- sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
+ this.sequenceLearningAlgorithm = DL4JClassLoading.createNewInstance(sequenceLearningAlgorithm);
+ this.sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
}
- if (elementsLearningAlgorithm == null && sequenceLearningAlgorithm == null) {
+ if (this.elementsLearningAlgorithm == null && this.sequenceLearningAlgorithm == null) {
throw new ND4JIllegalStateException("No LearningAlgorithms specified!");
}
-
/*
at this moment we should have everything ready for actual initialization
the only limitation we have - our sequence is detached from actual vocabulary, so we need to merge it back virtually
@@ -139,7 +126,7 @@ public class TrainingFunction implements VoidFunction
}
// do the same with labels, transfer them, if any
- if (sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
+ if (this.sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
for (T label : sequence.getSequenceLabels()) {
ShallowSequenceElement reduced = shallowVocabCache.tokenFor(label.getStorageId());
@@ -157,7 +144,7 @@ public class TrainingFunction implements VoidFunction
// FIXME: temporary hook
if (sequence.size() > 0)
paramServer.execDistributed(
- elementsLearningAlgorithm.frameSequence(mergedSequence, new AtomicLong(119), 25e-3));
+ this.elementsLearningAlgorithm.frameSequence(mergedSequence, new AtomicLong(119), 25e-3));
else
log.warn("Skipping empty sequence...");
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/VocabRddFunctionFlat.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/VocabRddFunctionFlat.java
index d14668154..6a6cdad6a 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/VocabRddFunctionFlat.java
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/VocabRddFunctionFlat.java
@@ -19,6 +19,7 @@ package org.deeplearning4j.spark.models.sequencevectors.functions;
import lombok.NonNull;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
+import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
@@ -56,12 +57,8 @@ public class VocabRddFunctionFlat implements FlatMapF
configuration = vectorsConfigurationBroadcast.getValue();
if (ela == null) {
- try {
- ela = (SparkElementsLearningAlgorithm) Class.forName(configuration.getElementsLearningAlgorithm())
- .newInstance();
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
+ String className = configuration.getElementsLearningAlgorithm();
+ ela = DL4JClassLoading.createNewInstance(className);
}
driver = ela.getTrainingDriver();
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/TokenizerFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/TokenizerFunction.java
index 41c9163e1..5049619ee 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/TokenizerFunction.java
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/TokenizerFunction.java
@@ -17,12 +17,14 @@
package org.deeplearning4j.spark.text.functions;
import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.lang3.StringUtils;
import org.apache.spark.api.java.function.Function;
+import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.deeplearning4j.text.tokenization.tokenizerfactory.NGramTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
-import java.util.Arrays;
+import java.util.Collections;
import java.util.List;
/**
@@ -44,35 +46,33 @@ public class TokenizerFunction implements Function> {
}
@Override
- public List call(String v1) throws Exception {
- if (tokenizerFactory == null)
+ public List call(String str) {
+ if (tokenizerFactory == null) {
tokenizerFactory = getTokenizerFactory();
- if (v1.isEmpty())
- return Arrays.asList("");
- return tokenizerFactory.create(v1).getTokens();
+ }
+
+ if (str.isEmpty()) {
+ return Collections.singletonList("");
+ }
+
+ return tokenizerFactory.create(str).getTokens();
}
private TokenizerFactory getTokenizerFactory() {
- try {
- TokenPreProcess tokenPreProcessInst = null;
- // token preprocess CAN be undefined
- if (tokenizerPreprocessorClazz != null && !tokenizerPreprocessorClazz.isEmpty()) {
- Class extends TokenPreProcess> clazz =
- (Class extends TokenPreProcess>) Class.forName(tokenizerPreprocessorClazz);
- tokenPreProcessInst = clazz.newInstance();
- }
+ TokenPreProcess tokenPreProcessInst = null;
- Class extends TokenizerFactory> clazz2 =
- (Class extends TokenizerFactory>) Class.forName(tokenizerFactoryClazz);
- tokenizerFactory = clazz2.newInstance();
- if (tokenPreProcessInst != null)
- tokenizerFactory.setTokenPreProcessor(tokenPreProcessInst);
- if (nGrams > 1) {
- tokenizerFactory = new NGramTokenizerFactory(tokenizerFactory, nGrams, nGrams);
- }
- } catch (Exception e) {
- log.error("",e);
+ if (StringUtils.isNotEmpty(tokenizerPreprocessorClazz)) {
+ tokenPreProcessInst = DL4JClassLoading.createNewInstance(tokenizerPreprocessorClazz);
}
+
+ tokenizerFactory = DL4JClassLoading.createNewInstance(tokenizerFactoryClazz);
+
+ if (tokenPreProcessInst != null)
+ tokenizerFactory.setTokenPreProcessor(tokenPreProcessInst);
+ if (nGrams > 1) {
+ tokenizerFactory = new NGramTokenizerFactory(tokenizerFactory, nGrams, nGrams);
+ }
+
return tokenizerFactory;
}
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/time/TimeSourceProvider.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/time/TimeSourceProvider.java
index 880c6d7fa..b699cf6fc 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/time/TimeSourceProvider.java
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/time/TimeSourceProvider.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.spark.time;
+import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.common.config.DL4JSystemProperties;
import java.lang.reflect.Method;
@@ -62,9 +63,9 @@ public class TimeSourceProvider {
*/
public static TimeSource getInstance(String className) {
try {
- Class> c = Class.forName(className);
- Method m = c.getMethod("getInstance");
- return (TimeSource) m.invoke(null);
+ Class> clazz = DL4JClassLoading.loadClassByName(className);
+ Method getInstance = clazz.getMethod("getInstance");
+ return (TimeSource) getInstance.invoke(null);
} catch (Exception e) {
throw new RuntimeException("Error getting TimeSource instance for class \"" + className + "\"", e);
}
diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/main/java/org/deeplearning4j/ui/model/stats/BaseStatsListener.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/main/java/org/deeplearning4j/ui/model/stats/BaseStatsListener.java
index 34d5dbb12..7e4518d16 100644
--- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/main/java/org/deeplearning4j/ui/model/stats/BaseStatsListener.java
+++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/main/java/org/deeplearning4j/ui/model/stats/BaseStatsListener.java
@@ -19,6 +19,7 @@ package org.deeplearning4j.ui.model.stats;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.bytedeco.javacpp.Pointer;
+import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.core.storage.StatsStorageRouter;
import org.deeplearning4j.core.storage.StorageMetaData;
import org.deeplearning4j.core.storage.listener.RoutingIterationListener;
@@ -696,11 +697,14 @@ public abstract class BaseStatsListener implements RoutingIterationListener {
return devPointers.get(device);
}
try {
- Class> c = Class.forName("org.nd4j.jita.allocator.pointers.CudaPointer");
- Constructor> constructor = c.getConstructor(long.class);
- Pointer p = (Pointer) constructor.newInstance((long) device);
- devPointers.put(device, p);
- return p;
+ Pointer pointer = DL4JClassLoading.createNewInstance(
+ "org.nd4j.jita.allocator.pointers.CudaPointer",
+ Pointer.class,
+ new Class[] { long.class },
+ (long) device);
+
+ devPointers.put(device, pointer);
+ return pointer;
} catch (Throwable t) {
devPointers.put(device, null); //Stops attempting the failure again later...
return null;
@@ -711,9 +715,9 @@ public abstract class BaseStatsListener implements RoutingIterationListener {
ModelInfo modelInfo = getModelInfo(model);
int examplesThisMinibatch = 0;
if (model instanceof MultiLayerNetwork) {
- examplesThisMinibatch = ((MultiLayerNetwork) model).batchSize();
+ examplesThisMinibatch = model.batchSize();
} else if (model instanceof ComputationGraph) {
- examplesThisMinibatch = ((ComputationGraph) model).batchSize();
+ examplesThisMinibatch = model.batchSize();
} else if (model instanceof Layer) {
examplesThisMinibatch = ((Layer) model).getInputMiniBatchSize();
}
diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/main/java/org/deeplearning4j/ui/model/storage/mapdb/MapDBStatsStorage.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/main/java/org/deeplearning4j/ui/model/storage/mapdb/MapDBStatsStorage.java
index 307fe6a9f..4ac703501 100644
--- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/main/java/org/deeplearning4j/ui/model/storage/mapdb/MapDBStatsStorage.java
+++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/main/java/org/deeplearning4j/ui/model/storage/mapdb/MapDBStatsStorage.java
@@ -18,6 +18,7 @@ package org.deeplearning4j.ui.model.storage.mapdb;
import lombok.Data;
import lombok.NonNull;
+import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.core.storage.*;
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
@@ -318,26 +319,18 @@ public class MapDBStatsStorage extends BaseCollectionStatsStorage {
}
@Override
+ @SuppressWarnings("unchecked")
public T deserialize(@NonNull DataInput2 input, int available) throws IOException {
int classIdx = input.readInt();
String className = getClassForInt(classIdx);
- Class> clazz;
- try {
- clazz = Class.forName(className);
- } catch (ClassNotFoundException e) {
- throw new RuntimeException(e); //Shouldn't normally happen...
- }
- Persistable p;
- try {
- p = (Persistable) clazz.newInstance();
- } catch (InstantiationException | IllegalAccessException e) {
- throw new RuntimeException(e);
- }
- int remainingLength = available - 4; //-4 for int class index
+
+ Persistable persistable = DL4JClassLoading.createNewInstance(className);
+
+ int remainingLength = available - 4; // -4 for int class index
byte[] temp = new byte[remainingLength];
input.readFully(temp);
- p.decode(temp);
- return (T) p;
+ persistable.decode(temp);
+ return (T) persistable;
}
@Override
diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java
index ff59dd2d8..56a3f2e63 100644
--- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java
+++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java
@@ -32,32 +32,42 @@ import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FilenameUtils;
+import org.deeplearning4j.common.config.DL4JClassLoading;
+import org.deeplearning4j.common.config.DL4JSystemProperties;
+import org.deeplearning4j.common.util.DL4JFileUtils;
import org.deeplearning4j.core.storage.StatsStorage;
import org.deeplearning4j.core.storage.StatsStorageEvent;
import org.deeplearning4j.core.storage.StatsStorageListener;
import org.deeplearning4j.core.storage.StatsStorageRouter;
-import org.deeplearning4j.common.config.DL4JSystemProperties;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.ui.api.Route;
import org.deeplearning4j.ui.api.UIModule;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.i18n.I18NProvider;
+import org.deeplearning4j.ui.model.storage.FileStatsStorage;
+import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
+import org.deeplearning4j.ui.model.storage.impl.QueueStatsStorageListener;
import org.deeplearning4j.ui.module.SameDiffModule;
import org.deeplearning4j.ui.module.convolutional.ConvolutionalListenerModule;
import org.deeplearning4j.ui.module.defaultModule.DefaultModule;
import org.deeplearning4j.ui.module.remote.RemoteReceiverModule;
import org.deeplearning4j.ui.module.train.TrainModule;
import org.deeplearning4j.ui.module.tsne.TsneModule;
-import org.deeplearning4j.ui.model.storage.FileStatsStorage;
-import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
-import org.deeplearning4j.ui.model.storage.impl.QueueStatsStorageListener;
-import org.deeplearning4j.common.util.DL4JFileUtils;
import org.nd4j.common.function.Function;
import org.nd4j.common.primitives.Pair;
import java.io.File;
-import java.util.*;
-import java.util.concurrent.*;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.ServiceLoader;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CopyOnWriteArrayList;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
@Slf4j
@@ -402,8 +412,7 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
}
private void modulesViaServiceLoader(List uiModules) {
-
- ServiceLoader sl = ServiceLoader.load(UIModule.class);
+ ServiceLoader sl = DL4JClassLoading.loadService(UIModule.class);
Iterator iter = sl.iterator();
if (!iter.hasNext()) {
@@ -411,19 +420,19 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
}
while (iter.hasNext()) {
- UIModule m = iter.next();
- Class> c = m.getClass();
+ UIModule module = iter.next();
+ Class> moduleClass = module.getClass();
boolean foundExisting = false;
for (UIModule mExisting : uiModules) {
- if (mExisting.getClass() == c) {
+ if (mExisting.getClass() == moduleClass) {
foundExisting = true;
break;
}
}
if (!foundExisting) {
- log.debug("Loaded UI module via service loader: {}", m.getClass());
- uiModules.add(m);
+ log.debug("Loaded UI module via service loader: {}", module.getClass());
+ uiModules.add(module);
}
}
}
diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/i18n/DefaultI18N.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/i18n/DefaultI18N.java
index 523dcd48e..2dd768b07 100644
--- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/i18n/DefaultI18N.java
+++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/i18n/DefaultI18N.java
@@ -19,6 +19,7 @@ package org.deeplearning4j.ui.i18n;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
+import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.ui.api.I18N;
import org.deeplearning4j.ui.api.UIModule;
@@ -100,13 +101,13 @@ public class DefaultI18N implements I18N {
}
private synchronized void loadLanguages(){
- ServiceLoader sl = ServiceLoader.load(UIModule.class);
+ ServiceLoader loadedModules = DL4JClassLoading.loadService(UIModule.class);
- for(UIModule m : sl){
- List resources = m.getInternationalizationResources();
- for(I18NResource r : resources){
+ for (UIModule module : loadedModules){
+ List resources = module.getInternationalizationResources();
+ for(I18NResource resource : resources){
try {
- String path = r.getResource();
+ String path = resource.getResource();
int idxLast = path.lastIndexOf('.');
if (idxLast < 0) {
log.warn("Skipping language resource file: cannot infer language: {}", path);
@@ -116,9 +117,9 @@ public class DefaultI18N implements I18N {
String langCode = path.substring(idxLast + 1).toLowerCase();
Map map = messagesByLanguage.computeIfAbsent(langCode, k -> new HashMap<>());
- parseFile(r, map);
+ parseFile(resource, map);
} catch (Throwable t){
- log.warn("Error parsing UI I18N content file; skipping: {}", r.getResource(), t);
+ log.warn("Error parsing UI I18N content file; skipping: {}", resource.getResource(), t);
languageLoadingException = t;
}
}
diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/module/remote/RemoteReceiverModule.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/module/remote/RemoteReceiverModule.java
index baee006fe..1f5040991 100644
--- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/module/remote/RemoteReceiverModule.java
+++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/module/remote/RemoteReceiverModule.java
@@ -22,6 +22,7 @@ import io.netty.handler.codec.http.HttpResponseStatus;
import io.vertx.core.json.JsonObject;
import io.vertx.ext.web.RoutingContext;
import lombok.extern.slf4j.Slf4j;
+import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.core.storage.*;
import org.deeplearning4j.core.storage.impl.RemoteUIStatsStorageRouter;
import org.deeplearning4j.ui.api.HttpMethod;
@@ -154,9 +155,12 @@ public class RemoteReceiverModule implements UIModule {
private StorageMetaData getMetaData(String dataClass, String content) {
StorageMetaData meta;
try {
- Class> c = Class.forName(dataClass);
- if (StorageMetaData.class.isAssignableFrom(c)) {
- meta = (StorageMetaData) c.newInstance();
+ Class> clazz = DL4JClassLoading.loadClassByName(dataClass);
+ if (StorageMetaData.class.isAssignableFrom(clazz)) {
+ meta = clazz
+ .asSubclass(StorageMetaData.class)
+ .getDeclaredConstructor()
+ .newInstance();
} else {
log.warn("Skipping invalid remote data: class {} in not an instance of {}", dataClass,
StorageMetaData.class.getName());
@@ -179,11 +183,14 @@ public class RemoteReceiverModule implements UIModule {
}
private Persistable getPersistable(String dataClass, String content) {
- Persistable p;
+ Persistable persistable;
try {
- Class> c = Class.forName(dataClass);
- if (Persistable.class.isAssignableFrom(c)) {
- p = (Persistable) c.newInstance();
+ Class> clazz = DL4JClassLoading.loadClassByName(dataClass);
+ if (Persistable.class.isAssignableFrom(clazz)) {
+ persistable = clazz
+ .asSubclass(Persistable.class)
+ .getDeclaredConstructor()
+ .newInstance();
} else {
log.warn("Skipping invalid remote data: class {} in not an instance of {}", dataClass,
Persistable.class.getName());
@@ -196,12 +203,12 @@ public class RemoteReceiverModule implements UIModule {
try {
byte[] bytes = DatatypeConverter.parseBase64Binary(content);
- p.decode(bytes);
+ persistable.decode(bytes);
} catch (Exception e) {
log.warn("Skipping invalid remote data: exception encountered when deserializing data", e);
return null;
}
- return p;
+ return persistable;
}
}
\ No newline at end of file
diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java
index 618184824..5d35fdeb6 100644
--- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java
+++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java
@@ -22,6 +22,7 @@ import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.BaseDL4JTest;
+import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.integration.util.CountingMultiDataSetIterator;
import org.deeplearning4j.nn.api.Model;
@@ -127,7 +128,7 @@ public class IntegrationTestRunner {
}
for (ClassPath.ClassInfo c : info) {
- Class> clazz = Class.forName(c.getName());
+ Class> clazz = DL4JClassLoading.loadClassByName(c.getName());
if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface())
continue;
diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/common/config/ND4JClassLoading.java b/nd4j/nd4j-common/src/main/java/org/nd4j/common/config/ND4JClassLoading.java
index 56d95642d..444010788 100644
--- a/nd4j/nd4j-common/src/main/java/org/nd4j/common/config/ND4JClassLoading.java
+++ b/nd4j/nd4j-common/src/main/java/org/nd4j/common/config/ND4JClassLoading.java
@@ -1,3 +1,19 @@
+/*******************************************************************************
+ * Copyright (c) Eclipse Deeplearning4j Contributors 2020
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
package org.nd4j.common.config;
import lombok.extern.slf4j.Slf4j;
@@ -6,6 +22,22 @@ import java.util.ServiceLoader;
/**
* Global context for class-loading in ND4J.
+ * Use {@code ND4JClassLoading} to define classloader for ND4J only! To define classloader used by
+ * {@code Deeplearning4j} use class {@link org.deeplearning4j.common.config.DL4JClassLoading}.
+ *
+ *
Usage:
+ *
{@code
+ * public class Application {
+ * static {
+ * ND4JClassLoading.setNd4jClassloaderFromClass(Application.class);
+ * }
+ *
+ * public static void main(String[] args) {
+ * }
+ * }
+ * }
+ *
+ * @see org.deeplearning4j.common.config.DL4JClassLoading
*
* @author Alexei KLENIN
*/