diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java
index 0d58f024d..32d2d1474 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java
@@ -1559,7 +1559,7 @@ public class DifferentialFunctionFactory {
public SDVariable elu(SDVariable iX) {
- return new ELU(sameDiff(), iX, false).outputVariable();
+ return new ELU(sameDiff(), iX).outputVariable();
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java
index 665c84096..56fd84676 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java
@@ -58,7 +58,7 @@ public class ActivationELU extends BaseActivationFunction {
public INDArray getActivation(INDArray in, boolean training) {
// no support in ELU native to override alpha
if (this.alpha != 1.00) {
- INDArray alphaMultiple = Nd4j.getExecutioner().exec(new ELU(in.dup()));
+ INDArray alphaMultiple = Nd4j.getExecutioner().exec(new ELU(in.dup()))[0];
alphaMultiple.muli(alpha);
BooleanIndexing.replaceWhere(in, alphaMultiple, Conditions.lessThan(0));
} else {
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java
index 74d258fb1..7c85bfb1d 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java
@@ -16,17 +16,20 @@
package org.nd4j.linalg.api.ops.impl.transforms.strict;
-import java.util.Collections;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
+import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
+import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.api.ops.BaseTransformFloatOp;
-import org.nd4j.linalg.api.ops.BaseTransformOp;
-import org.nd4j.linalg.api.ops.BaseTransformStrictOp;
+import org.nd4j.linalg.api.ops.DynamicCustomOp;
+import org.tensorflow.framework.AttrValue;
+import org.tensorflow.framework.GraphDef;
+import org.tensorflow.framework.NodeDef;
-import java.util.Arrays;
+import java.util.Collections;
import java.util.List;
+import java.util.Map;
/**
* ELU: Exponential Linear Unit (alpha=1.0)
@@ -37,25 +40,20 @@ import java.util.List;
*
* @author Alex Black
*/
-public class ELU extends BaseTransformStrictOp {
- public ELU(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
- super(sameDiff, i_v, inPlace);
+public class ELU extends DynamicCustomOp {
+ public ELU(SameDiff sameDiff, SDVariable i_v) {
+ super(sameDiff, new SDVariable[]{i_v});
}
public ELU() {
}
public ELU(INDArray x, INDArray z) {
- super(x, z);
+ super(null, wrapOrNull(x), wrapOrNull(z));
}
public ELU(INDArray x) {
- super(x);
- }
-
- @Override
- public int opNum() {
- return 35;
+ this(x, null);
}
@Override
@@ -73,6 +71,11 @@ public class ELU extends BaseTransformStrictOp {
return "Elu";
}
+ @Override
+ public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) {
+ super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
+ }
+
@Override
public List doDiff(List i_v) {
//ELU: e^x-1 if x<0, x otherwise
@@ -80,4 +83,11 @@ public class ELU extends BaseTransformStrictOp {
return Collections.singletonList(f().eluBp(arg(), i_v.get(0)));
}
+ @Override
+ public List calculateOutputDataTypes(List dataTypes) {
+ Preconditions.checkState(dataTypes != null && dataTypes.size() == 1, "Expected exactly 1 datatype for ELU, got %s", dataTypes);
+ Preconditions.checkState(dataTypes.get(0).isFPType(), "Expected floating point input type for ELU, got %s", dataTypes);
+
+ return dataTypes;
+ }
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java
index 1f4004cb2..af95c73f2 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java
@@ -438,7 +438,7 @@ public class Transforms {
public static INDArray elu(INDArray in, boolean copy) {
- return Nd4j.getExecutioner().exec(new ELU(in, (copy ? in.ulike() : in)));
+ return Nd4j.getExecutioner().exec(new ELU(in, (copy ? in.ulike() : in)))[0];
}
public static INDArray eluDerivative(INDArray arr) {