AI

Deconstructing nano-vllm: Key Techniques for LLM Inference

Posted by Xuanyi on January 4, 2026

Back in 2016, I had my first major deep dive into deep learning when I was a visiting student at SUTD (Singapore University of Technology and Design) for summer research. I worked on co-training NLP tasks using CNNs and Gated Units, eventually publishing a paper in IEEE DSAA 2017. While that work was a precious learning experience for me, the field has evolved exponentially since then.

To bridge the gap between those formative years and the current state-of-the-art, I spent time dissecting nano-vllm. It is a minimalist implementation of vLLM that strips away the complexity to reveal the core “Operating System” concepts required to serve modern LLMs.

Here is what I learned about the five pillars of modern LLM inference: Scheduling, Memory, Parallelism, Communication, and Compute Optimization.

Scheduling: Continuous Batching

In traditional deep learning, batching was static. You waited for inputs, stacked them, and ran them. In LLM generation, this is inefficient because requests finish at different times.

nano-vllm implements Continuous Batching. The Scheduler treats the generation process as an OS scheduler would, prioritizing the compute-heavy Prefill phase over the memory-bound Decode phase.

Code Insight: In the scheduler loop, we manage two queues: waiting (prefill) and running (decode). Note how we prioritize new requests but are ready to “preempt” (pause) running requests if memory is tight.

def schedule(self):
    # 1. Prioritize PREFILL (New requests)
    # If we have waiting requests and available blocks, schedule them first
    while self.waiting:
        seq = self.waiting[0]
        if not self.block_manager.can_allocate(seq):
            break
        self.block_manager.allocate(seq)
        self.waiting.popleft()
        self.running.append(seq)
        return [seq], True  # Return immediately to run Prefill

    # 2. Schedule DECODE (Running requests)
    scheduled_seqs = []
    while self.running:
        seq = self.running.popleft()
        
        # 3. Preemption Logic
        # If no space for the next token's KV block, free up memory
        if not self.block_manager.can_append(seq):
            victim = self.running.pop() # Take from the tail
            self.free_memory(victim)    # Preempt
            break
        
        self.block_manager.append_slot(seq)
        scheduled_seqs.append(seq)
        
    return scheduled_seqs, False

Memory

Paged Attention

The biggest bottleneck in LLM inference isn’t just compute; it’s memory bandwidth. Instead of allocating contiguous memory (which leads to fragmentation), nano-vllm uses PagedAttention. It allocates a massive chunk of GPU memory and maps “logical” tokens to “physical” blocks, just like Virtual Memory in an OS.

  • Upon initialization, the Model Runner calculates the total available blocks and allocates one massive, empty KV tensor.
  • Slicing: This tensor is sliced and distributed to all attention layers.
  • Virtual Mapping: The scheduler assigns a “block table” to every incoming Sequence. Just like Virtual Memory in Linux maps to Physical Memory pages, the sequence maps logical tokens to physical KV blocks.

Prefix Caching

Because we are using block-based memory, we can implement Prefix Caching.

  • The token sequence is sliced into blocks.
  • If a block is filled, it gets a hash.
  • If two requests share a prefix (e.g., a long system prompt), they don’t generate new KV blocks; they simply point to the existing hashed block.

Code Insight: I learned that Triton JIT kernels are used to write these KV caches efficiently. The kernel handles the non-contiguous memory writes by calculating offsets based on a block table.

import triton
import triton.language as tl

@triton.jit
def store_kvcache_kernel(
    key_ptr, value_ptr,           # Inputs from the model
    k_cache_ptr, v_cache_ptr,     # The big allocated KV cache
    slot_mapping_ptr,             # The virtual -> physical mapping
    ...
):
    idx = tl.program_id(0)
    # Load the virtual slot index for this token
    slot = tl.load(slot_mapping_ptr + idx)
    
    # Calculate physical offset in the cache
    # [num_blocks, block_size, head_dim] mapping logic happens here
    cache_offsets = slot * D + tl.arange(0, D)
    
    # Store data directly to the scattered physical location
    tl.store(k_cache_ptr + cache_offsets, key)
    tl.store(v_cache_ptr + cache_offsets, value)

Communication: Shared Memory

In nano-vllm, orchestrating multiple GPUs (Tensor Parallelism) requires sending identical control commands—like “run the forward pass with these tokens”—to all worker processes simultaneously. Instead of using standard distributed RPC (which can introduce overhead), the engine implements a lightweight Master-Worker protocol using Python’s multiprocessing.shared_memory.

Rank 0 (the master) writes serialized commands into a shared buffer and triggers a synchronization event. All other ranks (workers) constantly listen for this event, read the command, and execute it locally. This ensures that when the main engine steps forward, all GPUs move in lockstep.

Code Insight:: The ModelRunner class implements read_shm and write_shm to handle this mechanism.


class ModelRunner:
    def __init__(self, config, rank, event):
        # ... initialization ...
        
        # 1. Setup Shared Memory
        if self.world_size > 1:
            if rank == 0:
                # Master creates the memory block
                self.shm = SharedMemory(name="nanovllm", create=True, size=2**20)
                dist.barrier()
            else:
                # Workers attach to it and enter the listen loop
                dist.barrier()
                self.shm = SharedMemory(name="nanovllm")
                self.loop()

    # Worker Loop: Constantly wait for commands
    def loop(self):
        while True:
            method_name, args = self.read_shm()
            self.call(method_name, *args)
            if method_name == "exit":
                break

    # 2. Master writes commands
    def write_shm(self, method_name, *args):
        # Serialize
        ...
        # Wake up all workers
        for event in self.event:
            event.set()

    # 3. Workers read commands
    def read_shm(self):
        # Wait for signal from Master
        self.event.wait()
        # Deserialize
        ...
        self.event.clear()
        ...

Parallelism: Tensor Parallelism (TP)

When models are too large for a single GPU, we slice the tensors. nano-vllm demonstrates the standard Megatron-LM style Tensor Parallelism:

  • Attention Layers: Use ColumnParallelLinear. Different heads of the attention mechanism can be learned and computed on different GPUs independently.
  • FFN (Feed Forward Network): Uses a specific sandwich pattern. First, a ColumnParallelLinear expands the state, followed by a RowParallelLinear to project it back.

Code Insight:: The critical component of row pallel is the all_reduce. In a Row Parallel layer, every GPU computes a partial sum. We must synchronize all cards to get the final result using NCCL.

class RowParallelLinear(nn.Module):
    def forward(self, input_):
        # 1. Local Computation
        # Each GPU multiplies the input by its slice of the Weight matrix
        output_parallel = F.linear(input_, self.weight)
        
        # 2. Synchronization
        # We use torch.distributed to sum up results from all GPUs.
        if self.world_size > 1:
            torch.distributed.all_reduce(output_parallel, op=torch.distributed.ReduceOp.SUM)
        
        return output_parallel

Compute

Modern GPUs are so fast that the CPU launching the kernels often becomes the bottleneck. nano-vllm tackles this with three strategies:

CUDA Graphs

To avoid the cost of launching individual kernels for every small operation, nano-vllm uses CUDA Graphs to capture a sequence of operations and launch them as a single graph for Decoding phase. Capture from large to small, so that large graphs allocate GPU memory first, and smaller graphs can reuse it. graph_pool allows multiple graphs to share the same GPU memory region. CUDA Graphs require static tensor shapes. We must pad tensors to fixed buckets to ensure the graph remains valid. We capture graphs starting from larger sizes down to smaller ones.

# Capture
@torch.inference_mode()
def capture_cudagraph(self):
    max_bs = min(self.config.max_num_seqs, 512)
    ...
    
    # Pre Allocate staging tensor
    input_ids = torch.zeros(max_bs, dtype=torch.int64)
    self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))
    
    for bs in reversed(self.graph_bs):
        graph = torch.cuda.CUDAGraph()
        # warmup
        outputs[:bs] = self.model(input_ids[:bs], ...)
        # capture
        with torch.cuda.graph(graph, self.graph_pool):
            outputs[:bs] = self.model(input_ids[:bs], ...)
        if self.graph_pool is None:
            self.graph_pool = graph.pool()
        self.graphs[bs] = graph
        torch.cuda.synchronize()
    ...

# Relay
@torch.inference_mode()
    def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
        if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
            return self.model.compute_logits(self.model(input_ids, positions))
        else:
            bs = input_ids.size(0)
            graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
            ...
            graph.replay()
            ...

torch.compile

We use torch.compile to auto-generate fused operators. This is critical for layers like Softmax or RMSNorm. Fusing these prevents the GPU from writing intermediate results back to HBM (High Bandwidth Memory), keeping the computation in the faster SRAM/registers.

Warm Up

Before serving real traffic, the system performs a “Warm Up” routine. The purpose of warmup is to stabilize the peak GPU memory usage, which will be used to calculate the kv_cache blocks. Run once with a “worst-case” input so that PyTorch can complete all memory allocations and compilation.

class ModelRunner:
    def warmup_model(self):
        torch.cuda.empty_cache()
        ...
        seqs = [Sequence(...)]
        self.run(seqs, ...)
        torch.cuda.empty_cache()
    def prepare_prefill(self, seqs: list[Sequence]):
        ...
        for seq in seqs:
            ...
            if not seq.block_table:    # warmup
                continue
            ...

class Attention:
    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
        context = get_context()
        k_cache, v_cache = self.k_cache, self.v_cache
        if k_cache.numel() and v_cache.numel(): # warmup
            store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
        ...

Theoretical Refreshers

To fully grasp the code, I also refreshed my knowledge on the mathematical underpinnings that have evolved since 2017:

The Stanford CS336 Lecture 5 provided an excellent refresher on the Roofline model and GPU optimizations, bringing back memories of the high-quality parallel computing course I took years ago with Professor P. Sadayappan at OSU, a prominent figure in the HPC community.

What’s Next?

The infrastructure logic is clear. My next step is to get hands-on with the model artifacts:

  • Understand sampling strategies (Temperature, Top-P).
  • Run the Qwen model using this infra.
  • Perform profiling to visualize the difference between eager mode and CUDA graph execution.