Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
sachaarbonel committed Nov 5, 2024
1 parent 3475175 commit b0b95ab
Show file tree
Hide file tree
Showing 4 changed files with 268 additions and 175 deletions.
2 changes: 1 addition & 1 deletion examples/server/script.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export const options = {
scenarios: {
burst: {
executor: "shared-iterations",
vus: 4,
vus: 10,
iterations: 100,
maxDuration: "1m",
},
Expand Down
209 changes: 98 additions & 111 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -713,138 +713,125 @@ int main(int argc, char ** argv) {
// POST /inference handler
svr.Post(sparams.request_path + sparams.inference_path,
[&](const httplib::Request& req_, httplib::Response& res_) {
auto total_start = std::chrono::steady_clock::now();
const auto request_start = std::chrono::steady_clock::now();

std::cout << "[" << current_timestamp() << "] POST " << sparams.inference_path << "\n";
// Log initial state before processing
if (params.debug_mode) {
std::cout << "[" << current_timestamp() << "] POST " << sparams.inference_path << "\n"
<< thread_pool.get_stats() << "\n";
}

try {
// 1. Request validation and parsing
auto parse_start = std::chrono::steady_clock::now();
// Parse request and prepare data before queuing
if (!req_.has_file("file")) {
res_.status = 400;
res_.set_content(R"({"error":"Missing 'file' in request"})", "application/json");
return;
}

if (!req_.has_file("file")) {
res_.status = 400;
res_.set_content(R"({"error":"Missing 'file' in request"})", "application/json");
return;
}
const auto audio_file = req_.get_file_value("file");
if (audio_file.content.size() > MAX_UPLOAD_SIZE) {
res_.status = 413;
res_.set_content(R"({"error":"File too large"})", "application/json");
return;
}

auto audio_file = req_.get_file_value("file");
if (audio_file.content.size() > MAX_UPLOAD_SIZE) {
res_.status = 413;
res_.set_content(R"({"error":"File too large"})", "application/json");
return;
}
// Parse WAV data
std::vector<float> pcmf32;
std::vector<std::vector<float>> pcmf32s;
if (!::read_wav(audio_file.content, pcmf32, pcmf32s, params.diarize)) {
res_.status = 400;
res_.set_content(R"({"error":"Failed to read audio"})", "application/json");
return;
}

auto parse_end = std::chrono::steady_clock::now();
auto parse_duration = std::chrono::duration_cast<std::chrono::milliseconds>(
parse_end - parse_start).count();
// Create promise/future for the result
std::promise<json> result_promise;
auto result_future = result_promise.get_future();

if (params.debug_mode) {
std::cout << "[" << current_timestamp() << "] Request parsing completed in "
<< parse_duration << "ms, file size: "
<< (audio_file.content.size() / 1024) << "KB\n";
}
// Now enqueue the actual processing work
const auto queue_start = std::chrono::steady_clock::now();

// 2. WAV parsing
auto wav_start = std::chrono::steady_clock::now();
std::vector<float> pcmf32;
std::vector<std::vector<float>> pcmf32s;
thread_pool.enqueue([queue_start,
request_start,
pcmf32 = std::move(pcmf32),
pcmf32s = std::move(pcmf32s),
&pool,
&params,
result_promise = std::move(result_promise)]() mutable {
try {
// Try to acquire a whisper instance
auto instance = pool->get_instance();
if (!instance) {
json error_response = {
{"error", "No available instances"},
{"status", 503}
};
result_promise.set_value(error_response);
return;
}

if (!::read_wav(audio_file.content, pcmf32, pcmf32s, params.diarize)) {
res_.status = 400;
res_.set_content(R"({"error":"Failed to read audio"})", "application/json");
return;
}
const auto processing_start = std::chrono::steady_clock::now();
auto queue_time = std::chrono::duration_cast<std::chrono::milliseconds>(
processing_start - queue_start).count();

auto wav_end = std::chrono::steady_clock::now();
auto wav_duration = std::chrono::duration_cast<std::chrono::milliseconds>(
wav_end - wav_start).count();
// Process audio
auto process_result = process_audio(instance->ctx->get(), params, pcmf32, pcmf32s);

if (params.debug_mode) {
std::cout << "[" << current_timestamp() << "] WAV parsing completed in "
<< wav_duration << "ms\n";
}
// Release the instance
pool->release_instance(instance);

// 3. Instance acquisition
auto acquire_start = std::chrono::steady_clock::now();
auto instance = pool->get_instance();
auto acquire_end = std::chrono::steady_clock::now();
const auto processing_end = std::chrono::steady_clock::now();
auto processing_time = std::chrono::duration_cast<std::chrono::milliseconds>(
processing_end - processing_start).count();

if (!instance) {
res_.status = 503;
res_.set_content(R"({"error":"No available instances"})", "application/json");
return;
}

auto acquire_duration = std::chrono::duration_cast<std::chrono::milliseconds>(
acquire_end - acquire_start).count();

if (params.debug_mode) {
std::cout << "[" << current_timestamp() << "] Instance " << instance->id
<< " acquired by thread " << std::this_thread::get_id()
<< " after waiting " << acquire_duration << "ms\n";
}
// Prepare response with timing information
json response;
if (params.response_format == "json") {
response = json::parse(process_result);
} else {
response["text"] = process_result;
}

// 4. Audio processing
auto process_start = std::chrono::steady_clock::now();
auto result = process_audio(instance->ctx->get(), params, pcmf32, pcmf32s);
auto process_end = std::chrono::steady_clock::now();

auto process_duration = std::chrono::duration_cast<std::chrono::milliseconds>(
process_end - process_start).count();

if (process_duration > PROCESSING_TIMEOUT_MS) {
if (params.debug_mode) {
std::cout << "[" << current_timestamp() << "] Processing timeout after "
<< process_duration << "ms\n";
}
pool->release_instance(instance);
res_.status = 504; // Gateway Timeout
res_.set_content(R"({"error":"Processing timeout"})", "application/json");
return;
}
response["queue_time_ms"] = queue_time;
response["processing_time_ms"] = processing_time;
response["total_time_ms"] = std::chrono::duration_cast<std::chrono::milliseconds>(
processing_end - request_start).count();
response["status"] = 200;

// 5. Instance release
auto release_start = std::chrono::steady_clock::now();
pool->release_instance(instance);
auto release_end = std::chrono::steady_clock::now();

if (params.debug_mode) {
std::cout << "[" << current_timestamp() << "] Instance " << instance->id
<< " released by thread " << std::this_thread::get_id()
<< ", release took "
<< std::chrono::duration_cast<std::chrono::milliseconds>(
release_end - release_start).count() << "ms\n";
}
result_promise.set_value(response);

// 6. Response preparation and sending
json response;
if (params.response_format == "json") {
response = json::parse(result);
} else {
response["text"] = result;
} catch (const std::exception& e) {
result_promise.set_value({
{"error", "Internal server error"},
{"status", 500}
});
}
});

// Wait for the result with timeout
if (result_future.wait_for(std::chrono::milliseconds(PROCESSING_TIMEOUT_MS))
== std::future_status::timeout) {
res_.status = 504;
res_.set_content(R"({"error":"Processing timeout"})", "application/json");
return;
}

// Get the result and send response
json result = result_future.get();
int status = result["status"].get<int>();
result.erase("status");

res_.set_content(response.dump(), "application/json");
res_.status = status;
res_.set_content(result.dump(), "application/json");

auto total_end = std::chrono::steady_clock::now();
if (params.debug_mode) {
std::cout << "[" << current_timestamp() << "] Request completed in "
<< std::chrono::duration_cast<std::chrono::milliseconds>(
total_end - total_start).count() << "ms\n";
}
// Final stats after processing
if (params.debug_mode) {
const auto total_duration = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now() - request_start).count();

} catch (const std::exception& e) {
auto error_time = std::chrono::steady_clock::now();
if (params.debug_mode) {
std::cout << "[" << current_timestamp() << "] Error processing request after "
<< std::chrono::duration_cast<std::chrono::milliseconds>(
error_time - total_start).count()
<< "ms: " << e.what() << "\n";
}
res_.status = 500;
res_.set_content(R"({"error":"Internal server error"})", "application/json");
std::cout << "[" << current_timestamp() << "] Request completed. Final stats:\n"
<< thread_pool.get_stats() << "\n"
<< "Total request time: " << total_duration << "ms\n\n";
}
});

Expand Down
111 changes: 82 additions & 29 deletions examples/server/thread_pool.cpp
Original file line number Diff line number Diff line change
@@ -1,43 +1,96 @@
#include "thread_pool.h"

ThreadPool::ThreadPool(size_t threads)
: stop(false)
{
for (size_t i = 0; i < threads; ++i)
workers.emplace_back(
[this] {
for (;;) {
std::function<void()> task;

{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(
lock, [this] { return this->stop || !this->tasks.empty(); });
if (this->stop && this->tasks.empty())
return;
task = std::move(this->tasks.front());
this->tasks.pop();
ThreadPool::ThreadPool(size_t threads) : stop(false) {
for(size_t i = 0; i < threads; ++i) {
workers.emplace_back([this] {
const auto thread_id = std::this_thread::get_id();
std::cout << "[" << get_current_time() << "] Worker " << thread_id
<< " started\n";

while(true) {
std::function<void()> task;
std::chrono::steady_clock::time_point enqueue_time;
{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(lock, [this] {
return this->stop || !this->tasks.empty();
});

if(this->stop && this->tasks.empty()) {
std::cout << "[" << get_current_time() << "] Worker " << thread_id
<< " shutting down\n";
return;
}

task();
task = std::move(this->tasks.front().first);
enqueue_time = this->tasks.front().second;
this->tasks.pop();

std::cout << "[" << get_current_time() << "] Worker " << thread_id
<< " dequeued task. Queue size now: " << this->tasks.size()
<< ", Active workers: " << this->active_tasks.load() + 1 << "\n";
}
});
}

// The destructor joins all threads
ThreadPool::~ThreadPool()
{
shutdown();
active_tasks++;
auto start = std::chrono::steady_clock::now();
auto wait_time = std::chrono::duration_cast<std::chrono::milliseconds>(
start - enqueue_time).count();

total_wait_time.fetch_add(wait_time);

std::cout << "[" << get_current_time() << "] Worker " << thread_id
<< " starting task. Waited: " << wait_time << "ms\n";

task();

auto process_time = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now() - start).count();

total_processing_time.fetch_add(process_time);
total_tasks_processed++;
active_tasks--;

std::cout << "[" << get_current_time() << "] Worker " << thread_id
<< " completed task. Processing time: " << process_time << "ms\n";
}
});
}

std::cout << "[" << get_current_time() << "] Thread pool initialized with "
<< threads << " workers\n";
}

void ThreadPool::shutdown()
{
void ThreadPool::shutdown() {
std::cout << "[" << get_current_time() << "] Initiating thread pool shutdown\n";
{
std::unique_lock<std::mutex> lock(queue_mutex);
std::lock_guard<std::mutex> lock(queue_mutex);
stop = true;
}
condition.notify_all();
for (std::thread& worker : workers)
if (worker.joinable())
for(auto &worker: workers) {
if(worker.joinable()) {
worker.join();
}
}
std::cout << "[" << get_current_time() << "] Thread pool shutdown complete\n";
}

// Add the missing get_stats implementation
std::string ThreadPool::get_stats() const {
std::ostringstream oss;
size_t queued;
{
std::lock_guard<std::mutex> lock(queue_mutex);
queued = tasks.size();
}

oss << "ThreadPool Stats:\n"
<< " Total Workers: " << get_total_workers() << "\n"
<< " Active Workers: " << get_active_workers() << "\n"
<< " Tasks Queued: " << queued << "\n"
<< " Tasks Processed: " << get_total_tasks_processed() << "\n"
<< " Avg Wait Time: " << get_average_wait_time() << "ms\n"
<< " Avg Processing Time: " << get_average_processing_time() << "ms\n";

return oss.str();
}
Loading

0 comments on commit b0b95ab

Please sign in to comment.