Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support Deepseek-VL2 #2798

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open

[Feature] Support Deepseek-VL2 #2798

wants to merge 21 commits into from

Conversation

ccw1996
Copy link

@ccw1996 ccw1996 commented Jan 8, 2025

Motivation

Add Deepseek-VL2 model to SGLang, as requested in #2653

Modifications

  1. Add new model Deepseek-VL2

Checklist

  • Format your code according to the Contributor Guide.
  • Add unit tests as outlined in the Contributor Guide.
  • Update documentation as needed, including docstrings or example tutorials.

@zhyncs zhyncs added the enhancement New feature or request label Jan 8, 2025
@@ -0,0 +1,127 @@
from typing import List,Optional,Tuple,Union
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename the file to deepseek_vl2?

Copy link
Author

@ccw1996 ccw1996 Jan 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename done


self.layers = modules

def forward(self, x):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not yet implemented the forward part of the DeepseekV2ForCausalLM. I will finish all the implementations and add the unit test this weekend.

@yizhang2077
Copy link
Collaborator

@ccw1996 Do you need our help?

@SashvDave
Copy link

SashvDave commented Jan 29, 2025

Has support for deepseek vl2 been implemented?

Comment on lines 220 to 229
if config.projector_type == "downsample_mlp_gelu":
mlp_depth = config.depth
mlp_ratio = config.mlp_ratio
modules = [nn.Linear(config.input_dim * config.downsample_ratio * config.downsample_ratio, config.n_embed * mlp_ratio)]
for _ in range(1, mlp_depth - 1):
modules.append(nn.GELU())
modules.append(nn.Linear(config.n_embed * mlp_ratio, config.n_embed * mlp_ratio))
modules.append(nn.GELU())
modules.append(nn.Linear(config.n_embed * mlp_ratio, config.n_embed))
modules = nn.Sequential(*modules)
Copy link
Contributor

@Edenzzzz Edenzzzz Jan 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ccw1996 I'm happy to take the rest of the work to parallelize the remaining functions. Could you give me access to your branch?

@Edenzzzz
Copy link
Contributor

@ccw1996 Apologies for the delay. Would you like me to help with the rest of it?

@ccw1996
Copy link
Author

ccw1996 commented Jan 31, 2025

@ccw1996 Do you need our help?

@ccw1996 Apologies for the delay. Would you like me to help with the rest of it?

sure, i have some trouble about adapting preprocess. i need help to adapt siglip implement without timm

i will update my other implement code later

@Edenzzzz
Copy link
Contributor

Edenzzzz commented Jan 31, 2025

@ccw1996 I see, I think you can copy those layers from timm into python/sglang/srt/models/deepseekvl2.py, and then replace layers with sgl classes. I'm interested in helping if you can give me access.

@Edenzzzz
Copy link
Contributor

Edenzzzz commented Feb 4, 2025

@yizhang2077 @ispobock Looks like we'll have to copy lots of code from timm--now mostly just the linear layers with variable depth to parallelize, will finish soon

@ccw1996
Copy link
Author

ccw1996 commented Feb 6, 2025

@ccw1996 Apologies for the delay. Would you like me to help with the rest of it?

@Edenzzzz hello, i have run deepseekvl2 success with timm preprocess, but i am confused that result have some unexpected value. Can you help me find out the reason?

@Edenzzzz
Copy link
Contributor

Edenzzzz commented Feb 6, 2025

@Edenzzzz hello, i have run deepseekvl2 success with timm preprocess, but i am confused that result have some unexpected value. Can you help me find out the reason?

Sure, can you mark the problematic part?

Comment on lines 640 to 659
if config.projector_type == "downsample_mlp_gelu":
mlp_depth = config.depth
mlp_ratio = config.mlp_ratio
modules = [
nn.Linear(
config.input_dim
* config.downsample_ratio
* config.downsample_ratio,
config.n_embed * mlp_ratio,
)
]
for _ in range(1, mlp_depth - 1):
modules.append(nn.GELU())
modules.append(
nn.Linear(config.n_embed * mlp_ratio, config.n_embed * mlp_ratio)
)
modules.append(nn.GELU())
modules.append(nn.Linear(config.n_embed * mlp_ratio, config.n_embed))
modules = nn.Sequential(*modules)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to parallelize this part with Column and Row linear

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yizhang2077 Actually with GELU we'll have to gather output for each TP linear. Should we use replicated linear instead?

@ccw1996
Copy link
Author

ccw1996 commented Feb 9, 2025

@Edenzzzz hello, i have run deepseekvl2 success with timm preprocess, but i am confused that result have some unexpected value. Can you help me find out the reason?
Sure, can you mark the problematic part?

two problem. one is radix cache will make input error, i will try to fix it. the second is output seems like not use images embedding. Can you help me to debug it?

@Edenzzzz
Copy link
Contributor

Edenzzzz commented Feb 9, 2025

Let me try tomorrow

Comment on lines 1128 to 1130
input_embeds[idx].masked_scatter_(
image_seq_mask[idx].unsqueeze(-1), images_in_this_batch
)
Copy link
Contributor

@Edenzzzz Edenzzzz Feb 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ccw1996 The image embedding (images_in_this_batch) is indeed applied to the text embedding here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Edenzzzz thanks a lot. Now it can output right answer. I will finish cuda graph and clean code in this weekend.

Comment on lines +157 to +161
logger.info(
"Automatically turn off --chunked-prefill-size and disable radix cache for deekseek-vl2."
)
server_args.chunked_prefill_size = -1
server_args.disable_radix_cache = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The language part still supports radix cache.

Copy link
Author

@ccw1996 ccw1996 Feb 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The language part relay on input embed. If use radix cache, the input embed is wrong. I will try to debug it.

Copy link
Contributor

@Edenzzzz Edenzzzz Feb 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I think you're right. Llava and qwen_vl also don't use radix attn

],
)
cls.base_url += "/v1"

if __name__ == "__main__":
unittest.main()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ccw1996 This seems mostly ready. Did you encounter 400 Bad Request when running Qwen-VL?
image

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if qwen-vl is normal, i tested qwen2-vl and passed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests have not passed. We should test deepseek-vl2, not qwen-vl. There's some dim mismatch in capturing cuda graph. You can try to fix it and then it should be ready

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, i fix these error in latest commit. now it can test pass.

@ccw1996 ccw1996 changed the title [WIP] [Feature] Support Deepseek-VL2 [Feature] Support Deepseek-VL2 Feb 14, 2025
@ccw1996 ccw1996 marked this pull request as ready for review February 14, 2025 12:01
@ccw1996
Copy link
Author

ccw1996 commented Feb 14, 2025

@Edenzzzz Can you help me merge all the commits? Now, it's ready. Thanks a lot

Comment on lines 887 to 909
modules = ReplicatedLinear(
config.input_dim,
config.n_embed,
quant_config=quant_config,
)

elif config.projector_type == "mlp_gelu":
mlp_depth = config.depth
modules = [ReplicatedLinear(
config.input_dim,
config.n_embed,
quant_config=quant_config,
)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(
ReplicatedLinear(
config.n_embed,
config.n_embed,
quant_config=quant_config,
)
)
modules = nn.Sequential(*modules)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are still bugs when running the test. Replaced linear layers, we need to take out the first element of the output tuple

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants