class PPHandler:
"""Runs the PP sampled-token broadcast/recv on a side stream so the
default stream isn't gated by the matching peer call. Step T's recv is
consumed at step T+pp_size via `get_prev_sampled_outputs`.
Uses a dedicated NCCL communicator (sibling of the PP `device_group`)
for the broadcast so it does not serialize on the wire with the
inter-stage hidden-state p2p send/recv ops.
"""
def __init__(
self, max_num_reqs: int, num_speculative_steps: int, device: torch.device
):
self.is_last_rank = get_pp_group().is_last_rank
self.last_rank = get_pp_group().last_rank
self.max_sample_len = num_speculative_steps + 1
self.device = device
self.main_stream = torch.cuda.current_stream(device)
self.broadcast_stream = torch.cuda.Stream(device)
# On non-last ranks, a FIFO with one entry per in-flight step: the entry
# pushed by step T's `receive` is consumed pp_size steps later. Pre-seeded
# with pp_size None placeholders so the first pp_size consumes are no-ops.
# None means no postprocess is pending for that step (broadcast skipped).
self.queue: deque[PendingRecv | None] = (
deque() if self.is_last_rank else deque([None] * get_pp_group().world_size)
)
# Per req-index generation counter, incremented every time a request
# index is freed in RequestStats. Used for invalidating freed req data
# between PP decodes.
self.req_idx_gen_np = np.zeros(max_num_reqs, dtype=np.int32)
# Dedicated subgroup for the sampled-token broadcast.
self.broadcast_group = get_pp_group().make_sibling_device_group(
group_desc="pp_broadcast"
)
def on_req_idx_freed(self, req_idx: int) -> None:
self.req_idx_gen_np[req_idx] += 1
def get_prev_sampled_outputs(self) -> dict[str, torch.Tensor] | None:
"""Consume the entry from pp_size steps ago and wait for its recv event,
then filter out entries whose request was freed since `receive`.
"""
if not self.queue:
return None
slot = self.queue.popleft()
# Reserve this step's slot; `receive` overwrites it if applicable.
self.queue.append(None)
if slot is None:
return None
# Skip requests which did not need sampled output and/or those already
# finished. The post_update kernel skips the -1 entries.
freed = self.req_idx_gen_np[slot.idx_mapping_np] != slot.gen_at_receive_np
exclude_mask = freed | ~slot.need_sampled_mask
idx_mapping = slot.idx_mapping
if exclude_mask.any():
if exclude_mask.all():
# No states require update anymore.
return None
# Filter excluded request indices.
idx_mapping_np = np.where(exclude_mask, -1, slot.idx_mapping_np)
idx_mapping = async_copy_to_gpu(idx_mapping_np, device=self.device)
self.main_stream.wait_event(slot.event)
return dict(
sampled_tokens=slot.sampled_tokens,
num_sampled=slot.num_sampled,
num_rejected=slot.num_rejected,
idx_mapping=idx_mapping,
)
def receive(self, input_batch: InputBatch) -> bool:
"""Returns True iff sampled tokens need to be gathered from *all*
requests in the batch."""
assert not self.is_last_rank
need_sampled_mask = compute_need_sampled_mask(input_batch)
if need_sampled_mask is None:
# Leave this step's reserved slot as None.
return False
# Snapshot the per-slot generation counter so a later free of any of
# these RequestStates request indices is detectable at consume time.
gen_at_receive_np = self.req_idx_gen_np[input_batch.idx_mapping_np]
num_reqs = input_batch.num_reqs
with torch.cuda.stream(self.broadcast_stream):
self.broadcast_stream.wait_stream(self.main_stream)
sampled_tokens = torch.empty(
num_reqs, self.max_sample_len, dtype=torch.int64, device=self.device
)
combined = torch.empty(2, num_reqs, dtype=torch.int32, device=self.device)
torch.distributed.broadcast(
sampled_tokens, src=self.last_rank, group=self.broadcast_group
)
torch.distributed.broadcast(
combined, src=self.last_rank, group=self.broadcast_group
)
event = self.broadcast_stream.record_event()
num_sampled, num_rejected = combined.unbind(dim=0)
# Must record_stream since these were allocated on broadcast stream but
# later used on the main stream.
sampled_tokens.record_stream(self.main_stream)
combined.record_stream(self.main_stream)
self.queue[-1] = PendingRecv(
event,
sampled_tokens,
num_sampled,
num_rejected,
input_batch.idx_mapping,
input_batch.idx_mapping_np,
need_sampled_mask,
gen_at_receive_np,
)
return bool(need_sampled_mask.all())
def broadcast(
self,
sampled_token_ids: torch.Tensor,
num_sampled: torch.Tensor,
num_rejected: torch.Tensor,
input_batch: InputBatch,
) -> None:
assert self.is_last_rank
if compute_need_sampled_mask(input_batch) is None:
# No request needs sampled outputs for a subsequent decode step.
return
assert sampled_token_ids.dtype == torch.int64
with torch.cuda.stream(self.broadcast_stream):
self.broadcast_stream.wait_stream(self.main_stream)
torch.distributed.broadcast(
sampled_token_ids.contiguous(),
src=self.last_rank,
group=self.broadcast_group,
)
combined = torch.stack((num_sampled, num_rejected), dim=0)
torch.distributed.broadcast(
combined, src=self.last_rank, group=self.broadcast_group
)
for tensor in (sampled_token_ids, num_sampled, num_rejected):
tensor.record_stream(self.broadcast_stream)