Skip to content

Commit

Permalink
Benchmark BatchEnumerator methods
Browse files Browse the repository at this point in the history
  • Loading branch information
georg-jung committed Dec 11, 2023
1 parent 0c6dcf0 commit f217bec
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 1 deletion.
28 changes: 28 additions & 0 deletions src/Benchmarks/NotImplementedExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
using System.Threading.Channels;
using FastBertTokenizer;

namespace Benchmarks;

// These will be used if they are not implemented in the referenced FastBertTokenizer library.
// The real ones will be used if they are available as they take precedence over extension methods.
public static class NotImplementedExtensions {
public static IAsyncEnumerable<TokenizedBatch<TKey>> CreateAsyncBatchEnumerator<TKey>(this BertTokenizer tok, ChannelReader<(TKey Key, string Content)> sourceChannel, int tokensPerInput, int batchSize, int stride)
=> throw new NotImplementedException();

public static IAsyncEnumerable<TokenizedBatch<TKey>> CreateAsyncBatchEnumerator<TKey>(this BertTokenizer tok, IAsyncEnumerable<(TKey Key, string Content)> sourceChannel, int tokensPerInput, int batchSize, int stride)
=> throw new NotImplementedException();
}


public class TokenizedBatch<TKey>
{
public ReadOnlyMemory<long> InputIds { get; internal set; }

public ReadOnlyMemory<long> AttentionMask { get; internal set; }

public ReadOnlyMemory<TokenizedRange<TKey>?> OutputCorrelation { get; set; }
}

public record struct TokenizedRange<TKey>(TKey Key, int Offset, int? LastTokenizedWordStartIndex)
{
}
50 changes: 49 additions & 1 deletion src/Benchmarks/TokenizeSpeed.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Georg Jung. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Threading.Channels;
using BenchmarkDotNet.Attributes;
using BenchmarkDotNet.Configs;
using BenchmarkDotNet.Diagnosers;
Expand Down Expand Up @@ -76,7 +77,7 @@ public object SingleThreadedMemReuse()
[Benchmark]
public IReadOnlyCollection<(ReadOnlyMemory<long> InputIds, ReadOnlyMemory<long> AttentionMask, ReadOnlyMemory<long> TokenTypeIds)> MultithreadedAllocating()
{
// this might be interesting to benchmark but doesn't make much sense as a real world use case
// This would produce wrong results because BertTokenizer is not thread-safe.
List<(ReadOnlyMemory<long> InputIds, ReadOnlyMemory<long> AttentionMask, ReadOnlyMemory<long> TokenTypeIds)> res = new(_corpus.Length);
foreach(var x in _corpus.AsParallel().AsOrdered().Select(x => _tokenizer.Tokenize(x, _maxSequenceLength)))
{
Expand Down Expand Up @@ -122,6 +123,53 @@ public object SingleThreadedMemReuse()
return (iids.AsMemory(), attm.AsMemory(), toktyp.AsMemory());
}

[Benchmark]
public async Task<List<object>> ParallelBatchEnumerator()
{
var ch = Channel.CreateBounded<(int, string)>(10);
async Task FillChannel()
{
foreach (var (i, text) in _corpus.Select((x, i) => (i, x)))
{
await ch.Writer.WriteAsync((i, text));
}

ch.Writer.Complete();
}

var channelTask = Task.Run(FillChannel);
var ret = new List<object>(_corpus.Length / 100);
await foreach (var batch in _tokenizer.CreateAsyncBatchEnumerator(ch.Reader, _maxSequenceLength, 100, 0))
{
ret.Add(batch.OutputCorrelation);
}

await channelTask;

return ret;
}

[Benchmark]
public async Task<List<object>> BatchEnumerator()
{
async IAsyncEnumerable<(int, string)> Enumerate()
{
await Task.Yield();
foreach (var (i, text) in _corpus.Select((x, i) => (i, x)))
{
yield return (i, text);
}
}

var ret = new List<object>(_corpus.Length / 100);
await foreach (var batch in _tokenizer.CreateAsyncBatchEnumerator(Enumerate(), _maxSequenceLength, 100, 0))
{
ret.Add(batch.OutputCorrelation);
}

return ret;
}

private sealed class Config : ManualConfig
{
public Config()
Expand Down

0 comments on commit f217bec

Please sign in to comment.