Skip to content

vllm.v1.worker.gpu.spec_decode.eagle.cudagraph

DecodeEagleCudaGraphManager

Bases: CudaGraphManager

Eagle CudaGraphManager for decode draft generation, building its own attention metadata from scratch.

Source code in vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py
class DecodeEagleCudaGraphManager(CudaGraphManager):
    """Eagle CudaGraphManager for decode draft generation, building its own
    attention metadata from scratch."""

    def capture(
        self,
        forward_fn: Callable,
        model_state: ModelState,
        input_buffers: InputBuffers,
        block_tables: BlockTables,
        attn_groups: list[list[AttentionGroup]],
        kv_cache_config: KVCacheConfig,
        progress_bar_desc: str = "Capturing CUDA graphs",
    ) -> None:
        def create_forward_fn(
            desc: BatchExecutionDescriptor,
        ) -> tuple[Callable[[CUDAGraphMode], None], CapturedAttentionState]:
            num_tokens = desc.num_tokens
            num_reqs = desc.num_reqs or min(num_tokens, self.max_num_reqs)
            num_tokens_across_dp = (
                torch.full((self.dp_size,), num_tokens, dtype=torch.int32, device="cpu")
                if self.dp_size > 1
                else None
            )
            attn_state = prepare_inputs_to_capture(
                num_reqs,
                num_tokens,
                model_state,
                input_buffers,
                block_tables,
                attn_groups,
                kv_cache_config,
                skip_attn=(desc.cg_mode == CUDAGraphMode.PIECEWISE),
            )
            attn_metadata, slot_mappings = attn_state

            fwd = lambda cg_mode: forward_fn(
                num_reqs,
                num_tokens,
                attn_metadata,
                slot_mappings,
                num_tokens_across_dp,
                cg_mode,
            )
            return fwd, attn_state

        super().capture(create_forward_fn, progress_bar_desc)

PrefillEagleCudaGraphManager

Bases: CudaGraphManager

Eagle CudaGraphManager for prefill, using pre-built attention states from the target model's capture.

Source code in vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py
class PrefillEagleCudaGraphManager(CudaGraphManager):
    """Eagle CudaGraphManager for prefill, using pre-built attention states
    from the target model's capture."""

    def capture(
        self,
        forward_fn: Callable,
        full_cg_attn_states: dict[BatchExecutionDescriptor, CapturedAttentionState],
        progress_bar_desc: str = "Capturing CUDA graphs",
    ) -> None:
        def create_forward_fn(
            desc: BatchExecutionDescriptor,
        ) -> tuple[Callable[[CUDAGraphMode], None], CapturedAttentionState]:
            num_tokens = desc.num_tokens
            num_reqs = desc.num_reqs or min(num_tokens, self.max_num_reqs)
            num_tokens_across_dp = (
                torch.full((self.dp_size,), num_tokens, dtype=torch.int32, device="cpu")
                if self.dp_size > 1
                else None
            )
            attn_state = full_cg_attn_states[desc]
            attn_metadata, slot_mappings = attn_state
            fwd = lambda cg_mode: forward_fn(
                num_reqs,
                num_tokens,
                attn_metadata,
                slot_mappings,
                num_tokens_across_dp,
                cg_mode,
            )
            return fwd, attn_state

        super().capture(create_forward_fn, progress_bar_desc)