Skip to content

Commit

Permalink
flute/integrations/huggingface and utils: updated
Browse files Browse the repository at this point in the history
  • Loading branch information
HanGuo97 committed Jan 7, 2025
1 parent ba0c3e4 commit 2a7ff14
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 10 deletions.
19 changes: 18 additions & 1 deletion flute/integrations/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def replace_with_flute_linear(


def _repack_flute_linear(model: torch.nn.Module, quantization_config: FluteConfig) -> None:
import flute.tune
import flute.utils
from flute.integrations.base import FluteLinear

Expand All @@ -190,10 +191,11 @@ def _repack_flute_linear(model: torch.nn.Module, quantization_config: FluteConfi
workspace=flute.utils.get_workspace_streamk(device),
num_bits=module.num_bits,
group_size=module.group_size,
template_id_packed=module.template_id,
num_sms_packed=quantization_config.num_sms_packed)

# re-pack the tensors
Q_repacked = flute.utils.pack(
Q_repacked, tune_metadata = flute.tune.tune_and_pack(
Q_unpacked.T.contiguous().to(device="cpu"),
num_bits=module.num_bits,
group_size=module.group_size).to(device=module.weight.device)
Expand All @@ -219,6 +221,7 @@ def _repack_flute_linear(model: torch.nn.Module, quantization_config: FluteConfi

module.weight = Q_repacked
module.tables2 = tables2
module.template_id = tune_metadata.template_id

if len(list(module.children())) > 0:
_repack_flute_linear(module, quantization_config=quantization_config)
Expand Down Expand Up @@ -284,6 +287,20 @@ def _process_model_before_weight_loading(self, model: PreTrainedModel, keep_in_f
def _process_model_after_weight_loading(self, model: PreTrainedModel, **kwargs) -> None:
return _repack_flute_linear(model, quantization_config=self.quantization_config)

def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
from flute.integrations.base import FluteLinear

not_missing_keys = []
for name, module in model.named_modules():
if isinstance(module, FluteLinear):
for missing in missing_keys:
if (
(name in missing or name in f"{prefix}.{missing}")
and missing.endswith(torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX)
):
not_missing_keys.append(missing)
return [k for k in missing_keys if k not in not_missing_keys]

@property
def is_trainable(self) -> bool:
return False
Expand Down
18 changes: 9 additions & 9 deletions flute/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,8 @@ def reconstruct(
workspace: torch.Tensor,
num_bits: int,
group_size: int,
num_sms: Optional[int] = None,
template_id: int,
num_sms: int,
) -> torch.Tensor:
# we reconstruct the tensor using the fact that
# `W.T = I @ W.T` and thus using the `qgemm` routine
Expand All @@ -361,20 +362,17 @@ def reconstruct(
dtype=scales.dtype,
device=scales.device)

if num_sms is None:
_qgemm = qgemm_simple
else:
_qgemm = QGEMM_SIMPLE_DICT[num_sms]

weight_reconstructed = _qgemm(
weight_reconstructed = qgemm(
inputs,
weight,
scales,
tables,
tables2,
workspace,
num_bits,
group_size)
group_size,
template_id,
num_sms)
return weight_reconstructed.T


Expand All @@ -384,7 +382,8 @@ def unpack(
workspace: torch.Tensor,
num_bits: int,
group_size: int,
num_sms_packed: Optional[int] = None,
template_id_packed: int,
num_sms_packed: int,
) -> torch.Tensor:

# the scales needs to be just ones
Expand All @@ -404,6 +403,7 @@ def unpack(
workspace=workspace,
num_bits=num_bits,
group_size=group_size,
template_id=template_id_packed,
num_sms=num_sms_packed)


Expand Down

0 comments on commit 2a7ff14

Please sign in to comment.