Skip to content

vllm.model_executor.model_loader

Modules:

Name Description
base_loader
default_loader
dummy_loader
ep_weight_filter

Filter out non-local expert weights during loading to avoid redundant I/O.

gguf_loader
modelexpress_loader
reload

Layerwise weight reloading utilities for vLLM.

runai_streamer_loader
sharded_state_loader
tensorizer
tensorizer_loader
utils

Utilities for selecting and loading models.

weight_utils

Utilities for downloading and initializing model weights.

BaseModelLoader

Bases: ABC

Base class for model loaders.

Source code in vllm/model_executor/model_loader/base_loader.py
class BaseModelLoader(ABC):
    """Base class for model loaders."""

    def __init__(self, load_config: LoadConfig):
        self.load_config = load_config

    @abstractmethod
    def download_model(self, model_config: ModelConfig) -> None:
        """Download a model so that it can be immediately loaded."""
        raise NotImplementedError

    @abstractmethod
    def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
        """Load weights into a model. This standalone API allows
        inplace weights loading for an already-initialized model"""
        raise NotImplementedError

    @instrument(span_name="Load model")
    def load_model(
        self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
    ) -> nn.Module:
        """Load a model with the given configurations."""
        device_config = vllm_config.device_config
        load_config = vllm_config.load_config
        load_device = (
            device_config.device if load_config.device is None else load_config.device
        )
        target_device = torch.device(load_device)
        with set_default_torch_dtype(model_config.dtype):
            with target_device:
                model = initialize_model(
                    vllm_config=vllm_config,
                    model_config=model_config,
                    prefix=prefix,
                )

            log_model_inspection(model)

            logger.debug("Loading weights on %s ...", load_device)
            self.load_weights(model, model_config)

            # Log peak GPU memory after loading weights. This is needed
            # to have test coverage on peak memory for online quantization.
            if current_platform.is_cuda_alike():
                peak_memory = torch.accelerator.max_memory_allocated()
                logger.debug_once(
                    "Peak GPU memory after loading weights: %s GiB",
                    format_gib(peak_memory),
                )

            # Process weights into kernel format. Note that when using online
            # quantization, weights are (typically) quantized as they are loaded.
            if _has_online_quant(model):
                finalize_layerwise_processing(model, model_config)

            process_weights_after_loading(model, model_config, target_device)

        return model.eval()

download_model abstractmethod

download_model(model_config: ModelConfig) -> None

Download a model so that it can be immediately loaded.

Source code in vllm/model_executor/model_loader/base_loader.py
@abstractmethod
def download_model(self, model_config: ModelConfig) -> None:
    """Download a model so that it can be immediately loaded."""
    raise NotImplementedError

load_model

load_model(
    vllm_config: VllmConfig,
    model_config: ModelConfig,
    prefix: str = "",
) -> Module

Load a model with the given configurations.

Source code in vllm/model_executor/model_loader/base_loader.py
@instrument(span_name="Load model")
def load_model(
    self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
) -> nn.Module:
    """Load a model with the given configurations."""
    device_config = vllm_config.device_config
    load_config = vllm_config.load_config
    load_device = (
        device_config.device if load_config.device is None else load_config.device
    )
    target_device = torch.device(load_device)
    with set_default_torch_dtype(model_config.dtype):
        with target_device:
            model = initialize_model(
                vllm_config=vllm_config,
                model_config=model_config,
                prefix=prefix,
            )

        log_model_inspection(model)

        logger.debug("Loading weights on %s ...", load_device)
        self.load_weights(model, model_config)

        # Log peak GPU memory after loading weights. This is needed
        # to have test coverage on peak memory for online quantization.
        if current_platform.is_cuda_alike():
            peak_memory = torch.accelerator.max_memory_allocated()
            logger.debug_once(
                "Peak GPU memory after loading weights: %s GiB",
                format_gib(peak_memory),
            )

        # Process weights into kernel format. Note that when using online
        # quantization, weights are (typically) quantized as they are loaded.
        if _has_online_quant(model):
            finalize_layerwise_processing(model, model_config)

        process_weights_after_loading(model, model_config, target_device)

    return model.eval()

load_weights abstractmethod

load_weights(
    model: Module, model_config: ModelConfig
) -> None

Load weights into a model. This standalone API allows inplace weights loading for an already-initialized model

Source code in vllm/model_executor/model_loader/base_loader.py
@abstractmethod
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
    """Load weights into a model. This standalone API allows
    inplace weights loading for an already-initialized model"""
    raise NotImplementedError

DefaultModelLoader

Bases: BaseModelLoader

Model loader that can load different file types from disk.

Source code in vllm/model_executor/model_loader/default_loader.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
class DefaultModelLoader(BaseModelLoader):
    """Model loader that can load different file types from disk."""

    # default number of thread when enable multithread weight loading
    DEFAULT_NUM_THREADS = 8

    @dataclasses.dataclass
    class Source:
        """A source for weights."""

        model_or_path: str
        """The model ID or path."""

        revision: str | None
        """The optional model revision."""

        subfolder: str | None = None
        """The subfolder inside the model repo."""

        prefix: str = ""
        """A prefix to prepend to all weights."""

        fall_back_to_pt: bool = True
        """Whether .pt weights can be used."""

        allow_patterns_overrides: list[str] | None = None
        """If defined, weights will load exclusively using these patterns."""

    counter_before_loading_weights: float = 0.0
    counter_after_loading_weights: float = 0.0

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        self.local_expert_ids: set[int] | None = None

        extra_config = load_config.model_loader_extra_config
        allowed_keys = {
            "enable_multithread_load",
            "num_threads",
            "enable_weights_track",
        }
        unexpected_keys = set(extra_config.keys()) - allowed_keys

        if unexpected_keys:
            raise ValueError(
                f"Unexpected extra config keys for load format "
                f"{load_config.load_format}: "
                f"{unexpected_keys}"
            )

        self.enable_weights_track: bool | None = extra_config.get(
            "enable_weights_track", None
        )

    def _prepare_weights(
        self,
        model_name_or_path: str,
        subfolder: str | None,
        revision: str | None,
        fall_back_to_pt: bool,
        allow_patterns_overrides: list[str] | None,
    ) -> tuple[str, list[str], bool]:
        """Prepare weights for the model.

        If the model is not local, it will be downloaded."""
        model_name_or_path = (
            maybe_download_from_modelscope(model_name_or_path, revision)
            or model_name_or_path
        )

        is_local = os.path.isdir(model_name_or_path)
        load_format = self.load_config.load_format
        use_safetensors = False
        index_file = SAFE_WEIGHTS_INDEX_NAME

        # First check for 'auto' format that mistral files format are present.
        # This is to load mistral models with official format by default.
        if load_format == "auto":
            load_format = (
                "mistral"
                if len(
                    list_filtered_repo_files(
                        model_name_or_path=model_name_or_path,
                        allow_patterns=["consolidated*.safetensors"],
                        revision=revision,
                    )
                )
                > 0
                else "hf"
            )

        # Some quantized models use .pt files for storing the weights.
        if load_format == "hf":
            allow_patterns = ["*.safetensors", "*.bin"]
        elif (
            load_format == "safetensors"
            or load_format == "fastsafetensors"
            or load_format == "instanttensor"
        ):
            use_safetensors = True
            allow_patterns = ["*.safetensors"]
        elif load_format == "mistral":
            use_safetensors = True
            allow_patterns = ["consolidated*.safetensors"]
            index_file = "consolidated.safetensors.index.json"
        elif load_format == "pt":
            allow_patterns = ["*.pt"]
        elif load_format == "npcache":
            allow_patterns = ["*.bin"]
        else:
            raise ValueError(f"Unknown load_format: {load_format}")

        if fall_back_to_pt:
            allow_patterns += ["*.pt"]

        if allow_patterns_overrides is not None:
            allow_patterns = allow_patterns_overrides

        if not is_local:
            hf_folder = download_weights_from_hf(
                model_name_or_path,
                self.load_config.download_dir,
                allow_patterns,
                revision,
                subfolder=subfolder,
                ignore_patterns=self.load_config.ignore_patterns,
            )
        else:
            hf_folder = model_name_or_path

        if subfolder is not None:
            hf_folder = os.path.join(hf_folder, subfolder)

        hf_weights_files: list[str] = []
        for pattern in allow_patterns:
            hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
            if len(hf_weights_files) > 0:
                if pattern.endswith(".safetensors"):
                    use_safetensors = True
                break

        if use_safetensors:
            # For models like Mistral-7B-Instruct-v0.3
            # there are both sharded safetensors files and a consolidated
            # safetensors file. Using both breaks.
            # Here, we download the `model.safetensors.index.json` and filter
            # any files not found in the index.
            if not is_local:
                download_safetensors_index_file_from_hf(
                    model_name_or_path,
                    index_file,
                    cache_dir=self.load_config.download_dir,
                    subfolder=subfolder,
                    revision=revision,
                )
            hf_weights_files = filter_duplicate_safetensors_files(
                hf_weights_files, hf_folder, index_file
            )
        else:
            hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)

        if len(hf_weights_files) == 0:
            raise RuntimeError(
                f"Cannot find any model weights with `{model_name_or_path}`"
            )

        return hf_folder, hf_weights_files, use_safetensors

    def _get_weights_iterator(
        self, source: "Source"
    ) -> Generator[tuple[str, torch.Tensor], None, None]:
        """Get an iterator for the model weights based on the load format."""
        extra_config = self.load_config.model_loader_extra_config
        hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
            source.model_or_path,
            source.subfolder,
            source.revision,
            source.fall_back_to_pt,
            source.allow_patterns_overrides,
        )
        if self.load_config.load_format == "npcache":
            # Currently np_cache only support *.bin checkpoints
            assert use_safetensors is False
            weights_iterator = np_cache_weights_iterator(
                source.model_or_path,
                self.load_config.download_dir,
                hf_folder,
                hf_weights_files,
                self.load_config.use_tqdm_on_load,
            )
        elif use_safetensors:
            if self.load_config.load_format == "fastsafetensors":
                weights_iterator = fastsafetensors_weights_iterator(
                    hf_weights_files,
                    self.load_config.use_tqdm_on_load,
                )
            elif self.load_config.load_format == "instanttensor":
                weights_iterator = instanttensor_weights_iterator(
                    hf_weights_files,
                    self.load_config.use_tqdm_on_load,
                )
            else:
                if extra_config.get("enable_multithread_load"):
                    weights_iterator = multi_thread_safetensors_weights_iterator(
                        hf_weights_files,
                        self.load_config.use_tqdm_on_load,
                        max_workers=extra_config.get(
                            "num_threads", self.DEFAULT_NUM_THREADS
                        ),
                    )
                else:
                    weights_iterator = safetensors_weights_iterator(
                        hf_weights_files,
                        self.load_config.use_tqdm_on_load,
                        self.load_config.safetensors_load_strategy,
                        local_expert_ids=self.local_expert_ids,
                        safetensors_prefetch_num_threads=(
                            self.load_config.safetensors_prefetch_num_threads
                        ),
                        safetensors_prefetch_block_size=(
                            self.load_config.safetensors_prefetch_block_size
                        ),
                    )
        else:
            if extra_config.get("enable_multithread_load"):
                weights_iterator = multi_thread_pt_weights_iterator(
                    hf_weights_files,
                    self.load_config.use_tqdm_on_load,
                    self.load_config.pt_load_map_location,
                    max_workers=extra_config.get(
                        "num_threads", self.DEFAULT_NUM_THREADS
                    ),
                )
            else:
                weights_iterator = pt_weights_iterator(
                    hf_weights_files,
                    self.load_config.use_tqdm_on_load,
                    self.load_config.pt_load_map_location,
                )

        if self.counter_before_loading_weights == 0.0:
            self.counter_before_loading_weights = time.perf_counter()
        # Apply the prefix.
        return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)

    def get_all_weights(
        self,
        model_config: ModelConfig,
        model: nn.Module,
    ) -> Generator[tuple[str, torch.Tensor], None, None]:
        primary_weights = DefaultModelLoader.Source(
            model_config.model,
            model_config.revision,
            prefix="",
            fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
            allow_patterns_overrides=getattr(model, "allow_patterns_overrides", None),
        )
        yield from self._get_weights_iterator(primary_weights)

        secondary_weights = cast(
            Iterable[DefaultModelLoader.Source],
            getattr(model, "secondary_weights", ()),
        )
        for source in secondary_weights:
            yield from self._get_weights_iterator(source)

    def download_model(self, model_config: ModelConfig) -> None:
        self._prepare_weights(
            model_name_or_path=model_config.model,
            subfolder=None,
            revision=model_config.revision,
            fall_back_to_pt=True,
            allow_patterns_overrides=None,
        )

    def _init_ep_weight_filter(self, model_config: ModelConfig) -> None:
        """Compute local expert ids for EP weight filtering.

        When expert parallelism is active, each rank only needs a subset of
        expert weights.  By computing the set upfront we can skip non-local
        expert tensors *before* reading them from disk.
        """
        from vllm.config import get_current_vllm_config

        vllm_config = get_current_vllm_config()
        parallel_config = vllm_config.parallel_config

        if not (
            model_config.is_moe
            and parallel_config.enable_expert_parallel
            and parallel_config.enable_ep_weight_filter
        ):
            return

        # When EPLB is enabled, redundant physical expert slots may map to
        # logical experts that belong to other ranks in the default partition.
        # The weight loader needs to see ALL logical expert weights so it can
        # populate these redundant slots.  Skip the filter entirely.
        if parallel_config.enable_eplb:
            return

        num_experts = model_config.get_num_experts()
        if num_experts <= 0:
            return

        # EP size/rank computation mirrors FusedMoEParallelConfig.make():
        #   ep_size = dp_size * pcp_size * tp_size (flattened)
        #   ep_rank = dp_rank * pcp_size * tp_size + pcp_rank * tp_size + tp_rank
        from vllm.distributed import (
            get_dp_group,
            get_pcp_group,
            get_tensor_model_parallel_rank,
        )

        dp_size = parallel_config.data_parallel_size
        tp_size = parallel_config.tensor_parallel_size
        pcp_size = parallel_config.prefill_context_parallel_size
        dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
        tp_rank = get_tensor_model_parallel_rank() if tp_size > 1 else 0
        pcp_rank = get_pcp_group().rank_in_group if pcp_size > 1 else 0
        ep_size = dp_size * pcp_size * tp_size
        ep_rank = dp_rank * pcp_size * tp_size + pcp_rank * tp_size + tp_rank

        self.local_expert_ids = compute_local_expert_ids(
            num_experts,
            ep_size,
            ep_rank,
            placement=parallel_config.expert_placement_strategy,
        )
        if self.local_expert_ids is not None:
            logger.info_once(
                "EP weight filter: ep_size=%d, ep_rank=%d, loading %d/%d experts",
                ep_size,
                ep_rank,
                len(self.local_expert_ids),
                num_experts,
            )

    @instrument(span_name="Load weights")
    def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
        if model_config.quantization == "torchao":
            quant_config = get_quant_config(model_config, self.load_config)
            if (
                hasattr(quant_config, "is_checkpoint_torchao_serialized")
                and quant_config.is_checkpoint_torchao_serialized
                and torchao_version_at_least("0.15.0")
            ):
                self.load_config.safetensors_load_strategy = "torchao"

        self._init_ep_weight_filter(model_config)

        loaded_weights = model.load_weights(self.get_all_weights(model_config, model))

        self.counter_after_loading_weights = time.perf_counter()
        logger.info_once(
            "Loading weights took %.2f seconds",
            self.counter_after_loading_weights - self.counter_before_loading_weights,
        )
        # We only enable strict check for non-quantized models
        # that have loaded weights tracking by default.
        default_enable_weights_track = (
            model_config.quantization is None and loaded_weights is not None
        )
        enable_weights_track = (
            self.enable_weights_track
            if self.enable_weights_track is not None
            else default_enable_weights_track
        )
        if enable_weights_track:
            self.track_weights_loading(model, loaded_weights)

    def track_weights_loading(
        self, model: nn.Module, loaded_weights: set[str] | None
    ) -> None:
        weights_to_load = {name for name, _ in model.named_parameters()}
        if loaded_weights is not None:
            # ignore online quantization scales
            for name, module in model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                has_online_quant = getattr(quant_method, "uses_meta_device", False)
                has_postprocess_quant = getattr(
                    quant_method, "process_weights_after_loading", None
                )
                # ignore kv_cache scale and online quant scale,
                # which can be missing in checkpoints
                if has_online_quant or has_postprocess_quant:
                    for param_name, _ in module.named_parameters():
                        full_name = f"{name}.{param_name}" if name else param_name
                        loaded_weights.add(full_name)
            weights_not_loaded = weights_to_load - loaded_weights
            if weights_not_loaded:
                raise ValueError(
                    "Following weights were not initialized from "
                    f"checkpoint: {weights_not_loaded}"
                )

Source dataclass

A source for weights.

Source code in vllm/model_executor/model_loader/default_loader.py
@dataclasses.dataclass
class Source:
    """A source for weights."""

    model_or_path: str
    """The model ID or path."""

    revision: str | None
    """The optional model revision."""

    subfolder: str | None = None
    """The subfolder inside the model repo."""

    prefix: str = ""
    """A prefix to prepend to all weights."""

    fall_back_to_pt: bool = True
    """Whether .pt weights can be used."""

    allow_patterns_overrides: list[str] | None = None
    """If defined, weights will load exclusively using these patterns."""

allow_patterns_overrides class-attribute instance-attribute

allow_patterns_overrides: list[str] | None = None

If defined, weights will load exclusively using these patterns.

fall_back_to_pt class-attribute instance-attribute

fall_back_to_pt: bool = True

Whether .pt weights can be used.

model_or_path instance-attribute

model_or_path: str

The model ID or path.

prefix class-attribute instance-attribute

prefix: str = ''

A prefix to prepend to all weights.

revision instance-attribute

revision: str | None

The optional model revision.

subfolder class-attribute instance-attribute

subfolder: str | None = None

The subfolder inside the model repo.

_get_weights_iterator

_get_weights_iterator(
    source: Source,
) -> Generator[tuple[str, Tensor], None, None]

Get an iterator for the model weights based on the load format.

Source code in vllm/model_executor/model_loader/default_loader.py
def _get_weights_iterator(
    self, source: "Source"
) -> Generator[tuple[str, torch.Tensor], None, None]:
    """Get an iterator for the model weights based on the load format."""
    extra_config = self.load_config.model_loader_extra_config
    hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
        source.model_or_path,
        source.subfolder,
        source.revision,
        source.fall_back_to_pt,
        source.allow_patterns_overrides,
    )
    if self.load_config.load_format == "npcache":
        # Currently np_cache only support *.bin checkpoints
        assert use_safetensors is False
        weights_iterator = np_cache_weights_iterator(
            source.model_or_path,
            self.load_config.download_dir,
            hf_folder,
            hf_weights_files,
            self.load_config.use_tqdm_on_load,
        )
    elif use_safetensors:
        if self.load_config.load_format == "fastsafetensors":
            weights_iterator = fastsafetensors_weights_iterator(
                hf_weights_files,
                self.load_config.use_tqdm_on_load,
            )
        elif self.load_config.load_format == "instanttensor":
            weights_iterator = instanttensor_weights_iterator(
                hf_weights_files,
                self.load_config.use_tqdm_on_load,
            )
        else:
            if extra_config.get("enable_multithread_load"):
                weights_iterator = multi_thread_safetensors_weights_iterator(
                    hf_weights_files,
                    self.load_config.use_tqdm_on_load,
                    max_workers=extra_config.get(
                        "num_threads", self.DEFAULT_NUM_THREADS
                    ),
                )
            else:
                weights_iterator = safetensors_weights_iterator(
                    hf_weights_files,
                    self.load_config.use_tqdm_on_load,
                    self.load_config.safetensors_load_strategy,
                    local_expert_ids=self.local_expert_ids,
                    safetensors_prefetch_num_threads=(
                        self.load_config.safetensors_prefetch_num_threads
                    ),
                    safetensors_prefetch_block_size=(
                        self.load_config.safetensors_prefetch_block_size
                    ),
                )
    else:
        if extra_config.get("enable_multithread_load"):
            weights_iterator = multi_thread_pt_weights_iterator(
                hf_weights_files,
                self.load_config.use_tqdm_on_load,
                self.load_config.pt_load_map_location,
                max_workers=extra_config.get(
                    "num_threads", self.DEFAULT_NUM_THREADS
                ),
            )
        else:
            weights_iterator = pt_weights_iterator(
                hf_weights_files,
                self.load_config.use_tqdm_on_load,
                self.load_config.pt_load_map_location,
            )

    if self.counter_before_loading_weights == 0.0:
        self.counter_before_loading_weights = time.perf_counter()
    # Apply the prefix.
    return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)

_init_ep_weight_filter

_init_ep_weight_filter(model_config: ModelConfig) -> None

Compute local expert ids for EP weight filtering.

When expert parallelism is active, each rank only needs a subset of expert weights. By computing the set upfront we can skip non-local expert tensors before reading them from disk.

Source code in vllm/model_executor/model_loader/default_loader.py
def _init_ep_weight_filter(self, model_config: ModelConfig) -> None:
    """Compute local expert ids for EP weight filtering.

    When expert parallelism is active, each rank only needs a subset of
    expert weights.  By computing the set upfront we can skip non-local
    expert tensors *before* reading them from disk.
    """
    from vllm.config import get_current_vllm_config

    vllm_config = get_current_vllm_config()
    parallel_config = vllm_config.parallel_config

    if not (
        model_config.is_moe
        and parallel_config.enable_expert_parallel
        and parallel_config.enable_ep_weight_filter
    ):
        return

    # When EPLB is enabled, redundant physical expert slots may map to
    # logical experts that belong to other ranks in the default partition.
    # The weight loader needs to see ALL logical expert weights so it can
    # populate these redundant slots.  Skip the filter entirely.
    if parallel_config.enable_eplb:
        return

    num_experts = model_config.get_num_experts()
    if num_experts <= 0:
        return

    # EP size/rank computation mirrors FusedMoEParallelConfig.make():
    #   ep_size = dp_size * pcp_size * tp_size (flattened)
    #   ep_rank = dp_rank * pcp_size * tp_size + pcp_rank * tp_size + tp_rank
    from vllm.distributed import (
        get_dp_group,
        get_pcp_group,
        get_tensor_model_parallel_rank,
    )

    dp_size = parallel_config.data_parallel_size
    tp_size = parallel_config.tensor_parallel_size
    pcp_size = parallel_config.prefill_context_parallel_size
    dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
    tp_rank = get_tensor_model_parallel_rank() if tp_size > 1 else 0
    pcp_rank = get_pcp_group().rank_in_group if pcp_size > 1 else 0
    ep_size = dp_size * pcp_size * tp_size
    ep_rank = dp_rank * pcp_size * tp_size + pcp_rank * tp_size + tp_rank

    self.local_expert_ids = compute_local_expert_ids(
        num_experts,
        ep_size,
        ep_rank,
        placement=parallel_config.expert_placement_strategy,
    )
    if self.local_expert_ids is not None:
        logger.info_once(
            "EP weight filter: ep_size=%d, ep_rank=%d, loading %d/%d experts",
            ep_size,
            ep_rank,
            len(self.local_expert_ids),
            num_experts,
        )

_prepare_weights

_prepare_weights(
    model_name_or_path: str,
    subfolder: str | None,
    revision: str | None,
    fall_back_to_pt: bool,
    allow_patterns_overrides: list[str] | None,
) -> tuple[str, list[str], bool]

Prepare weights for the model.

If the model is not local, it will be downloaded.

Source code in vllm/model_executor/model_loader/default_loader.py
def _prepare_weights(
    self,
    model_name_or_path: str,
    subfolder: str | None,
    revision: str | None,
    fall_back_to_pt: bool,
    allow_patterns_overrides: list[str] | None,
) -> tuple[str, list[str], bool]:
    """Prepare weights for the model.

    If the model is not local, it will be downloaded."""
    model_name_or_path = (
        maybe_download_from_modelscope(model_name_or_path, revision)
        or model_name_or_path
    )

    is_local = os.path.isdir(model_name_or_path)
    load_format = self.load_config.load_format
    use_safetensors = False
    index_file = SAFE_WEIGHTS_INDEX_NAME

    # First check for 'auto' format that mistral files format are present.
    # This is to load mistral models with official format by default.
    if load_format == "auto":
        load_format = (
            "mistral"
            if len(
                list_filtered_repo_files(
                    model_name_or_path=model_name_or_path,
                    allow_patterns=["consolidated*.safetensors"],
                    revision=revision,
                )
            )
            > 0
            else "hf"
        )

    # Some quantized models use .pt files for storing the weights.
    if load_format == "hf":
        allow_patterns = ["*.safetensors", "*.bin"]
    elif (
        load_format == "safetensors"
        or load_format == "fastsafetensors"
        or load_format == "instanttensor"
    ):
        use_safetensors = True
        allow_patterns = ["*.safetensors"]
    elif load_format == "mistral":
        use_safetensors = True
        allow_patterns = ["consolidated*.safetensors"]
        index_file = "consolidated.safetensors.index.json"
    elif load_format == "pt":
        allow_patterns = ["*.pt"]
    elif load_format == "npcache":
        allow_patterns = ["*.bin"]
    else:
        raise ValueError(f"Unknown load_format: {load_format}")

    if fall_back_to_pt:
        allow_patterns += ["*.pt"]

    if allow_patterns_overrides is not None:
        allow_patterns = allow_patterns_overrides

    if not is_local:
        hf_folder = download_weights_from_hf(
            model_name_or_path,
            self.load_config.download_dir,
            allow_patterns,
            revision,
            subfolder=subfolder,
            ignore_patterns=self.load_config.ignore_patterns,
        )
    else:
        hf_folder = model_name_or_path

    if subfolder is not None:
        hf_folder = os.path.join(hf_folder, subfolder)

    hf_weights_files: list[str] = []
    for pattern in allow_patterns:
        hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
        if len(hf_weights_files) > 0:
            if pattern.endswith(".safetensors"):
                use_safetensors = True
            break

    if use_safetensors:
        # For models like Mistral-7B-Instruct-v0.3
        # there are both sharded safetensors files and a consolidated
        # safetensors file. Using both breaks.
        # Here, we download the `model.safetensors.index.json` and filter
        # any files not found in the index.
        if not is_local:
            download_safetensors_index_file_from_hf(
                model_name_or_path,
                index_file,
                cache_dir=self.load_config.download_dir,
                subfolder=subfolder,
                revision=revision,
            )
        hf_weights_files = filter_duplicate_safetensors_files(
            hf_weights_files, hf_folder, index_file
        )
    else:
        hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)

    if len(hf_weights_files) == 0:
        raise RuntimeError(
            f"Cannot find any model weights with `{model_name_or_path}`"
        )

    return hf_folder, hf_weights_files, use_safetensors

DummyModelLoader

Bases: BaseModelLoader

Model loader that will set model weights to random values.

Source code in vllm/model_executor/model_loader/dummy_loader.py
class DummyModelLoader(BaseModelLoader):
    """Model loader that will set model weights to random values."""

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            raise ValueError(
                f"Model loader extra config is not supported for "
                f"load format {load_config.load_format}"
            )

    def download_model(self, model_config: ModelConfig) -> None:
        pass  # Nothing to download

    def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
        for layer in model.modules():
            info = get_layerwise_info(layer)
            if info.can_load():
                self._process_online_quant_layer(layer, info)
            else:
                # NOTE(woosuk): For accurate performance evaluation, we assign
                # random values to the weights.
                initialize_dummy_weights(layer, model_config)

    def _process_online_quant_layer(
        self,
        layer: nn.Module,
        info: LayerReloadingInfo,
    ) -> None:
        """Materialize, apply dummy weights, and run quantization processing."""
        materialize_layer(layer, info)

        for tensor in get_layer_tensors(layer).values():
            initialize_single_dummy_weight(tensor)

        for param in get_layer_tensors(layer).values():
            param.weight_loader = _get_original_loader(param)

        quant_method = getattr(layer, "quant_method", None)
        if isinstance(quant_method, QuantizeMethodBase):
            quant_method.process_weights_after_loading(layer)

        info.reset()

_process_online_quant_layer

_process_online_quant_layer(
    layer: Module, info: LayerReloadingInfo
) -> None

Materialize, apply dummy weights, and run quantization processing.

Source code in vllm/model_executor/model_loader/dummy_loader.py
def _process_online_quant_layer(
    self,
    layer: nn.Module,
    info: LayerReloadingInfo,
) -> None:
    """Materialize, apply dummy weights, and run quantization processing."""
    materialize_layer(layer, info)

    for tensor in get_layer_tensors(layer).values():
        initialize_single_dummy_weight(tensor)

    for param in get_layer_tensors(layer).values():
        param.weight_loader = _get_original_loader(param)

    quant_method = getattr(layer, "quant_method", None)
    if isinstance(quant_method, QuantizeMethodBase):
        quant_method.process_weights_after_loading(layer)

    info.reset()

GGUFModelLoader

Bases: BaseModelLoader

Model loader that can load GGUF files. This is useful for loading models that are quantized with GGUF and saved in the GGUF format. This loader supports loading both full models and sharded models.

Source code in vllm/model_executor/model_loader/gguf_loader.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
class GGUFModelLoader(BaseModelLoader):
    """
    Model loader that can load GGUF files. This is useful for loading models
    that are quantized with GGUF and saved in the GGUF format. This loader
    supports loading both full models and sharded models.
    """

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            raise ValueError(
                f"Model loader extra config is not supported for "
                f"load format {load_config.load_format}"
            )

    def _prepare_weights(self, model_config: ModelConfig):
        model_name_or_path = model_config.model
        if os.path.isfile(model_name_or_path):
            return model_name_or_path
        # repo id/filename.gguf
        if "/" in model_name_or_path and model_name_or_path.endswith(".gguf"):
            repo_id, filename = model_name_or_path.rsplit("/", 1)
            return hf_api().hf_hub_download(
                repo_id=repo_id,
                filename=filename,
                revision=model_config.revision,
                cache_dir=self.load_config.download_dir,
            )
        # repo_id:quant_type
        elif "/" in model_name_or_path and ":" in model_name_or_path:
            repo_id, quant_type = model_name_or_path.rsplit(":", 1)
            return download_gguf(
                repo_id,
                quant_type,
                cache_dir=self.load_config.download_dir,
                revision=model_config.revision,
                ignore_patterns=self.load_config.ignore_patterns,
            )

        raise ValueError(
            f"Unrecognised GGUF reference: {model_name_or_path} "
            "(expected local file, <repo_id>/<filename>.gguf, "
            "or <repo_id>:<quant_type>)"
        )

    @staticmethod
    def _get_all_gguf_files(model_path: str) -> list[str]:
        """Discover all GGUF shard files from a single shard path.

        Supports variable-width shard indices by dynamically detecting
        the padding from the original filename.
        E.g. ``*-00001-of-00005.gguf`` → all 5 shards,
             ``*-01-of-15.gguf`` → all 15 shards.
        """
        match = re.search(r"-(\d+)-of-(\d+)\.gguf$", model_path)
        if not match:
            return [model_path]
        total = int(match.group(2))
        num_digits = len(match.group(1))
        prefix = model_path[: match.start(1)]
        suffix = model_path[match.end(2) :]
        files = []
        for i in range(1, total + 1):
            shard_path = f"{prefix}{i:0{num_digits}d}-of-{total:0{num_digits}d}{suffix}"
            if os.path.isfile(shard_path):
                files.append(shard_path)
        if files:
            logger.info("Discovered %d GGUF shard files", len(files))
        return files if files else [model_path]

    def _get_gguf_weights_map(self, model_config: ModelConfig):
        """
        GGUF uses this naming convention for their tensors from HF checkpoint:
        `blk.N.BB.weight` and `blk.N.BB.bias`
        where N signifies the block number of a layer, and BB signifies the
        attention/mlp layer components.
        See "Standardized tensor names" in
        https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
        """
        config = model_config.hf_config
        # Get text config to handle both nested (multimodal) and flat
        # (text-only) config structures. For multimodal models like
        # Gemma3Config, this returns config.text_config. For text-only
        # models, this returns config itself.
        text_config = config.get_text_config()
        model_type = config.model_type
        is_multimodal = (
            hasattr(config, "vision_config") and config.vision_config is not None
        )
        gguf_to_hf_name_map = {}
        sideload_params: list[re.Pattern] = []
        # hack: ggufs have a different name than transformers
        if model_type == "cohere":
            model_type = "command-r"
        if model_type == "gemma3_text":
            # Gemma3 models use "gemma3_text" in HuggingFace but
            # "gemma3" in GGUF architecture naming
            model_type = "gemma3"
        if model_type in ("deepseek_v3", "deepseek_v2"):
            model_type = "deepseek2"
            # GGUF layer map assumes that we will have a merged expert weights
            # so we need to map them manually
            for idx in range(config.num_hidden_layers):
                gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = (
                    f"model.layers.{idx}.mlp.gate.e_score_correction_bias"
                )
                gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = (
                    f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
                )
                gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = (
                    f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
                )
                gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = (
                    f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
                )
                sideload_params.append(
                    re.compile(
                        f"model\\.layers\\.{idx}"
                        r"\.mlp\.experts\.[0-9]+\.(gate|up|down)_proj\.weight"
                    )
                )
        if model_type in ("qwen2_moe", "qwen3_moe"):
            model_type = model_type.replace("_", "")
            # GGUF layer map assumes that we will have a merged expert weights
            # so we need to map them manually
            for idx in range(config.num_hidden_layers):
                gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = (
                    f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
                )
                gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = (
                    f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
                )
                gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = (
                    f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
                )
                sideload_params.append(
                    re.compile(
                        f"model\\.layers\\.{idx}"
                        r"\.mlp\.experts\.[0-9]+\.(gate|up|down)_proj\.weight"
                    )
                )
        if model_type == "minimax_m2":
            model_type = "minimax-m2"
            # GGUF layer map assumes merged expert weights
            # map them manually like deepseek2
            for idx in range(config.num_hidden_layers):
                gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = (
                    f"model.layers.{idx}.block_sparse_moe.e_score_correction_bias"
                )
                gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = (
                    f"model.layers.{idx}.block_sparse_moe.experts.0.w2.weight"
                )
                gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = (
                    f"model.layers.{idx}.block_sparse_moe.experts.0.w1.weight"
                )
                gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = (
                    f"model.layers.{idx}.block_sparse_moe.experts.0.w3.weight"
                )
                sideload_params.append(
                    re.compile(
                        f"model\\.layers\\.{idx}"
                        r"\.block_sparse_moe\.experts\.(gate_up_proj|down_proj)"
                    )
                )

        arch = None
        for key, value in gguf.MODEL_ARCH_NAMES.items():
            if value == model_type:
                arch = key
                break
        if arch is None:
            raise RuntimeError(f"Unknown gguf model_type: {model_type}")
        text_num_layers = text_config.num_hidden_layers
        text_name_map = gguf.get_tensor_name_map(arch, text_num_layers)

        if is_multimodal:
            mm_proj_arch = gguf.MODEL_ARCH.MMPROJ
            vision_num_layers = config.vision_config.num_hidden_layers
            vision_name_map = gguf.get_tensor_name_map(mm_proj_arch, vision_num_layers)
        else:
            vision_name_map = None

        # Create dummy model to extract parameter names
        # For multimodal: use AutoModelForImageTextToText to get
        # language + vision + projector params
        # For text-only: use AutoModelForCausalLM to get language model params
        auto_cls = (
            AutoModelForImageTextToText if is_multimodal else AutoModelForCausalLM
        )
        with torch.device("meta"):
            dummy_model = auto_cls.from_config(
                config, trust_remote_code=model_config.trust_remote_code
            )

        state_dict = dummy_model.state_dict()
        if hf_checkpoint_map := getattr(
            dummy_model, "_checkpoint_conversion_mapping", None
        ):

            def revert_hf_rename(name: str) -> str:
                for original_name, hf_name in hf_checkpoint_map.items():
                    if hf_name in name:
                        name = name.replace(hf_name, original_name).lstrip("^")
                return name

            state_dict = {
                revert_hf_rename(name): tensor for name, tensor in state_dict.items()
            }

        if model_type == "minimax-m2" and not hf_checkpoint_map:
            # Reverse HF convention: mlp -> block_sparse_moe
            state_dict = {
                name.replace(".mlp.", ".block_sparse_moe."): tensor
                for name, tensor in state_dict.items()
            }

        def find_hf_name_in_tensor_map(hf_name: str) -> str | None:
            """
            Map HuggingFace parameter name to GGUF tensor name.

            This function handles the mismatch between HF parameter naming
            conventions and gguf-py's expected format:
            1. Strips 'model.' prefix (common in multimodal models)
            2. Converts '_weight' suffix to '.weight' (Gemma3 compatibility)
            3. Searches vision_name_map for multimodal parameters
            4. Falls back to text_name_map for language model parameters

            Args:
                hf_name: Full HuggingFace parameter name (e.g.,
                        'model.multi_modal_projector.mm_soft_emb_norm.weight')

            Returns:
                GGUF tensor name with suffix (e.g., 'mm.soft_emb_norm.weight')
                or None if no mapping found
            """
            # In transformers v5, multimodal models (e.g. Gemma3) wrap
            # all sub-models under an outer 'model.' attribute, producing
            # state_dict keys like 'model.language_model.layers.0...' and
            # 'model.vision_tower.vision_model...'.  Strip this outer
            # prefix so the keys match what gguf-py expects.
            if is_multimodal and hf_name.startswith("model."):
                hf_name = hf_name[6:]  # Remove outer 'model.'

            # Strip 'language_model.' prefix for multimodal models - gguf-py
            # tensor mappings expect parameter names without this prefix.
            # Note: 'model.' prefix should be KEPT for text-only models as
            # gguf-py expects it.
            if hf_name.startswith("language_model."):
                hf_name = hf_name[15:]  # Remove 'language_model.'
                # Re-add 'model.' prefix because gguf-py text tensor maps
                # expect 'model.layers...' format.
                if is_multimodal:
                    hf_name = "model." + hf_name

            # Parse parameter name and suffix
            if hf_name.endswith((".weight", ".bias")):
                base_name, suffix = hf_name.rsplit(".", 1)
            else:
                base_name, suffix = hf_name, ""
                # Handle '_weight' suffix (Gemma3 naming: parameter ends with
                # '_weight' instead of '.weight')
                if base_name.endswith("_weight"):
                    base_name = base_name[:-7]  # Remove '_weight'
                    suffix = "weight"

            gguf_name = None
            # Priority 1: Search vision/projector parameters for multimodal models
            if vision_name_map is not None:
                gguf_name = vision_name_map.get_name(base_name)

            # Priority 2: Search text backbone parameters
            if gguf_name is None:
                gguf_name = text_name_map.get_name(base_name)

            if gguf_name is None:
                return None

            return gguf_name + "." + suffix

        # Build mapping and track unmapped parameters
        unmapped_params = []
        for hf_name in state_dict:
            gguf_name_with_suffix = find_hf_name_in_tensor_map(hf_name)

            # Track mapping success
            if gguf_name_with_suffix is not None:
                gguf_to_hf_name_map[gguf_name_with_suffix] = hf_name
                logger.debug("Mapped GGUF %s → HF %s", gguf_name_with_suffix, hf_name)
            elif hf_name not in gguf_to_hf_name_map.values():
                # Parameter not in manual overrides either
                unmapped_params.append(hf_name)

        # All parameters (except those initialized by other means) must be mapped:
        # both vision/projector and backbone
        if unmapped_params:
            unmapped_params = list(
                filter(
                    lambda x: not any(re.fullmatch(p, x) for p in sideload_params),
                    unmapped_params,
                )
            )
        if unmapped_params:
            raise RuntimeError(
                f"Failed to map GGUF parameters "
                f"({len(unmapped_params)}): "
                f"{unmapped_params}"
            )
        return gguf_to_hf_name_map

    def _get_gguf_weight_type(
        self,
        model_config: ModelConfig,
        model_name_or_path: str,
        gguf_to_hf_name_map: dict[str, str],
    ) -> dict[str, str]:
        gguf_files = self._get_all_gguf_files(model_name_or_path)
        weight_type_map = {}
        for f in gguf_files:
            weight_type_map.update(get_gguf_weight_type_map(f, gguf_to_hf_name_map))
        is_multimodal = hasattr(model_config.hf_config, "vision_config")
        if is_multimodal:
            mmproj_file = detect_gguf_multimodal(model_name_or_path)
            assert mmproj_file is not None, (
                "Could not find mm_proj file for multimodal GGUF model"
            )
            logger.info("Loading extra mm_proj weights from %s...", mmproj_file)
            mm_proj_weight_type_map = get_gguf_weight_type_map(
                mmproj_file, gguf_to_hf_name_map
            )
            weight_type_map.update(mm_proj_weight_type_map)
        return weight_type_map

    def _get_weights_iterator(
        self,
        model_config: ModelConfig,
        model_name_or_path: str,
        gguf_to_hf_name_map: dict[str, str],
    ) -> Generator[tuple[str, torch.Tensor], None, None]:
        """
        Iterate over GGUF model weights, loading from both main model file and
        mmproj.gguf for multimodal Gemma3 models.

        For Gemma3 multimodal GGUF models:
        - Main file (gemma-3-*.gguf): Language model weights (model.*)
        - mmproj file (mmproj*.gguf): Vision tower + projector weights (v.*, mm.*)

        Yields:
            Tuples of (parameter_name, tensor) for all model weights
        """
        hf_config = model_config.hf_config
        is_multimodal = hasattr(hf_config, "vision_config")

        if is_multimodal:
            # Load mm_proj (mm_encoder + projector) for multimodal weights
            mmproj_file = detect_gguf_multimodal(model_name_or_path)
            assert mmproj_file is not None, (
                "Could not find mm_proj file for multimodal GGUF model"
            )
            yield from gguf_quant_weights_iterator(mmproj_file, gguf_to_hf_name_map)

        gguf_files = self._get_all_gguf_files(model_name_or_path)
        if len(gguf_files) > 1:
            yield from gguf_quant_weights_iterator_multi(
                gguf_files, gguf_to_hf_name_map
            )
        else:
            yield from gguf_quant_weights_iterator(
                model_name_or_path, gguf_to_hf_name_map
            )

    def download_model(self, model_config: ModelConfig) -> None:
        self._prepare_weights(model_config)

    def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
        local_model_path = self._prepare_weights(model_config)
        gguf_weights_map = self._get_gguf_weights_map(model_config)
        model.load_weights(
            self._get_weights_iterator(model_config, local_model_path, gguf_weights_map)
        )

    def load_model(
        self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
    ) -> nn.Module:
        device_config = vllm_config.device_config
        local_model_path = self._prepare_weights(model_config)
        gguf_weights_map = self._get_gguf_weights_map(model_config)
        # we can only know if tie word embeddings after mapping weights
        gguf_files = self._get_all_gguf_files(local_model_path)
        all_extra_names = []
        for f in gguf_files:
            all_extra_names.extend(get_gguf_extra_tensor_names(f, gguf_weights_map))
        if "lm_head.weight" in all_extra_names:
            model_config.hf_config.update({"tie_word_embeddings": True})

        weight_type_map = self._get_gguf_weight_type(
            model_config, local_model_path, gguf_weights_map
        )
        # filter out unquantized modules to skip
        unquant_names = [
            name.removesuffix(".weight")
            for name, weight_type in weight_type_map.items()
            if weight_type in ("F32", "F16", "BF16") and name.endswith(".weight")
        ]
        logger.debug("GGUF unquantized modules: %s", unquant_names)
        if TYPE_CHECKING:
            vllm_config.quant_config = cast(GGUFConfig, vllm_config.quant_config)
        vllm_config.quant_config.unquantized_modules.extend(unquant_names)

        target_device = torch.device(device_config.device)
        with set_default_torch_dtype(model_config.dtype):
            with target_device:
                model = initialize_model(vllm_config=vllm_config, prefix=prefix)
            self.load_weights(model, model_config)

            process_weights_after_loading(model, model_config, target_device)
        return model

_get_all_gguf_files staticmethod

_get_all_gguf_files(model_path: str) -> list[str]

Discover all GGUF shard files from a single shard path.

Supports variable-width shard indices by dynamically detecting the padding from the original filename. E.g. *-00001-of-00005.gguf → all 5 shards, *-01-of-15.gguf → all 15 shards.

Source code in vllm/model_executor/model_loader/gguf_loader.py
@staticmethod
def _get_all_gguf_files(model_path: str) -> list[str]:
    """Discover all GGUF shard files from a single shard path.

    Supports variable-width shard indices by dynamically detecting
    the padding from the original filename.
    E.g. ``*-00001-of-00005.gguf`` → all 5 shards,
         ``*-01-of-15.gguf`` → all 15 shards.
    """
    match = re.search(r"-(\d+)-of-(\d+)\.gguf$", model_path)
    if not match:
        return [model_path]
    total = int(match.group(2))
    num_digits = len(match.group(1))
    prefix = model_path[: match.start(1)]
    suffix = model_path[match.end(2) :]
    files = []
    for i in range(1, total + 1):
        shard_path = f"{prefix}{i:0{num_digits}d}-of-{total:0{num_digits}d}{suffix}"
        if os.path.isfile(shard_path):
            files.append(shard_path)
    if files:
        logger.info("Discovered %d GGUF shard files", len(files))
    return files if files else [model_path]

_get_gguf_weights_map

_get_gguf_weights_map(model_config: ModelConfig)

GGUF uses this naming convention for their tensors from HF checkpoint: blk.N.BB.weight and blk.N.BB.bias where N signifies the block number of a layer, and BB signifies the attention/mlp layer components. See "Standardized tensor names" in https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.

Source code in vllm/model_executor/model_loader/gguf_loader.py
def _get_gguf_weights_map(self, model_config: ModelConfig):
    """
    GGUF uses this naming convention for their tensors from HF checkpoint:
    `blk.N.BB.weight` and `blk.N.BB.bias`
    where N signifies the block number of a layer, and BB signifies the
    attention/mlp layer components.
    See "Standardized tensor names" in
    https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
    """
    config = model_config.hf_config
    # Get text config to handle both nested (multimodal) and flat
    # (text-only) config structures. For multimodal models like
    # Gemma3Config, this returns config.text_config. For text-only
    # models, this returns config itself.
    text_config = config.get_text_config()
    model_type = config.model_type
    is_multimodal = (
        hasattr(config, "vision_config") and config.vision_config is not None
    )
    gguf_to_hf_name_map = {}
    sideload_params: list[re.Pattern] = []
    # hack: ggufs have a different name than transformers
    if model_type == "cohere":
        model_type = "command-r"
    if model_type == "gemma3_text":
        # Gemma3 models use "gemma3_text" in HuggingFace but
        # "gemma3" in GGUF architecture naming
        model_type = "gemma3"
    if model_type in ("deepseek_v3", "deepseek_v2"):
        model_type = "deepseek2"
        # GGUF layer map assumes that we will have a merged expert weights
        # so we need to map them manually
        for idx in range(config.num_hidden_layers):
            gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = (
                f"model.layers.{idx}.mlp.gate.e_score_correction_bias"
            )
            gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = (
                f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
            )
            gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = (
                f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
            )
            gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = (
                f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
            )
            sideload_params.append(
                re.compile(
                    f"model\\.layers\\.{idx}"
                    r"\.mlp\.experts\.[0-9]+\.(gate|up|down)_proj\.weight"
                )
            )
    if model_type in ("qwen2_moe", "qwen3_moe"):
        model_type = model_type.replace("_", "")
        # GGUF layer map assumes that we will have a merged expert weights
        # so we need to map them manually
        for idx in range(config.num_hidden_layers):
            gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = (
                f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
            )
            gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = (
                f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
            )
            gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = (
                f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
            )
            sideload_params.append(
                re.compile(
                    f"model\\.layers\\.{idx}"
                    r"\.mlp\.experts\.[0-9]+\.(gate|up|down)_proj\.weight"
                )
            )
    if model_type == "minimax_m2":
        model_type = "minimax-m2"
        # GGUF layer map assumes merged expert weights
        # map them manually like deepseek2
        for idx in range(config.num_hidden_layers):
            gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = (
                f"model.layers.{idx}.block_sparse_moe.e_score_correction_bias"
            )
            gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = (
                f"model.layers.{idx}.block_sparse_moe.experts.0.w2.weight"
            )
            gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = (
                f"model.layers.{idx}.block_sparse_moe.experts.0.w1.weight"
            )
            gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = (
                f"model.layers.{idx}.block_sparse_moe.experts.0.w3.weight"
            )
            sideload_params.append(
                re.compile(
                    f"model\\.layers\\.{idx}"
                    r"\.block_sparse_moe\.experts\.(gate_up_proj|down_proj)"
                )
            )

    arch = None
    for key, value in gguf.MODEL_ARCH_NAMES.items():
        if value == model_type:
            arch = key
            break
    if arch is None:
        raise RuntimeError(f"Unknown gguf model_type: {model_type}")
    text_num_layers = text_config.num_hidden_layers
    text_name_map = gguf.get_tensor_name_map(arch, text_num_layers)

    if is_multimodal:
        mm_proj_arch = gguf.MODEL_ARCH.MMPROJ
        vision_num_layers = config.vision_config.num_hidden_layers
        vision_name_map = gguf.get_tensor_name_map(mm_proj_arch, vision_num_layers)
    else:
        vision_name_map = None

    # Create dummy model to extract parameter names
    # For multimodal: use AutoModelForImageTextToText to get
    # language + vision + projector params
    # For text-only: use AutoModelForCausalLM to get language model params
    auto_cls = (
        AutoModelForImageTextToText if is_multimodal else AutoModelForCausalLM
    )
    with torch.device("meta"):
        dummy_model = auto_cls.from_config(
            config, trust_remote_code=model_config.trust_remote_code
        )

    state_dict = dummy_model.state_dict()
    if hf_checkpoint_map := getattr(
        dummy_model, "_checkpoint_conversion_mapping", None
    ):

        def revert_hf_rename(name: str) -> str:
            for original_name, hf_name in hf_checkpoint_map.items():
                if hf_name in name:
                    name = name.replace(hf_name, original_name).lstrip("^")
            return name

        state_dict = {
            revert_hf_rename(name): tensor for name, tensor in state_dict.items()
        }

    if model_type == "minimax-m2" and not hf_checkpoint_map:
        # Reverse HF convention: mlp -> block_sparse_moe
        state_dict = {
            name.replace(".mlp.", ".block_sparse_moe."): tensor
            for name, tensor in state_dict.items()
        }

    def find_hf_name_in_tensor_map(hf_name: str) -> str | None:
        """
        Map HuggingFace parameter name to GGUF tensor name.

        This function handles the mismatch between HF parameter naming
        conventions and gguf-py's expected format:
        1. Strips 'model.' prefix (common in multimodal models)
        2. Converts '_weight' suffix to '.weight' (Gemma3 compatibility)
        3. Searches vision_name_map for multimodal parameters
        4. Falls back to text_name_map for language model parameters

        Args:
            hf_name: Full HuggingFace parameter name (e.g.,
                    'model.multi_modal_projector.mm_soft_emb_norm.weight')

        Returns:
            GGUF tensor name with suffix (e.g., 'mm.soft_emb_norm.weight')
            or None if no mapping found
        """
        # In transformers v5, multimodal models (e.g. Gemma3) wrap
        # all sub-models under an outer 'model.' attribute, producing
        # state_dict keys like 'model.language_model.layers.0...' and
        # 'model.vision_tower.vision_model...'.  Strip this outer
        # prefix so the keys match what gguf-py expects.
        if is_multimodal and hf_name.startswith("model."):
            hf_name = hf_name[6:]  # Remove outer 'model.'

        # Strip 'language_model.' prefix for multimodal models - gguf-py
        # tensor mappings expect parameter names without this prefix.
        # Note: 'model.' prefix should be KEPT for text-only models as
        # gguf-py expects it.
        if hf_name.startswith("language_model."):
            hf_name = hf_name[15:]  # Remove 'language_model.'
            # Re-add 'model.' prefix because gguf-py text tensor maps
            # expect 'model.layers...' format.
            if is_multimodal:
                hf_name = "model." + hf_name

        # Parse parameter name and suffix
        if hf_name.endswith((".weight", ".bias")):
            base_name, suffix = hf_name.rsplit(".", 1)
        else:
            base_name, suffix = hf_name, ""
            # Handle '_weight' suffix (Gemma3 naming: parameter ends with
            # '_weight' instead of '.weight')
            if base_name.endswith("_weight"):
                base_name = base_name[:-7]  # Remove '_weight'
                suffix = "weight"

        gguf_name = None
        # Priority 1: Search vision/projector parameters for multimodal models
        if vision_name_map is not None:
            gguf_name = vision_name_map.get_name(base_name)

        # Priority 2: Search text backbone parameters
        if gguf_name is None:
            gguf_name = text_name_map.get_name(base_name)

        if gguf_name is None:
            return None

        return gguf_name + "." + suffix

    # Build mapping and track unmapped parameters
    unmapped_params = []
    for hf_name in state_dict:
        gguf_name_with_suffix = find_hf_name_in_tensor_map(hf_name)

        # Track mapping success
        if gguf_name_with_suffix is not None:
            gguf_to_hf_name_map[gguf_name_with_suffix] = hf_name
            logger.debug("Mapped GGUF %s → HF %s", gguf_name_with_suffix, hf_name)
        elif hf_name not in gguf_to_hf_name_map.values():
            # Parameter not in manual overrides either
            unmapped_params.append(hf_name)

    # All parameters (except those initialized by other means) must be mapped:
    # both vision/projector and backbone
    if unmapped_params:
        unmapped_params = list(
            filter(
                lambda x: not any(re.fullmatch(p, x) for p in sideload_params),
                unmapped_params,
            )
        )
    if unmapped_params:
        raise RuntimeError(
            f"Failed to map GGUF parameters "
            f"({len(unmapped_params)}): "
            f"{unmapped_params}"
        )
    return gguf_to_hf_name_map

_get_weights_iterator

_get_weights_iterator(
    model_config: ModelConfig,
    model_name_or_path: str,
    gguf_to_hf_name_map: dict[str, str],
) -> Generator[tuple[str, Tensor], None, None]

Iterate over GGUF model weights, loading from both main model file and mmproj.gguf for multimodal Gemma3 models.

For Gemma3 multimodal GGUF models: - Main file (gemma-3-.gguf): Language model weights (model.) - mmproj file (mmproj.gguf): Vision tower + projector weights (v., mm.*)

Yields:

Type Description
tuple[str, Tensor]

Tuples of (parameter_name, tensor) for all model weights

Source code in vllm/model_executor/model_loader/gguf_loader.py
def _get_weights_iterator(
    self,
    model_config: ModelConfig,
    model_name_or_path: str,
    gguf_to_hf_name_map: dict[str, str],
) -> Generator[tuple[str, torch.Tensor], None, None]:
    """
    Iterate over GGUF model weights, loading from both main model file and
    mmproj.gguf for multimodal Gemma3 models.

    For Gemma3 multimodal GGUF models:
    - Main file (gemma-3-*.gguf): Language model weights (model.*)
    - mmproj file (mmproj*.gguf): Vision tower + projector weights (v.*, mm.*)

    Yields:
        Tuples of (parameter_name, tensor) for all model weights
    """
    hf_config = model_config.hf_config
    is_multimodal = hasattr(hf_config, "vision_config")

    if is_multimodal:
        # Load mm_proj (mm_encoder + projector) for multimodal weights
        mmproj_file = detect_gguf_multimodal(model_name_or_path)
        assert mmproj_file is not None, (
            "Could not find mm_proj file for multimodal GGUF model"
        )
        yield from gguf_quant_weights_iterator(mmproj_file, gguf_to_hf_name_map)

    gguf_files = self._get_all_gguf_files(model_name_or_path)
    if len(gguf_files) > 1:
        yield from gguf_quant_weights_iterator_multi(
            gguf_files, gguf_to_hf_name_map
        )
    else:
        yield from gguf_quant_weights_iterator(
            model_name_or_path, gguf_to_hf_name_map
        )

ModelExpressModelLoader

Bases: BaseModelLoader

Thin vLLM loader wrapper for ModelExpress.

Source code in vllm/model_executor/model_loader/modelexpress_loader.py
class ModelExpressModelLoader(BaseModelLoader):
    """Thin vLLM loader wrapper for ModelExpress."""

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        self._loader = self._load_modelexpress_loader(load_config)

    @staticmethod
    def _load_modelexpress_loader(load_config: LoadConfig) -> BaseModelLoader:
        try:
            module = importlib.import_module(_MODELEXPRESS_LOADER_MODULE)
        except ModuleNotFoundError as exc:
            if exc.name not in _MISSING_MODELEXPRESS_MODULES:
                raise
            raise _missing_modelexpress_error() from exc

        ModelExpressVllmLoader = module.MxModelLoader
        return ModelExpressVllmLoader(load_config)

    def download_model(self, model_config: ModelConfig) -> None:
        self._loader.download_model(model_config)

    def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
        self._loader.load_weights(model, model_config)

    @instrument(span_name="Load model")
    def load_model(
        self,
        vllm_config: VllmConfig,
        model_config: ModelConfig,
        prefix: str = "",
    ) -> nn.Module:
        model = self._loader.load_model(
            vllm_config=vllm_config,
            model_config=model_config,
            prefix=prefix,
        )
        return model.eval()

RunaiModelStreamerLoader

Bases: BaseModelLoader

Model loader that can load safetensors files from local FS, S3, GCS, or Azure Blob Storage.

Source code in vllm/model_executor/model_loader/runai_streamer_loader.py
class RunaiModelStreamerLoader(BaseModelLoader):
    """
    Model loader that can load safetensors
    files from local FS, S3, GCS, or Azure Blob Storage.
    """

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)

        self._is_distributed: bool = False
        if load_config.model_loader_extra_config:
            extra_config = load_config.model_loader_extra_config

            if isinstance(distributed := extra_config.get("distributed"), bool):
                self._is_distributed = distributed
            if isinstance(concurrency := extra_config.get("concurrency"), int):
                os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(concurrency)
            if isinstance(memory_limit := extra_config.get("memory_limit"), int):
                os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(memory_limit)

            runai_streamer_s3_endpoint = os.getenv("RUNAI_STREAMER_S3_ENDPOINT")
            aws_endpoint_url = os.getenv("AWS_ENDPOINT_URL")
            if runai_streamer_s3_endpoint is None and aws_endpoint_url is not None:
                os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url

    def _prepare_weights(
        self, model_name_or_path: str, revision: str | None
    ) -> list[str]:
        """Prepare weights for the model.

        If the model is not local, it will be downloaded."""

        is_object_storage_path = is_runai_obj_uri(model_name_or_path)
        is_local = os.path.isdir(model_name_or_path)
        safetensors_pattern = "*.safetensors"
        index_file = SAFE_WEIGHTS_INDEX_NAME

        hf_folder = (
            model_name_or_path
            if (is_local or is_object_storage_path)
            else download_weights_from_hf(
                model_name_or_path,
                self.load_config.download_dir,
                [safetensors_pattern],
                revision,
                ignore_patterns=self.load_config.ignore_patterns,
            )
        )
        hf_weights_files = list_safetensors(path=hf_folder)

        if not is_local and not is_object_storage_path:
            download_safetensors_index_file_from_hf(
                model_name_or_path, index_file, self.load_config.download_dir, revision
            )

        if not hf_weights_files:
            raise RuntimeError(
                f"Cannot find any safetensors model weights with `{model_name_or_path}`"
            )

        return hf_weights_files

    def _get_weights_iterator(
        self, model_or_path: str, revision: str | None
    ) -> Generator[tuple[str, torch.Tensor], None, None]:
        """Get an iterator for the model weights based on the load format."""
        hf_weights_files = self._prepare_weights(model_or_path, revision)
        return runai_safetensors_weights_iterator(
            hf_weights_files, self.load_config.use_tqdm_on_load, self._is_distributed
        )

    def download_model(self, model_config: ModelConfig) -> None:
        """Download model if necessary"""
        self._prepare_weights(model_config.model, model_config.revision)

    def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
        """Load weights into a model."""
        model_weights = model_config.model
        if model_weights_override := model_config.model_weights:
            model_weights = model_weights_override
        model.load_weights(
            self._get_weights_iterator(model_weights, model_config.revision)
        )

_get_weights_iterator

_get_weights_iterator(
    model_or_path: str, revision: str | None
) -> Generator[tuple[str, Tensor], None, None]

Get an iterator for the model weights based on the load format.

Source code in vllm/model_executor/model_loader/runai_streamer_loader.py
def _get_weights_iterator(
    self, model_or_path: str, revision: str | None
) -> Generator[tuple[str, torch.Tensor], None, None]:
    """Get an iterator for the model weights based on the load format."""
    hf_weights_files = self._prepare_weights(model_or_path, revision)
    return runai_safetensors_weights_iterator(
        hf_weights_files, self.load_config.use_tqdm_on_load, self._is_distributed
    )

_prepare_weights

_prepare_weights(
    model_name_or_path: str, revision: str | None
) -> list[str]

Prepare weights for the model.

If the model is not local, it will be downloaded.

Source code in vllm/model_executor/model_loader/runai_streamer_loader.py
def _prepare_weights(
    self, model_name_or_path: str, revision: str | None
) -> list[str]:
    """Prepare weights for the model.

    If the model is not local, it will be downloaded."""

    is_object_storage_path = is_runai_obj_uri(model_name_or_path)
    is_local = os.path.isdir(model_name_or_path)
    safetensors_pattern = "*.safetensors"
    index_file = SAFE_WEIGHTS_INDEX_NAME

    hf_folder = (
        model_name_or_path
        if (is_local or is_object_storage_path)
        else download_weights_from_hf(
            model_name_or_path,
            self.load_config.download_dir,
            [safetensors_pattern],
            revision,
            ignore_patterns=self.load_config.ignore_patterns,
        )
    )
    hf_weights_files = list_safetensors(path=hf_folder)

    if not is_local and not is_object_storage_path:
        download_safetensors_index_file_from_hf(
            model_name_or_path, index_file, self.load_config.download_dir, revision
        )

    if not hf_weights_files:
        raise RuntimeError(
            f"Cannot find any safetensors model weights with `{model_name_or_path}`"
        )

    return hf_weights_files

download_model

download_model(model_config: ModelConfig) -> None

Download model if necessary

Source code in vllm/model_executor/model_loader/runai_streamer_loader.py
def download_model(self, model_config: ModelConfig) -> None:
    """Download model if necessary"""
    self._prepare_weights(model_config.model, model_config.revision)

load_weights

load_weights(
    model: Module, model_config: ModelConfig
) -> None

Load weights into a model.

Source code in vllm/model_executor/model_loader/runai_streamer_loader.py
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
    """Load weights into a model."""
    model_weights = model_config.model
    if model_weights_override := model_config.model_weights:
        model_weights = model_weights_override
    model.load_weights(
        self._get_weights_iterator(model_weights, model_config.revision)
    )

ShardedStateLoader

Bases: BaseModelLoader

Model loader that directly loads each worker's model state dict, which enables a fast load path for large tensor-parallel models where each worker only needs to read its own shard rather than the entire checkpoint. See examples/features/sharded_state/save_sharded_state_offline.py for creating a sharded checkpoint.

Source code in vllm/model_executor/model_loader/sharded_state_loader.py
class ShardedStateLoader(BaseModelLoader):
    """
    Model loader that directly loads each worker's model state dict, which
    enables a fast load path for large tensor-parallel models where each worker
    only needs to read its own shard rather than the entire checkpoint. See
    `examples/features/sharded_state/save_sharded_state_offline.py` for creating
    a sharded checkpoint.
    """

    DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)

        extra_config = (
            {}
            if load_config.model_loader_extra_config is None
            else copy(load_config.model_loader_extra_config)
        )
        self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
        if extra_config:
            raise ValueError(
                f"Unexpected extra config keys for load format "
                f"{load_config.load_format}: "
                f"{load_config.model_loader_extra_config.keys()}"
            )

    @staticmethod
    def _filter_subtensors(
        tensors: dict[str, torch.Tensor],
    ) -> dict[str, torch.Tensor]:
        """
        Filter out all tensors that share the same memory or a subset of the
        memory of another tensor.
        """
        same_storage_groups: dict[Any, list[tuple[str, torch.Tensor]]] = (
            collections.defaultdict(list)
        )
        for key, tensor in tensors.items():
            if tensor.numel():
                ptr = tensor.untyped_storage().data_ptr()
                same_storage_groups[tensor.device, ptr].append((key, tensor))

        def get_end_ptr(tensor: torch.Tensor) -> int:
            return tensor.view(-1)[-1].data_ptr() + tensor.element_size()

        result: dict[str, torch.Tensor] = {}
        for group in same_storage_groups.values():
            for k, t in group:
                a, b = t.data_ptr(), get_end_ptr(t)
                for k2, t2 in group:
                    if not t2.is_contiguous():
                        continue
                    a2, b2 = t2.data_ptr(), get_end_ptr(t2)
                    if a < a2 or b2 < b:
                        continue
                    if a2 < a or b < b2 or not t.is_contiguous():
                        break  # t2 covers strictly more memory than t.
                    if k2 < k:
                        # Same tensors, keep the one with the smaller key.
                        break
                else:
                    result[k] = t
        return result

    def _prepare_weights(self, model_name_or_path: str, revision: str | None):
        if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path):
            return model_name_or_path
        else:
            allow_patterns = ["*.safetensors"]
            return download_weights_from_hf(
                model_name_or_path,
                self.load_config.download_dir,
                allow_patterns,
                revision,
                ignore_patterns=self.load_config.ignore_patterns,
            )

    def download_model(self, model_config: ModelConfig) -> None:
        self._prepare_weights(model_config.model, model_config.revision)

    def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
        from vllm.distributed import get_tensor_model_parallel_rank

        model_weights = model_config.model
        if model_weights_override := model_config.model_weights:
            model_weights = model_weights_override
        local_model_path = model_weights

        rank = get_tensor_model_parallel_rank()
        pattern = os.path.join(
            local_model_path,
            self.pattern.format(rank=rank, part="*"),
        )

        filepaths = []
        if is_s3(local_model_path):
            file_pattern = f"*{self.pattern.format(rank=rank, part='*')}"
            filepaths = s3_glob(path=local_model_path, allow_pattern=[file_pattern])
        else:
            filepaths = glob.glob(pattern)
        if not filepaths:
            # TODO: support un-sharded checkpoints too
            raise ValueError(
                f"Could not find checkpoint files '{pattern}', only "
                f"pre-sharded checkpoints are currently supported!"
            )
        state_dict = self._filter_subtensors(model.state_dict())
        counter_before_loading_weights = time.perf_counter()
        for key, tensor in self.iterate_over_files(filepaths):
            # If loading with LoRA enabled, additional padding may
            # be added to certain parameters. We only load into a
            # narrowed view of the parameter data.
            param_data = state_dict[key].data
            param_shape = state_dict[key].shape
            for dim, size in enumerate(tensor.shape):
                if size < param_shape[dim]:
                    param_data = param_data.narrow(dim, 0, size)
            if tensor.shape != param_shape:
                logger.warning(
                    "loading tensor of shape %s into parameter '%s' of shape %s",
                    tensor.shape,
                    key,
                    param_shape,
                )
            param_data.copy_(tensor)
            state_dict.pop(key)
        counter_after_loading_weights = time.perf_counter()
        logger.info_once(
            "Loading weights took %.2f seconds",
            counter_after_loading_weights - counter_before_loading_weights,
        )
        if state_dict:
            raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")

    def iterate_over_files(
        self, paths
    ) -> Generator[tuple[str, torch.Tensor], None, None]:
        if self.load_config.load_format == "runai_streamer_sharded":
            yield from runai_safetensors_weights_iterator(paths, True)
        else:
            from safetensors.torch import safe_open

            for path in paths:
                with safe_open(path, framework="pt") as f:
                    for key in f.keys():  # noqa: SIM118
                        tensor = f.get_tensor(key)
                        yield key, tensor

    @staticmethod
    def save_model(
        model: torch.nn.Module,
        path: str,
        pattern: str | None = None,
        max_size: int | None = None,
    ) -> None:
        from safetensors.torch import save_file

        from vllm.distributed import get_tensor_model_parallel_rank

        if pattern is None:
            pattern = ShardedStateLoader.DEFAULT_PATTERN
        rank = get_tensor_model_parallel_rank()
        part_idx = 0
        total_size = 0
        state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
        state_dict_part: dict[str, torch.Tensor] = {}
        for key, tensor in state_dict.items():
            param_size = tensor.nelement() * tensor.element_size()
            if max_size is not None and total_size + param_size > max_size:
                filename = pattern.format(rank=rank, part=part_idx)
                save_file(
                    state_dict_part,
                    os.path.join(path, filename),
                )
                part_idx += 1
                total_size = 0
                state_dict_part = {}
            state_dict_part[key] = tensor
            total_size += param_size
        if len(state_dict_part) > 0:
            filename = pattern.format(rank=rank, part=part_idx)
            save_file(
                state_dict_part,
                os.path.join(path, filename),
            )

_filter_subtensors staticmethod

_filter_subtensors(
    tensors: dict[str, Tensor],
) -> dict[str, Tensor]

Filter out all tensors that share the same memory or a subset of the memory of another tensor.

Source code in vllm/model_executor/model_loader/sharded_state_loader.py
@staticmethod
def _filter_subtensors(
    tensors: dict[str, torch.Tensor],
) -> dict[str, torch.Tensor]:
    """
    Filter out all tensors that share the same memory or a subset of the
    memory of another tensor.
    """
    same_storage_groups: dict[Any, list[tuple[str, torch.Tensor]]] = (
        collections.defaultdict(list)
    )
    for key, tensor in tensors.items():
        if tensor.numel():
            ptr = tensor.untyped_storage().data_ptr()
            same_storage_groups[tensor.device, ptr].append((key, tensor))

    def get_end_ptr(tensor: torch.Tensor) -> int:
        return tensor.view(-1)[-1].data_ptr() + tensor.element_size()

    result: dict[str, torch.Tensor] = {}
    for group in same_storage_groups.values():
        for k, t in group:
            a, b = t.data_ptr(), get_end_ptr(t)
            for k2, t2 in group:
                if not t2.is_contiguous():
                    continue
                a2, b2 = t2.data_ptr(), get_end_ptr(t2)
                if a < a2 or b2 < b:
                    continue
                if a2 < a or b < b2 or not t.is_contiguous():
                    break  # t2 covers strictly more memory than t.
                if k2 < k:
                    # Same tensors, keep the one with the smaller key.
                    break
            else:
                result[k] = t
    return result

TensorizerLoader

Bases: BaseModelLoader

Model loader using CoreWeave's tensorizer library.

Source code in vllm/model_executor/model_loader/tensorizer_loader.py
class TensorizerLoader(BaseModelLoader):
    """Model loader using CoreWeave's tensorizer library."""

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
            self.tensorizer_config = load_config.model_loader_extra_config
        else:
            validate_config(load_config.model_loader_extra_config)
            self.tensorizer_config = TensorizerConfig(
                **load_config.model_loader_extra_config["tensorizer_config"]
            )

    def _verify_config(
        self, model_config: ModelConfig, parallel_config: ParallelConfig
    ):
        self.tensorizer_config.verify_with_model_config(model_config)
        self.tensorizer_config.verify_with_parallel_config(parallel_config)

    def _get_weights_iterator(
        self,
    ) -> Generator[tuple[str, torch.Tensor], None, None]:
        tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
        return tensorizer_weights_iterator(tensorizer_args)

    def _load_model_serialized_cpu(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ) -> nn.Module:
        """Load a serialized model with tensorizer to the CPU.

        This is only necessary when the model isn't vLLM-tensorized (see
        examples/features/tensorize_vllm_model.py) This should still
        be faster than default HuggingFace loading, but will be slower than
        loading a vLLM-tensorized model.
        """
        device_config = vllm_config.device_config
        model_config = vllm_config.model_config
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
                model = initialize_model(vllm_config=vllm_config, prefix=prefix)

            model.load_weights(self._get_weights_iterator())
        return model.eval()

    def download_model(self, model_config: ModelConfig) -> None:
        self.tensorizer_config.verify_with_model_config(model_config)

        with self.tensorizer_config.open_stream():
            pass

    def _patch_tensorizer_config(self, model_config: ModelConfig) -> TensorizerConfig:
        model_class = get_model_architecture(model_config)[0]
        tensorizer_config = copy.copy(self.tensorizer_config)
        tensorizer_config.model_class = model_class
        tensorizer_config.hf_config = model_config.hf_config
        tensorizer_config.dtype = model_config.dtype
        return tensorizer_config

    def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
        """Load serialized model weights with tensorizer.

        Expects a vLLM-tensorized model. See the
        examples/features/tensorize_vllm_model.py example script
        for serializing vLLM models."""
        if is_vllm_tensorized(self.tensorizer_config):
            tensorizer_config = self._patch_tensorizer_config(model_config)
            deserialize_tensorizer_model(model, tensorizer_config)
        else:
            model.load_weights(self._get_weights_iterator())

    def load_model(
        self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
    ) -> nn.Module:
        parallel_config = vllm_config.parallel_config
        self._verify_config(model_config, parallel_config)

        if parallel_config.tensor_parallel_size > 1:
            from vllm.distributed import get_tensor_model_parallel_rank

            assert self.tensorizer_config.tensorizer_uri is not None
            self.tensorizer_config.tensorizer_uri = (
                self.tensorizer_config.tensorizer_uri % get_tensor_model_parallel_rank()
            )

        if is_vllm_tensorized(self.tensorizer_config):
            tensorizer_config = self._patch_tensorizer_config(model_config)
            device_config = vllm_config.device_config
            with set_default_torch_dtype(model_config.dtype):
                with torch.device(device_config.device):
                    model = init_tensorizer_model(
                        tensorizer_config=tensorizer_config, vllm_config=vllm_config
                    )
            self.load_weights(model, model_config)
            return model
        return self._load_model_serialized_cpu(vllm_config=vllm_config, prefix=prefix)

    @staticmethod
    def save_model(
        model: torch.nn.Module,
        tensorizer_config: TensorizerConfig | dict,
        model_config: ModelConfig,
    ) -> None:
        if isinstance(tensorizer_config, dict):
            tensorizer_config = TensorizerConfig(**tensorizer_config)
        serialize_vllm_model(
            model=model,
            tensorizer_config=tensorizer_config,
            model_config=model_config,
        )

_load_model_serialized_cpu

_load_model_serialized_cpu(
    vllm_config: VllmConfig, prefix: str = ""
) -> Module

Load a serialized model with tensorizer to the CPU.

This is only necessary when the model isn't vLLM-tensorized (see examples/features/tensorize_vllm_model.py) This should still be faster than default HuggingFace loading, but will be slower than loading a vLLM-tensorized model.

Source code in vllm/model_executor/model_loader/tensorizer_loader.py
def _load_model_serialized_cpu(
    self,
    vllm_config: VllmConfig,
    prefix: str = "",
) -> nn.Module:
    """Load a serialized model with tensorizer to the CPU.

    This is only necessary when the model isn't vLLM-tensorized (see
    examples/features/tensorize_vllm_model.py) This should still
    be faster than default HuggingFace loading, but will be slower than
    loading a vLLM-tensorized model.
    """
    device_config = vllm_config.device_config
    model_config = vllm_config.model_config
    with set_default_torch_dtype(model_config.dtype):
        with torch.device(device_config.device):
            model = initialize_model(vllm_config=vllm_config, prefix=prefix)

        model.load_weights(self._get_weights_iterator())
    return model.eval()

load_weights

load_weights(
    model: Module, model_config: ModelConfig
) -> None

Load serialized model weights with tensorizer.

Expects a vLLM-tensorized model. See the examples/features/tensorize_vllm_model.py example script for serializing vLLM models.

Source code in vllm/model_executor/model_loader/tensorizer_loader.py
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
    """Load serialized model weights with tensorizer.

    Expects a vLLM-tensorized model. See the
    examples/features/tensorize_vllm_model.py example script
    for serializing vLLM models."""
    if is_vllm_tensorized(self.tensorizer_config):
        tensorizer_config = self._patch_tensorizer_config(model_config)
        deserialize_tensorizer_model(model, tensorizer_config)
    else:
        model.load_weights(self._get_weights_iterator())

get_model_loader

get_model_loader(
    load_config: LoadConfig,
) -> BaseModelLoader

Get a model loader based on the load format.

Source code in vllm/model_executor/model_loader/__init__.py
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
    """Get a model loader based on the load format."""
    load_format = load_config.load_format
    if load_format not in _LOAD_FORMAT_TO_MODEL_LOADER:
        raise ValueError(f"Load format `{load_format}` is not supported")
    return _LOAD_FORMAT_TO_MODEL_LOADER[load_format](load_config)

register_model_loader

register_model_loader(load_format: str)

Register a customized vllm model loader.

When a load format is not supported by vllm, you can register a customized model loader to support it.

Parameters:

Name Type Description Default
load_format str

The model loader format name.

required

Examples:

>>> from vllm.config.load import LoadConfig
>>> from vllm.model_executor.model_loader import (
...     get_model_loader,
...     register_model_loader,
... )
>>> from vllm.model_executor.model_loader.base_loader import BaseModelLoader
>>>
>>> @register_model_loader("my_loader")
... class MyModelLoader(BaseModelLoader):
...     def download_model(self):
...         pass
...
...     def load_weights(self):
...         pass
>>>
>>> load_config = LoadConfig(load_format="my_loader")
>>> type(get_model_loader(load_config))
<class 'MyModelLoader'>
Source code in vllm/model_executor/model_loader/__init__.py
def register_model_loader(load_format: str):
    """Register a customized vllm model loader.

    When a load format is not supported by vllm, you can register a customized
    model loader to support it.

    Args:
        load_format (str): The model loader format name.

    Examples:
        >>> from vllm.config.load import LoadConfig
        >>> from vllm.model_executor.model_loader import (
        ...     get_model_loader,
        ...     register_model_loader,
        ... )
        >>> from vllm.model_executor.model_loader.base_loader import BaseModelLoader
        >>>
        >>> @register_model_loader("my_loader")
        ... class MyModelLoader(BaseModelLoader):
        ...     def download_model(self):
        ...         pass
        ...
        ...     def load_weights(self):
        ...         pass
        >>>
        >>> load_config = LoadConfig(load_format="my_loader")
        >>> type(get_model_loader(load_config))
        <class 'MyModelLoader'>
    """  # noqa: E501

    def _wrapper(model_loader_cls):
        if load_format in _LOAD_FORMAT_TO_MODEL_LOADER:
            logger.warning(
                "Load format `%s` is already registered, and will be "
                "overwritten by the new loader class `%s`.",
                load_format,
                model_loader_cls,
            )
        if not issubclass(model_loader_cls, BaseModelLoader):
            raise ValueError(
                "The model loader must be a subclass of `BaseModelLoader`."
            )
        _LOAD_FORMAT_TO_MODEL_LOADER[load_format] = model_loader_cls
        logger.info(
            "Registered model loader `%s` with load format `%s`",
            model_loader_cls,
            load_format,
        )
        return model_loader_cls

    return _wrapper