diff --git a/crates/burn-wgpu/src/runtime.rs b/crates/burn-wgpu/src/runtime.rs index f1a7418534..4267b9faf4 100644 --- a/crates/burn-wgpu/src/runtime.rs +++ b/crates/burn-wgpu/src/runtime.rs @@ -43,7 +43,7 @@ impl Runtime for WgpuRuntime ComputeClient { RUNTIME.client(device, move || { - pollster::block_on(create_client::(device)) + pollster::block_on(create_client::(device, RuntimeOptions::default())) }) } @@ -52,16 +52,52 @@ impl Runtime for WgpuRuntime Self { + let max_tasks = match std::env::var("BURN_WGPU_MAX_TASKS") { + Ok(value) => value + .parse::() + .expect("BURN_WGPU_MAX_TASKS should be a positive integer."), + Err(_) => 64, // 64 tasks by default + }; + + Self { + dealloc_strategy: DeallocStrategy::new_period_tick(max_tasks * 2), + slice_strategy: SliceStrategy::Ratio(0.8), + max_tasks, + } + } +} + +/// Init the client sync, useful to configure the runtime options. +pub fn init_sync(device: &WgpuDevice, options: RuntimeOptions) { + let device = Arc::new(device); + let client = pollster::block_on(create_client::(&device, options)); + + RUNTIME.register(&device, client) +} + /// Init the client async, necessary for wasm. -pub async fn init_async(device: &WgpuDevice) { +pub async fn init_async(device: &WgpuDevice, options: RuntimeOptions) { let device = Arc::new(device); - let client = create_client::(&device).await; + let client = create_client::(&device, options).await; RUNTIME.register(&device, client) } async fn create_client( device: &WgpuDevice, + options: RuntimeOptions, ) -> ComputeClient< WgpuServer>, MutexComputeChannel>>, @@ -74,22 +110,11 @@ async fn create_client( info ); - // TODO: Support a way to modify max_tasks without std. - let max_tasks = match std::env::var("BURN_WGPU_MAX_TASKS") { - Ok(value) => value - .parse::() - .expect("BURN_WGPU_MAX_TASKS should be a positive integer."), - Err(_) => 64, // 64 tasks by default - }; - let device = Arc::new(device_wgpu); let storage = WgpuStorage::new(device.clone()); - let memory_management = SimpleMemoryManagement::new( - storage, - DeallocStrategy::new_period_tick(max_tasks * 2), - SliceStrategy::Ratio(0.8), - ); - let server = WgpuServer::new(memory_management, device, queue, max_tasks); + let memory_management = + SimpleMemoryManagement::new(storage, options.dealloc_strategy, options.slice_strategy); + let server = WgpuServer::new(memory_management, device, queue, options.max_tasks); let channel = MutexComputeChannel::new(server); let tuner_device_id = tuner_device_id(info); diff --git a/examples/image-classification-web/src/web.rs b/examples/image-classification-web/src/web.rs index 3238f2b4ea..6fc7d3f8f5 100644 --- a/examples/image-classification-web/src/web.rs +++ b/examples/image-classification-web/src/web.rs @@ -106,7 +106,7 @@ impl ImageClassifier { log::info!("Loading the model to the Wgpu backend"); let start = Instant::now(); let device = WgpuDevice::default(); - init_async::(&device).await; + init_async::(&device, Default::default()).await; self.model = ModelType::WithWgpuBackend(Model::new(&device)); let duration = start.elapsed(); log::debug!("Model is loaded to the Wgpu backend in {:?}", duration);