Skip to content

Commit

Permalink
Add Float16 and BFloat16 support to C# API (microsoft#5775)
Browse files Browse the repository at this point in the history
Add Float16 and BFloat16 support.
  • Loading branch information
yuslepukhin authored Nov 13, 2020
1 parent 4d517c6 commit 2f35e65
Show file tree
Hide file tree
Showing 10 changed files with 451 additions and 177 deletions.
24 changes: 21 additions & 3 deletions csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;

namespace Microsoft.ML.OnnxRuntime
{
Expand Down Expand Up @@ -55,15 +54,28 @@ public void Dispose()
/// tensors, sequences of tensors, sequences and maps
/// It extends NamedOnnxValue, exposes the OnnxValueType and Tensor type
/// The class must be disposed of.
/// It disposes of _ortValueHolder that owns the underlying Ort output value or
/// anything that the class that implements that interfaces needs to dispose.
/// It disposes of _ortValueHolder that owns the underlying Ort output value and
/// anything else that would need to be disposed by the instance of the class.
/// Use factory method CreateFromOrtValue to obtain an instance of the class.
/// </summary>
public class DisposableNamedOnnxValue : NamedOnnxValue, IDisposable
{
private IOrtValueOwner _ortValueHolder;
private bool _disposed = false;

/// <summary>
/// Ctor
/// </summary>
/// <param name="name">Name of the output value</param>
/// <param name="value">Managed object created to represent output value, such as DenseTensor<T>
/// List or Dictionary
/// </param>
/// <param name="onnxValueType">Use this to decide what you want to call to fetch data, AsTensor(), AsDictionary()
/// or AsEnumerable()</param>
/// <param name="elementType">Tensor element type if value type is a Tensor</param>
/// <param name="ortValueHolder">Object that holds native resources.
/// Typically, this is an output OrtValue that holds native memory where Tensor is mapped but may also be
/// other things that would need to be disposed by this instance depending on how IOrtValueOwner is implemented.</param>
private DisposableNamedOnnxValue(string name, Object value, OnnxValueType onnxValueType, TensorElementType elementType, IOrtValueOwner ortValueHolder)
: base(name, value)
{
Expand Down Expand Up @@ -169,6 +181,12 @@ internal static DisposableNamedOnnxValue CreateTensorFromOnnxValue(string name,
case TensorElementType.Bool:
result = DisposableNamedOnnxValueFromNativeTensor<bool>(name, ortValue);
break;
case TensorElementType.Float16:
result = DisposableNamedOnnxValueFromNativeTensor<Float16>(name, ortValue);
break;
case TensorElementType.BFloat16:
result = DisposableNamedOnnxValueFromNativeTensor<BFloat16>(name, ortValue);
break;
default:
throw new NotSupportedException("Tensor of element type: " + elemType + " is not supported");

Expand Down
62 changes: 9 additions & 53 deletions csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,60 +83,16 @@ internal static class TensorElementTypeConverter
{
public static void GetTypeAndWidth(TensorElementType elemType, out Type type, out int width)
{
switch (elemType)
TensorElementTypeInfo result = TensorBase.GetElementTypeInfo(elemType);
if(result != null)
{
case TensorElementType.Float:
type = typeof(float);
width = sizeof(float);
break;
case TensorElementType.Double:
type = typeof(double);
width = sizeof(double);
break;
case TensorElementType.Int16:
type = typeof(short);
width = sizeof(short);
break;
case TensorElementType.UInt16:
type = typeof(ushort);
width = sizeof(ushort);
break;
case TensorElementType.Int32:
type = typeof(int);
width = sizeof(int);
break;
case TensorElementType.UInt32:
type = typeof(uint);
width = sizeof(uint);
break;
case TensorElementType.Int64:
type = typeof(long);
width = sizeof(long);
break;
case TensorElementType.UInt64:
type = typeof(ulong);
width = sizeof(ulong);
break;
case TensorElementType.UInt8:
type = typeof(byte);
width = sizeof(byte);
break;
case TensorElementType.Int8:
type = typeof(sbyte);
width = sizeof(sbyte);
break;
case TensorElementType.String:
type = typeof(string);
width = sizeof(byte);
break;
case TensorElementType.Bool:
type = typeof(bool);
width = sizeof(bool);
break;
default:
type = null;
width = 0;
break;
type = result.TensorType;
width = result.TypeSize;
}
else
{
type = null;
width = 0;
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,16 @@ public static OrtValue CreateFromTensorObject(Object value, out MemoryHandle? me
out memHandle, out dataBufferLength,
out shape, out rank);
break;
case TensorElementType.Float16:
PinAsTensor(value as Tensor<Float16>, typeSize,
out memHandle, out dataBufferLength,
out shape, out rank);
break;
case TensorElementType.BFloat16:
PinAsTensor(value as Tensor<BFloat16>, typeSize,
out memHandle, out dataBufferLength,
out shape, out rank);
break;
default:
throw new NotSupportedException("Element type: " + elType + " is not of a supported type");
}
Expand Down
Loading

0 comments on commit 2f35e65

Please sign in to comment.