Skip to content

Commit

Permalink
fix conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
ceci3 committed Jan 24, 2025
2 parents e1125e1 + c914e4c commit d0ac2ce
Show file tree
Hide file tree
Showing 22 changed files with 528 additions and 87 deletions.
14 changes: 5 additions & 9 deletions examples/emu/conf/compress/compress_emu3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,16 @@ defaults:
- _self_

data:
data_path: /share/project/lms/emu3_testdata/3_text_document
num_calibration_steps: 16
max_seq_length: 9216
data_path: null
max_calib_data: null
max_seq_len: null
tokenzier_args:
tokenizer_path: /share/project/lms/Emu3-Gen/
special_tokens_file: /share/project/lms/Emu3-Gen/emu3_vision_tokens.txt
tokenizer_path: BAAI/Emu3-Gen/
special_tokens_file: BAAI/Emu3-Gen/emu3_vision_tokens.txt
trust_remote_code: true

compress_args:
quantization:
- algo:
smoothquant:
smoothing_strength: 0.5
ignore: ["lm_head"]
- algo:
targets: ["Linear"]
ignore: ["lm_head"]
Expand Down
12 changes: 4 additions & 8 deletions examples/emu/conf/compress/compress_emu3_w4a16.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,16 @@ defaults:
- _self_

data:
data_path: /share/project/lms/emu3_testdata/3_text_document
num_calibration_steps: 16
data_path:
num_calibration_samples: 16
max_seq_length: 9216
tokenzier_args:
tokenizer_path: /share/project/lms/Emu3-Gen/
special_tokens_file: /share/project/lms/Emu3-Gen/emu3_vision_tokens.txt
tokenizer_path: BAAI/Emu3-Gen/
special_tokens_file: BAAI/Emu3-Gen/emu3_vision_tokens.txt
trust_remote_code: true

compress_args:
quantization:
- algo:
smoothquant:
smoothing_strength: 0.5
ignore: ["lm_head"]
- algo:
gptq:
blocksize: 128
Expand Down
5 changes: 3 additions & 2 deletions examples/emu/conf/compress/emu3_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ system:
tensorboard_log_interval: 1
wandb_project: "compress-emu3-7B"
wandb_exp_name: "compress-test-7B"
save_dir: outputs/emu3/inference_model
save_dir:


model:
model_cls: AutoModelForCausalLM
model_path: /share/project/lms/Emu3-Gen/
model_path: BAAI/Emu3-Gen/
device_map: cuda:0
trust_remote_code: true
torch_dtype: bfloat16
10 changes: 1 addition & 9 deletions flagscale/compress/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
QuantizationScheme,
disable_quantization,
enable_quantization,
is_attention_module,
)
from llmcompressor.modifiers.quantization.calibration import (
apply_calibration_status,
Expand All @@ -40,7 +39,6 @@
from flagscale.compress.blockwise_compressor import BlockCompressor

from flagscale.runner.runner_utils import logger
import pdb

__all__ = ["LLMCompressorAdapter"]

Expand All @@ -55,8 +53,6 @@
class LLMCompressorAdapter:
def __init__(self, model, scheme=None, targets=None, algo=None, ignore=None, dataset=None, num_calibration_steps=384):
self.model = model
# print("model: ", model)
# modify_save_pretrained(self.model)
if algo is not None:
assert len(algo) == 1
for k, v in algo.items():
Expand Down Expand Up @@ -91,7 +87,6 @@ def __init__(self, model, scheme=None, targets=None, algo=None, ignore=None, dat
self.wrapper_cls = RTNWrapper
self.compress_granularity = LayerCompressor
quant_config = self.init_quant_config()
print(quant_config)

if quant_config is not None:
### find ignore and target to quant, initialize module for quant
Expand All @@ -101,7 +96,6 @@ def __init__(self, model, scheme=None, targets=None, algo=None, ignore=None, dat

self.init_compressor()
if self.require_calib:
# self.insert_observer()
if model.training == False: ### Post Training
assert self.dataset is not None, f"The algorithm {self.algo} you selected requires a calibration process, please provide the calibration data"
self.run_blockwise_calib_forward()
Expand All @@ -112,11 +106,11 @@ def __init__(self, model, scheme=None, targets=None, algo=None, ignore=None, dat
self.layer_compressors_[0].clear_early_stop()
for idx, layer_compressor in enumerate(self.layer_compressors_):
layer_compressor.pre_compress()
# import pdb;pdb.set_trace()
layer_compressor.compress()
layer_compressor.post_compress()
layer_compressor.revert_layer_wrappers()


def init_quant_config(self):
if self.scheme is not None:
# takes precedence over config_groups
Expand Down Expand Up @@ -182,7 +176,6 @@ def run_blockwise_calib_forward(self):
for idx, layer_compressor in enumerate(self.layer_compressors_):
logger.info(f"start calibration layer {layer_compressor.name}")
layer_compressor.pre_compress()
# print("idx: ", idx, intermediates)
unquantized_outputs = layer_compressor.calibrate_layer(intermediates)
layer_compressor.compress()
layer_compressor.post_compress()
Expand All @@ -192,4 +185,3 @@ def run_blockwise_calib_forward(self):
logger.info(f"Mean output error from quantization: {error:.3f}")
intermediates = quantized_outputs
self.model.apply(enable_quantization)

23 changes: 0 additions & 23 deletions flagscale/compress/algo/algo_base.py

This file was deleted.

16 changes: 3 additions & 13 deletions flagscale/compress/blockwise_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,10 @@ def replace_block(target: str, model: Module, target_module: Module):

class BlockCompressor(LayerCompressor):
def pre_compress(self):
# full_name = self._get_full_submodule_name(self.name)
full_name = self.name
# import pdb; pdb.set_trace()
with summon_full_params_context(self.layer):
wrapper = self.module_compressor_class(full_name, self.layer)
if len(full_name) == 0: # special case if layer has no children (i.e. lm_head)
with summon_full_params_context(self.model):
replace_block(full_name, self.model, wrapper)
else:
replace_block(full_name, self.model, wrapper)
replace_block(full_name, self.model, wrapper)
self.modules[full_name] = wrapper

self.layer = operator.attrgetter(self.name)(self.model)
Expand All @@ -43,10 +37,6 @@ def revert_layer_wrappers(self):
"""
for name, module_wrapper in self.modules.items():
full_name = self.name
if len(full_name) == 0: # special case if layer has no children (i.e. lm_head)
with summon_full_params_context(self.model):
replace_block(full_name, self.model, module_wrapper.layer)
else:
replace_block(full_name, self.model, module_wrapper.layer)
replace_block(full_name, self.model, module_wrapper.layer)
torch.cuda.empty_cache()
self.modules = None
self.modules = None
3 changes: 1 addition & 2 deletions flagscale/compress/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def compress(self):
if self.model is None:
model_cls = eval(self.cfg.model.pop("model_cls"))
self.model = model_cls.from_pretrained(self.model_path, **self.cfg.model)
# import pdb; pdb.set_trace()
assert isinstance(self.model, torch.nn.Module), f"model type {type(self.model)} error, please check it"
compress_args = self.cfg.compress_args
recipes = prepare_compress_methods(compress_args)
Expand Down Expand Up @@ -115,4 +114,4 @@ def convert(self, model):
args = parser.parse_args()
cfg = prepare_config(args.config_path)

Compressor(cfg)
Compressor(cfg)
12 changes: 0 additions & 12 deletions flagscale/compress/compressor_emu3.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,4 @@ def prepare_dataset(cfg):
dataset = prepare_dataset(cfg)
cmp = Compressor(cfg, dataset=dataset)
cmp.compress()
model = cmp.convert(cmp.model)

### test code
with torch.no_grad():
from llmcompressor.pytorch.utils import tensors_to_device
model_device = next(model.parameters()).device
for idx, data in enumerate(dataset):
data = tensors_to_device(data, model_device)
if idx < 2:
model(**data)
else:
break

2 changes: 0 additions & 2 deletions flagscale/compress/compressor_llava_ov.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,6 @@ def prepare_dataset(cfg, model, tokenizer):
elif isinstance(data_args.image_grid_pinpoints, str):
data_args.image_grid_pinpoints = ast.literal_eval(data_args.image_grid_pinpoints)
dataset = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
import pdb
pdb.set_trace()

ds = CusDataset(dataset["train_dataset"])
return ds
Expand Down
2 changes: 2 additions & 0 deletions flagscale/train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@
from .global_vars import set_extra_input_tensor
from .global_vars import get_parallel_context
from .global_vars import set_parallel_context
from .global_vars import get_spiky_loss_detector
from .global_vars import set_get_spiky_loss_detector
from .arguments import FSTrainArguments
15 changes: 14 additions & 1 deletion flagscale/train/global_vars.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import torch

from flagscale.train.hetero.parallel_context import ParallelContext
from flagscale.train.spiky_loss import SpikyLossDetector

_GLOBAL_EXTRA_VALID_DATASETS = None
_GLOBAL_EXATRA_INPUT_TENSOR = None
_GLOBAL_PARALLEL_CONTEXT = None

_GLOBAL_SPIKY_LOSS_DETECTOR = None

def _ensure_var_is_initialized(var, name):
"""Make sure the input variable is not None."""
Expand Down Expand Up @@ -49,3 +50,15 @@ def set_parallel_context(args):
global _GLOBAL_PARALLEL_CONTEXT
_ensure_var_is_not_initialized(_GLOBAL_PARALLEL_CONTEXT, 'parallel context')
_GLOBAL_PARALLEL_CONTEXT = ParallelContext(args)

def get_spiky_loss_detector():
"""Return spiky loss detector."""
_ensure_var_is_initialized(_GLOBAL_SPIKY_LOSS_DETECTOR, 'spiky loss detector')
return _GLOBAL_SPIKY_LOSS_DETECTOR


def set_get_spiky_loss_detector(args):
"""Initialize spiky loss detector."""
global _GLOBAL_SPIKY_LOSS_DETECTOR
_ensure_var_is_not_initialized(_GLOBAL_SPIKY_LOSS_DETECTOR, 'spiky loss detector')
_GLOBAL_SPIKY_LOSS_DETECTOR = SpikyLossDetector(args.spiky_loss_threshold)
52 changes: 52 additions & 0 deletions flagscale/train/spiky_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import math
import torch

class SpikyLossDetector:
'''This class represents a Spiky Loss Detector.
It is used to detect spikes in loss values during training.
'''
def __init__(self, threshold=0.2, loss = None):
self.last_loss = loss
self.threshold = threshold

def reduce_losses(self, losses_reduced):
loss_reduced = {}
from megatron.core import mpu
if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Average loss across microbatches.
for key in losses_reduced[0].keys():
numerator = 0
denominator = 0
for x in losses_reduced:
val = x[key]
# there is one dict per microbatch. in new reporting, we average
# over the total number of tokens across the global batch.
if isinstance(val, tuple) or isinstance(val, list):
numerator += val[0]
denominator += val[1]
else:
# legacy behavior. we average over the number of microbatches,
# and so the denominator is 1.
numerator += val
denominator += 1
loss_reduced[key] = numerator / denominator
return loss_reduced.get('lm loss')

def is_spkiy_loss(self, loss):
if loss is None:
return False
if self.last_loss is not None:
if math.isnan(loss) or math.isnan(self.last_loss):
self.last_loss = loss
elif math.isinf(loss) or math.isinf(self.last_loss):
return True
else:
result = (loss - self.last_loss) / self.last_loss >= self.threshold
if result:
return True
else:
self.last_loss = loss
else:
self.last_loss = loss
return False

43 changes: 42 additions & 1 deletion flagscale/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
from flagscale.train.extra_valid import extra_evaluate_and_print_results
from flagscale.train.extra_valid import build_extra_valid_data_iterators
from flagscale.train.stablelm2_scheduler import StableLM2SchedulerConfig
from flagscale.train.global_vars import get_parallel_context
from flagscale.train.global_vars import get_parallel_context, get_spiky_loss_detector
from flagscale.train.hetero.p2p_communication import get_device_type_for_comm

stimer = StragglerDetector()
Expand Down Expand Up @@ -832,6 +832,18 @@ def train_step(forward_step_func, data_iterator,
if should_exit:
return {}, True, should_checkpoint, should_exit, exit_code, None, None

########## FlagScale Begin ##########
if args.auto_skip_spiky_loss and (args.consumed_train_samples > args.lr_warmup_samples and args.curr_iteration > args.lr_warmup_iters):
spiky_loss_detector = get_spiky_loss_detector()
loss_ = spiky_loss_detector.reduce_losses(losses_reduced)
is_spiky_loss = spiky_loss_detector.is_spkiy_loss(loss_)
is_spiky_loss_tensor = torch.tensor(is_spiky_loss, dtype=torch.int, device="cuda")
torch.distributed.all_reduce(is_spiky_loss_tensor, op=torch.distributed.ReduceOp.MAX)
is_spiky_loss = is_spiky_loss_tensor.item()
if is_spiky_loss > 0:
return {}, True, should_checkpoint, should_exit, exit_code, None, None
########## FlagScale Begin ##########

# Empty unused memory.
if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache()
Expand Down Expand Up @@ -1573,6 +1585,35 @@ def get_e2e_base_metrics():

# Run training step.
args.curr_iteration = iteration

########## FlagScale Begin ##########
if args.skip_samples_range or args.skip_iters_range:
current_global_batch_size = get_current_global_batch_size()
start_skip_iteration = 0
end_skip_iteration = 0
if args.skip_samples_range:
if args.consumed_train_samples + current_global_batch_size > args.skip_samples_range[0] and args.consumed_train_samples < args.skip_samples_range[1]:
num_skipped_iters = (args.skip_samples_range[1] - args.consumed_train_samples + current_global_batch_size - 1) // current_global_batch_size
args.skip_samples_range[1] = args.consumed_train_samples + num_skipped_iters * current_global_batch_size
start_skip_iteration = iteration
end_skip_iteration = iteration + num_skipped_iters
else:
if iteration >= args.skip_iters_range[0] and iteration < args.skip_iters_range[1]:
start_skip_iteration = iteration
end_skip_iteration = args.skip_iters_range[1]
while iteration >= start_skip_iteration and iteration < end_skip_iteration:
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
for _ in range(get_num_microbatches()):
_ = next(train_data_iterator)
args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
get_num_microbatches()
update_num_microbatches(args.consumed_train_samples, consistency_check=True, verbose=True)
iteration += 1

args.curr_iteration = iteration
########## FlagScale Begin ##########

loss_dict, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad = \
train_step(forward_step_func,
train_data_iterator,
Expand Down
Loading

0 comments on commit d0ac2ce

Please sign in to comment.