From 82e3275cbf79b00ff526eddff5af6458a221ddfd Mon Sep 17 00:00:00 2001 From: Ross Halliday Date: Fri, 23 May 2025 11:37:22 +0100 Subject: [PATCH 1/3] Updated Tensorflow.Net to 0.70.2 with Tensorflow 2.7.0. NumSharp replaced with Tensorflow.NumPy. TensorShape replaced with Shape, Shape object has dimensions as 64 bit long, check added for casting to 32 bit int alsoTensor constructor using SafeTensorHandle/DangerousGetHandle and TF_DataType not required when casting. Added StringTensorFactory to wrap addition tensorflow.dll methods required to create Tensors from string based input. --- eng/Versions.props | 4 +- src/Microsoft.ML.DataView/VectorType.cs | 39 +++++ .../TensorTypeExtensions.cs | 4 +- .../TensorflowTransform.cs | 62 ++++---- .../TensorflowUtils.cs | 75 +++++---- .../DnnRetrainTransform.cs | 73 +++++---- .../ImageClassificationTrainer.cs | 146 ++++++++++++++---- .../TensorflowTests.cs | 9 +- 8 files changed, 275 insertions(+), 137 deletions(-) diff --git a/eng/Versions.props b/eng/Versions.props index f7f879a844..6071d3f53f 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -67,9 +67,9 @@ 2.1.3 0.11.1 1.4.2 - 0.20.1 + 0.70.2 2 - 2.3.1 + 2.7.0 1.4.1 0.2.3 1.48.0 diff --git a/src/Microsoft.ML.DataView/VectorType.cs b/src/Microsoft.ML.DataView/VectorType.cs index 574a473f1e..3ea42e2545 100644 --- a/src/Microsoft.ML.DataView/VectorType.cs +++ b/src/Microsoft.ML.DataView/VectorType.cs @@ -67,6 +67,24 @@ public VectorDataViewType(PrimitiveDataViewType itemType, params int[] dimension Size = ComputeSize(Dimensions); } + /// + /// Constructs a potentially multi-dimensional vector type. + /// + /// The type of the items contained in the vector. + /// The dimensions. Note that, like , must be non-empty, with all + /// non-negative values. Also, because is the product of , the result of + /// multiplying all these values together must not overflow . + public VectorDataViewType(PrimitiveDataViewType itemType, params long[] dimensions) + : base(GetRawType(itemType)) + { + Contracts.CheckParam(ArrayUtils.Size(dimensions) > 0, nameof(dimensions)); + Contracts.CheckParam(dimensions.All(d => d >= 0), nameof(dimensions)); + + ItemType = itemType; + Dimensions = CastLongArrayToIntArray(dimensions).ToImmutableArray(); + Size = ComputeSize(Dimensions); + } + /// /// Constructs a potentially multi-dimensional vector type. /// @@ -99,6 +117,27 @@ private static int ComputeSize(ImmutableArray dims) return size; } + private static int[] CastLongArrayToIntArray(long[] source) + { + if (source == null) + throw new ArgumentNullException(nameof(source)); + + int[] result = new int[source.Length]; + + for (int i = 0; i < source.Length; i++) + { + long value = source[i]; + if (value > int.MaxValue || value < int.MinValue) + { + throw new OverflowException($"Value at index {i} ({value}) cannot be safely cast from long to int."); + } + + result[i] = (int)value; + } + + return result; + } + /// /// Whether this is a vector type with known size. /// Equivalent to > 0. diff --git a/src/Microsoft.ML.TensorFlow/TensorTypeExtensions.cs b/src/Microsoft.ML.TensorFlow/TensorTypeExtensions.cs index 330c398133..5742a5f8cb 100644 --- a/src/Microsoft.ML.TensorFlow/TensorTypeExtensions.cs +++ b/src/Microsoft.ML.TensorFlow/TensorTypeExtensions.cs @@ -25,7 +25,7 @@ public static void ToScalar(this Tensor tensor, ref T dst) where T : unmanage return; } - if (typeof(T).as_dtype() != tensor.dtype) + if (typeof(T).as_tf_dtype() != tensor.dtype) throw new NotSupportedException(); unsafe @@ -37,7 +37,7 @@ public static void ToScalar(this Tensor tensor, ref T dst) where T : unmanage public static void CopyTo(this Tensor tensor, Span values) where T : unmanaged { - if (typeof(T).as_dtype() != tensor.dtype) + if (typeof(T).as_tf_dtype() != tensor.dtype) throw new NotSupportedException(); unsafe diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index fd556a175f..ae22454ae9 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -16,8 +16,8 @@ using Microsoft.ML.Runtime; using Microsoft.ML.TensorFlow; using Microsoft.ML.Transforms; -using NumSharp; using Tensorflow; +using Tensorflow.NumPy; using static Microsoft.ML.TensorFlow.TensorFlowUtils; using static Tensorflow.Binding; using Utils = Microsoft.ML.Internal.Utilities.Utils; @@ -51,7 +51,7 @@ public sealed class TensorFlowTransformer : RowToRowTransformerBase, IDisposable internal readonly DataViewType[] OutputTypes; internal readonly TF_DataType[] TFOutputTypes; internal readonly TF_DataType[] TFInputTypes; - internal readonly TensorShape[] TFInputShapes; + internal readonly Shape[] TFInputShapes; internal readonly (Operation, int)[] TFInputOperations; internal readonly (Operation, int)[] TFOutputOperations; internal TF_Output[] TFInputNodes; @@ -212,14 +212,14 @@ internal TensorFlowTransformer(IHostEnvironment env, TensorFlowEstimator.Options env.CheckValue(options, nameof(options)); } - private static ITensorValueGetter CreateTensorValueGetter(DataViewRow input, bool isVector, int colIndex, TensorShape tfShape) + private static ITensorValueGetter CreateTensorValueGetter(DataViewRow input, bool isVector, int colIndex, Shape tfShape) { if (isVector) return new TensorValueGetterVec(input, colIndex, tfShape); return new TensorValueGetter(input, colIndex, tfShape); } - private static ITensorValueGetter CreateTensorValueGetter(DataViewRow input, TF_DataType tfType, bool isVector, int colIndex, TensorShape tfShape) + private static ITensorValueGetter CreateTensorValueGetter(DataViewRow input, TF_DataType tfType, bool isVector, int colIndex, Shape tfShape) { var type = Tf2MlNetType(tfType); return Utils.MarshalInvoke(CreateTensorValueGetter, type.RawType, input, isVector, colIndex, tfShape); @@ -230,7 +230,7 @@ private static ITensorValueGetter[] GetTensorValueGetters( int[] inputColIndices, bool[] isInputVector, TF_DataType[] tfInputTypes, - TensorShape[] tfInputShapes) + Shape[] tfInputShapes) { var srcTensorGetters = new ITensorValueGetter[inputColIndices.Length]; for (int i = 0; i < inputColIndices.Length; i++) @@ -331,10 +331,10 @@ private static (Operation, int) GetOperationFromName(string operation, Session s return (session.graph.OperationByName(operation), 0); } - internal static (TF_DataType[] tfInputTypes, TensorShape[] tfInputShapes, (Operation, int)[]) GetInputInfo(IHost host, Session session, string[] inputs, int batchSize = 1) + internal static (TF_DataType[] tfInputTypes, Shape[] tfInputShapes, (Operation, int)[]) GetInputInfo(IHost host, Session session, string[] inputs, int batchSize = 1) { var tfInputTypes = new TF_DataType[inputs.Length]; - var tfInputShapes = new TensorShape[inputs.Length]; + var tfInputShapes = new Shape[inputs.Length]; var tfInputOperations = new (Operation, int)[inputs.Length]; int index = 0; @@ -351,7 +351,7 @@ internal static (TF_DataType[] tfInputTypes, TensorShape[] tfInputShapes, (Opera throw host.ExceptParam(nameof(session), $"Input type '{tfInputType}' of input column '{input}' is not supported in TensorFlow"); tfInputTypes[index] = tfInputType; - tfInputShapes[index] = ((Tensor)inputTensor).TensorShape; + tfInputShapes[index] = ((Tensor)inputTensor).shape; tfInputOperations[index] = (inputTensor, inputTensorIndex); index++; } @@ -359,7 +359,7 @@ internal static (TF_DataType[] tfInputTypes, TensorShape[] tfInputShapes, (Opera return (tfInputTypes, tfInputShapes, tfInputOperations); } - internal static TensorShape GetTensorShape(TF_Output output, Graph graph, Status status = null) + internal static Shape GetTensorShape(TF_Output output, Graph graph, Status status = null) { if (graph == IntPtr.Zero) throw new ObjectDisposedException(nameof(graph)); @@ -370,12 +370,12 @@ internal static TensorShape GetTensorShape(TF_Output output, Graph graph, Status cstatus.Check(); if (n == -1) - return new TensorShape(new int[0]); + return new Shape(new int[0]); var dims = new long[n]; c_api.TF_GraphGetTensorShape(graph, output, dims, dims.Length, cstatus.Handle); cstatus.Check(); - return new TensorShape(dims.Select(x => (int)x).ToArray()); + return new Shape(dims.Select(x => (int)x).ToArray()); } internal static (TF_DataType[] tfOutputTypes, DataViewType[] outputTypes, (Operation, int)[]) GetOutputInfo(IHost host, Session session, string[] outputs, bool treatOutputAsBatched) @@ -404,10 +404,10 @@ internal static (TF_DataType[] tfOutputTypes, DataViewType[] outputTypes, (Opera // This is the work around in absence of reshape transformer. var idims = shape.dims; - int[] dims = idims; + long[] dims = idims; if (treatOutputAsBatched) { - dims = shape.ndim > 0 ? idims.Skip(idims[0] == -1 ? 1 : 0).ToArray() : new int[0]; + dims = shape.ndim > 0 ? idims.Skip(idims[0] == -1 ? 1 : 0).ToArray() : new long[0]; } for (int j = 0; j < dims.Length; j++) dims[j] = dims[j] == -1 ? 0 : dims[j]; @@ -517,7 +517,7 @@ public void Dispose() if (Session != null && Session != IntPtr.Zero) { - Session.close(); // invoked Dispose() + Session.Dispose(); } } finally @@ -536,7 +536,7 @@ private sealed class Mapper : MapperBase private readonly TensorFlowTransformer _parent; private readonly int[] _inputColIndices; private readonly bool[] _isInputVector; - private readonly TensorShape[] _fullySpecifiedShapes; + private readonly Shape[] _fullySpecifiedShapes; private readonly ConcurrentBag _runners; public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) : @@ -546,7 +546,7 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) : _parent = parent; _inputColIndices = new int[_parent.Inputs.Length]; _isInputVector = new bool[_parent.Inputs.Length]; - _fullySpecifiedShapes = new TensorShape[_parent.Inputs.Length]; + _fullySpecifiedShapes = new Shape[_parent.Inputs.Length]; for (int i = 0; i < _parent.Inputs.Length; i++) { if (!inputSchema.TryGetColumnIndex(_parent.Inputs[i], out _inputColIndices[i])) @@ -570,11 +570,11 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) : { vecType = (VectorDataViewType)type; var colTypeDims = vecType.Dimensions.Select(dim => (int)dim).ToArray(); - _fullySpecifiedShapes[i] = new TensorShape(colTypeDims); + _fullySpecifiedShapes[i] = new Shape(colTypeDims); } else // for primitive type use default TensorShape - _fullySpecifiedShapes[i] = new TensorShape(); + _fullySpecifiedShapes[i] = new Shape(Array.Empty()); } else { @@ -582,7 +582,7 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) : var colTypeDims = vecType.Dimensions.Select(dim => (int)dim).ToArray(); // If the column is one dimension we make sure that the total size of the TF shape matches. // Compute the total size of the known dimensions of the shape. - int valCount = 1; + long valCount = 1; int numOfUnkDim = 0; foreach (var s in shape) { @@ -592,7 +592,7 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) : numOfUnkDim++; } // The column length should be divisible by this, so that the other dimensions can be integral. - int typeValueCount = type.GetValueCount(); + long typeValueCount = type.GetValueCount(); if (typeValueCount % valCount != 0) throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {originalShape.ToString()}, but input data is of length {typeValueCount}."); @@ -616,10 +616,10 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) : throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {originalShape.ToString()}, but input data is of length {typeValueCount}."); // Fill in the unknown dimensions. - var l = new int[originalShapeNdim]; + var l = new long[originalShapeNdim]; for (int ishape = 0; ishape < originalShapeNdim; ishape++) - l[ishape] = originalShapeDims[ishape] == -1 ? (int)d : originalShapeDims[ishape]; - _fullySpecifiedShapes[i] = new TensorShape(l); + l[ishape] = originalShapeDims[ishape] == -1 ? (long)d : originalShapeDims[ishape]; + _fullySpecifiedShapes[i] = new Shape(l); } if (_parent._addBatchDimensionInput) @@ -627,11 +627,11 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) : // ndim of default TensorShape is -1, make originDim to 0 in this case. // after addBatchDimension, input column will be changed: type -> type[] var originDim = _fullySpecifiedShapes[i].ndim < 0 ? 0 : _fullySpecifiedShapes[i].ndim; - var l = new int[originDim + 1]; + var l = new long[originDim + 1]; l[0] = 1; for (int ishape = 1; ishape < l.Length; ishape++) l[ishape] = _fullySpecifiedShapes[i].dims[ishape - 1]; - _fullySpecifiedShapes[i] = new TensorShape(l); + _fullySpecifiedShapes[i] = new Shape(l); } } @@ -720,7 +720,7 @@ private Delegate MakeGetter(DataViewRow input, int iinfo, ITensorValueGetter[ UpdateCacheIfNeeded(input.Position, srcTensorGetters, activeOutputColNames, outputCache); var tensor = outputCache.Outputs[_parent.Outputs[iinfo]]; - var tensorSize = tensor.TensorShape.dims.Where(x => x > 0).Aggregate((x, y) => x * y); + var tensorSize = tensor.shape.dims.Where(x => x > 0).Aggregate((x, y) => x * y); var editor = VBufferEditor.Create(ref dst, (int)tensorSize); FetchStringData(tensor, editor.Values); @@ -735,7 +735,7 @@ private Delegate MakeGetter(DataViewRow input, int iinfo, ITensorValueGetter[ UpdateCacheIfNeeded(input.Position, srcTensorGetters, activeOutputColNames, outputCache); var tensor = outputCache.Outputs[_parent.Outputs[iinfo]]; - var tensorSize = tensor.TensorShape.dims.Where(x => x > 0).Aggregate((x, y) => x * y); + var tensorSize = tensor.shape.dims.Where(x => x > 0).Aggregate((x, y) => x * y); var editor = VBufferEditor.Create(ref dst, (int)tensorSize); @@ -821,10 +821,10 @@ private class TensorValueGetter : ITensorValueGetter { private readonly ValueGetter _srcgetter; private readonly T[] _bufferedData; - private readonly TensorShape _tfShape; + private readonly Shape _tfShape; private int _position; - public TensorValueGetter(DataViewRow input, int colIndex, TensorShape tfShape) + public TensorValueGetter(DataViewRow input, int colIndex, Shape tfShape) { _srcgetter = input.GetGetter(input.Schema[colIndex]); _tfShape = tfShape; @@ -864,7 +864,7 @@ public Tensor GetBufferedBatchTensor() private class TensorValueGetterVec : ITensorValueGetter { private readonly ValueGetter> _srcgetter; - private readonly TensorShape _tfShape; + private readonly Shape _tfShape; private VBuffer _vBuffer; private T[] _denseData; private T[] _bufferedData; @@ -872,7 +872,7 @@ private class TensorValueGetterVec : ITensorValueGetter private readonly long[] _dims; private readonly long _bufferedDataSize; - public TensorValueGetterVec(DataViewRow input, int colIndex, TensorShape tfShape) + public TensorValueGetterVec(DataViewRow input, int colIndex, Shape tfShape) { _srcgetter = input.GetGetter>(input.Schema[colIndex]); _tfShape = tfShape; diff --git a/src/Microsoft.ML.TensorFlow/TensorflowUtils.cs b/src/Microsoft.ML.TensorFlow/TensorflowUtils.cs index faec243057..18cc4ab52c 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowUtils.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowUtils.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.ComponentModel; using System.IO; using System.Linq; using System.Security.AccessControl; @@ -12,8 +13,8 @@ using Microsoft.ML.Runtime; using Microsoft.ML.TensorFlow; using Microsoft.ML.Transforms; -using NumSharp; using Tensorflow; +using Tensorflow.NumPy; using static Tensorflow.Binding; using Utils = Microsoft.ML.Internal.Utilities.Utils; @@ -77,9 +78,9 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph grap } // Construct the final ML.NET type of a Tensorflow variable. - var tensorShape = op.output.TensorShape.dims; + var dimensions = op.output.shape.dims; - if (tensorShape == null) + if (dimensions == null) { // primitive column type schemaBuilder.AddColumn(op.name, mlType, metadataBuilder.ToAnnotations()); @@ -88,15 +89,15 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph grap { // vector column type DataViewType columnType = new VectorDataViewType(mlType); - if (!(Utils.Size(tensorShape) == 1 && tensorShape[0] <= 0) && - (Utils.Size(tensorShape) > 0 && tensorShape.Skip(1).All(x => x > 0))) + if (!(Utils.Size(dimensions) == 1 && dimensions[0] <= 0) && + (Utils.Size(dimensions) > 0 && dimensions.Skip(1).All(x => x > 0))) // treatOutputAsBatched == true means that if the first dimension is greater // than 0 we take the tensor shape as is. If the first value is less then 0, we treat it as the batch input so we can // ignore it for the shape of the ML.NET vector. I.E. if the input dimensions are [-1, 5], ML.NET will read the -1 as // batch input, and so the ML.NET data type will be a vector of length 5. if (treatOutputAsBatched) { - columnType = new VectorDataViewType(mlType, tensorShape[0] > 0 ? tensorShape : tensorShape.Skip(1).ToArray()); + columnType = new VectorDataViewType(mlType, dimensions[0] > 0 ? dimensions : dimensions.Skip(1).ToArray()); } // When treatOutputAsBatched is false, if the first value is less than 0 we want to set it to 0. TensorFlow // represents an unknown size as -1, but ML.NET represents it as 0 so we need to convert it. @@ -104,9 +105,9 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph grap // data type will be a vector of 2 dimensions, where the first dimension is unknown and the second has a length of 5. else { - if (tensorShape[0] < 0) - tensorShape[0] = 0; - columnType = new VectorDataViewType(mlType, tensorShape); + if (dimensions[0] < 0) + dimensions[0] = 0; + columnType = new VectorDataViewType(mlType, dimensions); } schemaBuilder.AddColumn(op.name, columnType, metadataBuilder.ToAnnotations()); @@ -441,32 +442,32 @@ internal static bool IsTypeSupported(TF_DataType tfoutput) } } - internal static Tensor CastDataAndReturnAsTensor(T[] data, TensorShape tfShape) + internal static Tensor CastDataAndReturnAsTensor(T[] data, Shape tfShape) { var dims = tfShape.dims.Select(x => (long)x).ToArray(); if (typeof(T) == typeof(sbyte)) - return new Tensor((sbyte[])(object)data, dims, TF_DataType.TF_INT8); + return new Tensor((sbyte[])(object)data, dims); else if (typeof(T) == typeof(long)) - return new Tensor((long[])(object)data, dims, TF_DataType.TF_INT64); + return new Tensor((long[])(object)data, dims); else if (typeof(T) == typeof(Int32)) - return new Tensor((Int32[])(object)data, dims, TF_DataType.TF_INT32); + return new Tensor((Int32[])(object)data, dims); else if (typeof(T) == typeof(Int16)) - return new Tensor((Int16[])(object)data, dims, TF_DataType.TF_INT16); + return new Tensor((Int16[])(object)data, dims); else if (typeof(T) == typeof(byte)) - return new Tensor((byte[])(object)data, dims, TF_DataType.TF_UINT8); + return new Tensor((byte[])(object)data, dims); else if (typeof(T) == typeof(ulong)) - return new Tensor((ulong[])(object)data, dims, TF_DataType.TF_UINT64); + return new Tensor((ulong[])(object)data, dims); else if (typeof(T) == typeof(UInt32)) - return new Tensor((UInt32[])(object)data, dims, TF_DataType.TF_UINT32); + return new Tensor((UInt32[])(object)data, dims); else if (typeof(T) == typeof(UInt16)) - return new Tensor((UInt16[])(object)data, dims, TF_DataType.TF_UINT16); + return new Tensor((UInt16[])(object)data, dims); else if (typeof(T) == typeof(bool)) - return new Tensor((bool[])(object)data, dims, TF_DataType.TF_BOOL); + return new Tensor((bool[])(object)data, dims); else if (typeof(T) == typeof(float)) - return new Tensor((float[])(object)data, dims, TF_DataType.TF_FLOAT); + return new Tensor((float[])(object)data, dims); else if (typeof(T) == typeof(double)) - return new Tensor((double[])(object)data, dims, TF_DataType.TF_DOUBLE); + return new Tensor((double[])(object)data, dims); else if (typeof(T) == typeof(ReadOnlyMemory)) { string[] strings = new string[data.Length]; @@ -484,27 +485,30 @@ internal static Tensor CastDataAndReturnAsTensor(T[] data, TensorShape tfShap internal static Tensor CastDataAndReturnAsTensor(T data) { if (typeof(T) == typeof(sbyte)) - return new Tensor((sbyte)(object)data, TF_DataType.TF_INT8); + return new Tensor((sbyte)(object)data); else if (typeof(T) == typeof(long)) - return new Tensor((long)(object)data, TF_DataType.TF_INT64); + return new Tensor((long)(object)data); else if (typeof(T) == typeof(Int32)) - return new Tensor((Int32)(object)data, TF_DataType.TF_INT32); + return new Tensor((Int32)(object)data); else if (typeof(T) == typeof(Int16)) - return new Tensor((Int16)(object)data, TF_DataType.TF_INT16); + return new Tensor((Int16)(object)data); else if (typeof(T) == typeof(byte)) - return new Tensor((byte)(object)data, TF_DataType.TF_UINT8); + return new Tensor((byte)(object)data); else if (typeof(T) == typeof(ulong)) - return new Tensor((ulong)(object)data, TF_DataType.TF_UINT64); + return new Tensor((ulong)(object)data); else if (typeof(T) == typeof(UInt32)) - return new Tensor((UInt32)(object)data, TF_DataType.TF_UINT32); + return new Tensor((UInt32)(object)data); else if (typeof(T) == typeof(UInt16)) - return new Tensor((UInt16)(object)data, TF_DataType.TF_UINT16); +#pragma warning disable IDE0055 + // Tensorflow.NET v2.7 has no constructor for UInt16 so using the array version + return new Tensor(new UInt16[]{(UInt16)(object)data}); +#pragma warning restore IDE0055 else if (typeof(T) == typeof(bool)) - return new Tensor((bool)(object)data, TF_DataType.TF_BOOL); + return new Tensor((bool)(object)data); else if (typeof(T) == typeof(float)) - return new Tensor((float)(object)data, TF_DataType.TF_FLOAT); + return new Tensor((float)(object)data); else if (typeof(T) == typeof(double)) - return new Tensor((double)(object)data, TF_DataType.TF_DOUBLE); + return new Tensor((double)(object)data); else if (typeof(T) == typeof(ReadOnlyMemory)) return new Tensor(data.ToString()); @@ -556,7 +560,8 @@ public Runner AddInput(Tensor value, int index) { _inputTensors[index]?.Dispose(); _inputTensors[index] = value; - _inputValues[index] = value; + _inputValues[index] = value.Handle.DangerousGetHandle(); + return this; } @@ -613,7 +618,9 @@ public Tensor[] Run() _status.Check(true); for (int i = 0; i < _outputs.Length; i++) - _outputTensors[i] = new Tensor(_outputValues[i]); + { + _outputTensors[i] = new Tensor(new SafeTensorHandle(_outputValues[i])); + } return _outputTensors; } diff --git a/src/Microsoft.ML.Vision/DnnRetrainTransform.cs b/src/Microsoft.ML.Vision/DnnRetrainTransform.cs index d172633057..8075dd1d03 100644 --- a/src/Microsoft.ML.Vision/DnnRetrainTransform.cs +++ b/src/Microsoft.ML.Vision/DnnRetrainTransform.cs @@ -15,7 +15,6 @@ using Microsoft.ML.Runtime; using Microsoft.ML.TensorFlow; using Microsoft.ML.Transforms; -using NumSharp; using Tensorflow; using static Microsoft.ML.TensorFlow.TensorFlowUtils; using static Tensorflow.Binding; @@ -50,7 +49,7 @@ internal sealed class DnnRetrainTransformer : RowToRowTransformerBase, IDisposab private readonly DataViewType[] _outputTypes; private readonly TF_DataType[] _tfOutputTypes; private readonly TF_DataType[] _tfInputTypes; - private readonly TensorShape[] _tfInputShapes; + private readonly Shape[] _tfInputShapes; private readonly (Operation, int)[] _tfInputOperations; private readonly (Operation, int)[] _tfOutputOperations; private readonly TF_Output[] _tfInputNodes; @@ -225,7 +224,7 @@ private void CheckTrainingParameters(DnnRetrainEstimator.Options options) } } - private (int, bool, TF_DataType, TensorShape) GetTrainingInputInfo(DataViewSchema inputSchema, string columnName, string tfNodeName, int batchSize) + private (int, bool, TF_DataType, Shape) GetTrainingInputInfo(DataViewSchema inputSchema, string columnName, string tfNodeName, int batchSize) { if (!inputSchema.TryGetColumnIndex(columnName, out int inputColIndex)) throw Host.Except($"Column {columnName} doesn't exist"); @@ -237,7 +236,7 @@ private void CheckTrainingParameters(DnnRetrainEstimator.Options options) var tfInput = new TF_Input(inputTensor, index); var tfInputType = inputTensor.OpType == "Placeholder" ? inputTensor.OutputType(index) : inputTensor.InputType(index); - var tfInputShape = ((Tensor)inputTensor).TensorShape; + var tfInputShape = ((Tensor)inputTensor).shape; var numInputDims = tfInputShape != null ? tfInputShape.ndim : -1; if (isInputVector && (tfInputShape == null || (numInputDims == 0))) @@ -248,17 +247,17 @@ private void CheckTrainingParameters(DnnRetrainEstimator.Options options) for (int indexLocal = 0; indexLocal < vecType.Dimensions.Length; indexLocal += 1) colTypeDims[indexLocal + 1] = vecType.Dimensions[indexLocal]; - tfInputShape = new TensorShape(colTypeDims); + tfInputShape = new Shape(colTypeDims); } if (numInputDims != -1) { - var newShape = new int[numInputDims]; + var newShape = new long[numInputDims]; var dims = tfInputShape.dims; newShape[0] = dims[0] == 0 || dims[0] == -1 ? batchSize : dims[0]; for (int j = 1; j < numInputDims; j++) newShape[j] = dims[j]; - tfInputShape = new TensorShape(newShape); + tfInputShape = new Shape(newShape); } var expectedType = Tf2MlNetType(tfInputType); @@ -278,7 +277,7 @@ private void TrainCore(DnnRetrainEstimator.Options options, IDataView input, IDa var inputColIndices = new int[inputsForTraining.Length]; var isInputVector = new bool[inputsForTraining.Length]; var tfInputTypes = new TF_DataType[inputsForTraining.Length]; - var tfInputShapes = new TensorShape[inputsForTraining.Length]; + var tfInputShapes = new Shape[inputsForTraining.Length]; for (int i = 0; i < _inputs.Length; i++) inputsForTraining[i] = _idvToTfMapping[_inputs[i]]; @@ -382,13 +381,13 @@ private void TrainCore(DnnRetrainEstimator.Options options, IDataView input, IDa runner.AddInput(srcTensorGetters[i].GetBufferedBatchTensor(), i + 1); Tensor[] tensor = runner.Run(); - if (tensor.Length > 0 && tensor[0] != IntPtr.Zero) + if (tensor.Length > 0 && tensor[0].TensorDataPointer != IntPtr.Zero) { tensor[0].ToScalar(ref loss); tensor[0].Dispose(); } - if (tensor.Length > 1 && tensor[1] != IntPtr.Zero) + if (tensor.Length > 1 && tensor[1].TensorDataPointer != IntPtr.Zero) { tensor[1].ToScalar(ref metric); tensor[1].Dispose(); @@ -460,14 +459,14 @@ private void UpdateModelOnDisk(string modelDir, DnnRetrainEstimator.Options opti } } - private static ITensorValueGetter CreateTensorValueGetter(DataViewRow input, bool isVector, int colIndex, TensorShape tfShape, bool keyType = false) + private static ITensorValueGetter CreateTensorValueGetter(DataViewRow input, bool isVector, int colIndex, Shape tfShape, bool keyType = false) { if (isVector) return new TensorValueGetterVec(input, colIndex, tfShape); return new TensorValueGetter(input, colIndex, tfShape, keyType); } - private static ITensorValueGetter CreateTensorValueGetter(DataViewRow input, TF_DataType tfType, bool isVector, int colIndex, TensorShape tfShape) + private static ITensorValueGetter CreateTensorValueGetter(DataViewRow input, TF_DataType tfType, bool isVector, int colIndex, Shape tfShape) { var type = Tf2MlNetType(tfType); if (input.Schema[colIndex].Type is KeyDataViewType && type.RawType == typeof(Int64)) @@ -481,7 +480,7 @@ private static ITensorValueGetter[] GetTensorValueGetters( int[] inputColIndices, bool[] isInputVector, TF_DataType[] tfInputTypes, - TensorShape[] tfInputShapes) + Shape[] tfInputShapes) { var srcTensorGetters = new ITensorValueGetter[inputColIndices.Length]; for (int i = 0; i < inputColIndices.Length; i++) @@ -574,10 +573,10 @@ private static (Operation, int) GetOperationFromName(string operation, Session s return (session.graph.OperationByName(operation), 0); } - internal static (TF_DataType[] tfInputTypes, TensorShape[] tfInputShapes, (Operation, int)[]) GetInputInfo(IHost host, Session session, string[] inputs, int batchSize = 1) + internal static (TF_DataType[] tfInputTypes, Shape[] tfInputShapes, (Operation, int)[]) GetInputInfo(IHost host, Session session, string[] inputs, int batchSize = 1) { var tfInputTypes = new TF_DataType[inputs.Length]; - var tfInputShapes = new TensorShape[inputs.Length]; + var tfInputShapes = new Shape[inputs.Length]; var tfInputOperations = new (Operation, int)[inputs.Length]; int index = 0; @@ -594,7 +593,7 @@ internal static (TF_DataType[] tfInputTypes, TensorShape[] tfInputShapes, (Opera throw host.ExceptParam(nameof(session), $"Input type '{tfInputType}' of input column '{input}' is not supported in TensorFlow"); tfInputTypes[index] = tfInputType; - tfInputShapes[index] = ((Tensor)inputTensor).TensorShape; + tfInputShapes[index] = ((Tensor)inputTensor).shape; tfInputOperations[index] = (inputTensor, inputTensorIndex); index++; } @@ -602,7 +601,7 @@ internal static (TF_DataType[] tfInputTypes, TensorShape[] tfInputShapes, (Opera return (tfInputTypes, tfInputShapes, tfInputOperations); } - internal static TensorShape GetTensorShape(TF_Output output, Graph graph, Status status = null) + internal static Shape GetTensorShape(TF_Output output, Graph graph, Status status = null) { if (graph == IntPtr.Zero) throw new ObjectDisposedException(nameof(graph)); @@ -613,12 +612,12 @@ internal static TensorShape GetTensorShape(TF_Output output, Graph graph, Status cstatus.Check(); if (n == -1) - return new TensorShape(new int[0]); + return new Shape(new int[0]); var dims = new long[n]; c_api.TF_GraphGetTensorShape(graph, output, dims, dims.Length, cstatus.Handle); cstatus.Check(); - return new TensorShape(dims.Select(x => (int)x).ToArray()); + return new Shape(dims.Select(x => (int)x).ToArray()); } internal static (TF_DataType[] tfOutputTypes, DataViewType[] outputTypes, (Operation, int)[]) GetOutputInfo(IHost host, Session session, string[] outputs) @@ -645,12 +644,12 @@ internal static (TF_DataType[] tfOutputTypes, DataViewType[] outputTypes, (Opera // i.e. the first dimension (if unknown) is assumed to be batch dimension. // If there are other dimension that are unknown the transformer will return a variable length vector. // This is the work around in absence of reshape transformer. - int[] dims = shape.ndim > 0 ? shape.dims.Skip(shape.dims[0] == -1 ? 1 : 0).ToArray() : new[] { 0 }; + long[] dims = shape.ndim > 0 ? shape.dims.Skip(shape.dims[0] == -1 ? 1 : 0).ToArray() : new long[] { 0 }; for (int j = 0; j < dims.Length; j++) dims[j] = dims[j] == -1 ? 0 : dims[j]; if (dims == null || dims.Length == 0) { - dims = new[] { 1 }; + dims = new long[] { 1 }; outputTypes[i] = Tf2MlNetType(tfOutputType); } else @@ -741,7 +740,7 @@ public void Dispose() { if (_session.graph != null) _session.graph.Dispose(); - _session.close(); + _session.Dispose(); } } finally @@ -760,7 +759,7 @@ private sealed class Mapper : MapperBase private readonly DnnRetrainTransformer _parent; private readonly int[] _inputColIndices; private readonly bool[] _isInputVector; - private readonly TensorShape[] _fullySpecifiedShapes; + private readonly Shape[] _fullySpecifiedShapes; private readonly ConcurrentBag _runners; public Mapper(DnnRetrainTransformer parent, DataViewSchema inputSchema) : @@ -770,7 +769,7 @@ public Mapper(DnnRetrainTransformer parent, DataViewSchema inputSchema) : _parent = parent; _inputColIndices = new int[_parent._inputs.Length]; _isInputVector = new bool[_parent._inputs.Length]; - _fullySpecifiedShapes = new TensorShape[_parent._inputs.Length]; + _fullySpecifiedShapes = new Shape[_parent._inputs.Length]; for (int i = 0; i < _parent._inputs.Length; i++) { if (!inputSchema.TryGetColumnIndex(_parent._inputs[i], out _inputColIndices[i])) @@ -792,12 +791,12 @@ public Mapper(DnnRetrainTransformer parent, DataViewSchema inputSchema) : var colTypeDims = vecType.Dimensions.Select(dim => (int)dim).ToArray(); if (shape == null || (shape.Length == 0)) - _fullySpecifiedShapes[i] = new TensorShape(colTypeDims); + _fullySpecifiedShapes[i] = new Shape(colTypeDims); else { // If the column is one dimension we make sure that the total size of the TF shape matches. // Compute the total size of the known dimensions of the shape. - int valCount = 1; + long valCount = 1; int numOfUnkDim = 0; foreach (var s in shape) { @@ -821,19 +820,19 @@ public Mapper(DnnRetrainTransformer parent, DataViewSchema inputSchema) : // Fill in the unknown dimensions. var originalShapeDims = originalShape.dims; var originalShapeNdim = originalShape.ndim; - var l = new int[originalShapeNdim]; + var l = new long[originalShapeNdim]; for (int ishape = 0; ishape < originalShapeNdim; ishape++) l[ishape] = originalShapeDims[ishape] == -1 ? (int)d : originalShapeDims[ishape]; - _fullySpecifiedShapes[i] = new TensorShape(l); + _fullySpecifiedShapes[i] = new Shape(l); } if (_parent._addBatchDimensionInput) { - var l = new int[_fullySpecifiedShapes[i].ndim + 1]; + var l = new long[_fullySpecifiedShapes[i].ndim + 1]; l[0] = 1; for (int ishape = 1; ishape < l.Length; ishape++) l[ishape] = _fullySpecifiedShapes[i].dims[ishape - 1]; - _fullySpecifiedShapes[i] = new TensorShape(l); + _fullySpecifiedShapes[i] = new Shape(l); } } @@ -891,7 +890,7 @@ private Delegate MakeGetter(DataViewRow input, int iinfo, ITensorValueGetter[ UpdateCacheIfNeeded(input.Position, srcTensorGetters, activeOutputColNames, outputCache); var tensor = outputCache.Outputs[_parent._outputs[iinfo]]; - var tensorSize = tensor.TensorShape.dims.Where(x => x > 0).Aggregate((x, y) => x * y); + var tensorSize = tensor.shape.dims.Where(x => x > 0).Aggregate((x, y) => x * y); var editor = VBufferEditor.Create(ref dst, (int)tensorSize); FetchStringData(tensor, editor.Values); @@ -906,7 +905,7 @@ private Delegate MakeGetter(DataViewRow input, int iinfo, ITensorValueGetter[ UpdateCacheIfNeeded(input.Position, srcTensorGetters, activeOutputColNames, outputCache); var tensor = outputCache.Outputs[_parent._outputs[iinfo]]; - var tensorSize = tensor.TensorShape.dims.Where(x => x > 0).Aggregate((x, y) => x * y); + var tensorSize = tensor.shape.dims.Where(x => x > 0).Aggregate((x, y) => x * y); var editor = VBufferEditor.Create(ref dst, (int)tensorSize); @@ -972,12 +971,12 @@ private class TensorValueGetter : ITensorValueGetter private readonly ValueGetter _srcgetter; private readonly T[] _bufferedData; private readonly Int64[] _bufferedDataLong; - private readonly TensorShape _tfShape; + private readonly Shape _tfShape; private int _position; private readonly bool _keyType; private readonly long[] _dims; - public TensorValueGetter(DataViewRow input, int colIndex, TensorShape tfShape, bool keyType = false) + public TensorValueGetter(DataViewRow input, int colIndex, Shape tfShape, bool keyType = false) { _srcgetter = input.GetGetter(input.Schema[colIndex]); _tfShape = tfShape; @@ -1035,7 +1034,7 @@ public Tensor GetBufferedBatchTensor() { if (_keyType) { - var tensor = new Tensor(_bufferedDataLong, _dims, TF_DataType.TF_INT64); + var tensor = new Tensor(_bufferedDataLong, _dims); _position = 0; return tensor; } @@ -1051,7 +1050,7 @@ public Tensor GetBufferedBatchTensor() private class TensorValueGetterVec : ITensorValueGetter { private readonly ValueGetter> _srcgetter; - private readonly TensorShape _tfShape; + private readonly Shape _tfShape; private VBuffer _vBuffer; private T[] _denseData; private T[] _bufferedData; @@ -1059,7 +1058,7 @@ private class TensorValueGetterVec : ITensorValueGetter private readonly long[] _dims; private readonly long _bufferedDataSize; - public TensorValueGetterVec(DataViewRow input, int colIndex, TensorShape tfShape) + public TensorValueGetterVec(DataViewRow input, int colIndex, Shape tfShape) { _srcgetter = input.GetGetter>(input.Schema[colIndex]); _tfShape = tfShape; diff --git a/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs b/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs index 846de00518..ff6617e034 100644 --- a/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs +++ b/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs @@ -26,6 +26,7 @@ using Tensorflow.Summaries; using static Microsoft.ML.Data.TextLoader; using static Microsoft.ML.TensorFlow.TensorFlowUtils; +using static Microsoft.ML.Vision.StringTensorFactory; using static Tensorflow.Binding; using Column = Microsoft.ML.Data.TextLoader.Column; @@ -763,23 +764,7 @@ private void CheckTrainingParameters(Options options) private static Tensor EncodeByteAsString(VBuffer buffer) { - int length = buffer.Length; - var size = c_api.TF_StringEncodedSize((ulong)length); - var handle = c_api.TF_AllocateTensor(TF_DataType.TF_STRING, Array.Empty(), 0, ((ulong)size + 8)); - - IntPtr tensor = c_api.TF_TensorData(handle); - Marshal.WriteInt64(tensor, 0); - - var status = new Status(); - unsafe - { - fixed (byte* src = buffer.GetValues()) - c_api.TF_StringEncode(src, (ulong)length, (byte*)(tensor + sizeof(Int64)), size, status.Handle); - } - - status.Check(true); - status.Dispose(); - return new Tensor(handle); + return StringTensorFactory.CreateStringTensor(buffer.DenseValues().ToArray()); } internal sealed class ImageProcessor @@ -976,8 +961,8 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath, metrics.Train = new TrainMetrics(); float accuracy = 0; float crossentropy = 0; - var labelTensorShape = _labelTensor.TensorShape.dims.Select(x => (long)x).ToArray(); - var featureTensorShape = _bottleneckInput.TensorShape.dims.Select(x => (long)x).ToArray(); + var labelTensorShape = _labelTensor.shape.dims.Select(x => (long)x).ToArray(); + var featureTensorShape = _bottleneckInput.shape.dims.Select(x => (long)x).ToArray(); byte[] buffer = new byte[sizeof(int)]; trainSetFeatureReader.ReadExactly(buffer, 0, 4); int trainingExamples = BitConverter.ToInt32(buffer, 0); @@ -1119,12 +1104,12 @@ private void TrainAndEvaluateClassificationLayerCore(int epoch, float learningRa { // Add learning rate as a placeholder only when learning rate scheduling is used. metrics.Train.LearningRate = learningRateScheduler.GetLearningRate(trainState); - runner.AddInput(new Tensor(metrics.Train.LearningRate, TF_DataType.TF_FLOAT), 2); + runner.AddInput(new Tensor(metrics.Train.LearningRate), 2); } - var outputTensors = runner.AddInput(new Tensor(featureBufferPtr, featureTensorShape, TF_DataType.TF_FLOAT, featuresFileBytesRead), 0) - .AddInput(new Tensor(labelBufferPtr, labelTensorShape, TF_DataType.TF_INT64, labelFileBytesRead), 1) - .Run(); + var outputTensors = runner.AddInput(new Tensor(featureBufferPtr, featureTensorShape, TF_DataType.TF_FLOAT), 0) + .AddInput(new Tensor(labelBufferPtr, labelTensorShape, TF_DataType.TF_INT64), 1) + .Run(); metrics.Train.BatchProcessedCount += 1; metricsAggregator(outputTensors, metrics); @@ -1186,7 +1171,7 @@ private void TryCleanupTemporaryWorkspace() { tf_with(tf.name_scope("correct_prediction"), delegate { - _prediction = tf.argmax(resultTensor, 1); + _prediction = tf.math.argmax(resultTensor, 1); correctPrediction = tf.equal(_prediction, groundTruthTensor); }); @@ -1240,7 +1225,7 @@ private void VariableSummaries(ResourceVariable var) string scoreColumnName, Tensor bottleneckTensor, bool isTraining, bool useLearningRateScheduler, float learningRate) { - var bottleneckTensorDims = bottleneckTensor.TensorShape.dims; + var bottleneckTensorDims = CastLongArrayToIntArray(bottleneckTensor.shape.dims); var (batch_size, bottleneck_tensor_size) = (bottleneckTensorDims[0], bottleneckTensorDims[1]); tf_with(tf.name_scope("input"), scope => { @@ -1254,7 +1239,7 @@ private void VariableSummaries(ResourceVariable var) _learningRateInput = tf.placeholder(tf.float32, null, name: "learningRateInputPlaceholder"); } - _labelTensor = tf.placeholder(tf.int64, new TensorShape(batch_size), name: labelColumn); + _labelTensor = tf.placeholder(tf.int64, new Shape(batch_size), name: labelColumn); }); string layerName = "final_retrain_ops"; @@ -1274,7 +1259,7 @@ private void VariableSummaries(ResourceVariable var) ResourceVariable layerBiases = null; tf_with(tf.name_scope("biases"), delegate { - TensorShape shape = new TensorShape(classCount); + Shape shape = new Shape(classCount); layerBiases = tf.Variable(tf.zeros(shape), name: "final_biases"); VariableSummaries(layerBiases); }); @@ -1313,6 +1298,27 @@ private void VariableSummaries(ResourceVariable var) return (_trainStep, crossEntropyMean, _labelTensor, _softMaxTensor); } + private static int[] CastLongArrayToIntArray(long[] source) + { + if (source == null) + throw new ArgumentNullException(nameof(source)); + + int[] result = new int[source.Length]; + + for (int i = 0; i < source.Length; i++) + { + long value = source[i]; + if (value > int.MaxValue || value < int.MinValue) + { + throw new OverflowException($"Value at index {i} ({value}) cannot be safely cast to int."); + } + + result[i] = (int)value; + } + + return result; + } + private void AddTransferLearningLayer(string labelColumn, string scoreColumnName, float learningRate, bool useLearningRateScheduling, int classCount) { @@ -1514,10 +1520,94 @@ public void Dispose() if (_session != null && _session != IntPtr.Zero) { - _session.close(); + _session.Dispose(); } _isDisposed = true; } } + +#pragma warning disable MSML_GeneralName +#pragma warning disable MSML_ParameterLocalVarName +#pragma warning disable IDE0055 + public class StringTensorFactory + { + // Define TF_TString struct + [StructLayout(LayoutKind.Sequential)] + struct TF_TString + { + public IntPtr data; + public UIntPtr length; + public UIntPtr capacity; + public int memory_type; + } + + // Import TF_TString methods from TensorFlow C API + [DllImport("tensorflow", CallingConvention = CallingConvention.Cdecl)] + private static extern unsafe void TF_StringInit(TF_TString* tstring); + + [DllImport("tensorflow", CallingConvention = CallingConvention.Cdecl)] + private static extern unsafe void TF_StringCopy(TF_TString* dst, byte* src, UIntPtr size); + + [DllImport("tensorflow", CallingConvention = CallingConvention.Cdecl)] + private static extern unsafe void TF_StringDealloc(TF_TString* tstring); + + private static readonly TF_Deallocator _deallocatorInstance = new StringTensorFactory.TF_Deallocator(Deallocator); + + // Delegate for TensorFlow deallocator + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void TF_Deallocator(IntPtr data, UIntPtr length, IntPtr arg); + + // Deallocator function + public static void Deallocator(IntPtr data, UIntPtr length, IntPtr arg) + { + unsafe + { + TF_StringDealloc((TF_TString*)data); + } + Marshal.FreeHGlobal(data); + } + + public static Tensor CreateStringTensor(byte[] data) + { + int sizeOfTString = Marshal.SizeOf(); + + // Allocate memory for TF_TString + IntPtr tstringPtr = Marshal.AllocHGlobal(sizeOfTString); + unsafe + { + TF_TString* tstring = (TF_TString*)tstringPtr; + TF_StringInit(tstring); + + fixed (byte* src = data) + { + TF_StringCopy(tstring, src, (UIntPtr)data.Length); + } + } + + // Create a scalar tensor (rank 0, so no shape dims) + Tensor tensor = new Tensor(new SafeTensorHandle(TF_NewTensor( + TF_DataType.TF_STRING, + Array.Empty(), + 0, + tstringPtr, + (UIntPtr)sizeOfTString, + _deallocatorInstance, + IntPtr.Zero + ))); + + return tensor; + } + + [DllImport("tensorflow", CallingConvention = CallingConvention.Cdecl)] + private static extern IntPtr TF_NewTensor( + TF_DataType dtype, + long[] dims, int num_dims, + IntPtr data, UIntPtr len, + TF_Deallocator deallocator, + IntPtr deallocator_arg); + } +#pragma warning restore MSML_GeneralName +#pragma warning restore MSML_ParameterLocalVarName +#pragma warning restore IDE0055 } diff --git a/test/Microsoft.ML.TensorFlow.Tests/TensorflowTests.cs b/test/Microsoft.ML.TensorFlow.Tests/TensorflowTests.cs index 16bc4a6b74..d9d362d020 100644 --- a/test/Microsoft.ML.TensorFlow.Tests/TensorflowTests.cs +++ b/test/Microsoft.ML.TensorFlow.Tests/TensorflowTests.cs @@ -18,6 +18,7 @@ using Microsoft.ML.Transforms; using Microsoft.ML.Transforms.Image; using Microsoft.ML.Vision; +using Tensorflow; using Xunit; using Xunit.Abstractions; using static Microsoft.ML.DataOperationsCatalog; @@ -1187,7 +1188,9 @@ public void TensorFlowSaveAndLoadSavedModel() predictFunction.Dispose(); // Reload the model and check the output schema consistency +#pragma warning disable IDE0055 DataViewSchema loadedInputschema; +#pragma warning restore IDE0055 var testTransformer = _mlContext.Model.Load(mlModelLocation, out loadedInputschema); var testOutputSchema = transformer.GetOutputSchema(data.Schema); Assert.True(TestCommon.CheckSameSchemas(outputSchema, testOutputSchema)); @@ -2055,7 +2058,7 @@ public void TensorflowPlaceholderShapeInferenceTest() new TextLoader.Column("name", DataKind.String, 1) }); - Tensorflow.TensorShape[] tfInputShape; + Tensorflow.Shape[] tfInputShape; using (var tfModel = _mlContext.Model.LoadTensorFlowModel(modelLocation)) { @@ -2070,8 +2073,8 @@ public void TensorflowPlaceholderShapeInferenceTest() transformer.Dispose(); } - Assert.Equal(imageHeight, tfInputShape.ElementAt(0)[1].dims[0]); - Assert.Equal(imageWidth, tfInputShape.ElementAt(0)[2].dims[0]); + Assert.Equal(imageHeight, tfInputShape.ElementAt(0)[Slice.Index(1)].dims[0]); + Assert.Equal(imageWidth, tfInputShape.ElementAt(0)[Slice.Index(2)].dims[0]); } } } From 16782f6ce7670c40f11549093902be2eb7a85e05 Mon Sep 17 00:00:00 2001 From: Ross Halliday Date: Mon, 26 May 2025 09:57:37 +0100 Subject: [PATCH 2/3] Common code to ArrayUtils, dotnet added to Sdk install to resolve CI issues. --- eng/Versions.props | 1 + global.json | 6 +++-- src/Microsoft.ML.Core/Utilities/ArrayUtils.cs | 21 +++++++++++++++++ src/Microsoft.ML.DataView/VectorType.cs | 23 +------------------ .../ImageClassificationTrainer.cs | 22 +----------------- 5 files changed, 28 insertions(+), 45 deletions(-) diff --git a/eng/Versions.props b/eng/Versions.props index 6071d3f53f..d264ca5d0f 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -90,6 +90,7 @@ 0.13.12 6.0.36 8.0.16 + 9.0.5 8.1.0 1.1.2 9.0.0-beta.24212.4 diff --git a/global.json b/global.json index 768272c719..7b05feb25b 100644 --- a/global.json +++ b/global.json @@ -4,11 +4,13 @@ "runtimes": { "dotnet": [ "$(DotNetRuntime60Version)", - "$(DotNetRuntime80Version)" + "$(DotNetRuntime80Version)", + "$(DotNetRuntime90Version)" ], "dotnet/x86": [ "$(DotNetRuntime60Version)", - "$(DotNetRuntime80Version)" + "$(DotNetRuntime80Version)", + "$(DotNetRuntime90Version)" ] } }, diff --git a/src/Microsoft.ML.Core/Utilities/ArrayUtils.cs b/src/Microsoft.ML.Core/Utilities/ArrayUtils.cs index 4c23831917..5ffb15fc41 100644 --- a/src/Microsoft.ML.Core/Utilities/ArrayUtils.cs +++ b/src/Microsoft.ML.Core/Utilities/ArrayUtils.cs @@ -100,5 +100,26 @@ public static int EnsureSize(ref T[] array, int min, int max, bool keepOld, o resized = true; return newSize; } + + public static int[] CastLongArrayToIntArray(long[] source) + { + if (source == null) + throw new ArgumentNullException(nameof(source)); + + int[] result = new int[source.Length]; + + for (int i = 0; i < source.Length; i++) + { + long value = source[i]; + if (value > int.MaxValue || value < int.MinValue) + { + throw new OverflowException($"Value at index {i} ({value}) cannot be safely cast to int."); + } + + result[i] = (int)value; + } + + return result; + } } } diff --git a/src/Microsoft.ML.DataView/VectorType.cs b/src/Microsoft.ML.DataView/VectorType.cs index 3ea42e2545..4423f1e8b0 100644 --- a/src/Microsoft.ML.DataView/VectorType.cs +++ b/src/Microsoft.ML.DataView/VectorType.cs @@ -81,7 +81,7 @@ public VectorDataViewType(PrimitiveDataViewType itemType, params long[] dimensio Contracts.CheckParam(dimensions.All(d => d >= 0), nameof(dimensions)); ItemType = itemType; - Dimensions = CastLongArrayToIntArray(dimensions).ToImmutableArray(); + Dimensions = ArrayUtils.CastLongArrayToIntArray(dimensions).ToImmutableArray(); Size = ComputeSize(Dimensions); } @@ -117,27 +117,6 @@ private static int ComputeSize(ImmutableArray dims) return size; } - private static int[] CastLongArrayToIntArray(long[] source) - { - if (source == null) - throw new ArgumentNullException(nameof(source)); - - int[] result = new int[source.Length]; - - for (int i = 0; i < source.Length; i++) - { - long value = source[i]; - if (value > int.MaxValue || value < int.MinValue) - { - throw new OverflowException($"Value at index {i} ({value}) cannot be safely cast from long to int."); - } - - result[i] = (int)value; - } - - return result; - } - /// /// Whether this is a vector type with known size. /// Equivalent to > 0. diff --git a/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs b/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs index ff6617e034..67381a87fe 100644 --- a/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs +++ b/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs @@ -25,6 +25,7 @@ using Tensorflow; using Tensorflow.Summaries; using static Microsoft.ML.Data.TextLoader; +using static Microsoft.ML.Internal.Utilities.ArrayUtils; using static Microsoft.ML.TensorFlow.TensorFlowUtils; using static Microsoft.ML.Vision.StringTensorFactory; using static Tensorflow.Binding; @@ -1298,27 +1299,6 @@ private void VariableSummaries(ResourceVariable var) return (_trainStep, crossEntropyMean, _labelTensor, _softMaxTensor); } - private static int[] CastLongArrayToIntArray(long[] source) - { - if (source == null) - throw new ArgumentNullException(nameof(source)); - - int[] result = new int[source.Length]; - - for (int i = 0; i < source.Length; i++) - { - long value = source[i]; - if (value > int.MaxValue || value < int.MinValue) - { - throw new OverflowException($"Value at index {i} ({value}) cannot be safely cast to int."); - } - - result[i] = (int)value; - } - - return result; - } - private void AddTransferLearningLayer(string labelColumn, string scoreColumnName, float learningRate, bool useLearningRateScheduling, int classCount) { From 85f7ad497b5b71b1d47c089a4f9370d60db60867 Mon Sep 17 00:00:00 2001 From: Eric StJohn Date: Tue, 17 Jun 2025 08:57:30 -0700 Subject: [PATCH 3/3] Temporarily omit GPU packages --- .../Microsoft.ML.Samples.GPU/Microsoft.ML.Samples.GPU.csproj | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/samples/Microsoft.ML.Samples.GPU/Microsoft.ML.Samples.GPU.csproj b/docs/samples/Microsoft.ML.Samples.GPU/Microsoft.ML.Samples.GPU.csproj index 6f33d4ea53..fe2ecb9cb9 100644 --- a/docs/samples/Microsoft.ML.Samples.GPU/Microsoft.ML.Samples.GPU.csproj +++ b/docs/samples/Microsoft.ML.Samples.GPU/Microsoft.ML.Samples.GPU.csproj @@ -46,14 +46,17 @@ + + - + DnnImageModels\ResNet18Onnx\ResNet18.onnx PreserveNewest