🎉

LLM

一个脚本

CUDA Version: 12.4

# torch
 
 
一个脚本直接干完
echo y | conda install cuda=11.8.0 -c nvidia pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 pip install liger-kernel pip install flash-attn deepspeed transformers --extra-index-url https://download.pytorch.org/whl/cu118 nvcc --version pip install vllm --extra-index-url https://download.pytorch.org/whl/cu118

cuda

echo y | conda install cuda=11.8.0 -c nvidia echo y | conda install cuda=12.2.0 -c nvidia echo y | conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=11.8 -c pytorch -c nvidia

torch

pip uninstall torch torchvision torchaudio # cuda 12.1 pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 # cuda 11.8 pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

flash attn

For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
尽量选11.8
git clone git@github.com:Dao-AILab/flash-attention.git git checkout tags/v2.3.6 python setup.py install # A100 可以额外装一下 layer_norm cd csrc/layer_norm && pip install .

deepspeed

pip install deepspeed

transformers

patch
diff --git a/src/transformers/dynamic_module_utils.py b/src/transformers/dynamic_module_utils.py index cd875b7b4..e2b7c49f5 100644 --- a/src/transformers/dynamic_module_utils.py +++ b/src/transformers/dynamic_module_utils.py @@ -200,6 +200,24 @@ def get_class_in_module(class_name: str, module_path: Union[str, os.PathLike]) - return getattr(module, class_name) +def safe_copy(src, dst, max_time=30): + """写入有延迟,确认写入完成再返回""" + shutil.copy(src, dst) + import time + + with open(src) as fr: + s0 = fr.read() + + for _ in range(max_time): + time.sleep(1) + n0 = os.path.getsize(src) + if os.path.getsize(dst) == n0: + with open(dst) as fr: + s1 = fr.read() + if s1 == s0: + return + + def get_cached_module_file( pretrained_model_name_or_path: Union[str, os.PathLike], module_file: str, @@ -322,7 +340,7 @@ def get_cached_module_file( if not (submodule_path / module_file).exists() or not filecmp.cmp( resolved_module_file, str(submodule_path / module_file) ): - shutil.copy(resolved_module_file, submodule_path / module_file) + safe_copy(resolved_module_file, submodule_path / module_file) importlib.invalidate_caches() for module_needed in modules_needed: module_needed = f"{module_needed}.py" @@ -330,7 +348,7 @@ def get_cached_module_file( if not (submodule_path / module_needed).exists() or not filecmp.cmp( module_needed_file, str(submodule_path / module_needed) ): - shutil.copy(module_needed_file, submodule_path / module_needed) + safe_copy(module_needed_file, submodule_path / module_needed) importlib.invalidate_caches() else: # Get the commit hash @@ -343,7 +361,7 @@ def get_cached_module_file( create_dynamic_module(full_submodule) if not (submodule_path / module_file).exists(): - shutil.copy(resolved_module_file, submodule_path / module_file) + safe_copy(resolved_module_file, submodule_path / module_file) importlib.invalidate_caches() # Make sure we also have every file with relative for module_needed in modules_needed: diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5c3a12183..49b06a6f7 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2469,7 +2469,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix variant = kwargs.pop("variant", None) adapter_kwargs = kwargs.pop("adapter_kwargs", {}) adapter_name = kwargs.pop("adapter_name", "default") - use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) + use_flash_attention_2 = kwargs.pop("use_flash_attention_2", None) + + if use_flash_attention_2 is None: + if cls._supports_flash_attn_2: + use_flash_attention_2 = True + else: + use_flash_attention_2 = False if is_fsdp_enabled(): low_cpu_mem_usage = True diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 3d9e31775..a187b9d0e 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -48,6 +48,29 @@ if is_flash_attn_available(): logger = logging.get_logger(__name__) + +try: + from flash_attn.layers.rotary import apply_rotary_emb_func +except ImportError: + apply_rotary_emb_func = None + +try: + from flash_attn.ops.rms_norm import dropout_add_rms_norm +except ImportError: + dropout_add_rms_norm = None + +try: + from flash_attn.ops.activations import swiglu +except ImportError: + swiglu = None + +if torch.cuda.is_available(): + try: + from flash_attn.losses.cross_entropy import CrossEntropyLoss + except ImportError: + pass + + _CONFIG_FOR_DOC = "LlamaConfig" @@ -105,12 +128,33 @@ class LlamaRMSNorm(nn.Module): self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps - def forward(self, hidden_states): + def forward(self, hidden_states, residual=None): + if dropout_add_rms_norm is not None and hidden_states.is_cuda: + out, res = dropout_add_rms_norm( + hidden_states, + residual, + self.weight, + None, # bias + 0., # dropout_p + self.variance_epsilon, + prenorm=True, + residual_in_fp32=False, + return_dropout_mask=False, + ) + return out if residual is None else (out, res) + else: + logger.warning_once("dropout_add_rms_norm is not installed. If you want to accelerate training please install: " + "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm") + + if residual is not None: + hidden_states = residual + hidden_states + residual = hidden_states input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + hidden_states = self.weight * hidden_states.to(input_dtype) + return hidden_states if residual is None else (hidden_states, residual) ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) @@ -123,19 +167,21 @@ class LlamaRotaryEmbedding(nn.Module): self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=max_position_embeddings, device=device, dtype=torch.get_default_dtype() ) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) + inv_freq = 1.0 / torch.pow(self.base, (torch.arange(0, self.dim, 2, device=device).to(torch.float32) / self.dim)) + t = torch.arange(self.max_seq_len_cached, device=device).to(torch.float32) # use float32 due to limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0) + freqs = torch.outer(t, inv_freq) + if not all(_t.dtype == torch.float32 for _t in [t, inv_freq, freqs]): + logger.warn( + f'LlamaRotaryEmbedding t.dtype: {t.dtype} inv_freq.dtype: {inv_freq.dtype} freqs.dtype: {freqs.dtype}' + ) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) @@ -161,10 +207,15 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + inv_freq = 1.0 / torch.pow(self.base, (torch.arange(0, self.dim, 2, device=device).to(torch.float32) / self.dim)) + t = torch.arange(self.max_seq_len_cached, device=device).to(torch.float32) t = t / self.scaling_factor - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, inv_freq) + if not all(_t.dtype == torch.float32 for _t in [t, inv_freq, freqs]): + logger.warn( + f'LlamaLinearScalingRotaryEmbedding t.dtype: {t.dtype} inv_freq.dtype: {inv_freq.dtype} ' + f'freqs.dtype: {freqs.dtype}' + ) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) @@ -174,23 +225,40 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, mixed=True): self.scaling_factor = scaling_factor + self.mixed = mixed super().__init__(dim, max_position_embeddings, base, device) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) + if self.mixed: + """ + Reference: + https://spaces.ac.cn/archives/9706 + https://github.com/bojone/rerope/blob/main/ntk_patch.py + """ + b = 0.75 + a = math.log(self.scaling_factor) / (self.dim / 2) ** b + inv_freq = torch.pow(self.base, (-torch.arange(0, self.dim, 2, device=device).to(torch.float32) / self.dim)) + inv_freq *= (-a * torch.arange(1, self.dim // 2 + 1, device=device).to(torch.float32) ** b).exp() + else: + base = torch.pow( + self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ), + self.dim / (self.dim - 2) + ) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=device).to(torch.float32) / self.dim)) + + t = torch.arange(self.max_seq_len_cached, device=device).to(torch.float32) + freqs = torch.outer(t, inv_freq) + if not all(_t.dtype == torch.float32 for _t in [t, inv_freq, freqs]): + logger.warn( + f'LlamaDynamicNTKScalingRotaryEmbedding t.dtype: {t.dtype} inv_freq.dtype: {inv_freq.dtype} ' + f'freqs.dtype: {freqs.dtype}' + ) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) @@ -204,8 +272,22 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, training=True): # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + if position_ids is None and training: + if apply_rotary_emb_func is not None: + rot_dim = cos.shape[-1] // 2 + cos = cos[0, 0, :, :rot_dim] + sin = sin[0, 0, :, :rot_dim] + q_embed = apply_rotary_emb_func(q.transpose(1, 2), cos, sin, False, False) # interleaved=False, inplace=False + k_embed = apply_rotary_emb_func(k.transpose(1, 2), cos, sin, False, False) # interleaved=False, inplace=False + q_embed = q_embed.transpose(1, 2) + k_embed = k_embed.transpose(1, 2) + return q_embed, k_embed + else: + logger.warning_once("rotary_emb is not installed. If you want to accelerate training please install: " + "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary") + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] @@ -215,6 +297,18 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): return q_embed, k_embed +def apply_logn_factor(q, position_ids, pretrained_length): + """ + Reference: + https://spaces.ac.cn/archives/9706 + https://github.com/bojone/rerope/blob/main/ntk_patch.py + """ + if position_ids[0, -1] + 1 > pretrained_length: + scale = ((position_ids + 1)[:, None, :, None].log() / math.log(pretrained_length)).clip(1) + q *= scale.to(q.dtype) + return q + + class LlamaMLP(nn.Module): def __init__(self, config): super().__init__() @@ -244,7 +338,12 @@ class LlamaMLP(nn.Module): ] down_proj = sum(down_proj) else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + if swiglu is not None: + down_proj = self.down_proj(swiglu(self.gate_proj(x), self.up_proj(x))) + else: + logger.warning_once("swiglu is not installed. If you want to accelerate training please install: " + "https://github.com/Dao-AILab/flash-attention") + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj @@ -296,6 +395,8 @@ class LlamaAttention(nn.Module): else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] + self.pretrained_length = int(self.max_position_embeddings / scaling_factor) + assert self.pretrained_length * scaling_factor == self.max_position_embeddings if scaling_type == "linear": self.rotary_emb = LlamaLinearScalingRotaryEmbedding( self.head_dim, @@ -358,7 +459,10 @@ class LlamaAttention(nn.Module): if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids, self.training) + + #if self.config.rope_scaling: + # query_states = apply_logn_factor(query_states, position_ids, self.pretrained_length) if past_key_value is not None: # reuse k, v, self_attention @@ -451,7 +555,7 @@ class LlamaFlashAttention2(LlamaAttention): cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids, self.training) if past_key_value is not None: # reuse k, v, self_attention @@ -641,11 +745,9 @@ class LlamaDecoderLayer(nn.Module): use_cache=use_cache, padding_mask=padding_mask, ) - hidden_states = residual + hidden_states # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states @@ -702,6 +804,8 @@ class LlamaPreTrainedModel(PreTrainedModel): def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, LlamaModel): + if value: + value = getattr(self, 'gradient_checkpoint_disable_layers', None) or value module.gradient_checkpointing = value @@ -862,7 +966,9 @@ class LlamaModel(LlamaPreTrainedModel): past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length - if position_ids is None: + if self.training and past_key_values_length == 0 and apply_rotary_emb_func is not None: + position_ids = None + elif position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device @@ -874,20 +980,16 @@ class LlamaModel(LlamaPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device - ) - padding_mask = None - else: - if 0 in attention_mask: - padding_mask = attention_mask - else: - padding_mask = None + padding_mask = None + if attention_mask is not None and 0 in attention_mask: + padding_mask = attention_mask - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) + if self.training and getattr(self.config, "_flash_attn_2_enabled", False): + attention_mask = None + else: + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) hidden_states = inputs_embeds @@ -903,13 +1005,19 @@ class LlamaModel(LlamaPreTrainedModel): all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None + if self.gradient_checkpointing > 1: + gc_disable_interval = len(self.layers) // self.gradient_checkpointing + gc_disable_layers_max = self.gradient_checkpointing * gc_disable_interval + for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: + if self.gradient_checkpointing and self.training and not ( + self.gradient_checkpointing > 1 and idx % gc_disable_interval == 0 and idx < gc_disable_layers_max + ): def create_custom_forward(module): def custom_forward(*inputs): @@ -919,7 +1027,9 @@ class LlamaModel(LlamaPreTrainedModel): return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids + create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids, + use_reentrant=True, + preserve_rng_state=False # attn_dropout = 0.0 ) else: layer_outputs = decoder_layer( @@ -958,6 +1068,7 @@ class LlamaModel(LlamaPreTrainedModel): class LlamaForCausalLM(LlamaPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"self_attn\.rotary_emb\.inv_freq"] _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): @@ -969,6 +1080,12 @@ class LlamaForCausalLM(LlamaPreTrainedModel): # Initialize weights and apply final processing self.post_init() + if CrossEntropyLoss.__module__.startswith('flash_attn'): + self.z_loss = 0.0 + else: + logger.warn("CrossEntropyLoss is not installed. If you want to accelerate training please install: " + "https://github.com/Dao-AILab/flash-attention") + def get_input_embeddings(self): return self.model.embed_tokens @@ -1046,7 +1163,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel): output_hidden_states=output_hidden_states, return_dict=return_dict, ) - + is_train_flash = self.training and CrossEntropyLoss.__module__.startswith('flash_attn') hidden_states = outputs[0] if self.config.pretraining_tp > 1: lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) @@ -1054,7 +1171,8 @@ class LlamaForCausalLM(LlamaPreTrainedModel): logits = torch.cat(logits, dim=-1) else: logits = self.lm_head(hidden_states) - logits = logits.float() + if not is_train_flash: + logits = logits.float() loss = None if labels is not None: @@ -1062,7 +1180,10 @@ class LlamaForCausalLM(LlamaPreTrainedModel): shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = CrossEntropyLoss() + if is_train_flash: + loss_fct = CrossEntropyLoss(inplace_backward=True, lse_square_scale=self.z_loss) # inplace_backward - saves memory + else: + loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism diff --git a/src/transformers/optimization.py b/src/transformers/optimization.py index 2b8f5d2a8..b94c0479c 100644 --- a/src/transformers/optimization.py +++ b/src/transformers/optimization.py @@ -776,3 +776,214 @@ def get_adafactor_schedule(optimizer, initial_lr=0.0): """ return AdafactorSchedule(optimizer, initial_lr) + + +from torch import Tensor +from typing import List + + +class SophiaG(Optimizer): + def __init__(self, params, lr=3e-4, betas=(0.9, 0.95), rho=0.05, + weight_decay=1e-1, batch_size=5120, hess_interval=11, *, maximize: bool = False, + capturable: bool = False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= rho: + raise ValueError("Invalid rho parameter at index 1: {}".format(rho)) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + defaults = dict(lr=lr, betas=betas, rho=rho, + weight_decay=weight_decay, + maximize=maximize, capturable=capturable) + logger.warn('SophiaG optimizer: ' + str(defaults)) + super(SophiaG, self).__init__(params, defaults) + self.batch_size = batch_size + self.hess_interval = hess_interval + self.need_update_hessian = None + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('maximize', False) + group.setdefault('capturable', False) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + def step(self, *args, **kwargs): + if self.need_update_hessian: + self.update_hessian() + self.need_update_hessian = False + else: + self._step(*args, **kwargs) + + @torch.no_grad() + def update_hessian(self): + for group in self.param_groups: + beta1, beta2 = group['betas'] + for p in group['params']: + if p.grad is None: + continue + state = self.state[p] + + if len(state) == 0: + state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ + if self.defaults['capturable'] else torch.tensor(0.) + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + if 'hessian' not in state.keys(): + state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + state['hessian'].mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2) + + @torch.no_grad() + def _step(self, closure=None, bs=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + if bs is None: + bs = self.batch_size + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + state_steps = [] + hessian = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + + if p.grad.is_sparse: + raise RuntimeError('Hero does not support sparse gradients') + grads.append(p.grad) + state = self.state[p] + # State initialization + if len(state) == 0: + state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ + if self.defaults['capturable'] else torch.tensor(0.) + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + if 'hessian' not in state.keys(): + state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + state_steps.append(state['step']) + hessian.append(state['hessian']) + + if self.defaults['capturable']: + bs = torch.ones((1,), dtype=torch.float, device=p.device) * bs + + sophiag(params_with_grad, + grads, + exp_avgs, + hessian, + state_steps, + bs=bs, + beta1=beta1, + beta2=beta2, + rho=group['rho'], + lr=group['lr'], + weight_decay=group['weight_decay'], + maximize=group['maximize'], + capturable=group['capturable']) + + return loss + + +def sophiag(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + hessian: List[Tensor], + state_steps: List[Tensor], + capturable: bool = False, + *, + bs: int, + beta1: float, + beta2: float, + rho: float, + lr: float, + weight_decay: float, + maximize: bool): + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") + + func = _single_tensor_sophiag + + func(params, + grads, + exp_avgs, + hessian, + state_steps, + bs=bs, + beta1=beta1, + beta2=beta2, + rho=rho, + lr=lr, + weight_decay=weight_decay, + maximize=maximize, + capturable=capturable) + + +def _single_tensor_sophiag(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + hessian: List[Tensor], + state_steps: List[Tensor], + *, + bs: int, + beta1: float, + beta2: float, + rho: float, + lr: float, + weight_decay: float, + maximize: bool, + capturable: bool): + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + hess = hessian[i] + step_t = state_steps[i] + + if capturable: + assert param.is_cuda and step_t.is_cuda and bs.is_cuda + + if torch.is_complex(param): + grad = torch.view_as_real(grad) + exp_avg = torch.view_as_real(exp_avg) + hess = torch.view_as_real(hess) + param = torch.view_as_real(param) + + # update step + step_t += 1 + + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + if capturable: + step = step_t + step_size = lr + step_size_neg = step_size.neg() + else: + step = step_t.item() + step_size_neg = - lr + + ratio = rho * bs * hess + 1e-15 + torch.div(exp_avg, ratio, out=ratio) + torch.clamp_(ratio, -1, 1) + param.add_(ratio, alpha=step_size_neg) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index b9e103761..0ec00aac2 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -64,7 +64,7 @@ from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, i from .modelcard import TrainingSummary from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES -from .optimization import Adafactor, get_scheduler +from .optimization import Adafactor, get_scheduler, SophiaG from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_less_than_1_11 from .tokenization_utils_base import PreTrainedTokenizerBase from .trainer_callback import ( @@ -1141,6 +1141,10 @@ class Trainer: optimizer_cls = torch.optim.Adagrad elif args.optim == OptimizerNames.RMSPROP: optimizer_cls = torch.optim.RMSprop + elif args.optim == OptimizerNames.SOPHIA: + optimizer_cls = SophiaG + total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size + optimizer_kwargs.update(batch_size=total_train_batch_size * args.model_max_length) else: raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") return optimizer_cls, optimizer_kwargs @@ -1751,6 +1755,14 @@ class Trainer: # Check if saved optimizer or scheduler states exist self._load_optimizer_and_scheduler(resume_from_checkpoint) + if isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper): + if self.args.learning_rate != self.lr_scheduler.scheduler.base_lrs[0]: + logger.warning("changed learning rate from {} to {}".format( + self.lr_scheduler.scheduler.base_lrs[0], self.args.learning_rate + )) + self.lr_scheduler.scheduler.base_lrs = \ + [self.args.learning_rate] * len(self.lr_scheduler.scheduler.base_lrs) + # important: at this point: # self.model is the Transformers Model # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), @@ -1786,6 +1798,13 @@ class Trainer: else: steps_trained_in_current_epoch = 0 + if args.save_steps is not None and args.save_steps != self.state.save_steps: + logger.warning("changed save_steps from {} to {}".format(self.state.save_steps, args.save_steps)) + if args.save_steps < 1: + self.state.save_steps = math.ceil(max_steps * args.save_steps) + else: + self.state.save_steps = args.save_steps + logger.info(" Continuing training from checkpoint, will skip to saved global_step") logger.info(f" Continuing training from epoch {epochs_trained}") logger.info(f" Continuing training from global step {self.state.global_step}") @@ -2307,6 +2326,11 @@ class Trainer: logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) logs["learning_rate"] = self._get_learning_rate() + if self.is_deepspeed_enabled: + logs['grad_norm'] = self.optimizer.optimizer._global_grad_norm + if isinstance(logs['grad_norm'], torch.Tensor): + logs['grad_norm'] = logs['grad_norm'].cpu().item() + self._total_loss_scalar += tr_loss_scalar self._globalstep_last_logged = self.state.global_step self.store_flos() @@ -2361,8 +2385,11 @@ class Trainer: "fashion, reproducibility is not guaranteed." ) return - - checkpoint_rng_state = torch.load(rng_file) + try: + checkpoint_rng_state = torch.load(rng_file) + except: + logger.warn(f'Load error: {rng_file}') + return random.setstate(checkpoint_rng_state["python"]) np.random.set_state(checkpoint_rng_state["numpy"]) torch.random.set_rng_state(checkpoint_rng_state["cpu"]) @@ -2794,17 +2821,32 @@ class Trainer: Subclass and override for custom behavior. """ - if self.label_smoother is not None and "labels" in inputs: + # Sophia optimizer + optimizer = self.optimizer + if hasattr(optimizer, 'optimizer'): + optimizer = optimizer.optimizer # deepspeed zero optimizer -> sophia + need_update_hessian = ( + isinstance(optimizer, SophiaG) and self.state.global_step % optimizer.hess_interval == 0 + ) + if need_update_hessian or (self.label_smoother is not None and "labels" in inputs): labels = inputs.pop("labels") else: labels = None + outputs = model(**inputs) # Save past state if it exists # TODO: this needs to be fixed and made cleaner later. if self.args.past_index >= 0: self._past = outputs[self.args.past_index] - if labels is not None: + if need_update_hessian: + optimizer.need_update_hessian = True + logits = outputs["logits"] if isinstance(outputs, dict) else outputs[0] + samp_dist = torch.distributions.Categorical(logits=logits) + y_sample = samp_dist.sample() + y_sample.masked_fill_(labels == -100, -100) + loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y_sample.view(-1)) + elif labels is not None: if is_peft_available() and isinstance(model, PeftModel): model_name = unwrap_model(model.base_model)._get_name() else: @@ -2884,6 +2926,19 @@ class Trainer: raise ValueError("Install Accelerate from main branch") try: state_dict = self.accelerator.get_state_dict(self.deepspeed) + + from .integrations import is_deepspeed_zero3_enabled + + if is_deepspeed_zero3_enabled(): + def _call_state_dict_hooks(module, prefix=''): + for name, child in module._modules.items(): + if child is not None: + _call_state_dict_hooks(child, prefix=prefix + name + '.') + for hook in module._state_dict_hooks.values(): + hook(module, state_dict, prefix, local_metadata={}) + + _call_state_dict_hooks(self.model) + if self.args.should_save: self._save(output_dir, state_dict=state_dict) except ValueError: diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 07e3d04ef..e302dd1b6 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -154,6 +154,7 @@ class OptimizerNames(ExplicitEnum): PAGED_LION = "paged_lion_32bit" PAGED_LION_8BIT = "paged_lion_8bit" RMSPROP = "rmsprop" + SOPHIA = "sophia" # TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903
git clone git@github.com:huggingface/transformers.git git checkout tags/v4.44.2 pip install -e .

vllm

pip install vllm vllm serve LLMs/Meta-Llama-3-8B-Instruct --dtype bfloat16 --api-key hello --port 8999 --- from openai import OpenAI client = OpenAI(api_key="hello", base_url="http://127.0.0.1:8999/v1") client.models.list() --- # Install vLLM from source git clone https://github.com/vllm-project/vllm.git cd vllm pip install -e . --extra-index-url https://download.pytorch.org/whl/cu121 # if error export CUDA_HOME=/usr/local/cuda export PATH="${CUDA_HOME}/bin:$PATH" nvcc --version # verify that nvcc is in your PATH ${CUDA_HOME}/bin/nvcc --version # verify that nvcc is in your CUDA_HOME # Install vLLM with CUDA 11.8. export VLLM_VERSION=0.4.0 export PYTHON_VERSION=310 pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux1_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118