-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
m1 mac gpu #42
Comments
Hi, I think that just setting the device as export LIBTORCH=$(python -c 'import torch; from pathlib import Path; print(Path(torch.__file__).parent)')
export DYLD_LIBRARY_PATH=${LIBTORCH}/lib
export LIBTORCH_CXX11_ABI=0 As I said, I can't try it but I hope it works. Alternatively, here you can find a colab notebook to use diffusers-rs with cuda. |
I got this working, but it took a few more steps. Setting those exports is enough to get it to compile and execute, but on the CPU. Changing
But there is a workaround suggested upstream (for a bug further upstream), and that works: diff --git a/examples/stable-diffusion/main.rs b/examples/stable-diffusion/main.rs
index 98d99e5..477518f 100644
--- a/examples/stable-diffusion/main.rs
+++ b/examples/stable-diffusion/main.rs
@@ -230,7 +230,7 @@ fn run(args: Args) -> anyhow::Result<()> {
stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size)
}
};
- let cuda_device = Device::cuda_if_available();
+ let cuda_device = Device::Mps;
let cpu_or_cuda = |name: &str| {
if cpu.iter().any(|c| c == "all" || c == name) {
Device::Cpu
diff --git a/src/pipelines/stable_diffusion.rs b/src/pipelines/stable_diffusion.rs
index e5a5813..bb0c65c 100644
--- a/src/pipelines/stable_diffusion.rs
+++ b/src/pipelines/stable_diffusion.rs
@@ -97,10 +97,12 @@ impl StableDiffusionConfig {
vae_weights: &str,
device: Device,
) -> anyhow::Result<vae::AutoEncoderKL> {
- let mut vs_ae = nn::VarStore::new(device);
+ let mut vs_ae = nn::VarStore::new(tch::Device::Mps);
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
let autoencoder = vae::AutoEncoderKL::new(vs_ae.root(), 3, 3, self.autoencoder.clone());
+ vs_ae.set_device(tch::Device::Cpu);
vs_ae.load(vae_weights)?;
+ vs_ae.set_device(tch::Device::Mps);
Ok(autoencoder)
}
@@ -110,10 +112,12 @@ impl StableDiffusionConfig {
device: Device,
in_channels: i64,
) -> anyhow::Result<unet_2d::UNet2DConditionModel> {
- let mut vs_unet = nn::VarStore::new(device);
+ let mut vs_unet = nn::VarStore::new(tch::Device::Mps);
let unet =
unet_2d::UNet2DConditionModel::new(vs_unet.root(), in_channels, 4, self.unet.clone());
+ vs_unet.set_device(tch::Device::Cpu);
vs_unet.load(unet_weights)?;
+ vs_unet.set_device(tch::Device::Mps);
Ok(unet)
}
@@ -126,9 +130,11 @@ impl StableDiffusionConfig {
clip_weights: &str,
device: tch::Device,
) -> anyhow::Result<clip::ClipTextTransformer> {
- let mut vs = tch::nn::VarStore::new(device);
+ let mut vs = tch::nn::VarStore::new(tch::Device::Mps);
let text_model = clip::ClipTextTransformer::new(vs.root(), &self.clip);
+ vs.set_device(tch::Device::Cpu);
vs.load(clip_weights)?;
+ vs.set_device(tch::Device::Mps);
Ok(text_model)
}
} Obviously hardcoding the device isn't what you'd want to do in the actual project, but it works if you just want to get something working locally. Looks like tch-rs might set up this workaround in that crate, so you may want to just wait for that to get landed and released. |
The mps changes on the tch-rs side have been released (PR-623), I've published a new version of the |
@LaurentMazare Sweet! Would it be possible to update the logic to default to the MPS device when it's available? |
Closing this as the related PR has been merged for a while, feel free to re-open if it's still an issue (I don't have a mac at hand to test). |
Is there a way to specify gpu use for m1 macs? It's using my cpu when generating but I have more than 8GB memory, I'm using the command:
cargo run --example stable-diffusion --features clap -- --prompt "A very rusty robot holding a fire torch." --cpu all --sd-version v1-5
The text was updated successfully, but these errors were encountered: