From 0e0ad4228cc91150e9c469157ee510e7c31f9701 Mon Sep 17 00:00:00 2001 From: R-Tars Date: Sun, 27 Oct 2024 20:56:03 +0000 Subject: [PATCH] Improve Implementation --- examples/BuddyStableDiffusion/CMakeLists.txt | 22 +- .../buddy-stablediffusion-main.cpp | 339 +++++++++++++----- .../import-stable-diffusion.py | 12 +- frontend/Interfaces/buddy/LLM/TextContainer.h | 4 +- 4 files changed, 272 insertions(+), 105 deletions(-) diff --git a/examples/BuddyStableDiffusion/CMakeLists.txt b/examples/BuddyStableDiffusion/CMakeLists.txt index 44ae01bf76..767eaa069b 100644 --- a/examples/BuddyStableDiffusion/CMakeLists.txt +++ b/examples/BuddyStableDiffusion/CMakeLists.txt @@ -182,11 +182,25 @@ add_custom_command( COMMENT "Building subgraph0_vae.o" VERBATIM) -add_library(STABLEDIFFUSION STATIC subgraph0_text_encoder.o forward_text_encoder.o subgraph0_unet.o forward_unet.o subgraph0_vae.o forward_vae.o) -SET_TARGET_PROPERTIES(STABLEDIFFUSION PROPERTIES LINKER_LANGUAGE C) +add_library(TEXTENCODER STATIC subgraph0_text_encoder.o forward_text_encoder.o) +add_library(UNET STATIC subgraph0_unet.o forward_unet.o) +add_library(VAE STATIC subgraph0_vae.o forward_vae.o) + +target_compile_options(TEXTENCODER PRIVATE -ffunction-sections -fdata-sections -mcmodel=large -mlong-calls) +target_compile_options(UNET PRIVATE -ffunction-sections -fdata-sections -mcmodel=large -mlong-calls) +target_compile_options(VAE PRIVATE -ffunction-sections -fdata-sections -mcmodel=large -mlong-calls) + +SET_TARGET_PROPERTIES(TEXTENCODER PROPERTIES LINKER_LANGUAGE C) +SET_TARGET_PROPERTIES(UNET PROPERTIES LINKER_LANGUAGE C) +SET_TARGET_PROPERTIES(VAE PROPERTIES LINKER_LANGUAGE C) add_executable(buddy-stablediffusion-run buddy-stablediffusion-main.cpp) target_link_directories(buddy-stablediffusion-run PRIVATE ${LLVM_LIBRARY_DIR}) -set(BUDDY_STABLEDIFFUSION_LIBS STABLEDIFFUSION mlir_c_runner_utils) -target_link_libraries(buddy-stablediffusion-run ${BUDDY_STABLEDIFFUSION_LIBS}) +target_link_options(buddy-stablediffusion-run PRIVATE -Wl,--no-relax) + +set(BUDDY_STABLEDIFFUSION_LIBS TEXTENCODER UNET VAE mlir_c_runner_utils) + +target_link_libraries(buddy-stablediffusion-run ${BUDDY_STABLEDIFFUSION_LIBS} ${JPEG_LIBRARIES} ${PNG_LIBRARIES}) +include_directories(${JPEG_INCLUDE_DIRS}) +include_directories(${PNG_INCLUDE_DIRS}) diff --git a/examples/BuddyStableDiffusion/buddy-stablediffusion-main.cpp b/examples/BuddyStableDiffusion/buddy-stablediffusion-main.cpp index 31fde1507a..1c49e79bbe 100644 --- a/examples/BuddyStableDiffusion/buddy-stablediffusion-main.cpp +++ b/examples/BuddyStableDiffusion/buddy-stablediffusion-main.cpp @@ -1,5 +1,8 @@ #include #include +#include +#include +#include #include #include #include @@ -9,7 +12,7 @@ using namespace buddy; -/// Capture input message. +// Capture input message. void getUserInput(std::string &inputStr) { std::cout << "\nPlease input your prompt:" << std::endl; std::cout << ">>> "; @@ -17,8 +20,15 @@ void getUserInput(std::string &inputStr) { std::cout << std::endl; } -void getTimeSteps(float &input) { - std::cout << "\nPlease input timesteps:" << std::endl; +void getInferenceSteps(int &input) { + std::cout << "\nPlease enter the number of inference steps:" << std::endl; + std::cout << ">>> "; + std::cin >> input; + std::cout << std::endl; +} + +void getFileName(std::string &input) { + std::cout << "\nPlease enter the file name of the generated image:" << std::endl; std::cout << ">>> "; std::cin >> input; std::cout << std::endl; @@ -29,11 +39,11 @@ void printIterInfo(size_t iterIdx, double time) { std::cout << "Time: " << time << "s" << std::endl; } -/// Print [Log] label in bold blue format. +// Print [Log] label in bold blue format. void printLogLabel() { std::cout << "\033[34;1m[Log] \033[0m"; } -/// Load parameters into data container. +// Load parameters into data container. void loadParametersInt64(const std::string &int64ParamPath, MemRef &int64Param) { @@ -51,7 +61,7 @@ void loadParametersInt64(const std::string &int64ParamPath, int64ParamFile.close(); } -/// Load parameters into data container. +// Load parameters into data container. void loadParametersFloat(const std::string &floatParamPath, MemRef &floatParam) { std::ifstream floatParamFile(floatParamPath, std::ios::in | std::ios::binary); @@ -72,30 +82,181 @@ void fill_random_normal(MemRef &input, size_t size, unsigned seed) { std::mt19937 generator(seed); std::normal_distribution distribution(0.0f, 1.0f); for (size_t i = 0 ; i < size ; i ++ ) { - input.getData()[i] = distribution(generator); + input.getData()[i] = distribution(generator); + } +} + + +// SchedulerConfig structure, which contains the necessary configuration information +struct SchedulerConfig { + std::string prediction_type; // Prediction type: 'epsilon' or 'v_prediction' + MemRef alphas_cumprod; // Store alpha t + MemRef cur_sample; + float final_alpha_cumprod; // Final alpha + + SchedulerConfig(const MemRef& alphas, const MemRef& sample) + : alphas_cumprod(alphas), cur_sample(sample) {} +}; + +MemRef generate_betas(float beta_start, float beta_end, size_t num_train_timesteps) { + MemRef betas({num_train_timesteps}); + + // Calculate the square root range + float start_sqrt = std::sqrt(beta_start); + float end_sqrt = std::sqrt(beta_end); + for (size_t i = 0; i < num_train_timesteps; i++) { + float t = static_cast(i) / (num_train_timesteps - 1); // Calculate the scale + float value = start_sqrt + t * (end_sqrt - start_sqrt); // Linear interpolation + betas[i] = value * value; // square + } + return betas; +} + +// Auxiliary function: scalar multiplication of multidimensional arrays +MemRef memref_mul_scalar(const MemRef& memref, float scalar) { + MemRef result({1, 4, 64, 64}); + for (int i = 0; i < 1 * 4 * 64 * 64; ++i) { + result[i] = memref[i] * scalar; + } + return result; +} + +// Auxiliary function: Addition of multi-dimensional arrays +MemRef memref_add(const MemRef& a, const MemRef& b) { + MemRef result({1, 4, 64, 64}); + for (int i = 0; i < 1 * 4 * 64 * 64; ++i) { + result[i] = a[i] + b[i]; + } + return result; +} + +// get_prev_sample 函数实现 +MemRef get_prev_sample( + const MemRef& sample, + int timestep, + int prev_timestep, + const MemRef& model_output, + SchedulerConfig& config +) { + float alpha_prod_t = config.alphas_cumprod.getData()[timestep]; + float alpha_prod_t_prev = (prev_timestep >= 0) ? config.alphas_cumprod.getData()[prev_timestep] : config.final_alpha_cumprod; + float beta_prod_t = 1.0f - alpha_prod_t; + float beta_prod_t_prev = 1.0f - alpha_prod_t_prev; + MemRef prev_sample({1, 4, 64, 64}); + + // Processing prediction type + if (config.prediction_type == "v_prediction") { + // v_prediction formula + for (int i = 0; i < 1 * 4 * 64 * 64; ++i) { + prev_sample[i] = + std::sqrt(alpha_prod_t) * model_output[i] + + std::sqrt(beta_prod_t) * sample[i]; + } + } else if (config.prediction_type != "epsilon") { + throw std::invalid_argument("prediction_type must be one of `epsilon` or `v_prediction`"); } + + // Calculate sample_coeff + float sample_coeff = std::sqrt(alpha_prod_t_prev / alpha_prod_t); + + // Calculate model_output_denom_coeff + float model_output_denom_coeff = + alpha_prod_t * std::sqrt(beta_prod_t_prev) + + std::sqrt(alpha_prod_t * beta_prod_t * alpha_prod_t_prev); + // Apply formula (9) to calculate prev_sample + for (int i = 0; i < 1 * 4 * 64 * 64; ++i) { + prev_sample[i] = + sample_coeff * sample[i] - + (alpha_prod_t_prev - alpha_prod_t) * model_output[i] / model_output_denom_coeff; + } + + return prev_sample; +} + +// The core function step_plms performs the inference step +MemRef step_plms( + const MemRef model_output, + int timestep, + MemRef sample, + int num_inference_steps, + SchedulerConfig &config, + int& counter, + std::vector>& ets +) { + int prev_timestep = timestep - 1000 / num_inference_steps; + if (counter != 1) { + if (ets.size() > 3) ets.erase(ets.begin(), ets.begin() + ets.size() - 3); + ets.push_back(model_output); + } else { + prev_timestep = timestep; + timestep += 1000 / num_inference_steps; + } + + MemRef updated_model_output({1, 4, 64, 64}); + + if (ets.size() == 1 && counter == 0) { + updated_model_output = model_output; + config.cur_sample = sample; + } else if (ets.size() == 1 && counter == 1) { + updated_model_output = memref_mul_scalar(memref_add(model_output, ets.back()), 0.5); + sample = config.cur_sample; + } else if (ets.size() == 2) { + updated_model_output = memref_mul_scalar(memref_add(memref_mul_scalar(ets.back(), 3.0), memref_mul_scalar(ets[ets.size() - 2], -1.0)), 0.5); + } else if (ets.size() == 3) { + updated_model_output = memref_mul_scalar(memref_add(memref_add(memref_mul_scalar(ets.back(), 23.0), memref_mul_scalar(ets[ets.size() - 2], -16.0)), memref_mul_scalar(ets[ets.size() - 3], 5.0)), 1.0 / 12.0); + } else { + updated_model_output = memref_mul_scalar(memref_add(memref_add(memref_add(memref_mul_scalar(ets.back(), 55.0), memref_mul_scalar(ets[ets.size() - 2], -59.0)), memref_mul_scalar(ets[ets.size() - 3], 37.0)), memref_mul_scalar(ets[ets.size() - 4], -9.0)), 1.0 / 24.0); + } + + MemRef prev_sample = get_prev_sample(sample, timestep, prev_timestep, updated_model_output, config); + + return prev_sample; +} + +std::vector set_timesteps(int num_inference_steps) { + std::vector timesteps; + std::vector prk_timesteps; + std::vector plms_timesteps; + int step_ratio = 1000 / num_inference_steps; + timesteps.resize(num_inference_steps); + for (int i = 0; i < num_inference_steps; ++i) { + timesteps[i] = static_cast(round(i * step_ratio)) + 1; + } + + prk_timesteps.clear(); + plms_timesteps.resize(timesteps.size() - 1 + 2); + std::copy(timesteps.begin(), timesteps.end() - 1, plms_timesteps.begin()); + if (num_inference_steps > 1) + plms_timesteps[plms_timesteps.size() - 2] = timesteps[timesteps.size() - 2]; + plms_timesteps[plms_timesteps.size() - 1] = timesteps.back(); + std::reverse(plms_timesteps.begin(), plms_timesteps.end()); + + timesteps = prk_timesteps; // Adjust as needed + timesteps.insert(timesteps.end(), plms_timesteps.begin(), plms_timesteps.end()); + if (num_inference_steps == 1) + timesteps.resize(1); + return timesteps; } struct MemRefContainer { MemRef memRef3D; MemRef memRef2D; - MemRefContainer(MemRef m1, MemRef m2) + MemRefContainer(MemRef m1, MemRef m2) : memRef3D(m1), memRef2D(m2) {} }; extern "C" void -_mlir_ciface_forward_text_encoder(MemRefContainer *result, - MemRef *arg0, MemRef *arg1, - MemRef *arg2, MemRef *arg3); +_mlir_ciface_forward_text_encoder(MemRefContainer *result, MemRef *arg0, + MemRef *arg1, MemRef *arg2); extern "C" void -_mlir_ciface_forward_unet(MemRef *result1, +_mlir_ciface_forward_unet(MemRef *result, MemRef *arg0, MemRef *arg1, MemRef *arg2, MemRef *arg3); extern "C" void -_mlir_ciface_forward_vae(MemRef *result1, +_mlir_ciface_forward_vae(MemRef *result, MemRef *arg0, MemRef *arg1); @@ -103,20 +264,21 @@ int main() { const std::string title = "StableDiffusion Inference Powered by Buddy Compiler"; std::cout << "\033[33;1m" << title << "\033[0m" << std::endl; - /// Define directories of vacabulary and parameter file. + // Define directories of vacabulary and parameter file. const std::string vocabDir = "../../examples/BuddyStableDiffusion/vocab_sd.txt"; const std::string TextEncoderParamsDir1 = "../../examples/BuddyStableDiffusion/arg0_text_encoder.data"; const std::string TextEncoderParamsDir2 = "../../examples/BuddyStableDiffusion/arg1_text_encoder.data"; const std::string UnetParamsDir = "../../examples/BuddyStableDiffusion/arg0_unet.data"; const std::string VaeParamsDir = "../../examples/BuddyStableDiffusion/arg0_vae.data"; - /// Get user message. + // Get user message. std::string inputStr; - float timesteps; + std::string image_name; + int InferenceSteps; getUserInput(inputStr); - getTimeSteps(timesteps); - - + getInferenceSteps(InferenceSteps); + getFileName(image_name); + // Define the text_encoder parameter MemRef myMemRef1({1, 77, 1024}); MemRef myMemRef2({1, 1024}); MemRefContainer resultTextEncoderPos(myMemRef1, myMemRef2); @@ -130,26 +292,16 @@ int main() { TextEncoderInputIDsPos.tokenizeStableDiffusion(vocabDir, 77); Text TextEncoderInputIDsNeg(""); TextEncoderInputIDsNeg.tokenizeStableDiffusion(vocabDir, 77); - MemRef attention_mask_pos({1, 77}, 0LL); - MemRef attention_mask_neg({1, 77}, 0LL); - for (int i = 0; i < 77; i++) { - attention_mask_pos.getData()[i] = 1; - if (TextEncoderInputIDsPos.getData()[i] == 49407) break; - } - for (int i = 0; i < 77; i++) { - attention_mask_neg.getData()[i] = 1; - if (TextEncoderInputIDsNeg.getData()[i] == 49407) break; - } - - MemRef resultUnet({1, 4, 96, 96}); + // Define unet parameters + MemRef resultUnet({1, 4, 64, 64}); MemRef arg0_unet({865910724}); - MemRef latents({1, 4, 96, 96}); - MemRef timestep({999}); - - MemRef resultVae({1, 3, 768, 768}); + MemRef latents({1, 4, 64, 64}); + MemRef timestep({1}); + // Define vae parameters + MemRef resultVae({1, 3, 512, 512}); MemRef arg0_vae({49490179}); - + // Loading model parameters printLogLabel(); std::cout << "Loading params..." << std::endl; const auto loadStart = std::chrono::high_resolution_clock::now(); @@ -164,78 +316,72 @@ int main() { << "s\n" << std::endl; - + // encode prompt printLogLabel(); std::cout << "Encoding prompt..." << std::endl; const auto encodeStart = std::chrono::high_resolution_clock::now(); - _mlir_ciface_forward_text_encoder(ptrPos, &arg0_text_encoder, &arg1_text_encoder, &TextEncoderInputIDsPos, &attention_mask_pos); - _mlir_ciface_forward_text_encoder(ptrNeg, &arg0_text_encoder, &arg1_text_encoder, &TextEncoderInputIDsNeg, &attention_mask_neg); + _mlir_ciface_forward_text_encoder(ptrPos, &arg0_text_encoder, &arg1_text_encoder, &TextEncoderInputIDsPos); + _mlir_ciface_forward_text_encoder(ptrNeg, &arg0_text_encoder, &arg1_text_encoder, &TextEncoderInputIDsNeg); const auto encodeEnd = std::chrono::high_resolution_clock::now(); const std::chrono::duration encodeTime = encodeEnd - encodeStart; printLogLabel(); - std::cout << "Encode prompt time: " << (double)(encodeTime.count()) / 1000 + std::cout << "Prompt encode time: " << (double)(encodeTime.count()) / 1000 << "s\n" << std::endl; + // Concatenation of Positive and Negative embeddings auto TextEncoderOutPos = ptrPos->memRef3D; auto TextEncoderOutNeg = ptrNeg->memRef3D; for (int i = 0; i < 2 * 77 * 1024 ; i ++ ){ if (i < 1 * 77 * 1024) - TextEncoderOut.getData()[i] = TextEncoderOutPos.getData()[i]; + TextEncoderOut.getData()[i] = TextEncoderOutNeg.getData()[i]; else - TextEncoderOut.getData()[i] = TextEncoderOutNeg.getData()[i % (1 * 77 * 1024)]; - } - std::ofstream outFileTextEncoder("../../examples/BuddyStableDiffusion/output_text_encoder.txt"); - if (outFileTextEncoder.is_open()) { - for (int i = 0 ; i < 2 * 77 * 1024 ; i ++ ) { - outFileTextEncoder << TextEncoderOut.getData()[i] << std::endl; - } - outFileTextEncoder.close(); - } else { - std::cerr << "Unable to open file" << std::endl; + TextEncoderOut.getData()[i] = TextEncoderOutPos.getData()[i % (1 * 77 * 1024)]; } - - fill_random_normal(latents, 1 * 4 * 96 * 96, 12345); + // Generate initial noise + fill_random_normal(latents, 1 * 4 * 64 * 64, 42); printLogLabel(); std::cout << "Start denoising..." << std::endl; - - for (int i = 1; i <= timesteps ; i ++){ - MemRef noise({2, 4, 96, 96}); - for (int j = 0 ; j < 2 * 4 * 96 * 96 ; j ++ ) - noise.getData()[j] = latents.getData()[j % (1 * 4 * 96 * 96)]; - + // Set config + MemRef alphas_cumprod({1000}); + MemRef cur_sample({1, 4, 64, 64}); + SchedulerConfig config(alphas_cumprod, cur_sample); + alphas_cumprod = generate_betas(0.00085, 0.012, 1000); + for (int i = 0 ; i < 1000 ; i ++ ){ + alphas_cumprod.getData()[i] = 1.0 - alphas_cumprod.getData()[i]; + if (i >= 1) + alphas_cumprod.getData()[i] = alphas_cumprod.getData()[i] * alphas_cumprod.getData()[i - 1]; + config.alphas_cumprod.getData()[i] = alphas_cumprod.getData()[i]; + } + config.final_alpha_cumprod = config.alphas_cumprod.getData()[0]; + config.prediction_type = "epsilon"; + std::vector> ets; + auto timesteps = set_timesteps(InferenceSteps); + + // Denoising loop + for (int i = 0; i < (int)timesteps.size() ; i ++){ + MemRef noise({2, 4, 64, 64}); + for (int j = 0 ; j < 2 * 4 * 64 * 64 ; j ++ ) + noise.getData()[j] = latents.getData()[j % (1 * 4 * 64 * 64)]; + + timestep.getData()[0] = timesteps[i]; const auto inferenceStart = std::chrono::high_resolution_clock::now(); _mlir_ciface_forward_unet(&resultUnet, &arg0_unet, &noise, ×tep, &TextEncoderOut); const auto inferenceEnd = std::chrono::high_resolution_clock::now(); const std::chrono::duration inferenceTime = inferenceEnd - inferenceStart; printIterInfo(i, inferenceTime.count() / 1000); - - - std::ofstream outFileUnet("../../examples/BuddyStableDiffusion/output_unet.txt"); - if (outFileUnet.is_open()) { - for (int j = 0 ; j < 1 * 4 * 96 * 96 ; j ++ ) { - outFileUnet << resultUnet.getData()[j] << std::endl; - } - outFileUnet.close(); - } else { - std::cerr << "Unable to open file" << std::endl; + MemRef pred_noise({2, 4, 64, 64}); + for (int j = 0 ; j < 1 * 4 * 64 * 64; j ++){ + pred_noise.getData()[j] = resultUnet.getData()[j] + 7.5 * (resultUnet.getData()[j + 1 * 4 * 64 * 64] - resultUnet.getData()[j]); } - - MemRef pred_noise({1, 4, 96, 96}); - for (int j = 0 ; j < 1 * 4 * 96 * 96; j ++){ - pred_noise.getData()[i] = resultUnet.getData()[i] + 7.5 * (resultUnet.getData()[i + 1 * 4 * 96 * 96] - resultUnet.getData()[i]); - } - - //There is a scheduler that still needs to be implemented. - // latents = schedulerStep(pred_noise, timestep, latents) - + latents = step_plms(pred_noise, timesteps[i], latents, InferenceSteps, config, i, ets); } - for (int i = 0 ; i < 1 * 4 * 96 * 96 ; i ++){ + for (int i = 0 ; i < 1 * 4 * 64 * 64 ; i ++){ latents.getData()[i] = latents.getData()[i] / 0.18215; } - + // decode std::cout << std::endl; printLogLabel(); std::cout << "Start decoding..." << std::endl; @@ -248,27 +394,34 @@ int main() { << "s\n" << std::endl; - std::ofstream outFileVae("../../examples/BuddyStableDiffusion/output_vae.txt"); - if (outFileVae.is_open()) { - for (int i = 0 ; i < 1 * 3 * 768 * 768 ; i ++ ) { - outFileVae << resultVae.getData()[i] << std::endl; - } - outFileVae.close(); - } else { - std::cerr << "Unable to open file" << std::endl; - } - - for (int i = 0 ; i < 1 * 3 * 768 * 768 ; i ++ ){ + for (int i = 0 ; i < 1 * 3 * 512 * 512 ; i ++ ){ resultVae.getData()[i] = (resultVae.getData()[i] + 1) / 2; - //clamp(0, 1) + // clamp(0, 1) if (resultVae.getData()[i] < 0) resultVae.getData()[i] = 0; if (resultVae.getData()[i] > 1) resultVae.getData()[i] = 1; resultVae.getData()[i] = resultVae.getData()[i] * 255; } + intptr_t sizes[4] = {512, 512, 3, 1}; + Img img(sizes); + + // Rearrange the images + for (int i = 0 ; i < 3 * 512 * 512 ; i += 3 ){ + img.getData()[i] = resultVae.getData()[i / 3 + 512 * 512 * 2]; + img.getData()[i + 1] = resultVae.getData()[i / 3 + 512 * 512 * 1]; + img.getData()[i + 2] = resultVae.getData()[i / 3 + 512 * 512 * 0]; + } + + String filename = "../../examples/BuddyStableDiffusion/" + image_name + ".png"; + // Call the imwrite function + bool success = imwrite(filename, img); - // The conversion of data to the image part still needs to be implemented. + if (success) { + std::cout << "Image saved successfully to " << filename << std::endl; + } else { + std::cerr << "Failed to save the image." << std::endl; + } return 0; } \ No newline at end of file diff --git a/examples/BuddyStableDiffusion/import-stable-diffusion.py b/examples/BuddyStableDiffusion/import-stable-diffusion.py index c38dde3dfe..a48056816d 100644 --- a/examples/BuddyStableDiffusion/import-stable-diffusion.py +++ b/examples/BuddyStableDiffusion/import-stable-diffusion.py @@ -13,7 +13,7 @@ from diffusers import StableDiffusionPipeline device = torch.device("cuda") -model_id = "stabilityai/stable-diffusion-2-1" +model_id = "stabilityai/stable-diffusion-2-1-base" pipe = StableDiffusionPipeline.from_pretrained( model_id, torch_dtype=torch.float32 ) @@ -24,7 +24,7 @@ text_encoder = pipe.text_encoder.forward unet = pipe.unet.forward -vae = pipe.vae.decoder +vae = pipe.vae.decode # Initialize Dynamo Compiler with specific configurations as an importer. @@ -50,19 +50,19 @@ prompt, return_tensors="pt", padding="max_length" ).to(device) data_unet = { - "sample": torch.ones((1, 4, 96, 96), dtype=torch.float32).to(device), + "sample": torch.ones((2, 4, 64, 64), dtype=torch.float32).to(device), "timestep": torch.tensor([1], dtype=torch.float32).to(device), - "encoder_hidden_states": torch.ones((1, 77, 1024), dtype=torch.float32).to( + "encoder_hidden_states": torch.ones((2, 77, 1024), dtype=torch.float32).to( device ), } -data_vae = torch.ones((1, 4, 96, 96), dtype=torch.float32).to(device) +data_vae = torch.ones((1, 4, 64, 64), dtype=torch.float32).to(device) # Import the model into MLIR module and parameters. with torch.no_grad(): graphs_text_encoder = dynamo_compiler_text_encoder.importer( - text_encoder, **data_text_encoder + text_encoder, data_text_encoder["input_ids"].to(device), None ) graphs_unet = dynamo_compiler_unet.importer(unet, **data_unet) graphs_vae = dynamo_compiler_vae.importer(vae, data_vae) diff --git a/frontend/Interfaces/buddy/LLM/TextContainer.h b/frontend/Interfaces/buddy/LLM/TextContainer.h index a02bf8b67f..ffa5c03446 100644 --- a/frontend/Interfaces/buddy/LLM/TextContainer.h +++ b/frontend/Interfaces/buddy/LLM/TextContainer.h @@ -335,7 +335,7 @@ void Text::tokenizeStableDiffusion(const std::string &vocab, size_t length size_t size = this->product(this->sizes); this->allocated = (T *)malloc(sizeof(T) * size); this->aligned = this->allocated; - this->pad = 49407; + this->pad = 0; this->unk = 49407; this->bos = 49406; this->eos = 49407; @@ -347,7 +347,7 @@ void Text::tokenizeStableDiffusion(const std::string &vocab, size_t length std::string token; for (size_t i = 0; i < str.size(); ++i) { - char c = str[i]; + char c = tolower(str[i]); // Special match cases if (str.substr(i, 15) == "<|startoftext|>" || str.substr(i, 13) == "<|endoftext|>") { if (!token.empty()) {