Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

loss gradient when run hidden_states = hidden_states.to(torch.float32) #6675

Closed
1 task done
hanlinxuy opened this issue Jan 16, 2025 · 2 comments
Closed
1 task done
Labels
solved This problem has been already solved

Comments

@hanlinxuy
Copy link

Reminder

  • I have read the above rules and searched the existing issues.

System Info

  • llamafactory version: 0.9.2.dev0
  • Platform: Linux-5.15.0-1039-nvidia-lowlatency-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • PyTorch version: 2.4.0+cu121 (GPU)
  • Transformers version: 4.47.1
  • Datasets version: 3.1.0
  • Accelerate version: 1.0.1
  • PEFT version: 0.12.0
  • TRL version: 0.9.6
  • GPU type: NVIDIA A800 80GB PCIe
  • DeepSpeed version: 0.15.0

Reproduction

[2025-01-16 12:29:33,117] [INFO] [config.py:1003:print]   zero_config .................. stage=1 contiguous_gradients=True reduce_scatter=True reduce_bucket_size=500000000 use_multi_rank_bucket_allreduce=True allgather_partitions=True allgather_bucket_size=500000000 overlap_comm=True load_from_fp32_weights=True elastic_checkpoint=False offload_param=None offload_optimizer=None sub_group_size=1000000000 cpu_offload_param=None cpu_offload_use_pin_memory=None cpu_offload=None prefetch_bucket_size=50000000 param_persistence_threshold=100000 model_persistence_threshold=9223372036854775807 max_live_parameters=1000000000 max_reuse_distance=1000000000 gather_16bit_weights_on_model_save=False use_all_reduce_for_fetch_params=False stage3_gather_fp16_weights_on_model_save=False ignore_unused_parameters=True legacy_stage1=False round_robin_gradients=True zero_hpz_partition_size=1 zero_quantized_weights=False zero_quantized_nontrainable_weights=False zero_quantized_gradients=False mics_shard_size=-1 mics_hierarchical_params_gather=False memory_efficient_linear=True pipeline_loading_checkpoint=False override_module_apply=True
[2025-01-16 12:29:33,117] [INFO] [config.py:1003:print]   zero_enabled ................. True
[2025-01-16 12:29:33,117] [INFO] [config.py:1003:print]   zero_force_ds_cpu_optimizer .. True
[2025-01-16 12:29:33,117] [INFO] [config.py:1003:print]   zero_optimization_stage ...... 1
[2025-01-16 12:29:33,117] [INFO] [config.py:989:print_user_config]   json = {
    "train_batch_size": 128,
    "train_micro_batch_size_per_gpu": 4,
    "gradient_accumulation_steps": 32,
    "gradient_clipping": 1.0,
    "zero_allow_untested_optimizer": true,
    "fp16": {
        "enabled": false,
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "bf16": {
        "enabled": false
    },
    "zero_optimization": {
        "stage": 1,
        "allgather_partitions": true,
        "allgather_bucket_size": 5.000000e+08,
        "overlap_comm": true,
        "reduce_scatter": true,
        "reduce_bucket_size": 5.000000e+08,
        "contiguous_gradients": true,
        "round_robin_gradients": true
    },
    "comms_logger": {
        "enabled": true,
        "verbose": false,
        "prof_all": true,
        "debug": false
    },
    "flops_profiler": {
        "enabled": true,
        "profile_step": 10,
        "module_depth": -1,
        "top_modules": 1,
        "detailed": false,
        "output_file": null
    },
    "steps_per_print": inf
}
[INFO|trainer.py:2362] 2025-01-16 12:29:33,118 >> ***** Running training *****
[INFO|trainer.py:2363] 2025-01-16 12:29:33,118 >>   Num examples = 103
[INFO|trainer.py:2364] 2025-01-16 12:29:33,118 >>   Num Epochs = 10,000
[INFO|trainer.py:2365] 2025-01-16 12:29:33,118 >>   Instantaneous batch size per device = 4
[INFO|trainer.py:2368] 2025-01-16 12:29:33,118 >>   Total train batch size (w. parallel, distributed & accumulation) = 128
[INFO|trainer.py:2369] 2025-01-16 12:29:33,118 >>   Gradient Accumulation steps = 32
[INFO|trainer.py:2370] 2025-01-16 12:29:33,118 >>   Total optimization steps = 10,000
[INFO|trainer.py:2371] 2025-01-16 12:29:33,118 >>   Number of trainable parameters = 290,920,960
  0%|                                                                                                                                                                                                                                | 0/10000 [00:00<?, ?it/s]False --------------------------------
##################Go to decode layer 0##################
decode0 True tensor(-9.5963e-06, device='cuda:0', dtype=torch.bfloat16)
norm0 True tensor(-9.5963e-06, device='cuda:0', dtype=torch.bfloat16)
norm1 False tensor(-9.6218e-06, device='cuda:0')
norm2 False tensor(0.0006, device='cuda:0')
decode1 False tensor(0.0057, device='cuda:0', dtype=torch.bfloat16)
decode1 True tensor(0.0057, device='cuda:0', dtype=torch.bfloat16)
norm0 False tensor(0.0057, device='cuda:0', dtype=torch.bfloat16)
norm1 False tensor(0.0057, device='cuda:0')
norm2 False tensor(0.0908, device='cuda:0')
0 False tensor(0.5228, device='cuda:0')
norm0 True tensor(0.0061, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
norm1 True tensor(0.0061, device='cuda:0', grad_fn=<MeanBackward0>)
norm2 True tensor(0.2087, device='cuda:0', grad_fn=<MeanBackward0>)
loss True False
predictor_loss: tensor(0.5228, device='cuda:0')
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/llama_factory/src/llamafactory/launcher.py", line 24, in <module>
[rank0]:     launch()
[rank0]:   File "/home/llama_factory/src/llamafactory/launcher.py", line 20, in launch
[rank0]:     run_exp()
[rank0]:   File "/home/llama_factory/src/llamafactory/train/tuner.py", line 56, in run_exp
[rank0]:     func(model_args, data_args, training_args, finetuning_args, callbacks)
[rank0]:   File "/home/llama_factory/src/llamafactory/train/pt/workflow.py", line 67, in run_pt
[rank0]:     train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2164, in train
[rank0]:     return inner_training_loop(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2524, in _inner_training_loop
[rank0]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 3687, in training_step
[rank0]:     self.accelerator.backward(loss, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 2238, in backward
[rank0]:     self.deepspeed_engine_wrapped.backward(loss, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/deepspeed.py", line 186, in backward
[rank0]:     self.engine.backward(loss, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank0]:     ret_val = func(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2020, in backward
[rank0]:     self.optimizer.backward(loss, retain_graph=retain_graph)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 2064, in backward
[rank0]:     self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
[rank0]:     scaled_loss.backward(retain_graph=retain_graph)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_tensor.py", line 521, in backward
[rank0]:     torch.autograd.backward(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py", line 289, in backward
[rank0]:     _engine_run_backward(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py", line 768, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]: RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

I modified qwen to performed some experiment as powerinfer v2 (or says turbo sparse), but it is weird that the gradient of input_layernorm broken.
I can not understand why hidden_states = hidden_states.to(torch.float32) will break the gradient property. Anyone who can help?

class Qwen2RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        Qwen2RMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        print("norm0", hidden_states.requires_grad, hidden_states.mean())
        hidden_states = hidden_states.to(torch.float32)
        
        print("norm1", hidden_states.requires_grad, hidden_states.mean())
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        
        print("norm2", variance.requires_grad, variance.mean())
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

class Qwen2MLP(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_act]
        self.predictor = nn.Sequential(
            nn.Linear(self.hidden_size, 1024, bias=False),
            nn.ReLU(),
            nn.Linear(1024, self.intermediate_size, bias=False),
            nn.Sigmoid()
        )
        self.is_sparse = False

    def forward(self, x):
        hidden_states = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
        logits = self.predictor(x)
        probs = logits.float()
        y = hidden_states.clone().detach()
        y = y.float()
        y = (y > 0).to(y.dtype)
        weight = (y.sum() / y.numel()) + 0.005
        loss_weight = y * (1 - weight) + weight
        loss = torch.nn.functional.binary_cross_entropy(probs, y, weight=loss_weight)       
        # calculate results
        _d_probs = probs.clone().detach()
        d_probs = _d_probs >= 0.5

        dif = y.int() - d_probs.int()
        miss = (dif > 0.0).float()  
        _y = y.sum(dim=1)
        _miss = miss.sum(dim=1)
        recall = (_y - _miss).mean()
        true_sparsity = _y.float().mean()
        classifier_sparsity = d_probs.sum(dim=1).float().mean()
        
        predictor_loss = (loss.float(), recall, true_sparsity, classifier_sparsity)
        print(self.layer_idx, loss.requires_grad, loss)
        self._mask = None
        if self.is_sparse:
            hidden_states = hidden_states * self._mask
        down_proj = self.down_proj(hidden_states)
        return down_proj, predictor_loss

Others

No response

@hanlinxuy hanlinxuy added bug Something isn't working pending This problem is yet to be addressed labels Jan 16, 2025
@hanlinxuy
Copy link
Author

I reproduce this with llamafactory main brach without any modification.

System Info

  • llamafactory version: 0.9.2.dev0
  • Platform: Linux-5.15.0-1039-nvidia-lowlatency-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • PyTorch version: 2.4.0+cu121 (GPU)
  • Transformers version: 4.46.1
  • Datasets version: 3.1.0
  • Accelerate version: 1.0.1
  • PEFT version: 0.12.0
  • TRL version: 0.9.6
  • GPU type: NVIDIA A800 80GB PCIe
  • DeepSpeed version: 0.14.4

Reproduction

just modify the code in transformers lib for printing
/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py

class Qwen2RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        Qwen2RMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        print("norm0", hidden_states.requires_grad, hidden_states.mean())
        hidden_states = hidden_states.to(torch.float32)
        print("norm1", hidden_states.requires_grad, hidden_states.mean())
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        print("norm2", hidden_states.requires_grad, hidden_states.mean())
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        import ipdb;ipdb.set_trace()
        return self.weight * hidden_states.to(input_dtype)
WANDB_DISABLED=true NCCL_SOCKET_IFNAME=eth0 FORCE_TORCHRUN=1 CUDA_VISIBLE_DEVICES=0 llamafactory-cli train qwen2_full_pt.yaml
### model
model_name_or_path: ./Qwen2.5-1.5B-Instruct
trust_remote_code: true
flash_attn: disabled
### method
stage: pt
finetuning_type: full
do_train: true
deepspeed: examples/deepspeed/ds_z1_config.json  # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json]

### dataset
dataset: wiki_demo
template: qwen
cutoff_len: 2048
max_steps: 10000
overwrite_cache: true
preprocessing_num_workers: 16
print_param_status: true
### output
output_dir: saves/qwen2-1b5/full/pt
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
pure_bf16: true
### train
per_device_train_batch_size: 4
gradient_accumulation_steps: 32
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
ddp_timeout: 180000000

### eval
val_size: 1
per_device_eval_batch_size: 4
eval_strategy: steps
eval_steps: 100

include_tokens_per_second: true

Others

the log shows after hidden_states = hidden_states.to(torch.float32), the gradient lost.

  0%|                                                     | 0/10000 [00:00<?, ?it/s]
norm0 True tensor(-9.5963e-06, device='cuda:0', dtype=torch.bfloat16)
norm1 False tensor(-9.6218e-06, device='cuda:0')
norm2 False tensor(-9.6218e-06, device='cuda:0')
> /usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py(85)forward()
     84         import ipdb;ipdb.set_trace()
---> 85         return self.weight * hidden_states.to(input_dtype)
     86

@hiyouga
Copy link
Owner

hiyouga commented Jan 17, 2025

set disable_gradient_checkpointing: true

@hiyouga hiyouga closed this as completed Jan 17, 2025
@hiyouga hiyouga added solved This problem has been already solved and removed bug Something isn't working pending This problem is yet to be addressed labels Jan 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
solved This problem has been already solved
Projects
None yet
Development

No branches or pull requests

2 participants