From a2ec3dbc97a1f2aa9dc8b8c683c2301f58e87aed Mon Sep 17 00:00:00 2001 From: Andrii T <39699084+atuzhykov@users.noreply.github.com> Date: Mon, 9 Mar 2020 04:35:17 +0200 Subject: [PATCH] Image namespace (#176) * created NDImage.java and fixed constructor in AdjustContrast.java * created NDImage.java and fixed constructor in AdjustContrast.java * created NDImage.java and fixed constructor in AdjustContrast.java v2 * regenerated NDImage from cleaned Image,kt also cleaned AdjustContrast.java * draft of NDCNN * draft of NDCNN * started NDRNN * started NDRNN * looking like finished with namespace * Regenerate namespaces Signed-off-by: AlexDBlack * Add ND4J namespace methods for new namespaces Signed-off-by: AlexDBlack * Fixes, cleanup Signed-off-by: Alex Black * More fixes Signed-off-by: Alex Black * Fixes Signed-off-by: Alex Black * Fix Signed-off-by: Alex Black Co-authored-by: Andrii Tuzhykov Co-authored-by: Andrii Tuzhykov Co-authored-by: AlexDBlack --- .../linalg/api/ops/custom/AdjustContrast.java | 4 + .../api/ops/impl/image/CropAndResize.java | 9 + .../ops/impl/image/ExtractImagePatches.java | 12 + .../impl/layers/convolution/AvgPooling2D.java | 5 + .../impl/layers/convolution/AvgPooling3D.java | 10 +- .../ops/impl/layers/convolution/Col2Im.java | 10 +- .../ops/impl/layers/convolution/Conv1D.java | 8 + .../ops/impl/layers/convolution/Conv2D.java | 8 + .../ops/impl/layers/convolution/Conv3D.java | 8 + .../ops/impl/layers/convolution/DeConv2D.java | 9 + .../ops/impl/layers/convolution/DeConv3D.java | 10 +- .../impl/layers/convolution/DepthToSpace.java | 8 +- .../layers/convolution/DepthwiseConv2D.java | 17 +- .../ops/impl/layers/convolution/Im2col.java | 8 +- .../LocalResponseNormalization.java | 7 + .../impl/layers/convolution/MaxPooling2D.java | 7 +- .../impl/layers/convolution/MaxPooling3D.java | 9 +- .../impl/layers/convolution/Pooling3D.java | 4 +- .../convolution/Pooling3DDerivative.java | 5 +- .../ops/impl/layers/convolution/SConv2D.java | 10 +- .../impl/layers/convolution/SpaceToDepth.java | 9 + .../impl/layers/convolution/Upsampling2d.java | 19 +- .../ops/impl/layers/recurrent/GRUCell.java | 7 + .../impl/layers/recurrent/LSTMBlockCell.java | 10 + .../ops/impl/layers/recurrent/LSTMCell.java | 17 +- .../ops/impl/layers/recurrent/LSTMLayer.java | 9 + .../api/ops/impl/layers/recurrent/SRU.java | 18 +- .../ops/impl/layers/recurrent/SRUCell.java | 9 + .../layers/recurrent/weights/GRUWeights.java | 14 +- .../layers/recurrent/weights/LSTMWeights.java | 16 +- .../layers/recurrent/weights/RNNWeights.java | 37 +- .../layers/recurrent/weights/SRUWeights.java | 12 +- .../impl/transforms/custom/BatchToSpace.java | 13 + .../impl/transforms/custom/Dilation2D.java | 35 +- .../impl/transforms/custom/SpaceToBatch.java | 14 + .../java/org/nd4j/linalg/factory/Nd4j.java | 36 ++ .../nd4j/linalg/factory/enums/DataFormat.java | 27 + .../org/nd4j/linalg/factory/ops/NDCNN.java | 499 ++++++++++++++++++ .../org/nd4j/linalg/factory/ops/NDImage.java | 221 ++++++++ .../org/nd4j/linalg/factory/ops/NDRNN.java | 130 +++++ 40 files changed, 1239 insertions(+), 81 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/enums/DataFormat.java create mode 100755 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDCNN.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDImage.java create mode 100755 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRNN.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java index e196e5108..55d551369 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java @@ -33,6 +33,10 @@ public class AdjustContrast extends BaseAdjustContrast { super(sameDiff,new SDVariable[]{in,factor}); } + public AdjustContrast(@NonNull INDArray in, double factor) { + this(in, factor, null); + } + @Override public String opName() { return "adjust_contrast"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java index 96f667690..97b826064 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -36,6 +37,7 @@ import java.util.*; */ @NoArgsConstructor public class CropAndResize extends DynamicCustomOp { + public enum Method {BILINEAR, NEAREST}; protected Method method = Method.BILINEAR; protected double extrapolationValue = 0.0; @@ -48,6 +50,7 @@ public class CropAndResize extends DynamicCustomOp { addArgs(); } + public CropAndResize(@NonNull INDArray image, @NonNull INDArray cropBoxes, @NonNull INDArray boxIndices, @NonNull INDArray cropOutSize, @NonNull Method method, double extrapolationValue, INDArray output){ @@ -62,6 +65,12 @@ public class CropAndResize extends DynamicCustomOp { outputArguments.add(output); } + public CropAndResize(@NonNull INDArray image, @NonNull INDArray cropBoxes, @NonNull INDArray boxIndices, + @NonNull INDArray cropOutSize, double extrapolationValue) { + this(image, cropBoxes, boxIndices, cropOutSize, Method.BILINEAR, extrapolationValue, null); + } + + @Override public String opName() { return "crop_and_resize"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ExtractImagePatches.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ExtractImagePatches.java index 18bf3cb65..71b8b1fb2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ExtractImagePatches.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ExtractImagePatches.java @@ -72,6 +72,18 @@ public class ExtractImagePatches extends DynamicCustomOp { addArgs(); } + public ExtractImagePatches(INDArray input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) { + super(new INDArray[]{input},null); + int[] kSises = {kH,kW}; + int[] strides = {sH,sW}; + int[] rates = {rH, rW}; + this.kSizes = kSises; + this.strides = strides; + this.rates = rates; + this.isSameMode = sameMode; + addArgs(); + } + @Override public String opName() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling2D.java index 2f295cc6a..17ed2e124 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling2D.java @@ -49,10 +49,15 @@ public class AvgPooling2D extends DynamicCustomOp { protected Pooling2DConfig config; + public enum Pooling2DType { MAX, AVG, PNORM, } + public AvgPooling2D(@NonNull INDArray input, Pooling2DConfig config) { + this(input, null, config); + } + @Builder(builderMethodName = "sameDiffBuilder") public AvgPooling2D(SameDiff sameDiff, SDVariable input, Pooling2DConfig config) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling3D.java index 178942d18..79bcacab0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling3D.java @@ -17,6 +17,8 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; @@ -38,9 +40,8 @@ import java.util.Map; */ @Slf4j @Getter +@NoArgsConstructor public class AvgPooling3D extends Pooling3D { - public AvgPooling3D() { - } public AvgPooling3D(SameDiff sameDiff, SDVariable input, Pooling3DConfig config) { super(sameDiff, new SDVariable[]{input}, null, null, false, config, Pooling3DType.AVG); @@ -50,6 +51,11 @@ public class AvgPooling3D extends Pooling3D { super(sameDiff, null, new INDArray[]{arrayInput}, wrapOrNull(arrayOutput), false, config, Pooling3DType.AVG); } + public AvgPooling3D(@NonNull INDArray input, Pooling3DConfig pooling3DConfig) { + super(null,null,new INDArray[]{input},null,false, pooling3DConfig, Pooling3DType.AVG); + } + + @Override public boolean isConfigProperties() { return true; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Col2Im.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Col2Im.java index 620fd6092..4a7df20b4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Col2Im.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Col2Im.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; import lombok.Builder; import lombok.Getter; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -53,7 +54,7 @@ public class Col2Im extends DynamicCustomOp { addArgs(); } - public Col2Im(SameDiff sd, SDVariable input, Conv2DConfig config){ + public Col2Im(@NonNull SameDiff sd, @NonNull SDVariable input, @NonNull Conv2DConfig config){ super(null, sd, new SDVariable[]{input}); this.conv2DConfig = config; addArgs(); @@ -61,6 +62,13 @@ public class Col2Im extends DynamicCustomOp { public Col2Im() {} + public Col2Im(@NonNull INDArray in, @NonNull Conv2DConfig conv2DConfig) { + super("col2Im",in,null,null,null); + this.conv2DConfig = conv2DConfig; + } + + + protected void addArgs() { addIArgument(conv2DConfig.getSH()); addIArgument(conv2DConfig.getSW()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java index 486023331..819d1d10c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java @@ -64,6 +64,14 @@ public class Conv1D extends DynamicCustomOp { this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); } + public Conv1D( @NonNull INDArray input, @NonNull INDArray weights, INDArray bias, Conv1DConfig conv1DConfig) { + this(wrapFilterNull(input, weights, bias), null, conv1DConfig); + } + + public Conv1D(@NonNull INDArray input, @NonNull INDArray weights, Conv1DConfig conv1DConfig) { + this(new INDArray[]{input, weights}, null, conv1DConfig); + } + private void initConfig(Conv1DConfig config){ this.config = config; Preconditions.checkState(config.getS() >= 1 && config.getP() >= 0, INVALID_CONFIGURATION, config.getS(), config.getP()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java index 3794469ae..60bdfbfcc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java @@ -75,6 +75,14 @@ public class Conv2D extends DynamicCustomOp { this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); } + public Conv2D(@NonNull INDArray layerInput, @NonNull INDArray weights, @NonNull Conv2DConfig conv2DConfig) { + this(new INDArray[]{layerInput, weights}, null, conv2DConfig); + } + + public Conv2D(@NonNull INDArray layerInput, @NonNull INDArray weights, INDArray bias, @NonNull Conv2DConfig conv2DConfig) { + this(wrapFilterNull(layerInput, weights,bias), null, conv2DConfig); + } + protected void initConfig(Conv2DConfig config){ this.config = config; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java index 0eb753932..94fb897b0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java @@ -70,6 +70,14 @@ public class Conv3D extends DynamicCustomOp { this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); } + public Conv3D(@NonNull INDArray input,@NonNull INDArray weights, @NonNull Conv3DConfig conv3DConfig) { + this(new INDArray[]{input, weights}, null, conv3DConfig); + } + + public Conv3D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, @NonNull Conv3DConfig conv3DConfig) { + this(wrapFilterNull(input, weights, bias) , null, conv3DConfig); + } + private void initConfig(Conv3DConfig config){ this.config = config; Preconditions.checkState(config.getSW() >= 1 && config.getPH() >= 0 && config.getDW() >= 1, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java index 34c6d2e94..f3500bec0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java @@ -73,6 +73,15 @@ public class DeConv2D extends DynamicCustomOp { this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); } + public DeConv2D(@NonNull INDArray layerInput, @NonNull INDArray weights, DeConv2DConfig deConv2DConfig) { + this(wrapFilterNull(layerInput, weights), null, deConv2DConfig); + } + + public DeConv2D(INDArray layerInput, INDArray weights, INDArray bias, DeConv2DConfig deConv2DConfig) { + this(wrapFilterNull(layerInput, weights, bias), null, deConv2DConfig); + } + + @Override public long[] iArgs() { if (iArguments.size() == 0) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java index c77b12481..a4652850c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java @@ -48,7 +48,7 @@ public class DeConv3D extends DynamicCustomOp { protected DeConv3DConfig config; - public DeConv3D(SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) { + public DeConv3D(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) { super(sameDiff, toArr(input, weights, bias)); this.config = config; addArgs(); @@ -65,6 +65,14 @@ public class DeConv3D extends DynamicCustomOp { this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); } + public DeConv3D(@NonNull INDArray input, @NonNull INDArray weights, @NonNull DeConv3DConfig deConv3DConfig) { + this(new INDArray[]{input, weights}, null, deConv3DConfig); + } + + public DeConv3D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, @NonNull DeConv3DConfig deConv3DConfig) { + this(wrapFilterNull(input, weights, bias), null, deConv3DConfig); + } + private static SDVariable[] toArr(SDVariable input, SDVariable weights, SDVariable bias){ if(bias != null){ return new SDVariable[]{input, weights, bias}; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthToSpace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthToSpace.java index 704f8bdd4..3becef510 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthToSpace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthToSpace.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; +import lombok.NonNull; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -24,6 +25,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.enums.DataFormat; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -58,7 +60,7 @@ public class DepthToSpace extends DynamicCustomOp { addIArgument(blockSize, isNHWC ? 1 : 0); } - public DepthToSpace(INDArray in, INDArray out, int blockSize, String dataFormat) { + public DepthToSpace(@NonNull INDArray in, INDArray out, int blockSize, @NonNull String dataFormat) { super(null, in, out, null, null); this.blockSize = blockSize; this.dataFormat = dataFormat; @@ -66,6 +68,10 @@ public class DepthToSpace extends DynamicCustomOp { addIArgument(blockSize, isNHWC ? 1 : 0); } + public DepthToSpace(@NonNull INDArray x, int blockSize, DataFormat dataFormat) { + this(x, null, blockSize, dataFormat.toString()); + } + @Override public List doDiff(List i_v) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java index 4b10909a0..ab42c3c5a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java @@ -16,11 +16,8 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; -import lombok.Builder; -import lombok.Getter; -import lombok.NonNull; +import lombok.*; import lombok.extern.slf4j.Slf4j; -import lombok.val; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -52,6 +49,7 @@ import java.util.*; */ @Slf4j @Getter +@NoArgsConstructor public class DepthwiseConv2D extends DynamicCustomOp { protected Conv2DConfig config; @@ -77,7 +75,16 @@ public class DepthwiseConv2D extends DynamicCustomOp { this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); } - public DepthwiseConv2D() { + public DepthwiseConv2D(INDArray layerInput, INDArray depthWeights, Conv2DConfig conv2DConfig) { + this(wrapFilterNull(layerInput, depthWeights), null, conv2DConfig); + } + + public DepthwiseConv2D(INDArray layerInput, INDArray depthWeights, INDArray bias, Conv2DConfig conv2DConfig) { + this(wrapFilterNull(layerInput, depthWeights, bias), null, conv2DConfig); + } + + public DepthwiseConv2D(INDArray inputs, Conv2DConfig conv2DConfig) { + this(wrapFilterNull(inputs), null, conv2DConfig); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Im2col.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Im2col.java index 91bbf88ee..46f5f4e79 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Im2col.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Im2col.java @@ -58,6 +58,13 @@ public class Im2col extends DynamicCustomOp { public Im2col() {} + public Im2col(INDArray in, Conv2DConfig conv2DConfig) { + super("im2Col",in,null,null,null); + this.conv2DConfig = conv2DConfig; + addArgs(); + } + + protected void addArgs() { addIArgument(conv2DConfig.getKH()); addIArgument(conv2DConfig.getKW()); @@ -68,7 +75,6 @@ public class Im2col extends DynamicCustomOp { addIArgument(conv2DConfig.getDH()); addIArgument(conv2DConfig.getDW()); addIArgument(ArrayUtil.fromBoolean(conv2DConfig.isSameMode())); - } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java index 8dfb7131a..cc5780a7c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java @@ -65,6 +65,13 @@ public class LocalResponseNormalization extends DynamicCustomOp { addArgs(); } + public LocalResponseNormalization(@NonNull INDArray input, @NonNull LocalResponseNormalizationConfig LocalResponseNormalizationConfig){ + super(new INDArray[]{input}, null); + + this.config = config; + addArgs(); + } + @Override public Map propertiesForFunction() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java index 844102b59..9f7c9bfb7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java @@ -60,14 +60,17 @@ public class MaxPooling2D extends DynamicCustomOp { addArgs(); } - public MaxPooling2D(INDArray input, INDArray output, @NonNull Pooling2DConfig config){ + public MaxPooling2D(@NonNull INDArray input, INDArray output, @NonNull Pooling2DConfig config){ super(null, new INDArray[]{input}, wrapOrNull(output)); config.setType(Pooling2D.Pooling2DType.MAX); - this.config = config; addArgs(); } + public MaxPooling2D(@NonNull INDArray input, @NonNull Pooling2DConfig pooling2DConfig) { + this(input, null, pooling2DConfig); + } + @Override public boolean isConfigProperties() { return true; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling3D.java index 9ae930bf0..6c4ccaa9a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling3D.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; import lombok.Getter; +import lombok.NoArgsConstructor; import lombok.extern.slf4j.Slf4j; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; @@ -38,9 +39,9 @@ import java.util.Map; */ @Slf4j @Getter +@NoArgsConstructor public class MaxPooling3D extends Pooling3D { - public MaxPooling3D() { - } + public MaxPooling3D(SameDiff sameDiff, SDVariable input, Pooling3DConfig config) { super(sameDiff, new SDVariable[]{input}, null, null, false, config, Pooling3DType.MAX); @@ -50,6 +51,10 @@ public class MaxPooling3D extends Pooling3D { super(sameDiff, null, new INDArray[]{arrayInput}, wrapOrNull(arrayOutput), false, config, Pooling3DType.MAX); } + public MaxPooling3D(INDArray input, Pooling3DConfig pooling3DConfig) { + super(null, null, new INDArray[]{input},null, false, pooling3DConfig, Pooling3DType.MAX); + } + @Override public boolean isConfigProperties() { return true; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling3D.java index 98156596d..17305cdf6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling3D.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; +import lombok.NoArgsConstructor; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; @@ -37,6 +38,7 @@ import java.util.*; * Pooling3D operation */ @Slf4j +@NoArgsConstructor public abstract class Pooling3D extends DynamicCustomOp { protected Pooling3DConfig config; @@ -52,8 +54,6 @@ public abstract class Pooling3D extends DynamicCustomOp { return super.iArgs(); } - public Pooling3D() {} - public Pooling3D(SameDiff sameDiff, SDVariable[] inputs,INDArray[] inputArrays, INDArray[] outputs,boolean inPlace, Pooling3DConfig pooling3DConfig, Pooling3DType type) { super(null,sameDiff, inputs, inPlace); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling3DDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling3DDerivative.java index e8e478af1..ad48f988c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling3DDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling3DDerivative.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; import lombok.Builder; +import lombok.NoArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -33,6 +34,7 @@ import java.util.List; * Pooling3DDerivative operation */ @Slf4j +@NoArgsConstructor public class Pooling3DDerivative extends Pooling3D { @Builder(builderMethodName = "derivativeBuilder") @@ -41,9 +43,6 @@ public class Pooling3DDerivative extends Pooling3D { super(sameDiff, inputs, inputArrays, outputs, inPlace, pooling3DConfig, type); } - public Pooling3DDerivative() {} - - @Override public String opName() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java index d18470a7d..b28b9a987 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java @@ -48,12 +48,18 @@ public class SConv2D extends Conv2D { super(inputs, outputs, config); } - public SConv2D(@NonNull INDArray input, @NonNull INDArray depthWeights, INDArray pointWeights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){ - this(wrapFilterNull(input, depthWeights, pointWeights, bias), wrapOrNull(output), config); + public SConv2D(@NonNull INDArray layerInput, @NonNull INDArray depthWeights, INDArray pointWeights, INDArray bias, @NonNull Conv2DConfig Conv2DConfig){ + this(wrapFilterNull(layerInput, depthWeights, pointWeights, bias), null, Conv2DConfig); + } + + public SConv2D(@NonNull INDArray layerInput, @NonNull INDArray depthWeights, INDArray pointWeights, @NonNull Conv2DConfig Conv2DConfig){ + this(wrapFilterNull(layerInput, depthWeights, pointWeights), null, Conv2DConfig); } public SConv2D() {} + + @Override public String opName() { return "sconv2d"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SpaceToDepth.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SpaceToDepth.java index e591c9f1c..5ae281ae2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SpaceToDepth.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SpaceToDepth.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; +import lombok.NonNull; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -24,6 +25,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.enums.DataFormat; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -65,6 +67,13 @@ public class SpaceToDepth extends DynamicCustomOp { addIArgument(blockSize, isNHWC ? 1 : 0); } + + + public SpaceToDepth(@NonNull INDArray x, int blockSize, @NonNull DataFormat dataFormat) { + this(x, null, blockSize,dataFormat.toString()); + } + + @Override public List doDiff(List i_v) { // Gradient to SpaceToDepth is just DepthToSpace of same block size and data format. diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling2d.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling2d.java index 21f3f4d5d..574682a36 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling2d.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling2d.java @@ -17,12 +17,15 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Collections; @@ -34,6 +37,7 @@ import java.util.List; */ @Slf4j @Getter +@NoArgsConstructor public class Upsampling2d extends DynamicCustomOp { @@ -53,7 +57,20 @@ public class Upsampling2d extends DynamicCustomOp { } - public Upsampling2d() {} + public Upsampling2d(INDArray input, int scale) { + this(input, scale, scale, true); + } + + public Upsampling2d(INDArray input, int scaleH, int scaleW, boolean nchw) { + super(new INDArray[]{input}, null); + this.nchw = nchw; + this.scaleH = scaleH; + this.scaleW = scaleW; + + addIArgument(scaleH); + addIArgument(scaleW); + addIArgument(nchw ? 1 : 0); + } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRUCell.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRUCell.java index cd0f83c89..d938cf8e9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRUCell.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRUCell.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights; @@ -46,6 +47,12 @@ public class GRUCell extends DynamicCustomOp { this.weights = weights; } + public GRUCell(INDArray x, INDArray hLast, GRUWeights gruWeights) { + super(null, null, gruWeights.argsWithInputs(x, hLast)); + this.weights = gruWeights; + } + + @Override public String opName() { return "gruCell"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMBlockCell.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMBlockCell.java index aa9ade658..bfc65ab89 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMBlockCell.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMBlockCell.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights; @@ -89,6 +90,15 @@ public class LSTMBlockCell extends DynamicCustomOp { addTArgument(configuration.tArgs()); } + public LSTMBlockCell(INDArray x, INDArray cLast, INDArray yLast, LSTMWeights lstmWeights, LSTMConfiguration lstmConfiguration) { + super(null, null, lstmWeights.argsWithInputs(x, cLast, yLast)); + this.configuration = lstmConfiguration; + this.weights = lstmWeights; + addIArgument(configuration.iArgs(false)); + addTArgument(configuration.tArgs()); + } + + @Override public List calculateOutputDataTypes(List inputDataTypes) { Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 8, "Expected exactly 8 inputs to LSTMBlockCell, got %s", inputDataTypes); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMCell.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMCell.java index e9d2ffd3b..5cb0700dc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMCell.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMCell.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent; +import lombok.NoArgsConstructor; import onnx.Onnx; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -31,13 +32,11 @@ import java.util.Map; * * @author Adam Gibson */ +@NoArgsConstructor public class LSTMCell extends DynamicCustomOp { private LSTMCellConfiguration configuration; - public LSTMCell() { - } - public LSTMCell(SameDiff sameDiff, LSTMCellConfiguration configuration) { super(null, sameDiff, configuration.args()); this.configuration = configuration; @@ -66,16 +65,4 @@ public class LSTMCell extends DynamicCustomOp { public String tensorflowName() { return super.tensorflowName(); } - - @Override - public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph); - } - - @Override - public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { - - } - - } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java index 1e1ae3c47..59b85f500 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java @@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat; @@ -91,6 +92,14 @@ public class LSTMLayer extends DynamicCustomOp { addTArgument(configuration.tArgs()); } + public LSTMLayer(INDArray x, INDArray cLast, INDArray yLast, INDArray maxTSLength, LSTMWeights lstmWeights, LSTMConfiguration lstmConfiguration) { + super(null, null, lstmWeights.argsWithInputs(maxTSLength, x, cLast, yLast)); + this.configuration = lstmConfiguration; + this.weights = lstmWeights; + addIArgument(configuration.iArgs(true)); + addTArgument(configuration.tArgs()); + } + @Override public List calculateOutputDataTypes(List inputDataTypes) { Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 9, "Expected exactly 9 inputs to LSTMLayer, got %s", inputDataTypes); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRU.java index a2de2beb8..dbeed80db 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRU.java @@ -16,14 +16,14 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent; -import java.util.Arrays; -import java.util.List; import lombok.Getter; +import lombok.NoArgsConstructor; import lombok.NonNull; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights; import org.tensorflow.framework.AttrValue; @@ -37,6 +37,7 @@ import java.util.Map; * * @author Adam Gibson */ +@NoArgsConstructor public class SRU extends DynamicCustomOp { @Getter @@ -45,14 +46,23 @@ public class SRU extends DynamicCustomOp { @Getter private SDVariable mask; - public SRU() { } - public SRU(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) { super(null, sameDiff, wrapFilterNull(x, weights.getWeights(), weights.getBias(), initialC, mask)); this.mask = mask; this.weights = weights; } + public SRU(INDArray x, INDArray initialC, INDArray mask, SRUWeights sruWeights) { + super(wrapFilterNull(x, sruWeights.getIWeights(), sruWeights.getIBias(), initialC, mask), null); + this.mask = (SDVariable) mask; + this.weights = sruWeights; + } + + public SRU(INDArray x, INDArray initialC, SRUWeights sruWeights) { + super(wrapFilterNull(x, sruWeights.getIWeights(), sruWeights.getIBias(), initialC), null); + this.weights = sruWeights; + } + @Override public String opName() { return "sru"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRUCell.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRUCell.java index ac3f6c07f..36a40348a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRUCell.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRUCell.java @@ -22,6 +22,7 @@ import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights; import org.tensorflow.framework.AttrValue; @@ -46,6 +47,14 @@ public class SRUCell extends DynamicCustomOp { this.weights = weights; } + + + public SRUCell(INDArray x, INDArray cLast, SRUWeights sruWeights) { + super(null, null, sruWeights.argsWithInputs(x, cLast)); + this.weights = sruWeights; + } + + @Override public String opName() { return "sruCell"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/GRUWeights.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/GRUWeights.java index f95438ae3..52a3c166f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/GRUWeights.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/GRUWeights.java @@ -5,6 +5,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell; /** @@ -21,31 +22,36 @@ public class GRUWeights extends RNNWeights { * * The reset weights are the [:, 0:numUnits] subset and the update weights are the [:, numUnits:2*numUnits] subset. */ - @NonNull private SDVariable ruWeight; + private INDArray iRuWeights; /** * Cell gate weights, with a shape of [inSize + numUnits, numUnits] */ - @NonNull private SDVariable cWeight; + private INDArray iCWeight; /** * Reset and Update gate bias, with a shape of [2*numUnits]. May be null. * * The reset bias is the [0:numUnits] subset and the update bias is the [numUnits:2*numUnits] subset. */ - @NonNull private SDVariable ruBias; + private INDArray iRUBias; /** * Cell gate bias, with a shape of [numUnits]. May be null. */ - @NonNull private SDVariable cBias; + private INDArray iCBias; @Override public SDVariable[] args() { return filterNonNull(ruWeight, cWeight, ruBias, cBias); } + + @Override + public INDArray[] arrayArgs() { + return filterNonNull(iRuWeights, iCWeight, iRUBias, iCBias); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/LSTMWeights.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/LSTMWeights.java index bf401d66c..f2309b9e4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/LSTMWeights.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/LSTMWeights.java @@ -5,6 +5,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell; import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer; @@ -23,35 +24,40 @@ public class LSTMWeights extends RNNWeights { * Input to hidden and hidden to hidden are concatenated in dimension 0, * so the input to hidden weights are [:inSize, :] and the hidden to hidden weights are [inSize:, :]. */ - @NonNull private SDVariable weights; + private INDArray iWeights; /** * Cell peephole (t-1) connections to input modulation gate, with a shape of [numUnits]. */ - @NonNull private SDVariable inputPeepholeWeights; + private INDArray iInputPeepholeWeights; /** * Cell peephole (t-1) connections to forget gate, with a shape of [numUnits]. */ - @NonNull private SDVariable forgetPeepholeWeights; + private INDArray iForgetPeepholeWeights; /** * Cell peephole (t) connections to output gate, with a shape of [numUnits]. */ - @NonNull private SDVariable outputPeepholeWeights; + private INDArray iOutputPeepholeWeights; /** * Input to hidden and hidden to hidden biases, with shape [1, 4*numUnits]. */ - @NonNull private SDVariable bias; + private INDArray iBias; @Override public SDVariable[] args() { return filterNonNull(weights, inputPeepholeWeights, forgetPeepholeWeights, outputPeepholeWeights, bias); } + + @Override + public INDArray[] arrayArgs() { + return filterNonNull(iWeights, iInputPeepholeWeights, iForgetPeepholeWeights, iOutputPeepholeWeights, iBias); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/RNNWeights.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/RNNWeights.java index 62e295d80..e0b8a5c27 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/RNNWeights.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/RNNWeights.java @@ -1,35 +1,38 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights; +import java.lang.reflect.Array; import java.util.Arrays; import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.util.ArrayUtil; public abstract class RNNWeights { public abstract SDVariable[] args(); - protected static SDVariable[] filterNonNull(SDVariable... args){ + public abstract INDArray[] arrayArgs(); + + protected static T[] filterNonNull(T... args){ int count = 0; - for(SDVariable v : args){ - if(v != null){ - count++; + for( int i=0; i + * + * @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param Pooling2DConfig Configuration Object + * @return output Result after applying average pooling on the input (NUMERIC type) + */ + public INDArray avgPooling2d(INDArray input, Pooling2DConfig Pooling2DConfig) { + NDValidation.validateNumerical("avgPooling2d", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D(input, Pooling2DConfig))[0]; + } + + /** + * 3D convolution layer operation - average pooling 3d
+ * + * @param input the input to average pooling 3d operation - 5d activations in NCDHW format + * (shape [minibatch, channels, depth, height, width]) or NDHWC format + * (shape [minibatch, depth, height, width, channels]) (NUMERIC type) + * @param Pooling3DConfig Configuration Object + * @return output after applying average pooling on the input (NUMERIC type) + */ + public INDArray avgPooling3d(INDArray input, Pooling3DConfig Pooling3DConfig) { + NDValidation.validateNumerical("avgPooling3d", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling3D(input, Pooling3DConfig))[0]; + } + + /** + * Convolution 2d layer batch to space operation on 4d input.
+ * Reduces input batch dimension by rearranging data into a larger spatial dimensions
+ * + * @param x Input variable. 4d input (NUMERIC type) + * @param blocks Block size, in the height/width dimension (Size: Exactly(count=2)) + * @param croppingTop (Size: Exactly(count=2)) + * @param croppingBottom (Size: Exactly(count=2)) + * @return output Output variable (NUMERIC type) + */ + public INDArray batchToSpace(INDArray x, int[] blocks, int[] croppingTop, int... croppingBottom) { + NDValidation.validateNumerical("batchToSpace", "x", x); + Preconditions.checkArgument(blocks.length == 2, "blocks has incorrect size/length. Expected: blocks.length == 2, got %s", blocks.length); + Preconditions.checkArgument(croppingTop.length == 2, "croppingTop has incorrect size/length. Expected: croppingTop.length == 2, got %s", croppingTop.length); + Preconditions.checkArgument(croppingBottom.length == 2, "croppingBottom has incorrect size/length. Expected: croppingBottom.length == 2, got %s", croppingBottom.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace(x, blocks, croppingTop, croppingBottom))[0]; + } + + /** + * col2im operation for use in 2D convolution operations. Outputs a 4d array with shape
+ * [minibatch, inputChannels, height, width]
+ * + * @param in Input - rank 6 input with shape [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth] (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output Col2Im output variable (NUMERIC type) + */ + public INDArray col2Im(INDArray in, Conv2DConfig Conv2DConfig) { + NDValidation.validateNumerical("col2Im", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.Col2Im(in, Conv2DConfig))[0]; + } + + /** + * Conv1d operation.
+ * + * @param input the inputs to conv1d (NUMERIC type) + * @param weights weights for conv1d op - rank 3 array with shape [kernelSize, inputChannels, outputChannels] (NUMERIC type) + * @param bias bias for conv1d op - rank 1 array with shape [outputChannels]. May be null. (NUMERIC type) + * @param Conv1DConfig Configuration Object + * @return output result of conv1d op (NUMERIC type) + */ + public INDArray conv1d(INDArray input, INDArray weights, INDArray bias, + Conv1DConfig Conv1DConfig) { + NDValidation.validateNumerical("conv1d", "input", input); + NDValidation.validateNumerical("conv1d", "weights", weights); + NDValidation.validateNumerical("conv1d", "bias", bias); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D(input, weights, bias, Conv1DConfig))[0]; + } + + /** + * Conv1d operation.
+ * + * @param input the inputs to conv1d (NUMERIC type) + * @param weights weights for conv1d op - rank 3 array with shape [kernelSize, inputChannels, outputChannels] (NUMERIC type) + * @param Conv1DConfig Configuration Object + * @return output result of conv1d op (NUMERIC type) + */ + public INDArray conv1d(INDArray input, INDArray weights, Conv1DConfig Conv1DConfig) { + NDValidation.validateNumerical("conv1d", "input", input); + NDValidation.validateNumerical("conv1d", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D(input, weights, null, Conv1DConfig))[0]; + } + + /** + * 2D Convolution operation with optional bias
+ * + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (NUMERIC type) + * @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels] (NUMERIC type) + * @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of conv2d op (NUMERIC type) + */ + public INDArray conv2d(INDArray layerInput, INDArray weights, INDArray bias, + Conv2DConfig Conv2DConfig) { + NDValidation.validateNumerical("conv2d", "layerInput", layerInput); + NDValidation.validateNumerical("conv2d", "weights", weights); + NDValidation.validateNumerical("conv2d", "bias", bias); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D(layerInput, weights, bias, Conv2DConfig))[0]; + } + + /** + * 2D Convolution operation with optional bias
+ * + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (NUMERIC type) + * @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels] (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of conv2d op (NUMERIC type) + */ + public INDArray conv2d(INDArray layerInput, INDArray weights, Conv2DConfig Conv2DConfig) { + NDValidation.validateNumerical("conv2d", "layerInput", layerInput); + NDValidation.validateNumerical("conv2d", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D(layerInput, weights, null, Conv2DConfig))[0]; + } + + /** + * Convolution 3D operation with optional bias
+ * + * @param input the input to average pooling 3d operation - 5d activations in NCDHW format + * (shape [minibatch, channels, depth, height, width]) or NDHWC format + * (shape [minibatch, depth, height, width, channels]) (NUMERIC type) + * @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type) + * @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type) + * @param Conv3DConfig Configuration Object + * @return output Conv3d output variable (NUMERIC type) + */ + public INDArray conv3d(INDArray input, INDArray weights, INDArray bias, + Conv3DConfig Conv3DConfig) { + NDValidation.validateNumerical("conv3d", "input", input); + NDValidation.validateNumerical("conv3d", "weights", weights); + NDValidation.validateNumerical("conv3d", "bias", bias); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3D(input, weights, bias, Conv3DConfig))[0]; + } + + /** + * Convolution 3D operation with optional bias
+ * + * @param input the input to average pooling 3d operation - 5d activations in NCDHW format + * (shape [minibatch, channels, depth, height, width]) or NDHWC format + * (shape [minibatch, depth, height, width, channels]) (NUMERIC type) + * @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type) + * @param Conv3DConfig Configuration Object + * @return output Conv3d output variable (NUMERIC type) + */ + public INDArray conv3d(INDArray input, INDArray weights, Conv3DConfig Conv3DConfig) { + NDValidation.validateNumerical("conv3d", "input", input); + NDValidation.validateNumerical("conv3d", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3D(input, weights, null, Conv3DConfig))[0]; + } + + /** + * 2D deconvolution operation with optional bias
+ * + * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type) + * @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type) + * @param DeConv2DConfig Configuration Object + * @return output result of deconv2d op (NUMERIC type) + */ + public INDArray deconv2d(INDArray layerInput, INDArray weights, INDArray bias, + DeConv2DConfig DeConv2DConfig) { + NDValidation.validateNumerical("deconv2d", "layerInput", layerInput); + NDValidation.validateNumerical("deconv2d", "weights", weights); + NDValidation.validateNumerical("deconv2d", "bias", bias); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D(layerInput, weights, bias, DeConv2DConfig))[0]; + } + + /** + * 2D deconvolution operation with optional bias
+ * + * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type) + * @param DeConv2DConfig Configuration Object + * @return output result of deconv2d op (NUMERIC type) + */ + public INDArray deconv2d(INDArray layerInput, INDArray weights, DeConv2DConfig DeConv2DConfig) { + NDValidation.validateNumerical("deconv2d", "layerInput", layerInput); + NDValidation.validateNumerical("deconv2d", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D(layerInput, weights, null, DeConv2DConfig))[0]; + } + + /** + * 3D CNN deconvolution operation with or without optional bias
+ * + * @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) (NUMERIC type) + * @param weights Weights array - shape [kD, kH, kW, oC, iC] (NUMERIC type) + * @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels] (NUMERIC type) + * @param DeConv3DConfig Configuration Object + * @return output result of 3D CNN deconvolution operation (NUMERIC type) + */ + public INDArray deconv3d(INDArray input, INDArray weights, INDArray bias, + DeConv3DConfig DeConv3DConfig) { + NDValidation.validateNumerical("deconv3d", "input", input); + NDValidation.validateNumerical("deconv3d", "weights", weights); + NDValidation.validateNumerical("deconv3d", "bias", bias); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3D(input, weights, bias, DeConv3DConfig))[0]; + } + + /** + * 3D CNN deconvolution operation with or without optional bias
+ * + * @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) (NUMERIC type) + * @param weights Weights array - shape [kD, kH, kW, oC, iC] (NUMERIC type) + * @param DeConv3DConfig Configuration Object + * @return output result of 3D CNN deconvolution operation (NUMERIC type) + */ + public INDArray deconv3d(INDArray input, INDArray weights, DeConv3DConfig DeConv3DConfig) { + NDValidation.validateNumerical("deconv3d", "input", input); + NDValidation.validateNumerical("deconv3d", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3D(input, weights, null, DeConv3DConfig))[0]; + } + + /** + * Convolution 2d layer batch to space operation on 4d input.
+ * Reduces input channels dimension by rearranging data into a larger spatial dimensions
+ * Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]
+ * = [mb, 2, 4, 4]
+ * + * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param blockSize Block size, in the height/width dimension + * @param dataFormat Data format: "NCHW" or "NHWC" + * @return output Output variable (NUMERIC type) + */ + public INDArray depthToSpace(INDArray x, int blockSize, DataFormat dataFormat) { + NDValidation.validateNumerical("depthToSpace", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace(x, blockSize, dataFormat))[0]; + } + + /** + * Depth-wise 2D convolution operation with optional bias
+ * + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (NUMERIC type) + * @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type) + * @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of depthwise conv2d op (NUMERIC type) + */ + public INDArray depthWiseConv2d(INDArray layerInput, INDArray depthWeights, INDArray bias, + Conv2DConfig Conv2DConfig) { + NDValidation.validateNumerical("depthWiseConv2d", "layerInput", layerInput); + NDValidation.validateNumerical("depthWiseConv2d", "depthWeights", depthWeights); + NDValidation.validateNumerical("depthWiseConv2d", "bias", bias); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D(layerInput, depthWeights, bias, Conv2DConfig))[0]; + } + + /** + * Depth-wise 2D convolution operation with optional bias
+ * + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (NUMERIC type) + * @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of depthwise conv2d op (NUMERIC type) + */ + public INDArray depthWiseConv2d(INDArray layerInput, INDArray depthWeights, + Conv2DConfig Conv2DConfig) { + NDValidation.validateNumerical("depthWiseConv2d", "layerInput", layerInput); + NDValidation.validateNumerical("depthWiseConv2d", "depthWeights", depthWeights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D(layerInput, depthWeights, null, Conv2DConfig))[0]; + } + + /** + * TODO doc string
+ * + * @param df (NUMERIC type) + * @param weights df (NUMERIC type) + * @param strides weights (Size: Exactly(count=2)) + * @param rates strides (Size: Exactly(count=2)) + * @param isSameMode isSameMode + * @return output Computed the grayscale dilation of 4-D input and 3-D filters tensors. (NUMERIC type) + */ + public INDArray dilation2D(INDArray df, INDArray weights, int[] strides, int[] rates, + boolean isSameMode) { + NDValidation.validateNumerical("dilation2D", "df", df); + NDValidation.validateNumerical("dilation2D", "weights", weights); + Preconditions.checkArgument(strides.length == 2, "strides has incorrect size/length. Expected: strides.length == 2, got %s", strides.length); + Preconditions.checkArgument(rates.length == 2, "rates has incorrect size/length. Expected: rates.length == 2, got %s", rates.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Dilation2D(df, weights, strides, rates, isSameMode))[0]; + } + + /** + * Extract image patches
+ * + * @param input Input array. Must be rank 4, with shape [minibatch, height, width, channels] (NUMERIC type) + * @param kH Kernel height + * @param kW Kernel width + * @param sH Stride height + * @param sW Stride width + * @param rH Rate height + * @param rW Rate width + * @param sameMode If true: use same mode padding. If false + * @return output The result is a 4D tensor which is indexed by batch, row, and column. (NUMERIC type) + */ + public INDArray extractImagePatches(INDArray input, int kH, int kW, int sH, int sW, int rH, + int rW, boolean sameMode) { + NDValidation.validateNumerical("extractImagePatches", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches(input, kH, kW, sH, sW, rH, rW, sameMode))[0]; + } + + /** + * im2col operation for use in 2D convolution operations. Outputs a 6d array with shape
+ * [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth]
+ * + * @param in Input - rank 4 input with shape [minibatch, inputChannels, height, width] (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output Im2Col output variable (NUMERIC type) + */ + public INDArray im2Col(INDArray in, Conv2DConfig Conv2DConfig) { + NDValidation.validateNumerical("im2Col", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col(in, Conv2DConfig))[0]; + } + + /** + * 2D convolution layer operation - local response normalization
+ * + * @param input the inputs to lrn (NUMERIC type) + * @param LocalResponseNormalizationConfig Configuration Object + * @return output Result after Local Response Normalization (NUMERIC type) + */ + public INDArray localResponseNormalization(INDArray input, + LocalResponseNormalizationConfig LocalResponseNormalizationConfig) { + NDValidation.validateNumerical("localResponseNormalization", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalization(input, LocalResponseNormalizationConfig))[0]; + } + + /** + * 2D Convolution layer operation - max pooling 2d
+ * + * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param Pooling2DConfig Configuration Object + * @return output Result after applying max pooling on the input (NUMERIC type) + */ + public INDArray maxPooling2d(INDArray input, Pooling2DConfig Pooling2DConfig) { + NDValidation.validateNumerical("maxPooling2d", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling2D(input, Pooling2DConfig))[0]; + } + + /** + * 3D convolution layer operation - max pooling 3d operation.
+ * + * @param input the input to average pooling 3d operation - 5d activations in NCDHW format + * (shape [minibatch, channels, depth, height, width]) or NDHWC format + * (shape [minibatch, depth, height, width, channels]) (NUMERIC type) + * @param Pooling3DConfig Configuration Object + * @return output Result after applying max pooling on the input (NUMERIC type) + */ + public INDArray maxPooling3d(INDArray input, Pooling3DConfig Pooling3DConfig) { + NDValidation.validateNumerical("maxPooling3d", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling3D(input, Pooling3DConfig))[0]; + } + + /** + * Separable 2D convolution operation with optional bias
+ * + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type) + * @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type) + * @param bias Optional bias, rank 1 with shape [outputChannels]. May be null. (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of separable convolution 2d operation (NUMERIC type) + */ + public INDArray separableConv2d(INDArray layerInput, INDArray depthWeights, INDArray pointWeights, + INDArray bias, Conv2DConfig Conv2DConfig) { + NDValidation.validateNumerical("separableConv2d", "layerInput", layerInput); + NDValidation.validateNumerical("separableConv2d", "depthWeights", depthWeights); + NDValidation.validateNumerical("separableConv2d", "pointWeights", pointWeights); + NDValidation.validateNumerical("separableConv2d", "bias", bias); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.SConv2D(layerInput, depthWeights, pointWeights, bias, Conv2DConfig))[0]; + } + + /** + * Separable 2D convolution operation with optional bias
+ * + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type) + * @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of separable convolution 2d operation (NUMERIC type) + */ + public INDArray separableConv2d(INDArray layerInput, INDArray depthWeights, INDArray pointWeights, + Conv2DConfig Conv2DConfig) { + NDValidation.validateNumerical("separableConv2d", "layerInput", layerInput); + NDValidation.validateNumerical("separableConv2d", "depthWeights", depthWeights); + NDValidation.validateNumerical("separableConv2d", "pointWeights", pointWeights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.SConv2D(layerInput, depthWeights, pointWeights, null, Conv2DConfig))[0]; + } + + /** + * Convolution 2d layer space to batch operation on 4d input.
+ * Increases input batch dimension by rearranging data from spatial dimensions into batch dimension
+ * + * @param x Input variable. 4d input (NUMERIC type) + * @param blocks Block size, in the height/width dimension (Size: Exactly(count=2)) + * @param paddingTop Optional 2d int[] array for padding the result: values [[pad top, pad bottom], [pad left, pad right]] (Size: Exactly(count=2)) + * @param paddingBottom Optional 2d int[] array for padding the result: values [[pad top, pad bottom], [pad left, pad right]] (Size: Exactly(count=2)) + * @return output Output variable (NUMERIC type) + */ + public INDArray spaceToBatch(INDArray x, int[] blocks, int[] paddingTop, int... paddingBottom) { + NDValidation.validateNumerical("spaceToBatch", "x", x); + Preconditions.checkArgument(blocks.length == 2, "blocks has incorrect size/length. Expected: blocks.length == 2, got %s", blocks.length); + Preconditions.checkArgument(paddingTop.length == 2, "paddingTop has incorrect size/length. Expected: paddingTop.length == 2, got %s", paddingTop.length); + Preconditions.checkArgument(paddingBottom.length == 2, "paddingBottom has incorrect size/length. Expected: paddingBottom.length == 2, got %s", paddingBottom.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.SpaceToBatch(x, blocks, paddingTop, paddingBottom))[0]; + } + + /** + * Convolution 2d layer space to depth operation on 4d input.
+ * Increases input channels (reduced spatial dimensions) by rearranging data into a larger channels dimension
+ * Example: if input has shape [mb, 2, 4, 4] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]
+ * = [mb, 2, 4, 4]
+ * + * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param blockSize Block size, in the height/width dimension + * @param dataFormat Data format: "NCHW" or "NHWC" + * @return output Output variable (NUMERIC type) + */ + public INDArray spaceToDepth(INDArray x, int blockSize, DataFormat dataFormat) { + NDValidation.validateNumerical("spaceToDepth", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.SpaceToDepth(x, blockSize, dataFormat))[0]; + } + + /** + * Upsampling layer for 2D inputs.
+ * scale is used for both height and width dimensions.
+ * + * @param input Input in NCHW format (NUMERIC type) + * @param scale The scale for both height and width dimensions. + * @return output Upsampled input (NUMERIC type) + */ + public INDArray upsampling2d(INDArray input, int scale) { + NDValidation.validateNumerical("upsampling2d", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2d(input, scale))[0]; + } + + /** + * 2D Convolution layer operation - Upsampling 2d
+ * + * @param input Input in NCHW format (NUMERIC type) + * @param scaleH Scale to upsample in height dimension + * @param scaleW Scale to upsample in width dimension + * @param nchw If true: input is in NCHW (minibatch, channels, height, width) format. False: NHWC format + * @return output Upsampled input (NUMERIC type) + */ + public INDArray upsampling2d(INDArray input, int scaleH, int scaleW, boolean nchw) { + NDValidation.validateNumerical("upsampling2d", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2d(input, scaleH, scaleW, nchw))[0]; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDImage.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDImage.java new file mode 100644 index 000000000..859ad43c3 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDImage.java @@ -0,0 +1,221 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.linalg.factory.ops; + +import static org.nd4j.linalg.factory.NDValidation.isSameType; + +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.NDValidation; +import org.nd4j.linalg.factory.Nd4j; + +public class NDImage { + public NDImage() { + } + + /** + * Given an input image and some crop boxes, extract out the image subsets and resize them to the specified size.
+ * + * @param image Input image, with shape [batch, height, width, channels] (NUMERIC type) + * @param cropBoxes Float32 crop, shape [numBoxes, 4] with values in range 0 to 1 (NUMERIC type) + * @param boxIndices Indices: which image (index to dimension 0) the cropBoxes belong to. Rank 1, shape [numBoxes] (NUMERIC type) + * @param cropOutSize Output size for the images - int32, rank 1 with values [outHeight, outWidth] (INT type) + * @param extrapolationValue Used for extrapolation, when applicable. 0.0 should be used for the default + * @return output Cropped and resized images (NUMERIC type) + */ + public INDArray cropAndResize(INDArray image, INDArray cropBoxes, INDArray boxIndices, + INDArray cropOutSize, double extrapolationValue) { + NDValidation.validateNumerical("CropAndResize", "image", image); + NDValidation.validateNumerical("CropAndResize", "cropBoxes", cropBoxes); + NDValidation.validateNumerical("CropAndResize", "boxIndices", boxIndices); + NDValidation.validateInteger("CropAndResize", "cropOutSize", cropOutSize); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.image.CropAndResize(image, cropBoxes, boxIndices, cropOutSize, extrapolationValue))[0]; + } + + /** + * Given an input image and some crop boxes, extract out the image subsets and resize them to the specified size.
+ * + * @param image Input image, with shape [batch, height, width, channels] (NUMERIC type) + * @param cropBoxes Float32 crop, shape [numBoxes, 4] with values in range 0 to 1 (NUMERIC type) + * @param boxIndices Indices: which image (index to dimension 0) the cropBoxes belong to. Rank 1, shape [numBoxes] (NUMERIC type) + * @param cropOutSize Output size for the images - int32, rank 1 with values [outHeight, outWidth] (INT type) + * @return output Cropped and resized images (NUMERIC type) + */ + public INDArray cropAndResize(INDArray image, INDArray cropBoxes, INDArray boxIndices, + INDArray cropOutSize) { + NDValidation.validateNumerical("CropAndResize", "image", image); + NDValidation.validateNumerical("CropAndResize", "cropBoxes", cropBoxes); + NDValidation.validateNumerical("CropAndResize", "boxIndices", boxIndices); + NDValidation.validateInteger("CropAndResize", "cropOutSize", cropOutSize); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.image.CropAndResize(image, cropBoxes, boxIndices, cropOutSize, 0.0))[0]; + } + + /** + * Adjusts contrast of RGB or grayscale images.
+ * + * @param in images to adjust. 3D shape or higher (NUMERIC type) + * @param factor multiplier for adjusting contrast + * @return output Contrast-adjusted image (NUMERIC type) + */ + public INDArray adjustContrast(INDArray in, double factor) { + NDValidation.validateNumerical("adjustContrast", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.AdjustContrast(in, factor))[0]; + } + + /** + * Adjust hue of RGB image
+ * + * @param in image as 3D array (NUMERIC type) + * @param delta value to add to hue channel + * @return output adjusted image (NUMERIC type) + */ + public INDArray adjustHue(INDArray in, double delta) { + NDValidation.validateNumerical("adjustHue", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.AdjustHue(in, delta))[0]; + } + + /** + * Adjust saturation of RGB images
+ * + * @param in RGB image as 3D array (NUMERIC type) + * @param factor factor for saturation + * @return output adjusted image (NUMERIC type) + */ + public INDArray adjustSaturation(INDArray in, double factor) { + NDValidation.validateNumerical("adjustSaturation", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.AdjustSaturation(in, factor))[0]; + } + + /** + * Given an input image, extract out image patches (of size kSizes - h x w) and place them in the depth dimension.
+ * + * @param image Input image to extract image patches from - shape [batch, height, width, channels] (NUMERIC type) + * @param kSizes Kernel size - size of the image patches, [height, width] (Size: Exactly(count=2)) + * @param strides Stride in the input dimension for extracting image patches, [stride_height, stride_width] (Size: Exactly(count=2)) + * @param rates Usually [1,1]. Equivalent to dilation rate in dilated convolutions - how far apart the output pixels + * in the patches should be, in the input. A dilation of [a,b] means every {@code a}th pixel is taken + * along the height/rows dimension, and every {@code b}th pixel is take along the width/columns dimension (Size: AtLeast(min=0)) + * @param sameMode Padding algorithm. If true: use Same padding + * @return output The extracted image patches (NUMERIC type) + */ + public INDArray extractImagePatches(INDArray image, int[] kSizes, int[] strides, int[] rates, + boolean sameMode) { + NDValidation.validateNumerical("extractImagePatches", "image", image); + Preconditions.checkArgument(kSizes.length == 2, "kSizes has incorrect size/length. Expected: kSizes.length == 2, got %s", kSizes.length); + Preconditions.checkArgument(strides.length == 2, "strides has incorrect size/length. Expected: strides.length == 2, got %s", strides.length); + Preconditions.checkArgument(rates.length >= 0, "rates has incorrect size/length. Expected: rates.length >= 0, got %s", rates.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches(image, kSizes, strides, rates, sameMode))[0]; + } + + /** + * Converting image from HSV to RGB format
+ * + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public INDArray hsvToRgb(INDArray input) { + NDValidation.validateNumerical("hsvToRgb", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.HsvToRgb(input))[0]; + } + + /** + * Greedily selects a subset of bounding boxes in descending order of score
+ * + * @param boxes Might be null. Name for the output variable (NUMERIC type) + * @param scores vector of shape [num_boxes] (NUMERIC type) + * @param maxOutSize scalar representing the maximum number of boxes to be selected + * @param iouThreshold threshold for deciding whether boxes overlap too much with respect to IOU + * @param scoreThreshold threshold for deciding when to remove boxes based on score + * @return output vectort of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size (NUMERIC type) + */ + public INDArray nonMaxSuppression(INDArray boxes, INDArray scores, int maxOutSize, + double iouThreshold, double scoreThreshold) { + NDValidation.validateNumerical("nonMaxSuppression", "boxes", boxes); + NDValidation.validateNumerical("nonMaxSuppression", "scores", scores); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression(boxes, scores, maxOutSize, iouThreshold, scoreThreshold))[0]; + } + + /** + * Randomly crops image
+ * + * @param input input array (NUMERIC type) + * @param shape shape for crop (INT type) + * @return output cropped array (NUMERIC type) + */ + public INDArray randomCrop(INDArray input, INDArray shape) { + NDValidation.validateNumerical("randomCrop", "input", input); + NDValidation.validateInteger("randomCrop", "shape", shape); + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.RandomCrop(input, shape))[0]; + } + + /** + * Converting array from HSV to RGB format
+ * + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public INDArray rgbToHsv(INDArray input) { + NDValidation.validateNumerical("rgbToHsv", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.RgbToHsv(input))[0]; + } + + /** + * Converting array from RGB to YIQ format
+ * + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public INDArray rgbToYiq(INDArray input) { + NDValidation.validateNumerical("rgbToYiq", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.RgbToYiq(input))[0]; + } + + /** + * Converting array from RGB to YUV format
+ * + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public INDArray rgbToYuv(INDArray input) { + NDValidation.validateNumerical("rgbToYuv", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.RgbToYuv(input))[0]; + } + + /** + * Converting image from YIQ to RGB format
+ * + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public INDArray yiqToRgb(INDArray input) { + NDValidation.validateNumerical("yiqToRgb", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.YiqToRgb(input))[0]; + } + + /** + * Converting image from YUV to RGB format
+ * + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public INDArray yuvToRgb(INDArray input) { + NDValidation.validateNumerical("yuvToRgb", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.YuvToRgb(input))[0]; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRNN.java new file mode 100755 index 000000000..0587aeda5 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRNN.java @@ -0,0 +1,130 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.linalg.factory.ops; + +import static org.nd4j.linalg.factory.NDValidation.isSameType; + +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights; +import org.nd4j.linalg.factory.NDValidation; +import org.nd4j.linalg.factory.Nd4j; + +public class NDRNN { + public NDRNN() { + } + + /** + * The GRU cell. Does a single time step operation
+ * + * @param x Input, with shape [batchSize, inSize] (NUMERIC type) + * @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type) + * @param GRUWeights Configuration Object + * @return output The cell's outputs. (NUMERIC type) + */ + public INDArray gru(INDArray x, INDArray hLast, GRUWeights GRUWeights) { + NDValidation.validateNumerical("gru", "x", x); + NDValidation.validateNumerical("gru", "hLast", hLast); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(x, hLast, GRUWeights))[0]; + } + + /** + * The LSTM cell. Does a single time step operation.
+ * + * @param x Input, with shape [batchSize, inSize] (NUMERIC type) + * @param cLast Previous cell state, with shape [batchSize, numUnits] (NUMERIC type) + * @param yLast revious cell output, with shape [batchSize, numUnits] (NUMERIC type) + * @param LSTMWeights Configuration Object + * @param LSTMConfiguration Configuration Object + * @return output The cell's outputs (NUMERIC type) + */ + public INDArray lstmCell(INDArray x, INDArray cLast, INDArray yLast, LSTMWeights LSTMWeights, + LSTMConfiguration LSTMConfiguration) { + NDValidation.validateNumerical("lstmCell", "x", x); + NDValidation.validateNumerical("lstmCell", "cLast", cLast); + NDValidation.validateNumerical("lstmCell", "yLast", yLast); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell(x, cLast, yLast, LSTMWeights, LSTMConfiguration))[0]; + } + + /** + * The LSTM layer. Does multiple time steps.
+ * + * @param maxTSLength (NUMERIC type) + * @param x Input, with shape dependent on the data format (in config). (NUMERIC type) + * @param cLast Previous/initial cell state, with shape [batchSize, numUnits] (NUMERIC type) + * @param yLast Previous/initial cell output, with shape [batchSize, numUnits] (NUMERIC type) + * @param LSTMWeights Configuration Object + * @param LSTMConfiguration Configuration Object + * @return output The layer's outputs. (NUMERIC type) + */ + public INDArray lstmLayer(INDArray maxTSLength, INDArray x, INDArray cLast, INDArray yLast, + LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) { + NDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength); + NDValidation.validateNumerical("lstmLayer", "x", x); + NDValidation.validateNumerical("lstmLayer", "cLast", cLast); + NDValidation.validateNumerical("lstmLayer", "yLast", yLast); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration))[0]; + } + + /** + * The SRU layer. Does a single time step operation.
+ * + * @param x Input, with shape [batchSize, inSize] (NUMERIC type) + * @param initialC Initial cell state, with shape [batchSize, inSize] (NUMERIC type) + * @param mask An optional dropout mask, with shape [batchSize, inSize] (NUMERIC type) + * @param SRUWeights Configuration Object + * @return output The cell's outputs.. (NUMERIC type) + */ + public INDArray sru(INDArray x, INDArray initialC, INDArray mask, SRUWeights SRUWeights) { + NDValidation.validateNumerical("sru", "x", x); + NDValidation.validateNumerical("sru", "initialC", initialC); + NDValidation.validateNumerical("sru", "mask", mask); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU(x, initialC, mask, SRUWeights))[0]; + } + + /** + * The SRU layer. Does a single time step operation.
+ * + * @param x Input, with shape [batchSize, inSize] (NUMERIC type) + * @param initialC Initial cell state, with shape [batchSize, inSize] (NUMERIC type) + * @param SRUWeights Configuration Object + * @return output The cell's outputs.. (NUMERIC type) + */ + public INDArray sru(INDArray x, INDArray initialC, SRUWeights SRUWeights) { + NDValidation.validateNumerical("sru", "x", x); + NDValidation.validateNumerical("sru", "initialC", initialC); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU(x, initialC, null, SRUWeights))[0]; + } + + /** + * The SRU layer. Does a single time step operation.
+ * + * @param x Input, with shape [batchSize, inSize] (NUMERIC type) + * @param cLast Previous cell state, with shape [batchSize, inSize] (NUMERIC type) + * @param SRUWeights Configuration Object + * @return output The cell's outputs. (NUMERIC type) + */ + public INDArray sruCell(INDArray x, INDArray cLast, SRUWeights SRUWeights) { + NDValidation.validateNumerical("sruCell", "x", x); + NDValidation.validateNumerical("sruCell", "cLast", cLast); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell(x, cLast, SRUWeights))[0]; + } +}