vllm.v1.sample.ops.topk_topp_triton ¶
Combined Top-K and Top-P Triton kernels.
Based on the paper "Qrita: High-performance Top-k and Top-p Algorithm for GPUs using Pivot-based Truncation and Selection" By Park et al. (https://arxiv.org/abs/2602.01518)
_update_min_larger_stats ¶
Update running (min, count) of values above a pivot across tiles.
Tracks the smallest value strictly above a pivot and how many times it occurs. Called once per tile per pivot; the running state is carried across tiles via min_larger / num_min_larger.
Merge rule
- tile min < running min → replace both
- tile min == running min → accumulate count
- tile min > running min → keep running values
Source code in vllm/v1/sample/ops/topk_topp_triton.py
apply_top_k_top_p_triton ¶
apply_top_k_top_p_triton(
logits: Tensor,
k: Tensor | None,
p: Tensor | None,
mask_value: float = float("-inf"),
) -> Tensor
Apply combined top-k and top-p masking using Triton.
Top-k is applied first (by logit value), then top-p is applied to the remaining k values (by probability).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
logits | Tensor | [batch_size, vocab_size] float32 tensor. The returned tensor may alias this input or be a new contiguous tensor for unsupported layouts. | required |
k | Tensor | None | [batch_size] int32 tensor of top-k values per row, or None to disable top-k | required |
p | Tensor | None | [batch_size] float32 tensor of top-p values per row (0 to 1), or None to disable top-p | required |
mask_value | float | Value for masked positions (default: -inf) | float('-inf') |
Returns:
| Type | Description |
|---|---|
Tensor | The masked logits tensor. It may or may not be modified in-place. |
Source code in vllm/v1/sample/ops/topk_topp_triton.py
856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 | |