diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp index bf5a3eb6e..5fe7455fc 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp @@ -46,7 +46,7 @@ namespace nd4j { getOpDescriptor() ->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedOutputTypes(0, DataType::INHERIT) - ->setAllowedOutputTypes(1, DataType::INT64); + ->setAllowedOutputTypes(1, {ALL_INTS}); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 5b60ac0b4..3c6b969b8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -103,6 +103,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNormDerivative.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Col2Im.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D.class, + org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1DDerivative.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2DDerivative.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3D.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java index 75b82dc29..f8763c41a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java @@ -60,7 +60,7 @@ public class NonMaxSuppression extends DynamicCustomOp { @Override public String[] tensorflowNames() { - return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2","NonMaxSuppressionV3","NonMaxSuppressionV4"}; + return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2"}; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java index 58602d85e..b966d4389 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java @@ -204,7 +204,7 @@ public class MaxPoolWithArgmax extends DynamicCustomOp { if(attributesForNode.containsKey("argmax")) { outputType = TFGraphMapper.convertType(attributesForNode.get("argmax").getType()); } else { - outputType = DataType.UINT32; + outputType = DataType.LONG; } } @@ -278,7 +278,7 @@ public class MaxPoolWithArgmax extends DynamicCustomOp { Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected 1 input data type for %s, got %s", getClass(), inputDataTypes); List result = new ArrayList<>(); result.add(inputDataTypes.get(0)); - result.add(outputType == null ? DataType.UINT32 : outputType); + result.add(outputType == null ? DataType.INT : outputType); return result; } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index 9dd529399..6a32d9ea9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -760,7 +760,7 @@ public class LayerOpValidation extends BaseOpValidation { .isSameMode(true) .build(); - SDVariable[] results = sd.nn().maxPoolWithArgmax(new String[]{"",""}, in, pooling2DConfig); + SDVariable[] results = sd.nn().maxPoolWithArgmax(new String[]{"out","idx"}, in, pooling2DConfig); assertArrayEquals(inArr.shape(), results[0].eval().shape()); assertArrayEquals(inArr.shape(), results[1].eval().shape()); } @@ -1050,7 +1050,7 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable in = sd.var("in", inArr); SDVariable w = sd.var("w", wArr); - SDVariable res = sd.cnn.conv1d(in, w, Conv1DConfig.builder().k(kernel).build()); + SDVariable res = sd.cnn.conv1d(in, w, Conv1DConfig.builder().k(kernel).paddingMode(PaddingMode.VALID).build()); INDArray expected = Nd4j.createFromArray( new double[][][]{ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java index 996ccff7f..a6f7b6bea 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java @@ -23,13 +23,7 @@ import static org.junit.Assert.fail; import org.junit.Assert; import org.junit.Test; import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*; public class ConvConfigTests { @@ -489,24 +483,24 @@ public class ConvConfigTests { @Test public void testConv1D(){ - Conv1DConfig.builder().k(2).build(); + Conv1DConfig.builder().k(2).paddingMode(PaddingMode.SAME).build(); try{ - Conv1DConfig.builder().k(0).build(); + Conv1DConfig.builder().k(0).paddingMode(PaddingMode.SAME).build(); fail(); } catch (IllegalArgumentException e){ assertTrue(e.getMessage().contains("Kernel")); } try{ - Conv1DConfig.builder().k(4).s(-2).build(); + Conv1DConfig.builder().k(4).s(-2).paddingMode(PaddingMode.SAME).build(); fail(); } catch (IllegalArgumentException e){ assertTrue(e.getMessage().contains("Stride")); } try{ - Conv1DConfig.builder().k(3).p(-2).build(); + Conv1DConfig.builder().k(3).p(-2).paddingMode(PaddingMode.SAME).build(); fail(); } catch (IllegalArgumentException e){ assertTrue(e.getMessage().contains("Padding"));