From 9bf8d0534d847fb72dbaf9c05bdb6c323ce79806 Mon Sep 17 00:00:00 2001 From: Junxian Guo Date: Sat, 16 Nov 2024 13:54:05 -0500 Subject: [PATCH 1/2] [fix] unexpected behaviour for args --- tinychat/models/llama.py | 4 ++-- tinychat/modules/fused_attn.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tinychat/models/llama.py b/tinychat/models/llama.py index edc3d4e..185d721 100644 --- a/tinychat/models/llama.py +++ b/tinychat/models/llama.py @@ -85,7 +85,7 @@ def __init__(self, args): self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = args.max_position_embeddings self.rope_theta = args.rope_theta - self.rope_scaling = args.rope_scaling + self.rope_scaling = getattr(args, 'rope_scaling', None) if self.rope_scaling is None: self.rope_scaling = 1.0 else: @@ -304,7 +304,7 @@ def __init__(self, params): self.norm = RMSNorm(params.hidden_size, eps=params.rms_norm_eps) # Note (Haotian): rope_theta has to be defined here, otherwise context stage is wrong. - rope_scale = self.params.rope_scaling + rope_scale = getattr(self.params, 'rope_scaling', None) if rope_scale is None: rope_scale = 1.0 else: diff --git a/tinychat/modules/fused_attn.py b/tinychat/modules/fused_attn.py index 48151c9..32c266b 100644 --- a/tinychat/modules/fused_attn.py +++ b/tinychat/modules/fused_attn.py @@ -340,7 +340,7 @@ def __init__(self, hidden_size, num_heads, qkv_layer, o_proj, dev, args): self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = args.max_position_embeddings self.rope_theta = args.rope_theta - self.rope_scaling = args.rope_scaling + self.rope_scaling = getattr(args, 'rope_scaling', None) if self.rope_scaling is None: self.rope_scaling = 1.0 From 987417f4a449f3ea18c9a89b1e4248e00848e390 Mon Sep 17 00:00:00 2001 From: Junxian Guo Date: Sat, 16 Nov 2024 15:09:07 -0500 Subject: [PATCH 2/2] [update] replace fastertransformer like kv cache and kernel with flash attention for support of longer sequence --- tinychat/modules/fused_attn.py | 104 +++++++++++---------------------- 1 file changed, 34 insertions(+), 70 deletions(-) diff --git a/tinychat/modules/fused_attn.py b/tinychat/modules/fused_attn.py index 32c266b..a43b133 100644 --- a/tinychat/modules/fused_attn.py +++ b/tinychat/modules/fused_attn.py @@ -354,29 +354,26 @@ def __init__(self, hidden_size, num_heads, qkv_layer, o_proj, dev, args): torch.zeros( ( max_batch_size, - self.num_key_value_heads, - # args.max_position_embeddings, kv_max_seq_len, + self.num_key_value_heads, self.head_dim, - ) + ), + device=dev, + dtype=torch.float16, ) - .to(dev) - .half() ) # added to half # 8: pack 8 fp16 in FT, if fp32 then use 4 self.cache_k = ( torch.zeros( ( max_batch_size, - self.num_key_value_heads, - self.head_dim // 8, - # args.max_position_embeddings, kv_max_seq_len, - 8, - ) + self.num_key_value_heads, + self.head_dim, + ), + device=dev, + dtype=torch.float16, ) - .to(dev) - .half() ) # added to half # dummy @@ -400,82 +397,49 @@ def forward( self.n_local_heads + self.num_key_value_heads * 2, self.head_dim, ) - xq = xqkv[:, :, 0 : self.n_local_heads] + xq = xqkv[:, :, 0 : self.n_local_heads].contiguous() xk = xqkv[ :, :, self.n_local_heads : (self.n_local_heads + self.num_key_value_heads) - ] - xv = xqkv[:, :, -self.num_key_value_heads :] + ].contiguous() + xv = xqkv[:, :, -self.num_key_value_heads :].contiguous() if seqlen > 1: xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) - self.cache_k = self.cache_k.to(xq) - self.cache_v = self.cache_v.to(xq) + self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv + self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk - values_store = xv.transpose(2, 1) - keys_store = ( - xk.reshape(bsz, seqlen, self.num_key_value_heads, self.head_dim // 8, 8) - .permute(0, 2, 3, 1, 4) - .contiguous() - ) + # if chunk_prefilling: + keys = self.cache_k[:bsz:, 0 : start_pos + seqlen] + values = self.cache_v[:bsz:, 0 : start_pos + seqlen] - self.cache_v[:bsz, :, start_pos : start_pos + seqlen, :] = values_store - self.cache_k[:bsz, :, :, start_pos : start_pos + seqlen, :] = keys_store + # else: + # keys = xk + # values = xv - if chunk_prefilling: - keys = self.cache_k[:, :, :, 0 : start_pos + seqlen, :] - keys = ( - keys.permute(0, 3, 1, 2, 4) - .reshape( - bsz, start_pos + seqlen, self.num_key_value_heads, self.head_dim - ) - .contiguous() - ) - values = self.cache_v[:, :, 0 : start_pos + seqlen, :] - values = ( - values.transpose(2, 1) - .reshape( - bsz, start_pos + seqlen, self.num_key_value_heads, self.head_dim - ) - .contiguous() - ) - else: - keys = xk - values = xv - - keys = torch.repeat_interleave( - keys, dim=2, repeats=self.num_key_value_groups - ) - values = torch.repeat_interleave( - values, dim=2, repeats=self.num_key_value_groups - ) output = flash_attn_func( q=xq, k=keys, v=values, causal=True, ) - output = output.contiguous().view(bsz, seqlen, -1) + output = output.view(bsz, seqlen, -1) else: - xq = xq.view(bsz, self.n_local_heads, self.head_dim) - xk = xk.view(bsz, self.num_key_value_heads, self.head_dim) - xv = xv.view(bsz, self.num_key_value_heads, self.head_dim) + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) - output = awq_inference_engine.single_query_attention( - xq, - xk, - xv, - self.cache_k, - self.cache_v, - None, - None, - start_pos, - self.head_dim, - self.rope_theta, - self.rope_scaling, - True, + self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv + self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk + + keys = self.cache_k[:bsz, 0 : start_pos + seqlen] + values = self.cache_v[:bsz, 0 : start_pos + seqlen] + + output = flash_attn_func( + q=xq, + k=keys, + v=values, + causal=True, ) - output = output.reshape(bsz, 1, -1) + output = output.view(bsz, seqlen, -1) return self.o_proj(output)