Skip to content

Commit

Permalink
fixed bug and implemented mult buffersizes
Browse files Browse the repository at this point in the history
  • Loading branch information
faressc committed Oct 25, 2023
1 parent ae732d6 commit a202d0f
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 20 deletions.
6 changes: 6 additions & 0 deletions source/dsp/inference/InferenceThread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,10 @@ void InferenceThread::setBackend(InferenceBackend backend) {
void InferenceThread::testInference(InferenceBackend backend) {
currentBackend.store(backend);
run();
}

void InferenceThread::testPushSamples(int numSamples) {
for (int i = 0; i < numSamples; i++) {
rawModelInput.pushSample(-1.f + (float) (std::rand()) / ((float) (RAND_MAX/2.f)), 0);
}
}
1 change: 1 addition & 0 deletions source/dsp/inference/InferenceThread.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class InferenceThread : public juce::Thread {
RingBuffer& getModelInputBuffer();
RingBuffer& getModelOutputBuffer();
void testInference(InferenceBackend backend);
void testPushSamples(int numSamples);

private:
void run() override;
Expand Down
40 changes: 20 additions & 20 deletions test/benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,51 +23,51 @@ class InferenceFixture : public benchmark::Fixture
{
public:
std::unique_ptr<AudioPluginAudioProcessor> plugin;
std::unique_ptr<juce::AudioBuffer<float>> buffer;
std::unique_ptr<juce::MidiBuffer> midiBuffer;

InferenceFixture() {
}

void SetUp(const ::benchmark::State& state) {
plugin = std::make_unique<AudioPluginAudioProcessor>();
buffer = std::make_unique<juce::AudioBuffer<float>>();
midiBuffer = std::make_unique<juce::MidiBuffer>();
buffer->setSize (2, 512);
buffer->clear();
midiBuffer->clear();
plugin->prepareToPlay (44100, 1024);
plugin->processBlock (*buffer, *midiBuffer);
plugin->prepareToPlay (44100, (int) state.range(0));
}

void TearDown(const ::benchmark::State& state) {
plugin.reset();
buffer.reset();
midiBuffer.reset();
std::cout << "Buffer size: " << state.range(0) << " samples | " << state.range(0) * 1000.f/44100.f << " ms" << std::endl;
}
};

BENCHMARK_DEFINE_F(InferenceFixture, BM_ONNX_INFERENCE)(benchmark::State& st) {
for (auto _ : st) {
BENCHMARK_DEFINE_F(InferenceFixture, BM_ONNX_INFERENCE)(benchmark::State& state) {
for (auto _ : state) {
state.PauseTiming();
plugin->getInferenceThread().testPushSamples((int) state.range(0));
state.ResumeTiming();
plugin->getInferenceThread().testInference(ONNX);
}
}

BENCHMARK_DEFINE_F(InferenceFixture, BM_LIBTORCH_INFERENCE)(benchmark::State& st) {
for (auto _ : st) {
BENCHMARK_DEFINE_F(InferenceFixture, BM_LIBTORCH_INFERENCE)(benchmark::State& state) {
for (auto _ : state) {
state.PauseTiming();
plugin->getInferenceThread().testPushSamples((int) state.range(0));
state.ResumeTiming();
plugin->getInferenceThread().testInference(LIBTORCH);
}
}

BENCHMARK_DEFINE_F(InferenceFixture, BM_TFLITE_INFERENCE)(benchmark::State& st) {
for (auto _ : st) {
BENCHMARK_DEFINE_F(InferenceFixture, BM_TFLITE_INFERENCE)(benchmark::State& state) {
for (auto _ : state) {
state.PauseTiming();
plugin->getInferenceThread().testPushSamples((int) state.range(0));
state.ResumeTiming();
plugin->getInferenceThread().testInference(TFLITE);
}
}

// Register the function as a benchmark
BENCHMARK(BM_PROCESSOR)->Unit(benchmark::kMillisecond);
BENCHMARK(BM_EDITOR)->Unit(benchmark::kMillisecond);
BENCHMARK_REGISTER_F(InferenceFixture, BM_ONNX_INFERENCE)->Unit(benchmark::kMillisecond)->Iterations(10)->Repetitions(1)->Threads(1);
BENCHMARK_REGISTER_F(InferenceFixture, BM_LIBTORCH_INFERENCE)->Unit(benchmark::kMillisecond)->Iterations(10)->Repetitions(1)->Threads(1);
// BENCHMARK_REGISTER_F(InferenceFixture, BM_TFLITE_INFERENCE)->Unit(benchmark::kMillisecond)->Iterations(10)->Repetitions(1)->Threads(1);
BENCHMARK_REGISTER_F(InferenceFixture, BM_ONNX_INFERENCE)->Unit(benchmark::kMillisecond)->Iterations(32)->Repetitions(1)->RangeMultiplier(2)->Range(128, 8<<10);
BENCHMARK_REGISTER_F(InferenceFixture, BM_LIBTORCH_INFERENCE)->Unit(benchmark::kMillisecond)->Iterations(32)->Repetitions(1)->RangeMultiplier(2)->Range(128, 8<<10);
BENCHMARK_REGISTER_F(InferenceFixture, BM_TFLITE_INFERENCE)->Unit(benchmark::kMillisecond)->Iterations(32)->Repetitions(1)->RangeMultiplier(2)->Range(128, 8<<10);

0 comments on commit a202d0f

Please sign in to comment.