Skip to content

vLLM's V1 Engine Architecture #8779

Open
@simon-mo

Description

@simon-mo

This issues describes the high level directions that "create LLM Engine V1". We want the design to be as transparent as possible and created this issue to track progress and solicit feedback.

Goal:

  • The new engine will be simple and performant. We found the first iteration of the engine to be simple, the multistep engine to be performant, but we want best of the both worlds. For it to be performat, we want to minimize GPU idle time.
  • The new architecture will be extensible and modular. We found the current codebase becoming difficult to extend and add new features (both production and experimental features) due to the hard tangling of different features. In the new design, features should be compatible with each other.
  • Tech debts will be cleaned up. We will remove optimizations that compromise code readability. We will also redo ad-hoc implementations to support certain features/models.

Non-goals, the following are important but orthogonal:

  • Optimize GPU time/kernels
  • Add new features/optimizations
  • Performance in rare cases

The scope is exclusively in the scheduler, memory manager, distributed architecture. We will not touch APIs, models, kernels, and most parts of the model runner.

Highlights of the new design:

  • Driver process + SPMD workers
    • When TP=n & PP=m, vLLM engine will have n*m + 1 processes in total.
      • Corollary: even when using a single GPU, we will have 2 processes.
    • The driver process will have the scheduler, memory manager, etc.
    • The workers are stateful, maintaining most of the request states.
      • The driver will only send the “diffs”
        • New request: input token IDs & block tables & sampling params, etc.
        • In-flight request: scheduled request IDs, new block IDs (no token IDs, sampling params, etc.)
    • Clean up data structures like SeqGroupMetadata
  • Async single-step scheduling, instead of multi-step scheduling
    • Scheduler will schedule the n+1-th step, while the worker is executing the n-th step.
    • We will reuse the code from multi-step scheduling to incrementally update the model inputs.
    • Needs a special care for PP, since the output token IDs from the last stage should be sent to the first stage.
  • De-tokenizer moves to the driver process
    • Async de-tokenization can be regarded as part of async scheduling
  • Native support for different types of model states
    • Regular KV cache, Mamba cache, encoder cache, etc.
    • Dedicated memory manager & block table for each type of cache
  • Drop beam search from vLLM engine
    • Provide a solution to emulate beam search outside vLLM engine
  • Prefix-caching as a first-class feature
    • Implement parallel sampling via prefix caching
    • Remove the concept of SequenceGroup
    • Optimize prefix caching overheads
  • Remove/minimize PyObjectCache

Lessons we learned from V1:

  • To achieve high GPU utilization, we should care about everything happening on the CPU.
    • Python is slow.
    • Fast GPUs like H100 do not necessarily have fast CPUs. They may have hundreds of CPU cores, but each with low clock speed.
    • Moreover, GPUs will get faster and faster, while CPUs will not.
  • Scheduling is not cheap.
    • For every step, the vLLM scheduler goes over the whole self.running queue and performs some operations for each request (e.g., allocating a new block). And this is written in Python.
  • Input broadcasting is expensive.
    • Instead of sending request information from scheduler to workers every step, the workers should be stateful and maintain most of the request states.
  • Preparing the model & sampler inputs (e.g., block table) is expensive.
    • We should cache the inputs of the previous steps, and** build new inputs incrementally from the cached inputs**, if possible.
    • However, not every state should be kept in GPU memory. It’s OK to cache & incrementally build some inputs in CPU memory, and send them to GPU every step.
  • De-tokenization is expensive.
    • For every step, vLLM de-tokenizes the generated output token IDs and checks the stop criteria.
    • The overhead becomes significant for large batch sizes.
  • Sampler is expensive.
    • The GPU operations themselves are not very expensive.
    • However, “pythonizing” the sampler outputs is expensive.
    • Plus, the sampler can launch many small GPU kernels with CPU-GPU synchronizations.
  • Supporting different types of model states (e.g., KV cache, Mamba cache, encoder cache) is challenging.
    • We need native cache managers for these different types of caches.
    • We need to deal with memory fragmentation due to the different sizes of the different states

Timeline wise, we plan to execute the changes incrementally. Overtime we will add PRs and issues related to the new architecture here.

The design is led by the vLLM maintainers @WoosukKwon @zhuohan123 @youkaichao @simon-mo @LiuXiaoxuanPKU @comaniac @alexm-neuralmagic @njhill @robertgshaw2-neuralmagic @rkooo567 and many others!

Metadata

Metadata

Assignees

No one assigned

    Labels

    RFCkeep-openPrevents stale label being applied

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions