Ubicloud is an open source alternative to AWS. We offer managed cloud services that build on top of PostgreSQL, Kubernetes, vLLM, and others.
vLLM is an open-source inference engine that serves large language models. We deploy multiple vLLM instances across GPUs and load open weight models like Llama 4 into them. We then load balance traffic across vLLM instances, run health checks, and do upgrades. Our customers consume our managed service by sending their prompts to our API endpoints. This endpoint also determines the vLLM instance that serves their prompt.
vLLM sits at the intersection of AI and systems programming, so we thought that diving into its details might interest some of our readers. In this blog post, we describe how an inference request travels through vLLM’s OpenAI-compatible API server and core engine. We also provide key code pointers.
We assume readers are already familiar with the transformer architecture and large language models. If you're not, we highly recommend this video by OpenAI co-founder Andrej Karpathy. We will focus on the new V1 architecture of vLLM and how it achieves state-of-the-art text generation performance. If you're looking for the V0 behavior or multi-modal inference, please refer to other vLLM documentation.
We use the following terms throughout this blog. These terms also align with what is used in vLLM’s codebase and documentation:
Before diving into the request flow, let’s first take a look at the main components of vLLM V1’s serving architecture:
The journey begins when an HTTP request arrives at the vLLM server (e.g. a POST to /v1/chat/completions). The OpenAI-compatible API server handles the HTTP communication and authentication (via the API key, if configured). This server is often launched by running the vllm serve command defined by vllm/entrypoints/cli/serve.py.
After validation, the server invokes the AsyncLLM engine to handle the request. In code, this is done by calling the AsyncLLM engine method generate(). The method is given the prompt and a freshly generated request_id to track this request. When the generate() is called, the AsyncLLM performs tokenization with the model’s tokenizer (often from Hugging Face) to convert text to token IDs. Some multimodal input processing also happens at this stage, but for simplicity, we’ll focus on text prompts here. It then sends the request to the EngineCore via AsyncMPClient using asynchronous IPC calls.
Then, a background loop at the EngineCore picks up the request from the IPC channel and places it into an internal input queue for scheduling.
Note that the AsyncLLM and EngineCore run in different processes. This bypasses Python’s Global Interpreter Lock (GIL). It allows CPU-intensive tasks, like tokenization and HTTP communication, to run alongside the GPU-intensive task, the model execution. This maximizes overall throughput.
The Scheduler module is central to vLLM’s ability to achieve high throughput. It keeps track of all requests and orchestrates their progress. The scheduler maintains a waiting deque for requests that are ready to be processed (new or resumed ones) and a running list for requests currently in active generation. In each iteration (cycle) of the engine, it first examines the internal input queue and adds new requests into the scheduler’s waiting deque. It then picks a set of requests to advance, looking at the requests from the running list first, and then the requests from the waiting deque.
vLLM uses a continuous batching algorithm. This helps maximize GPU utilization by keeping a full workload within a fixed token budget, a.k.a max_num_batched_tokens. It also ensures fairness in request order and limits the batch to just one forward pass of the model.
Here’s a continuous batching example:
Suppose we have three requests with 3, 5, and 12 prompt tokens, respectively, and a token budget of 10. In a traditional server, these might be handled sequentially or in fixed-size batches. In vLLM, however, the scheduler could decide in one iteration (Step 0) to feed, for example, 3 prompt tokens from R1, 5 from R2, and 2 from R3 in a single forward pass. Note that it only picks two prompt tokens from R3 due to the token budget of 10. In the next iteration (Step 1), it can continue with the remaining 8 prompt tokens from R3. At the same time, it will start generating 1 token each for R1 and R2. They have finished filling their prompt tokens and are now in the decoding phase.
Note that during the prefill phase, all prompt tokens from a request can be processed in one batch. This is possible because the query (Q) tensors, calculated from the tokens immediately before them, are available for each prompt token position. In the decoding phase, however, the immediately preceding token must be decoded iteratively, so we can only process one token at a time per request. The scheduler ensures that all requests complete the prefill phase before entering the decoding phase.
During the scheduling phase, the KV cache blocks are also determined. The KV Cache Manager groups tokens into fixed-size chunks and either allocates a new KV block or retrieves an existing KV block for each chunk of the tokens. Note that the actual GPU memory allocation occurs during vLLM initialization. The KV Cache Manager determines the block IDs for the KV caches of each request and attaches these IDs to the request, so they can later be used by the ModelRunners when executing the model.
Once the scheduler has determined which requests to advance, it invokes the ModelRunners via the ray-based ModelExecutor. This executes the model’s forward pass, either processing the prompt tokens or generating new tokens.
All tokens from the selected requests are combined into a single large tensor, a.k.a a batch, and are processed layer by layer through the transformer's weight matrices. At each layer, the three attention tensors, Key (K), Value (V), and Query (Q), of each attention head are computed. Then a final attention output tensor for the layer is computed from these tensors. The K and V tensors are stored for future use, and the output tensor becomes the input for the next layer. This is where GPUs come into play, as they are great at large matrix computations. Each tensor includes data from all batched requests. This allows each matrix operation of these requests to be processed on the GPU in parallel using SIMD(T) (single instruction, multiple data, and multiple threads) across all CUDA cores.
For requests in the decoding phase, the output tensor from the final transformer layer produces the logits, which represent the predicted probabilities for the next token. vLLM then applies the sampling or decoding strategy. For each sequence, it looks at the logits and either selects the top token (greedy or deterministic) or samples according to the provided parameters, e.g. temperature. This process produces the next token for each batched request. These tokens are placed into an internal output queue, where a background loop picks them up and sends them back to the AsyncLLM via IPC for output processing.
This timeline diagram shows a portion of the forward pass for the DeepSeek R1 Distilled Qwen 32B model that we serve.
Within the busy loop of the EngineCore, the model's forward method runs. This method calls the forward methods of each transformer layer. Then, the forward method of each layer calls the FlashAttention module. FlashAttention is a highly optimized function for computing transformer attention. It uses the FlashAttention-3 algorithm. This model contains a total of 64 transformer layers and this timeline diagram shows 3 of them at the bottom. The model runner repeats the forward pass through all 64 transformer layers at each engine step.
The AsyncLLM gets the new tokens from the IPC channel. Then, it processes these tokens by detokenizing them and placing them into an internal output queue. The original caller, the AsyncLLM’s generate() function, will then pick them up and pass them back to the API server.
In non-streaming mode, the API handler would accumulate tokens internally until the request is finished and then return the final output. In streaming mode, the handler will send each partial output immediately to the client in chunks like data: {...}
In this blog post, we’ve gone through the entire life of a vLLM inference request. The journey starts with the API server handing the request to the AsyncLLM, which tokenizes and sends requests to the EngineCore. At EngineCore, the scheduler batches requests. The executor works with the ModelRunner to run forward passes through the attention layers on GPUs. Then, tokens are streamed back to clients. Each component plays a distinct role in this lifecycle, and together they allow vLLM to achieve state-of-the-art serving performance.
For AI engineers looking to deploy LLMs, we hope understanding this lifecycle helps in tuning and customizing vLLM. For others, we hope this breakdown provides insight into vLLM's inner workings. By doing so, we aim to demystify how large language models are served efficiently and at scale.