diff --git a/libnd4j/blas/cpu/GraphExecutioner.cpp b/libnd4j/blas/cpu/GraphExecutioner.cpp index 6f97bc024..ef45a3e0c 100644 --- a/libnd4j/blas/cpu/GraphExecutioner.cpp +++ b/libnd4j/blas/cpu/GraphExecutioner.cpp @@ -583,7 +583,7 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace) auto fName = builder.CreateString(*(var->getName())); auto id = CreateIntPair(builder, var->id(), var->index()); - auto fv = CreateFlatVariable(builder, id, fName, static_cast(array->dataType()), 0, fArray); + auto fv = CreateFlatVariable(builder, id, fName, static_cast(array->dataType()), 0, fArray); variables_vector.push_back(fv); arrays++; diff --git a/libnd4j/include/array/DataTypeUtils.h b/libnd4j/include/array/DataTypeUtils.h index 8346442eb..2a52ba6f5 100644 --- a/libnd4j/include/array/DataTypeUtils.h +++ b/libnd4j/include/array/DataTypeUtils.h @@ -38,7 +38,7 @@ namespace nd4j { public: static int asInt(DataType type); static DataType fromInt(int dtype); - static DataType fromFlatDataType(nd4j::graph::DataType dtype); + static DataType fromFlatDataType(nd4j::graph::DType dtype); FORCEINLINE static std::string asString(DataType dataType); template diff --git a/libnd4j/include/array/impl/DataTypeUtils.cpp b/libnd4j/include/array/impl/DataTypeUtils.cpp index f0b261039..cdf688b25 100644 --- a/libnd4j/include/array/impl/DataTypeUtils.cpp +++ b/libnd4j/include/array/impl/DataTypeUtils.cpp @@ -27,7 +27,7 @@ namespace nd4j { return (DataType) val; } - DataType DataTypeUtils::fromFlatDataType(nd4j::graph::DataType dtype) { + DataType DataTypeUtils::fromFlatDataType(nd4j::graph::DType dtype) { return (DataType) dtype; } diff --git a/libnd4j/include/graph/generated/array_generated.h b/libnd4j/include/graph/generated/array_generated.h index 5848c0ac4..b581240ad 100644 --- a/libnd4j/include/graph/generated/array_generated.h +++ b/libnd4j/include/graph/generated/array_generated.h @@ -40,56 +40,56 @@ inline const char *EnumNameByteOrder(ByteOrder e) { return EnumNamesByteOrder()[index]; } -enum DataType { - DataType_INHERIT = 0, - DataType_BOOL = 1, - DataType_FLOAT8 = 2, - DataType_HALF = 3, - DataType_HALF2 = 4, - DataType_FLOAT = 5, - DataType_DOUBLE = 6, - DataType_INT8 = 7, - DataType_INT16 = 8, - DataType_INT32 = 9, - DataType_INT64 = 10, - DataType_UINT8 = 11, - DataType_UINT16 = 12, - DataType_UINT32 = 13, - DataType_UINT64 = 14, - DataType_QINT8 = 15, - DataType_QINT16 = 16, - DataType_BFLOAT16 = 17, - DataType_UTF8 = 50, - DataType_MIN = DataType_INHERIT, - DataType_MAX = DataType_UTF8 +enum DType { + DType_INHERIT = 0, + DType_BOOL = 1, + DType_FLOAT8 = 2, + DType_HALF = 3, + DType_HALF2 = 4, + DType_FLOAT = 5, + DType_DOUBLE = 6, + DType_INT8 = 7, + DType_INT16 = 8, + DType_INT32 = 9, + DType_INT64 = 10, + DType_UINT8 = 11, + DType_UINT16 = 12, + DType_UINT32 = 13, + DType_UINT64 = 14, + DType_QINT8 = 15, + DType_QINT16 = 16, + DType_BFLOAT16 = 17, + DType_UTF8 = 50, + DType_MIN = DType_INHERIT, + DType_MAX = DType_UTF8 }; -inline const DataType (&EnumValuesDataType())[19] { - static const DataType values[] = { - DataType_INHERIT, - DataType_BOOL, - DataType_FLOAT8, - DataType_HALF, - DataType_HALF2, - DataType_FLOAT, - DataType_DOUBLE, - DataType_INT8, - DataType_INT16, - DataType_INT32, - DataType_INT64, - DataType_UINT8, - DataType_UINT16, - DataType_UINT32, - DataType_UINT64, - DataType_QINT8, - DataType_QINT16, - DataType_BFLOAT16, - DataType_UTF8 +inline const DType (&EnumValuesDType())[19] { + static const DType values[] = { + DType_INHERIT, + DType_BOOL, + DType_FLOAT8, + DType_HALF, + DType_HALF2, + DType_FLOAT, + DType_DOUBLE, + DType_INT8, + DType_INT16, + DType_INT32, + DType_INT64, + DType_UINT8, + DType_UINT16, + DType_UINT32, + DType_UINT64, + DType_QINT8, + DType_QINT16, + DType_BFLOAT16, + DType_UTF8 }; return values; } -inline const char * const *EnumNamesDataType() { +inline const char * const *EnumNamesDType() { static const char * const names[] = { "INHERIT", "BOOL", @@ -147,9 +147,9 @@ inline const char * const *EnumNamesDataType() { return names; } -inline const char *EnumNameDataType(DataType e) { +inline const char *EnumNameDType(DType e) { const size_t index = static_cast(e); - return EnumNamesDataType()[index]; + return EnumNamesDType()[index]; } struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -165,8 +165,8 @@ struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::Vector *buffer() const { return GetPointer *>(VT_BUFFER); } - DataType dtype() const { - return static_cast(GetField(VT_DTYPE, 0)); + DType dtype() const { + return static_cast(GetField(VT_DTYPE, 0)); } ByteOrder byteOrder() const { return static_cast(GetField(VT_BYTEORDER, 0)); @@ -192,7 +192,7 @@ struct FlatArrayBuilder { void add_buffer(flatbuffers::Offset> buffer) { fbb_.AddOffset(FlatArray::VT_BUFFER, buffer); } - void add_dtype(DataType dtype) { + void add_dtype(DType dtype) { fbb_.AddElement(FlatArray::VT_DTYPE, static_cast(dtype), 0); } void add_byteOrder(ByteOrder byteOrder) { @@ -214,7 +214,7 @@ inline flatbuffers::Offset CreateFlatArray( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset> shape = 0, flatbuffers::Offset> buffer = 0, - DataType dtype = DataType_INHERIT, + DType dtype = DType_INHERIT, ByteOrder byteOrder = ByteOrder_LE) { FlatArrayBuilder builder_(_fbb); builder_.add_buffer(buffer); @@ -228,7 +228,7 @@ inline flatbuffers::Offset CreateFlatArrayDirect( flatbuffers::FlatBufferBuilder &_fbb, const std::vector *shape = nullptr, const std::vector *buffer = nullptr, - DataType dtype = DataType_INHERIT, + DType dtype = DType_INHERIT, ByteOrder byteOrder = ByteOrder_LE) { return nd4j::graph::CreateFlatArray( _fbb, diff --git a/libnd4j/include/graph/generated/array_generated.js b/libnd4j/include/graph/generated/array_generated.js index 8a2b644e6..b98410a9e 100644 --- a/libnd4j/include/graph/generated/array_generated.js +++ b/libnd4j/include/graph/generated/array_generated.js @@ -23,7 +23,7 @@ nd4j.graph.ByteOrder = { /** * @enum */ -nd4j.graph.DataType = { +nd4j.graph.DType = { INHERIT: 0, BOOL: 1, FLOAT8: 2, @@ -123,11 +123,11 @@ nd4j.graph.FlatArray.prototype.bufferArray = function() { }; /** - * @returns {nd4j.graph.DataType} + * @returns {nd4j.graph.DType} */ nd4j.graph.FlatArray.prototype.dtype = function() { var offset = this.bb.__offset(this.bb_pos, 8); - return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DataType.INHERIT; + return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DType.INHERIT; }; /** @@ -205,10 +205,10 @@ nd4j.graph.FlatArray.startBufferVector = function(builder, numElems) { /** * @param {flatbuffers.Builder} builder - * @param {nd4j.graph.DataType} dtype + * @param {nd4j.graph.DType} dtype */ nd4j.graph.FlatArray.addDtype = function(builder, dtype) { - builder.addFieldInt8(2, dtype, nd4j.graph.DataType.INHERIT); + builder.addFieldInt8(2, dtype, nd4j.graph.DType.INHERIT); }; /** diff --git a/libnd4j/include/graph/generated/nd4j/graph/DataType.cs b/libnd4j/include/graph/generated/nd4j/graph/DType.cs similarity index 93% rename from libnd4j/include/graph/generated/nd4j/graph/DataType.cs rename to libnd4j/include/graph/generated/nd4j/graph/DType.cs index 9cd9518c9..00e399b50 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/DataType.cs +++ b/libnd4j/include/graph/generated/nd4j/graph/DType.cs @@ -5,7 +5,7 @@ namespace nd4j.graph { -public enum DataType : sbyte +public enum DType : sbyte { INHERIT = 0, BOOL = 1, diff --git a/libnd4j/include/graph/generated/nd4j/graph/DataType.java b/libnd4j/include/graph/generated/nd4j/graph/DType.java similarity index 95% rename from libnd4j/include/graph/generated/nd4j/graph/DataType.java rename to libnd4j/include/graph/generated/nd4j/graph/DType.java index 369c1b6ae..20d3d475b 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/DataType.java +++ b/libnd4j/include/graph/generated/nd4j/graph/DType.java @@ -2,8 +2,8 @@ package nd4j.graph; -public final class DataType { - private DataType() { } +public final class DType { + private DType() { } public static final byte INHERIT = 0; public static final byte BOOL = 1; public static final byte FLOAT8 = 2; diff --git a/libnd4j/include/graph/generated/nd4j/graph/DataType.py b/libnd4j/include/graph/generated/nd4j/graph/DType.py similarity index 93% rename from libnd4j/include/graph/generated/nd4j/graph/DataType.py rename to libnd4j/include/graph/generated/nd4j/graph/DType.py index e07aace5d..24cadf44e 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/DataType.py +++ b/libnd4j/include/graph/generated/nd4j/graph/DType.py @@ -2,7 +2,7 @@ # namespace: graph -class DataType(object): +class DType(object): INHERIT = 0 BOOL = 1 FLOAT8 = 2 diff --git a/libnd4j/include/graph/generated/nd4j/graph/FlatArray.cs b/libnd4j/include/graph/generated/nd4j/graph/FlatArray.cs index a19325fb7..60d836aeb 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/FlatArray.cs +++ b/libnd4j/include/graph/generated/nd4j/graph/FlatArray.cs @@ -33,13 +33,13 @@ public struct FlatArray : IFlatbufferObject public ArraySegment? GetBufferBytes() { return __p.__vector_as_arraysegment(6); } #endif public sbyte[] GetBufferArray() { return __p.__vector_as_array(6); } - public DataType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DataType)__p.bb.GetSbyte(o + __p.bb_pos) : DataType.INHERIT; } } + public DType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DType)__p.bb.GetSbyte(o + __p.bb_pos) : DType.INHERIT; } } public ByteOrder ByteOrder { get { int o = __p.__offset(10); return o != 0 ? (ByteOrder)__p.bb.GetSbyte(o + __p.bb_pos) : ByteOrder.LE; } } public static Offset CreateFlatArray(FlatBufferBuilder builder, VectorOffset shapeOffset = default(VectorOffset), VectorOffset bufferOffset = default(VectorOffset), - DataType dtype = DataType.INHERIT, + DType dtype = DType.INHERIT, ByteOrder byteOrder = ByteOrder.LE) { builder.StartObject(4); FlatArray.AddBuffer(builder, bufferOffset); @@ -58,7 +58,7 @@ public struct FlatArray : IFlatbufferObject public static VectorOffset CreateBufferVector(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte(data[i]); return builder.EndVector(); } public static VectorOffset CreateBufferVectorBlock(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); } public static void StartBufferVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); } - public static void AddDtype(FlatBufferBuilder builder, DataType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); } + public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); } public static void AddByteOrder(FlatBufferBuilder builder, ByteOrder byteOrder) { builder.AddSbyte(3, (sbyte)byteOrder, 0); } public static Offset EndFlatArray(FlatBufferBuilder builder) { int o = builder.EndObject(); diff --git a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.cs b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.cs index c1068811d..0810d2e6e 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.cs +++ b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.cs @@ -97,14 +97,14 @@ public struct FlatNode : IFlatbufferObject public ArraySegment? GetOpNameBytes() { return __p.__vector_as_arraysegment(36); } #endif public byte[] GetOpNameArray() { return __p.__vector_as_array(36); } - public DataType OutputTypes(int j) { int o = __p.__offset(38); return o != 0 ? (DataType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DataType)0; } + public DType OutputTypes(int j) { int o = __p.__offset(38); return o != 0 ? (DType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DType)0; } public int OutputTypesLength { get { int o = __p.__offset(38); return o != 0 ? __p.__vector_len(o) : 0; } } #if ENABLE_SPAN_T public Span GetOutputTypesBytes() { return __p.__vector_as_span(38); } #else public ArraySegment? GetOutputTypesBytes() { return __p.__vector_as_arraysegment(38); } #endif - public DataType[] GetOutputTypesArray() { return __p.__vector_as_array(38); } + public DType[] GetOutputTypesArray() { return __p.__vector_as_array(38); } public FlatArray? Scalar { get { int o = __p.__offset(40); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } } public static Offset CreateFlatNode(FlatBufferBuilder builder, @@ -196,8 +196,8 @@ public struct FlatNode : IFlatbufferObject public static void StartOutputNamesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); } public static void AddOpName(FlatBufferBuilder builder, StringOffset opNameOffset) { builder.AddOffset(16, opNameOffset.Value, 0); } public static void AddOutputTypes(FlatBufferBuilder builder, VectorOffset outputTypesOffset) { builder.AddOffset(17, outputTypesOffset.Value, 0); } - public static VectorOffset CreateOutputTypesVector(FlatBufferBuilder builder, DataType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); } - public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DataType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); } + public static VectorOffset CreateOutputTypesVector(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); } + public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); } public static void StartOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); } public static void AddScalar(FlatBufferBuilder builder, Offset scalarOffset) { builder.AddOffset(18, scalarOffset.Value, 0); } public static Offset EndFlatNode(FlatBufferBuilder builder) { diff --git a/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.cs b/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.cs index d5f8014f2..9764668a0 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.cs +++ b/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.cs @@ -25,7 +25,7 @@ public struct FlatVariable : IFlatbufferObject public ArraySegment? GetNameBytes() { return __p.__vector_as_arraysegment(6); } #endif public byte[] GetNameArray() { return __p.__vector_as_array(6); } - public DataType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DataType)__p.bb.GetSbyte(o + __p.bb_pos) : DataType.INHERIT; } } + public DType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DType)__p.bb.GetSbyte(o + __p.bb_pos) : DType.INHERIT; } } public long Shape(int j) { int o = __p.__offset(10); return o != 0 ? __p.bb.GetLong(__p.__vector(o) + j * 8) : (long)0; } public int ShapeLength { get { int o = __p.__offset(10); return o != 0 ? __p.__vector_len(o) : 0; } } #if ENABLE_SPAN_T @@ -41,7 +41,7 @@ public struct FlatVariable : IFlatbufferObject public static Offset CreateFlatVariable(FlatBufferBuilder builder, Offset idOffset = default(Offset), StringOffset nameOffset = default(StringOffset), - DataType dtype = DataType.INHERIT, + DType dtype = DType.INHERIT, VectorOffset shapeOffset = default(VectorOffset), Offset ndarrayOffset = default(Offset), int device = 0, @@ -60,7 +60,7 @@ public struct FlatVariable : IFlatbufferObject public static void StartFlatVariable(FlatBufferBuilder builder) { builder.StartObject(7); } public static void AddId(FlatBufferBuilder builder, Offset idOffset) { builder.AddOffset(0, idOffset.Value, 0); } public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); } - public static void AddDtype(FlatBufferBuilder builder, DataType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); } + public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); } public static void AddShape(FlatBufferBuilder builder, VectorOffset shapeOffset) { builder.AddOffset(3, shapeOffset.Value, 0); } public static VectorOffset CreateShapeVector(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); for (int i = data.Length - 1; i >= 0; i--) builder.AddLong(data[i]); return builder.EndVector(); } public static VectorOffset CreateShapeVectorBlock(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); builder.Add(data); return builder.EndVector(); } diff --git a/libnd4j/include/graph/generated/node_generated.js b/libnd4j/include/graph/generated/node_generated.js index a7b2e264f..bd2274dad 100644 --- a/libnd4j/include/graph/generated/node_generated.js +++ b/libnd4j/include/graph/generated/node_generated.js @@ -312,11 +312,11 @@ nd4j.graph.FlatNode.prototype.opName = function(optionalEncoding) { /** * @param {number} index - * @returns {nd4j.graph.DataType} + * @returns {nd4j.graph.DType} */ nd4j.graph.FlatNode.prototype.outputTypes = function(index) { var offset = this.bb.__offset(this.bb_pos, 38); - return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DataType} */ (0); + return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DType} */ (0); }; /** @@ -686,7 +686,7 @@ nd4j.graph.FlatNode.addOutputTypes = function(builder, outputTypesOffset) { /** * @param {flatbuffers.Builder} builder - * @param {Array.} data + * @param {Array.} data * @returns {flatbuffers.Offset} */ nd4j.graph.FlatNode.createOutputTypesVector = function(builder, data) { diff --git a/libnd4j/include/graph/generated/variable_generated.h b/libnd4j/include/graph/generated/variable_generated.h index e441c17dc..ca1a705a0 100644 --- a/libnd4j/include/graph/generated/variable_generated.h +++ b/libnd4j/include/graph/generated/variable_generated.h @@ -65,8 +65,8 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::String *name() const { return GetPointer(VT_NAME); } - DataType dtype() const { - return static_cast(GetField(VT_DTYPE, 0)); + DType dtype() const { + return static_cast(GetField(VT_DTYPE, 0)); } const flatbuffers::Vector *shape() const { return GetPointer *>(VT_SHAPE); @@ -106,7 +106,7 @@ struct FlatVariableBuilder { void add_name(flatbuffers::Offset name) { fbb_.AddOffset(FlatVariable::VT_NAME, name); } - void add_dtype(DataType dtype) { + void add_dtype(DType dtype) { fbb_.AddElement(FlatVariable::VT_DTYPE, static_cast(dtype), 0); } void add_shape(flatbuffers::Offset> shape) { @@ -137,7 +137,7 @@ inline flatbuffers::Offset CreateFlatVariable( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset id = 0, flatbuffers::Offset name = 0, - DataType dtype = DataType_INHERIT, + DType dtype = DType_INHERIT, flatbuffers::Offset> shape = 0, flatbuffers::Offset ndarray = 0, int32_t device = 0, @@ -157,7 +157,7 @@ inline flatbuffers::Offset CreateFlatVariableDirect( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset id = 0, const char *name = nullptr, - DataType dtype = DataType_INHERIT, + DType dtype = DType_INHERIT, const std::vector *shape = nullptr, flatbuffers::Offset ndarray = 0, int32_t device = 0, diff --git a/libnd4j/include/graph/generated/variable_generated.js b/libnd4j/include/graph/generated/variable_generated.js index 3f128e4fc..9012af2de 100644 --- a/libnd4j/include/graph/generated/variable_generated.js +++ b/libnd4j/include/graph/generated/variable_generated.js @@ -76,11 +76,11 @@ nd4j.graph.FlatVariable.prototype.name = function(optionalEncoding) { }; /** - * @returns {nd4j.graph.DataType} + * @returns {nd4j.graph.DType} */ nd4j.graph.FlatVariable.prototype.dtype = function() { var offset = this.bb.__offset(this.bb_pos, 8); - return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DataType.INHERIT; + return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DType.INHERIT; }; /** @@ -150,10 +150,10 @@ nd4j.graph.FlatVariable.addName = function(builder, nameOffset) { /** * @param {flatbuffers.Builder} builder - * @param {nd4j.graph.DataType} dtype + * @param {nd4j.graph.DType} dtype */ nd4j.graph.FlatVariable.addDtype = function(builder, dtype) { - builder.addFieldInt8(2, dtype, nd4j.graph.DataType.INHERIT); + builder.addFieldInt8(2, dtype, nd4j.graph.DType.INHERIT); }; /** diff --git a/libnd4j/include/graph/impl/FlatUtils.cpp b/libnd4j/include/graph/impl/FlatUtils.cpp index bc8ff7e33..ec76cb4d2 100644 --- a/libnd4j/include/graph/impl/FlatUtils.cpp +++ b/libnd4j/include/graph/impl/FlatUtils.cpp @@ -111,7 +111,7 @@ namespace nd4j { auto bo = static_cast(BitwiseUtils::asByteOrder()); - return CreateFlatArray(builder, fShape, fBuffer, static_cast(array.dataType()), bo); + return CreateFlatArray(builder, fShape, fBuffer, static_cast(array.dataType()), bo); } } } \ No newline at end of file diff --git a/libnd4j/include/graph/impl/Variable.cpp b/libnd4j/include/graph/impl/Variable.cpp index 6dd881f11..e54112783 100644 --- a/libnd4j/include/graph/impl/Variable.cpp +++ b/libnd4j/include/graph/impl/Variable.cpp @@ -219,7 +219,7 @@ namespace nd4j { throw std::runtime_error("CONSTANT variable must have NDArray bundled"); auto ar = flatVariable->ndarray(); - if (ar->dtype() == DataType_UTF8) { + if (ar->dtype() == DType_UTF8) { _ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar); } else { _ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar); @@ -320,7 +320,7 @@ namespace nd4j { auto fBuffer = builder.CreateVector(array->asByteVector()); // packing array - auto fArray = CreateFlatArray(builder, fShape, fBuffer, (nd4j::graph::DataType) array->dataType()); + auto fArray = CreateFlatArray(builder, fShape, fBuffer, (nd4j::graph::DType) array->dataType()); // packing id/index of this var auto fVid = CreateIntPair(builder, this->_id, this->_index); @@ -331,7 +331,7 @@ namespace nd4j { stringId = builder.CreateString(this->_name); // returning array - return CreateFlatVariable(builder, fVid, stringId, static_cast(array->dataType()), 0, fArray); + return CreateFlatVariable(builder, fVid, stringId, static_cast(array->dataType()), 0, fArray); } else { throw std::runtime_error("Variable::asFlatVariable isn't possible for NDArrayList"); } diff --git a/libnd4j/include/graph/scheme/array.fbs b/libnd4j/include/graph/scheme/array.fbs index f415ffb08..91e338500 100644 --- a/libnd4j/include/graph/scheme/array.fbs +++ b/libnd4j/include/graph/scheme/array.fbs @@ -23,7 +23,7 @@ enum ByteOrder:byte { } // DataType for arrays/buffers -enum DataType:byte { +enum DType:byte { INHERIT, BOOL, FLOAT8, @@ -49,7 +49,7 @@ enum DataType:byte { table FlatArray { shape:[long]; // shape in Nd4j format buffer:[byte]; // byte buffer with data - dtype:DataType; // data type of actual data within buffer + dtype:DType; // data type of actual data within buffer byteOrder:ByteOrder; // byte order of buffer } diff --git a/libnd4j/include/graph/scheme/node.fbs b/libnd4j/include/graph/scheme/node.fbs index 6117e7125..930702f6d 100644 --- a/libnd4j/include/graph/scheme/node.fbs +++ b/libnd4j/include/graph/scheme/node.fbs @@ -48,7 +48,7 @@ table FlatNode { opName:string; //Used to help resolving the class. In a few cases, multiple classes/opNames are mapped to same hash, and might have different config/properties/differentiability // output data types (optional) - outputTypes:[DataType]; + outputTypes:[DType]; //Scalar value - used for scalar ops. Should be single value only. scalar:FlatArray; diff --git a/libnd4j/include/graph/scheme/uigraphstatic.fbs b/libnd4j/include/graph/scheme/uigraphstatic.fbs index cce0da4ad..814c28fa5 100644 --- a/libnd4j/include/graph/scheme/uigraphstatic.fbs +++ b/libnd4j/include/graph/scheme/uigraphstatic.fbs @@ -51,7 +51,7 @@ table UIVariable { id:IntPair; //Existing IntPair class name:string; type:VarType; //Use existing VarType: VARIABLE, CONSTANT, ARRAY, PLACEHOLDER - datatype:DataType; + datatype:DType; shape:[long]; controlDeps:[string]; //Input control dependencies: variable x -> this outputOfOp:string; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of diff --git a/libnd4j/include/graph/scheme/variable.fbs b/libnd4j/include/graph/scheme/variable.fbs index 43f343c7c..31eafafa7 100644 --- a/libnd4j/include/graph/scheme/variable.fbs +++ b/libnd4j/include/graph/scheme/variable.fbs @@ -30,7 +30,7 @@ enum VarType:byte { table FlatVariable { id:IntPair; // ID of the Variable, in format of IntPair.first is node Id, IntPair.second is output index of the node name:string; // symbolic ID of the Variable (if defined) - dtype:DataType; + dtype:DType; shape:[long]; // shape is absolutely optional. either shape or ndarray might be set ndarray:FlatArray; diff --git a/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp b/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp index cf9f2914e..49dd0657d 100644 --- a/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp @@ -94,10 +94,10 @@ TEST_F(FlatBuffersTest, FlatGraphTest1) { auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector()); auto fBuffer = builder.CreateVector(array->asByteVector()); - auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_FLOAT); + auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_FLOAT); auto fVid = CreateIntPair(builder, -1); - auto fVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, 0, fArray); + auto fVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, 0, fArray); std::vector outputs1, outputs2, inputs1, inputs2; outputs1.push_back(2); @@ -265,7 +265,7 @@ TEST_F(FlatBuffersTest, ExplicitOutputTest1) { auto name1 = builder.CreateString("wow1"); - auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM, 0, in1, 0, nd4j::graph::DataType::FLOAT); + auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM, 0, in1, 0, nd4j::graph::DType::FLOAT); std::vector> variables_vector; variables_vector.push_back(fXVar); diff --git a/libnd4j/tests_cpu/layers_tests/VariableTests.cpp b/libnd4j/tests_cpu/layers_tests/VariableTests.cpp index e31347b0e..fcdd1db3c 100644 --- a/libnd4j/tests_cpu/layers_tests/VariableTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/VariableTests.cpp @@ -73,9 +73,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_1) { auto fBuffer = builder.CreateVector(vec); auto fVid = CreateIntPair(builder, 1, 12); - auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_FLOAT); + auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_FLOAT); - auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, 0, fArray); + auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, 0, fArray); builder.Finish(flatVar); @@ -107,9 +107,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_2) { auto fBuffer = builder.CreateVector(vec); auto fVid = CreateIntPair(builder, 1, 12); - auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_DOUBLE); + auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_DOUBLE); - auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_DOUBLE, 0, fArray); + auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_DOUBLE, 0, fArray); builder.Finish(flatVar); @@ -144,9 +144,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_3) { auto fBuffer = builder.CreateVector(vec); auto fVid = CreateIntPair(builder, 1, 12); - auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_DOUBLE); + auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_DOUBLE); - auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_DOUBLE, 0, fArray); + auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_DOUBLE, 0, fArray); builder.Finish(flatVar); @@ -180,7 +180,7 @@ TEST_F(VariableTests, Test_FlatVariableDataType_4) { auto fShape = builder.CreateVector(original.getShapeAsFlatVector()); auto fVid = CreateIntPair(builder, 37, 12); - auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, fShape, 0, 0, VarType_PLACEHOLDER); + auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, fShape, 0, 0, VarType_PLACEHOLDER); builder.Finish(flatVar); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java index 6faf29bfc..cce38cf24 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java @@ -31,7 +31,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.VariableType; import org.nd4j.autodiff.samediff.internal.Variable; import org.nd4j.base.Preconditions; -import org.nd4j.graph.DataType; +import org.nd4j.graph.DType; import org.nd4j.graph.FlatArray; import org.nd4j.graph.FlatNode; import org.nd4j.graph.FlatProperties; @@ -66,33 +66,33 @@ public class FlatBuffersMapper { public static byte getDataTypeAsByte(@NonNull org.nd4j.linalg.api.buffer.DataType type) { switch (type) { case FLOAT: - return DataType.FLOAT; + return DType.FLOAT; case DOUBLE: - return DataType.DOUBLE; + return DType.DOUBLE; case HALF: - return DataType.HALF; + return DType.HALF; case INT: - return DataType.INT32; + return DType.INT32; case LONG: - return DataType.INT64; + return DType.INT64; case BOOL: - return DataType.BOOL; + return DType.BOOL; case SHORT: - return DataType.INT16; + return DType.INT16; case BYTE: - return DataType.INT8; + return DType.INT8; case UBYTE: - return DataType.UINT8; + return DType.UINT8; case UTF8: - return DataType.UTF8; + return DType.UTF8; case UINT16: - return DataType.UINT16; + return DType.UINT16; case UINT32: - return DataType.UINT32; + return DType.UINT32; case UINT64: - return DataType.UINT64; + return DType.UINT64; case BFLOAT16: - return DataType.BFLOAT16; + return DType.BFLOAT16; default: throw new ND4JIllegalStateException("Unknown or unsupported DataType used: [" + type + "]"); } @@ -102,33 +102,33 @@ public class FlatBuffersMapper { * This method converts enums for DataType */ public static org.nd4j.linalg.api.buffer.DataType getDataTypeFromByte(byte val) { - if (val == DataType.FLOAT) { + if (val == DType.FLOAT) { return org.nd4j.linalg.api.buffer.DataType.FLOAT; - } else if (val == DataType.DOUBLE) { + } else if (val == DType.DOUBLE) { return org.nd4j.linalg.api.buffer.DataType.DOUBLE; - } else if (val == DataType.HALF) { + } else if (val == DType.HALF) { return org.nd4j.linalg.api.buffer.DataType.HALF; - } else if (val == DataType.INT32) { + } else if (val == DType.INT32) { return org.nd4j.linalg.api.buffer.DataType.INT; - } else if (val == DataType.INT64) { + } else if (val == DType.INT64) { return org.nd4j.linalg.api.buffer.DataType.LONG; - } else if (val == DataType.INT8) { + } else if (val == DType.INT8) { return org.nd4j.linalg.api.buffer.DataType.BYTE; - } else if (val == DataType.BOOL) { + } else if (val == DType.BOOL) { return org.nd4j.linalg.api.buffer.DataType.BOOL; - } else if (val == DataType.UINT8) { + } else if (val == DType.UINT8) { return org.nd4j.linalg.api.buffer.DataType.UBYTE; - } else if (val == DataType.INT16) { + } else if (val == DType.INT16) { return org.nd4j.linalg.api.buffer.DataType.SHORT; - } else if (val == DataType.UTF8) { + } else if (val == DType.UTF8) { return org.nd4j.linalg.api.buffer.DataType.UTF8; - } else if (val == DataType.UINT16) { + } else if (val == DType.UINT16) { return org.nd4j.linalg.api.buffer.DataType.UINT16; - } else if (val == DataType.UINT32) { + } else if (val == DType.UINT32) { return org.nd4j.linalg.api.buffer.DataType.UINT32; - } else if (val == DataType.UINT64) { + } else if (val == DType.UINT64) { return org.nd4j.linalg.api.buffer.DataType.UINT64; - } else if (val == DataType.BFLOAT16){ + } else if (val == DType.BFLOAT16){ return org.nd4j.linalg.api.buffer.DataType.BFLOAT16; } else { throw new RuntimeException("Unknown datatype: " + val); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/DataType.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/DType.java similarity index 95% rename from nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/DataType.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/DType.java index 17a0752f0..2617ce8f6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/DataType.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/DType.java @@ -2,8 +2,8 @@ package org.nd4j.graph; -public final class DataType { - private DataType() { } +public final class DType { + private DType() { } public static final byte INHERIT = 0; public static final byte BOOL = 1; public static final byte FLOAT8 = 2;