diff --git a/intel_extension_for_pytorch/_inductor/xpu/triton_heuristics.py b/intel_extension_for_pytorch/_inductor/xpu/triton_heuristics.py index 960793820..7847fc0b7 100644 --- a/intel_extension_for_pytorch/_inductor/xpu/triton_heuristics.py +++ b/intel_extension_for_pytorch/_inductor/xpu/triton_heuristics.py @@ -212,7 +212,7 @@ def launcher({', '.join(def_args)}, grid, stream): launcher.config = cfg launcher.n_regs = getattr(binary, "n_regs", None) launcher.n_spills = getattr(binary, "n_spills", None) - launcher.shared = getattr(binary, "shared", None) + launcher.shared = scope["shared"] launcher.store_cubin = False # store this global varible to avoid the high overhead of reading it when calling run if launcher.store_cubin: