Compare commits

..

2 Commits

Author SHA1 Message Date
tc-mb
2496f9c149 mtmd : support MiniCPM-V 4.6 (#22529)
* Support MiniCPM-V 4.6 in new branch

Signed-off-by: tc-mb <tianchi_cai@icloud.com>

* fix code bug

Signed-off-by: tc-mb <tianchi_cai@icloud.com>

* fix pre-commit

Signed-off-by: tc-mb <tianchi_cai@icloud.com>

* fix convert

Signed-off-by: tc-mb <tianchi_cai@icloud.com>

* rename clip_graph_minicpmv4_6

Signed-off-by: tc-mb <tianchi_cai@icloud.com>

* use new TYPE_MINICPMV4_6

Signed-off-by: tc-mb <tianchi_cai@icloud.com>

* use build_attn to allow flash attention support

Signed-off-by: tc-mb <tianchi_cai@icloud.com>

* no use legacy code, restored here.

Signed-off-by: tc-mb <tianchi_cai@icloud.com>

* use the existing tensors name

Signed-off-by: tc-mb <tianchi_cai@icloud.com>

* unused ctx->model.hparams.minicpmv_version

Signed-off-by: tc-mb <tianchi_cai@icloud.com>

* use n_merge for slice alignment

Signed-off-by: tc-mb <tianchi_cai@icloud.com>

* borrow wa_layer_indexes for vit_merger insertion point

Signed-off-by: tc-mb <tianchi_cai@icloud.com>

* fix code style

Signed-off-by: tc-mb <tianchi_cai@icloud.com>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* use filter_tensors and add model.vision_tower

Signed-off-by: tc-mb <tianchi_cai@icloud.com>

* fix chkhsh

Signed-off-by: tc-mb <tianchi_cai@icloud.com>

* fix type check

Signed-off-by: tc-mb <tianchi_cai@icloud.com>

---------

Signed-off-by: tc-mb <tianchi_cai@icloud.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-05-06 21:54:09 +02:00
Gilad S.
5207d120ea model : don't crash on unsupported architecture (#22742)
* model: don't crash on unsupported architecture

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-05-06 18:51:21 +02:00
14 changed files with 702 additions and 4 deletions

View File

@@ -1360,6 +1360,9 @@ class TextModel(ModelBase):
if chkhsh == "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c":
# ref: https://huggingface.co/Qwen/Qwen3-Embedding-0.6B
res = "qwen2"
if chkhsh == "1444df51289cfa8063b96f0e62b1125440111bc79a52003ea14b6eac7016fd5f":
# ref: https://huggingface.co/openbmb/MiniCPM-V-4_6
res = "qwen35"
if chkhsh == "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273":
# ref: https://huggingface.co/alvarobartt/grok-2-tokenizer
res = "grok-2"
@@ -5499,16 +5502,101 @@ class _LinearAttentionVReorderBase(Qwen3NextModel):
yield from super().modify_tensors(data_torch, name, bid)
class _Qwen35MRopeMixin:
# Qwen3.5 always applies interleaved MRoPE (see Qwen3_5RotaryEmbedding in transformers);
# the upstream default mrope_section is [11, 11, 10] and llama.cpp's QWEN35 / QWEN35MOE
# loaders treat qwen35.rope.dimension_sections as required, so make sure it is always
# written even when a particular checkpoint omits the field in `rope_parameters`.
_QWEN35_DEFAULT_MROPE_SECTION = [11, 11, 10, 0]
gguf_writer: gguf.GGUFWriter
rope_parameters: dict
def set_gguf_parameters(self):
super().set_gguf_parameters() # ty: ignore[unresolved-attribute]
if "mrope_section" not in self.rope_parameters:
self.gguf_writer.add_rope_dimension_sections(self._QWEN35_DEFAULT_MROPE_SECTION)
@ModelBase.register("Qwen3_5ForConditionalGeneration", "Qwen3_5ForCausalLM")
class Qwen3_5TextModel(_LinearAttentionVReorderBase):
class Qwen3_5TextModel(_Qwen35MRopeMixin, _LinearAttentionVReorderBase):
model_arch = gguf.MODEL_ARCH.QWEN35
@ModelBase.register("Qwen3_5MoeForConditionalGeneration", "Qwen3_5MoeForCausalLM")
class Qwen3_5MoeTextModel(_LinearAttentionVReorderBase):
class Qwen3_5MoeTextModel(_Qwen35MRopeMixin, _LinearAttentionVReorderBase):
model_arch = gguf.MODEL_ARCH.QWEN35MOE
# MiniCPM-V 4.6: text tower is Qwen3.5 (linear+full hybrid attention) wrapped under
# `model.language_model.*`; vision tower is SigLIP + a window-attention ViT merger
# + a final DownsampleMLP merger. The same HF arch is registered twice below: once as
# the LM (text mode) and once as the mmproj (vision mode), mirroring the Qwen3-VL setup.
@ModelBase.register("MiniCPMV4_6ForConditionalGeneration")
class MiniCPMV4_6TextModel(Qwen3_5TextModel):
model_arch = gguf.MODEL_ARCH.QWEN35
@classmethod
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
name, gen = item
if name.startswith("model.merger."):
return None
# MTP tensors are not used at inference yet; align with Qwen3Next behaviour
if name.startswith("mtp"):
return None
return super().filter_tensors(item)
@ModelBase.register("MiniCPMV4_6ForConditionalGeneration")
class MiniCPMV4_6VisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.hparams_vision is not None:
# In MiniCPM-V 4.6 `vision_config.image_size` (980) describes the SigLIP
# positional embedding bucket grid (70 x 70), while the per-slice processing
# resolution is the preprocessor's `scale_resolution` (typically 448).
# The CLIP loader in tools/mtmd/clip.cpp consumes `clip.vision.image_size`
# as the slice size and warmup resolution, so report `scale_resolution` there
# to match the upstream MiniCPMV4_6ImageProcessorPil slicing rules.
scale_resolution = self.preprocessor_config.get("scale_resolution")
if scale_resolution is not None:
self.hparams_vision["image_size"] = int(scale_resolution)
def set_gguf_parameters(self):
super().set_gguf_parameters()
assert self.hparams_vision is not None
# projector type string is consumed by clip_projector_type_from_string() in clip.cpp
# (mapped to PROJECTOR_TYPE_MINICPMV4_6).
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.MINICPMV4_6)
# ViT merger 2x2 + final merger 2x2 = 4x spatial merge per dimension; used for slice alignment
self.gguf_writer.add_vision_projector_scale_factor(4)
# borrow wa_layer_indexes for vit_merger insertion point
insert_layer_id = int(self.global_config.get(
"insert_layer_id", self.hparams_vision.get("insert_layer_id", 6)))
self.gguf_writer.add_vision_wa_layer_indexes([insert_layer_id])
# SigLIP vision body uses gelu_pytorch_tanh, which matches ggml_gelu (tanh approx).
self.gguf_writer.add_vision_use_gelu(True)
self.gguf_writer.add_vision_attention_layernorm_eps(
self.hparams_vision.get("layer_norm_eps", 1e-6))
@classmethod
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
name, gen = item
# lm_head / MTP -> belong to the LM file
if name.startswith(("lm_head.", "mtp")):
return None
return super().filter_tensors(item)
@ModelBase.register("GPT2LMHeadModel")
class GPT2Model(TextModel):
model_arch = gguf.MODEL_ARCH.GPT2

View File

@@ -175,6 +175,7 @@ pre_computed_hashes = [
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"},
{"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"},
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3-Embedding-0.6B", "chkhsh": "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c"},
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openbmb/MiniCPM-V-4_6", "chkhsh": "1444df51289cfa8063b96f0e62b1125440111bc79a52003ea14b6eac7016fd5f"},
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
# jina-v2-de variants
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/aari1995/German_Semantic_V3", "chkhsh": "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df"},

View File

@@ -0,0 +1,49 @@
## MiniCPM-V 4.6
### Prepare models and code
Download [MiniCPM-V-4_6](https://huggingface.co/openbmb/MiniCPM-V-4_6) PyTorch model from huggingface to "MiniCPM-V-4_6" folder.
The model must be the standard `transformers` v5.7.0+ checkpoint (no `trust_remote_code`); the architecture in `config.json` is `MiniCPMV4_6ForConditionalGeneration` with a `qwen3_5_text` text model and a SigLIP-based vision tower plus a window-attention `vit_merger`.
### Build llama.cpp
If there are differences in usage, please refer to the official build [documentation](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md)
Clone llama.cpp:
```bash
git clone https://github.com/ggml-org/llama.cpp
cd llama.cpp
```
Build llama.cpp using `CMake`:
```bash
cmake -B build
cmake --build build --config Release
```
### Usage of MiniCPM-V 4.6
Unlike older MiniCPM-V variants, MiniCPM-V 4.6 is converted directly through `convert_hf_to_gguf.py`. The same script is invoked twice on the original Hugging Face directory: once to produce the language-model GGUF and once with `--mmproj` to produce the multimodal projector GGUF.
```bash
# language model
python ./convert_hf_to_gguf.py ../MiniCPM-V-4_6 --outfile ../MiniCPM-V-4_6/ggml-model-f16.gguf
# multimodal projector (vision tower + window-attention vit_merger + DownsampleMLP merger)
python ./convert_hf_to_gguf.py ../MiniCPM-V-4_6 --mmproj --outfile ../MiniCPM-V-4_6/mmproj-model-f16.gguf
# optional: quantize to Q4_K_M
./build/bin/llama-quantize ../MiniCPM-V-4_6/ggml-model-f16.gguf ../MiniCPM-V-4_6/ggml-model-Q4_K_M.gguf Q4_K_M
```
Inference on Linux or Mac
```bash
# run in single-turn mode
./build/bin/llama-mtmd-cli -m ../MiniCPM-V-4_6/ggml-model-f16.gguf --mmproj ../MiniCPM-V-4_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
# run in conversation mode
./build/bin/llama-mtmd-cli -m ../MiniCPM-V-4_6/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-4_6/mmproj-model-f16.gguf
```

View File

@@ -773,6 +773,14 @@ class MODEL_TENSOR(IntEnum):
V_DS_NORM = auto() # qwen3vl
V_DS_FC1 = auto() # qwen3vl
V_DS_FC2 = auto() # qwen3vl
V_MERGER_LN1 = auto() # minicpmv4_6
V_MERGER_ATTN_Q = auto() # minicpmv4_6
V_MERGER_ATTN_K = auto() # minicpmv4_6
V_MERGER_ATTN_V = auto() # minicpmv4_6
V_MERGER_ATTN_O = auto() # minicpmv4_6
V_MERGER_DS_LN = auto() # minicpmv4_6
V_MERGER_DS_UP = auto() # minicpmv4_6
V_MERGER_DS_DOWN = auto() # minicpmv4_6
V_MM_POST_FC_NORM = auto() # cogvlm
V_MM_UP = auto() # cogvlm
V_MM_DOWN = auto() # cogvlm
@@ -1277,6 +1285,14 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.V_DS_NORM: "v.deepstack.{bid}.norm",
MODEL_TENSOR.V_DS_FC1: "v.deepstack.{bid}.fc1",
MODEL_TENSOR.V_DS_FC2: "v.deepstack.{bid}.fc2",
MODEL_TENSOR.V_MERGER_LN1: "v.vit_merger.ln1",
MODEL_TENSOR.V_MERGER_ATTN_Q: "v.vit_merger.attn_q",
MODEL_TENSOR.V_MERGER_ATTN_K: "v.vit_merger.attn_k",
MODEL_TENSOR.V_MERGER_ATTN_V: "v.vit_merger.attn_v",
MODEL_TENSOR.V_MERGER_ATTN_O: "v.vit_merger.attn_out",
MODEL_TENSOR.V_MERGER_DS_LN: "v.vit_merger.ds_ln",
MODEL_TENSOR.V_MERGER_DS_UP: "v.vit_merger.ds_ffn_up",
MODEL_TENSOR.V_MERGER_DS_DOWN: "v.vit_merger.ds_ffn_down",
MODEL_TENSOR.V_MM_POST_FC_NORM: "mm.post_fc_norm", # cogvlm
MODEL_TENSOR.V_MM_UP: "mm.up",
MODEL_TENSOR.V_MM_DOWN: "mm.down",
@@ -1449,6 +1465,14 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.V_DS_NORM,
MODEL_TENSOR.V_DS_FC1,
MODEL_TENSOR.V_DS_FC2,
MODEL_TENSOR.V_MERGER_LN1,
MODEL_TENSOR.V_MERGER_ATTN_Q,
MODEL_TENSOR.V_MERGER_ATTN_K,
MODEL_TENSOR.V_MERGER_ATTN_V,
MODEL_TENSOR.V_MERGER_ATTN_O,
MODEL_TENSOR.V_MERGER_DS_LN,
MODEL_TENSOR.V_MERGER_DS_UP,
MODEL_TENSOR.V_MERGER_DS_DOWN,
MODEL_TENSOR.V_MM_POST_FC_NORM,
MODEL_TENSOR.V_MM_UP,
MODEL_TENSOR.V_MM_DOWN,
@@ -4224,6 +4248,7 @@ class VisionProjectorType:
NEMOTRON_V2_VL = "nemotron_v2_vl"
HUNYUANOCR = "hunyuanocr"
HUNYUANVL = "hunyuanvl"
MINICPMV4_6 = "minicpmv4_6"
GRANITE_SPEECH = "granite_speech" # audio

View File

@@ -1399,6 +1399,7 @@ class TensorNameMap:
MODEL_TENSOR.V_ENC_EMBD_PATCH: (
"vision_tower.vision_model.embeddings.patch_embedding",
"model.vision_tower.embeddings.patch_embedding", # minicpmv4_6
"model.vision_tower.embeddings.patch_embeddings.projection", # Intern-S1
"vpm.embeddings.patch_embedding",
"model.vision_model.embeddings.patch_embedding", # SmolVLM
@@ -1424,6 +1425,7 @@ class TensorNameMap:
MODEL_TENSOR.V_ENC_EMBD_POS: (
"vision_tower.vision_model.embeddings.position_embedding",
"model.vision_tower.embeddings.position_embedding", # minicpmv4_6
"model.vision_tower.embeddings.position_embeddings", # Intern-S1
"vpm.embeddings.position_embedding",
"model.vision_model.embeddings.position_embedding", # SmolVLM
@@ -1460,6 +1462,7 @@ class TensorNameMap:
MODEL_TENSOR.V_ENC_ATTN_Q: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj",
"model.vision_tower.encoder.layers.{bid}.self_attn.q_proj", # minicpmv4_6
"model.vision_tower.encoder.layer.{bid}.attention.q_proj", # Intern-S1
"vpm.encoder.layers.{bid}.self_attn.q_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM
@@ -1483,6 +1486,7 @@ class TensorNameMap:
MODEL_TENSOR.V_ENC_ATTN_K: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj",
"model.vision_tower.encoder.layers.{bid}.self_attn.k_proj", # minicpmv4_6
"model.vision_tower.encoder.layer.{bid}.attention.k_proj", # Intern-S1
"vpm.encoder.layers.{bid}.self_attn.k_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM
@@ -1506,6 +1510,7 @@ class TensorNameMap:
MODEL_TENSOR.V_ENC_ATTN_V: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj",
"model.vision_tower.encoder.layers.{bid}.self_attn.v_proj", # minicpmv4_6
"model.vision_tower.encoder.layer.{bid}.attention.v_proj", # Intern-S1
"vpm.encoder.layers.{bid}.self_attn.v_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM
@@ -1522,6 +1527,7 @@ class TensorNameMap:
MODEL_TENSOR.V_ENC_INPUT_NORM: (
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm1",
"model.vision_tower.encoder.layers.{bid}.layer_norm1", # minicpmv4_6
"vision_tower.vision_model.encoder.layers.{bid}.norm1", # InternVL
"model.vision_tower.encoder.layer.{bid}.layernorm_before", # Intern-S1
"vpm.encoder.layers.{bid}.layer_norm1",
@@ -1542,6 +1548,7 @@ class TensorNameMap:
MODEL_TENSOR.V_ENC_ATTN_O: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj",
"model.vision_tower.encoder.layers.{bid}.self_attn.out_proj", # minicpmv4_6
"vision_tower.vision_model.encoder.layers.{bid}.attn.proj", # InternVL
"model.vision_tower.encoder.layer.{bid}.attention.projection_layer", # Intern-S1
"vpm.encoder.layers.{bid}.self_attn.out_proj",
@@ -1564,6 +1571,7 @@ class TensorNameMap:
MODEL_TENSOR.V_ENC_POST_ATTN_NORM: (
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm2",
"model.vision_tower.encoder.layers.{bid}.layer_norm2", # minicpmv4_6
"vision_tower.vision_model.encoder.layers.{bid}.norm2", # InternVL
"model.vision_tower.encoder.layer.{bid}.layernorm_after", # Intern-S1
"vpm.encoder.layers.{bid}.layer_norm2",
@@ -1585,6 +1593,7 @@ class TensorNameMap:
MODEL_TENSOR.V_ENC_FFN_UP: (
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1",
"model.vision_tower.encoder.layers.{bid}.mlp.fc1", # minicpmv4_6
"model.vision_tower.encoder.layer.{bid}.mlp.fc1", # Intern-S1
"vpm.encoder.layers.{bid}.mlp.fc1",
"model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3
@@ -1613,6 +1622,7 @@ class TensorNameMap:
MODEL_TENSOR.V_ENC_FFN_DOWN: (
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2",
"model.vision_tower.encoder.layers.{bid}.mlp.fc2", # minicpmv4_6
"model.vision_tower.encoder.layer.{bid}.mlp.fc2", # Intern-S1
"vpm.encoder.layers.{bid}.mlp.fc2",
"model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3
@@ -1668,6 +1678,7 @@ class TensorNameMap:
MODEL_TENSOR.V_POST_NORM: (
"vision_tower.vision_model.post_layernorm",
"model.vision_tower.post_layernorm", # minicpmv4_6
"model.vision_model.post_layernorm", # SmolVLM
"vision_model.layernorm_post", # llama4
"visual.merger.ln_q", # qwen2vl
@@ -1696,6 +1707,7 @@ class TensorNameMap:
"mlp_AR.pre_norm", # PaddleOCR-VL
"merger.ln_q",
"vision_tower.merger.ln_q", # dots.ocr
"model.merger.mlp.0.pre_norm", # minicpmv4_6
),
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: (
@@ -1769,6 +1781,38 @@ class TensorNameMap:
"model.visual.deepstack_merger_list.{bid}.linear_fc2", # deepstack in qwen3vl
),
MODEL_TENSOR.V_MERGER_LN1: (
"model.vision_tower.vit_merger.layer_norm1", # minicpmv4_6
),
MODEL_TENSOR.V_MERGER_ATTN_Q: (
"model.vision_tower.vit_merger.self_attn.q_proj", # minicpmv4_6
),
MODEL_TENSOR.V_MERGER_ATTN_K: (
"model.vision_tower.vit_merger.self_attn.k_proj", # minicpmv4_6
),
MODEL_TENSOR.V_MERGER_ATTN_V: (
"model.vision_tower.vit_merger.self_attn.v_proj", # minicpmv4_6
),
MODEL_TENSOR.V_MERGER_ATTN_O: (
"model.vision_tower.vit_merger.self_attn.out_proj", # minicpmv4_6
),
MODEL_TENSOR.V_MERGER_DS_LN: (
"model.vision_tower.vit_merger.pre_norm", # minicpmv4_6
),
MODEL_TENSOR.V_MERGER_DS_UP: (
"model.vision_tower.vit_merger.linear_1", # minicpmv4_6
),
MODEL_TENSOR.V_MERGER_DS_DOWN: (
"model.vision_tower.vit_merger.linear_2", # minicpmv4_6
),
MODEL_TENSOR.V_SAM_POS_EMBD: (
"model.sam_model.pos_embed",
),
@@ -1828,11 +1872,13 @@ class TensorNameMap:
MODEL_TENSOR.V_MM_UP: (
"model.vision.linear_proj.dense_h_to_4h", # cogvlm
"visual.merger.up_proj", # glm4v
"model.merger.mlp.0.linear_1", # minicpmv4_6
),
MODEL_TENSOR.V_MM_DOWN: (
"model.vision.linear_proj.dense_4h_to_h", # cogvlm
"visual.merger.down_proj", # glm4v
"model.merger.mlp.0.linear_2", # minicpmv4_6
),
MODEL_TENSOR.V_MM_GATE: (

View File

@@ -285,7 +285,7 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params
case LLM_ARCH_STEP35:
return new llama_model_step35(params);
default:
GGML_ABORT("unimplemented model class");
throw std::runtime_error(std::string("unsupported model architecture: '") + llm_arch_name(arch) + "'");
}
}

View File

@@ -49,6 +49,7 @@ For the following models, you can use `convert_hf_to_gguf.py` with `--mmproj` fl
- Qwen 2 VL and Qwen 2.5 VL (from [Qwen](https://huggingface.co/Qwen))
- [Mistral Small 3.1 24B](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)
- InternVL 2.5 and InternVL 3 from [OpenGVLab](https://huggingface.co/OpenGVLab) (note: we don't support conversion of `InternVL3-*-hf` model, only non-HF version is supported ; `InternLM2Model` **text** model is not supported)
- [MiniCPM-V 4.6](https://huggingface.co/openbmb/MiniCPM-V-4_6) ; See the guide [here](../../docs/multimodal/minicpmv4.6.md) - requires the standard `transformers` v5.7.0+ checkpoint
For older models, please refer to the relevant guide for instructions on how to obtain or create them:
@@ -60,4 +61,7 @@ NOTE: conversion scripts are located under `tools/mtmd/legacy-models`
- [MiniCPM-V 2.5](../../docs/multimodal/minicpmv2.5.md)
- [MiniCPM-V 2.6](../../docs/multimodal/minicpmv2.6.md)
- [MiniCPM-o 2.6](../../docs/multimodal/minicpmo2.6.md)
- [MiniCPM-V 4.0](../../docs/multimodal/minicpmv4.0.md)
- [MiniCPM-o 4.0](../../docs/multimodal/minicpmo4.0.md)
- [MiniCPM-V 4.5](../../docs/multimodal/minicpmv4.5.md)
- [IBM Granite Vision](../../docs/multimodal/granitevision.md)

View File

@@ -132,6 +132,17 @@
#define TN_MINICPMV_ATTN "resampler.attn.%s.%s"
#define TN_MINICPMV_LN "resampler.ln_%s.%s"
// MiniCPM-V 4.6 ViT merger (window attention + MLP downsample),
// matching the upstream `vit_merger` module name in transformers.
#define TN_VIT_MERGER_LN1 "v.vit_merger.ln1.%s"
#define TN_VIT_MERGER_ATTN_Q "v.vit_merger.attn_q.%s"
#define TN_VIT_MERGER_ATTN_K "v.vit_merger.attn_k.%s"
#define TN_VIT_MERGER_ATTN_V "v.vit_merger.attn_v.%s"
#define TN_VIT_MERGER_ATTN_O "v.vit_merger.attn_out.%s"
#define TN_VIT_MERGER_DS_LN "v.vit_merger.ds_ln.%s"
#define TN_VIT_MERGER_DS_UP "v.vit_merger.ds_ffn_up.%s"
#define TN_VIT_MERGER_DS_DOWN "v.vit_merger.ds_ffn_down.%s"
#define TN_GLM_ADAPER_CONV "adapter.conv.%s"
#define TN_GLM_ADAPTER_LINEAR "adapter.linear.linear.%s"
#define TN_GLM_ADAPTER_NORM_1 "adapter.linear.norm1.%s"
@@ -331,6 +342,7 @@ enum projector_type {
PROJECTOR_TYPE_NEMOTRON_V2_VL,
PROJECTOR_TYPE_HUNYUANOCR,
PROJECTOR_TYPE_HUNYUANVL,
PROJECTOR_TYPE_MINICPMV4_6,
PROJECTOR_TYPE_GRANITE_SPEECH,
PROJECTOR_TYPE_UNKNOWN,
};
@@ -379,6 +391,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_NEMOTRON_V2_VL, "nemotron_v2_vl"},
{ PROJECTOR_TYPE_HUNYUANOCR, "hunyuanocr"},
{ PROJECTOR_TYPE_HUNYUANVL, "hunyuanvl"},
{ PROJECTOR_TYPE_MINICPMV4_6, "minicpmv4_6"},
{ PROJECTOR_TYPE_GRANITE_SPEECH, "granite_speech"},
};

View File

@@ -110,6 +110,7 @@ struct clip_hparams {
bool has_llava_projector = false;
int minicpmv_version = 0;
int32_t minicpmv_query_num = 0; // MiniCPM-V query number
int32_t insert_layer_id = 0; // MiniCPM-V 4.6 ViT merger insertion layer
// custom value provided by user, can be undefined if not set
int32_t custom_image_min_tokens = -1;
@@ -424,6 +425,24 @@ struct clip_model {
ggml_tensor * mm_model_ln_post_w = nullptr;
ggml_tensor * mm_model_ln_post_b = nullptr;
// MiniCPM-V 4.6 ViT merger (window self-attention + ViT MLP downsample)
ggml_tensor * vit_merger_ln1_w = nullptr;
ggml_tensor * vit_merger_ln1_b = nullptr;
ggml_tensor * vit_merger_attn_q_w = nullptr;
ggml_tensor * vit_merger_attn_q_b = nullptr;
ggml_tensor * vit_merger_attn_k_w = nullptr;
ggml_tensor * vit_merger_attn_k_b = nullptr;
ggml_tensor * vit_merger_attn_v_w = nullptr;
ggml_tensor * vit_merger_attn_v_b = nullptr;
ggml_tensor * vit_merger_attn_o_w = nullptr;
ggml_tensor * vit_merger_attn_o_b = nullptr;
ggml_tensor * vit_merger_ds_ln_w = nullptr;
ggml_tensor * vit_merger_ds_ln_b = nullptr;
ggml_tensor * vit_merger_ds_up_w = nullptr;
ggml_tensor * vit_merger_ds_up_b = nullptr;
ggml_tensor * vit_merger_ds_down_w = nullptr;
ggml_tensor * vit_merger_ds_down_b = nullptr;
// gemma3
ggml_tensor * mm_input_proj_w = nullptr;
ggml_tensor * mm_soft_emb_norm_w = nullptr;

View File

@@ -874,6 +874,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{
builder = std::make_unique<clip_graph_minicpmv>(ctx, img);
} break;
case PROJECTOR_TYPE_MINICPMV4_6:
{
builder = std::make_unique<clip_graph_minicpmv4_6>(ctx, img);
} break;
case PROJECTOR_TYPE_INTERNVL:
{
builder = std::make_unique<clip_graph_internvl>(ctx, img);
@@ -1231,6 +1235,20 @@ struct clip_model_loader {
hparams.minicpmv_version = 2; // default to 2 if not set
}
} break;
case PROJECTOR_TYPE_MINICPMV4_6:
{
// MiniCPM-V 4.6 unified merger projector
// ViT merger 2x2 + final merger 2x2 = 4x spatial merge per dimension
hparams.n_merge = 4;
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
// borrow wa_layer_indexes for vit_merger insertion point
std::vector<int> wa_layer_indexes_vec;
get_arr_int(KEY_WIN_ATTN_LAYER_INDEXES, wa_layer_indexes_vec, false);
if (!wa_layer_indexes_vec.empty()) {
hparams.insert_layer_id = wa_layer_indexes_vec[0];
}
} break;
case PROJECTOR_TYPE_INTERNVL:
{
// use default llava-uhd preprocessing params
@@ -1737,6 +1755,7 @@ struct clip_model_loader {
|| model.proj_type == PROJECTOR_TYPE_GEMMA3
|| model.proj_type == PROJECTOR_TYPE_IDEFICS3
|| model.proj_type == PROJECTOR_TYPE_MINICPMV
|| model.proj_type == PROJECTOR_TYPE_MINICPMV4_6
) && layer.ff_up_w && layer.ff_down_w && layer.ff_down_w->ne[0] == hparams.n_embd;
if (is_ffn_swapped) {
// swap up and down weights
@@ -1838,6 +1857,34 @@ struct clip_model_loader {
model.mm_model_ln_post_w = get_tensor(string_format(TN_MINICPMV_LN, "post", "weight"));
model.mm_model_ln_post_b = get_tensor(string_format(TN_MINICPMV_LN, "post", "bias"));
} break;
case PROJECTOR_TYPE_MINICPMV4_6:
{
// ViT merger: window self-attention
model.vit_merger_ln1_w = get_tensor(string_format(TN_VIT_MERGER_LN1, "weight"));
model.vit_merger_ln1_b = get_tensor(string_format(TN_VIT_MERGER_LN1, "bias"));
model.vit_merger_attn_q_w = get_tensor(string_format(TN_VIT_MERGER_ATTN_Q, "weight"));
model.vit_merger_attn_q_b = get_tensor(string_format(TN_VIT_MERGER_ATTN_Q, "bias"), false);
model.vit_merger_attn_k_w = get_tensor(string_format(TN_VIT_MERGER_ATTN_K, "weight"));
model.vit_merger_attn_k_b = get_tensor(string_format(TN_VIT_MERGER_ATTN_K, "bias"), false);
model.vit_merger_attn_v_w = get_tensor(string_format(TN_VIT_MERGER_ATTN_V, "weight"));
model.vit_merger_attn_v_b = get_tensor(string_format(TN_VIT_MERGER_ATTN_V, "bias"), false);
model.vit_merger_attn_o_w = get_tensor(string_format(TN_VIT_MERGER_ATTN_O, "weight"));
model.vit_merger_attn_o_b = get_tensor(string_format(TN_VIT_MERGER_ATTN_O, "bias"), false);
// ViT merger: MLP downsample
model.vit_merger_ds_ln_w = get_tensor(string_format(TN_VIT_MERGER_DS_LN, "weight"));
model.vit_merger_ds_ln_b = get_tensor(string_format(TN_VIT_MERGER_DS_LN, "bias"));
model.vit_merger_ds_up_w = get_tensor(string_format(TN_VIT_MERGER_DS_UP, "weight"));
model.vit_merger_ds_up_b = get_tensor(string_format(TN_VIT_MERGER_DS_UP, "bias"), false);
model.vit_merger_ds_down_w = get_tensor(string_format(TN_VIT_MERGER_DS_DOWN, "weight"));
model.vit_merger_ds_down_b = get_tensor(string_format(TN_VIT_MERGER_DS_DOWN, "bias"), false);
// Final Merger (DownsampleMLP)
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM);
model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B, false);
model.mm_ffn_up_w = get_tensor(string_format(TN_MM_UP, "weight"));
model.mm_ffn_up_b = get_tensor(string_format(TN_MM_UP, "bias"), false);
model.mm_ffn_down_w = get_tensor(string_format(TN_MM_DOWN, "weight"));
model.mm_ffn_down_b = get_tensor(string_format(TN_MM_DOWN, "bias"), false);
} break;
case PROJECTOR_TYPE_GLM_EDGE:
{
model.mm_model_adapter_conv_w = get_tensor(string_format(TN_GLM_ADAPER_CONV, "weight"));
@@ -3055,6 +3102,11 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
}
}
} break;
case PROJECTOR_TYPE_MINICPMV4_6:
{
// ViT merger 4x + final merger 4x = 16x total spatial downsample
n_patches = n_patches / 16;
} break;
case PROJECTOR_TYPE_QWEN2VL:
case PROJECTOR_TYPE_QWEN25VL:
case PROJECTOR_TYPE_QWEN3VL:
@@ -3377,6 +3429,92 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
}
set_input_f32("omega", omega);
} break;
case PROJECTOR_TYPE_MINICPMV4_6:
{
// SigLIP position buckets (same as resampler path)
std::vector<int32_t> positions(pos_h * pos_w);
int bucket_coords_h[1024];
int bucket_coords_w[1024];
for (int i = 0; i < pos_h; i++){
bucket_coords_h[i] = std::floor(70.0*i/pos_h);
}
for (int i = 0; i < pos_w; i++){
bucket_coords_w[i] = std::floor(70.0*i/pos_w);
}
for (int i = 0, id = 0; i < pos_h; i++){
for (int j = 0; j < pos_w; j++){
positions[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j];
}
}
set_input_i32("positions", positions);
const int half_h = pos_h / 2;
const int half_w = pos_w / 2;
// window reorder indices for 2x2 windows
std::vector<int32_t> window_idx(n_pos);
std::vector<int32_t> inv_window_idx(n_pos);
{
int k = 0;
for (int wi = 0; wi < half_h; wi++) {
for (int wj = 0; wj < half_w; wj++) {
window_idx[k++] = (2*wi ) * pos_w + (2*wj );
window_idx[k++] = (2*wi ) * pos_w + (2*wj + 1);
window_idx[k++] = (2*wi + 1) * pos_w + (2*wj );
window_idx[k++] = (2*wi + 1) * pos_w + (2*wj + 1);
}
}
for (int i = 0; i < n_pos; i++) {
inv_window_idx[window_idx[i]] = i;
}
}
set_input_i32("vit_merger_window_idx", window_idx);
set_input_i32("vit_merger_inv_window_idx", inv_window_idx);
// block-diagonal attention mask: tokens in the same 4-token
// window attend to each other (mask = 0), all other positions
// are masked out (-inf). matches the window-major reorder above.
std::vector<float> window_mask_data(n_pos * n_pos, std::numeric_limits<float>::lowest());
for (int wi = 0; wi < n_pos / 4; wi++) {
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
window_mask_data[(wi*4 + i) * n_pos + (wi*4 + j)] = 0.0f;
}
}
}
set_input_f32("vit_merger_window_mask", window_mask_data);
// ViT merger 2x2 downsample indices
auto make_ds_idx = [](int off_r, int off_c, int ds_h, int ds_w, int stride_w) {
std::vector<int32_t> idx(ds_h * ds_w);
for (int i = 0; i < ds_h; i++) {
for (int j = 0; j < ds_w; j++) {
idx[i * ds_w + j] = (2*i + off_r) * stride_w + (2*j + off_c);
}
}
return idx;
};
auto vit_merger_ds_0 = make_ds_idx(0, 0, half_h, half_w, pos_w);
auto vit_merger_ds_1 = make_ds_idx(0, 1, half_h, half_w, pos_w);
auto vit_merger_ds_2 = make_ds_idx(1, 0, half_h, half_w, pos_w);
auto vit_merger_ds_3 = make_ds_idx(1, 1, half_h, half_w, pos_w);
set_input_i32("vit_merger_ds_idx_0", vit_merger_ds_0);
set_input_i32("vit_merger_ds_idx_1", vit_merger_ds_1);
set_input_i32("vit_merger_ds_idx_2", vit_merger_ds_2);
set_input_i32("vit_merger_ds_idx_3", vit_merger_ds_3);
// final merger 2x2 downsample indices (operates on half_h x half_w grid)
const int qh = half_h / 2;
const int qw = half_w / 2;
auto m_ds_0 = make_ds_idx(0, 0, qh, qw, half_w);
auto m_ds_1 = make_ds_idx(0, 1, qh, qw, half_w);
auto m_ds_2 = make_ds_idx(1, 0, qh, qw, half_w);
auto m_ds_3 = make_ds_idx(1, 1, qh, qw, half_w);
set_input_i32("merger_ds_idx_0", m_ds_0);
set_input_i32("merger_ds_idx_1", m_ds_1);
set_input_i32("merger_ds_idx_2", m_ds_2);
set_input_i32("merger_ds_idx_3", m_ds_3);
} break;
case PROJECTOR_TYPE_QWEN2VL:
case PROJECTOR_TYPE_QWEN3VL:
case PROJECTOR_TYPE_GLM4V:
@@ -3931,6 +4069,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->model.mm_3_b->ne[0];
case PROJECTOR_TYPE_MINICPMV:
return ctx->model.mm_model_proj->ne[0];
case PROJECTOR_TYPE_MINICPMV4_6:
return ctx->model.mm_ffn_down_w->ne[1];
case PROJECTOR_TYPE_GLM_EDGE:
return ctx->model.mm_model_mlp_3_w->ne[1];
case PROJECTOR_TYPE_QWEN2VL:
@@ -3997,6 +4137,9 @@ int clip_is_minicpmv(const struct clip_ctx * ctx) {
if (ctx->proj_type() == PROJECTOR_TYPE_MINICPMV) {
return ctx->model.hparams.minicpmv_version;
}
if (ctx->proj_type() == PROJECTOR_TYPE_MINICPMV4_6) {
return 46;
}
return 0;
}

View File

@@ -112,3 +112,294 @@ ggml_cgraph * clip_graph_minicpmv::build() {
return gf;
}
ggml_cgraph * clip_graph_minicpmv4_6::build() {
const int insert_lid = hparams.insert_layer_id;
const int n_pos = n_patches;
const int half_h = n_patches_y / 2;
const int half_w = n_patches_x / 2;
const int n_ds = half_h * half_w; // after ViT merger 2x2 downsample
const int qh = half_h / 2;
const int qw = half_w / 2;
const int n_ds2 = qh * qw; // after final merger 2x2 downsample
auto add_i32_input = [&](const char * name, int n) {
ggml_tensor * t = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n);
ggml_set_name(t, name);
ggml_set_input(t);
return t;
};
// position indices for ViT learned positional embeddings
ggml_tensor * positions = add_i32_input("positions", n_pos);
ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, model.position_embeddings, positions);
// ViT merger window reorder indices + block-diagonal mask
// (mask layout follows qwen2vl: -inf except for 4x4 blocks on the diagonal,
// so each window-major group of 4 tokens only attends to itself)
ggml_tensor * vit_merger_window_idx = add_i32_input("vit_merger_window_idx", n_pos);
ggml_tensor * vit_merger_inv_window_idx = add_i32_input("vit_merger_inv_window_idx", n_pos);
ggml_tensor * vit_merger_window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_pos, n_pos);
ggml_set_name(vit_merger_window_mask, "vit_merger_window_mask");
ggml_set_input(vit_merger_window_mask);
if (flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
vit_merger_window_mask = ggml_cast(ctx0, vit_merger_window_mask, GGML_TYPE_F16);
}
// ViT merger 2x2 downsample gather indices
ggml_tensor * vit_merger_ds_idx_0 = add_i32_input("vit_merger_ds_idx_0", n_ds);
ggml_tensor * vit_merger_ds_idx_1 = add_i32_input("vit_merger_ds_idx_1", n_ds);
ggml_tensor * vit_merger_ds_idx_2 = add_i32_input("vit_merger_ds_idx_2", n_ds);
ggml_tensor * vit_merger_ds_idx_3 = add_i32_input("vit_merger_ds_idx_3", n_ds);
// final merger 2x2 downsample gather indices
ggml_tensor * merger_ds_idx_0 = add_i32_input("merger_ds_idx_0", n_ds2);
ggml_tensor * merger_ds_idx_1 = add_i32_input("merger_ds_idx_1", n_ds2);
ggml_tensor * merger_ds_idx_2 = add_i32_input("merger_ds_idx_2", n_ds2);
ggml_tensor * merger_ds_idx_3 = add_i32_input("merger_ds_idx_3", n_ds2);
// patch embedding + positional embedding
ggml_tensor * inp = build_inp();
inp = ggml_add(ctx0, inp, learned_pos_embd);
cb(inp, "pos_embed", -1);
ggml_tensor * inpL = inp;
if (model.pre_ln_w) {
inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, NORM_TYPE_NORMAL, eps, -1);
cb(inpL, "pre_ln", -1);
}
// ViT layers 0..insert_layer_id (inclusive)
// Mirrors the separate-qkv path of clip_graph::build_vit so the two manually
// unrolled segments around the ViT merger read like build_vit() expansions.
for (int il = 0; il <= insert_lid; il++) {
auto & layer = model.layers[il];
ggml_tensor * cur = inpL;
cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il);
cb(cur, "layer_inp_normed", il);
{
ggml_tensor * Qcur = build_mm(layer.q_w, cur);
if (layer.q_b) {
Qcur = ggml_add(ctx0, Qcur, layer.q_b);
}
ggml_tensor * Kcur = build_mm(layer.k_w, cur);
if (layer.k_b) {
Kcur = ggml_add(ctx0, Kcur, layer.k_b);
}
ggml_tensor * Vcur = build_mm(layer.v_w, cur);
if (layer.v_b) {
Vcur = ggml_add(ctx0, Vcur, layer.v_b);
}
Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos);
Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = build_attn(layer.o_w, layer.o_b, Qcur, Kcur, Vcur, nullptr, kq_scale, il);
cb(cur, "attn_out", il);
}
if (layer.ls_1_w) {
cur = ggml_mul(ctx0, cur, layer.ls_1_w);
cb(cur, "attn_out_scaled", il);
}
cur = ggml_add(ctx0, cur, inpL);
inpL = cur;
cb(cur, "ffn_inp", il);
cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il);
cb(cur, "ffn_inp_normed", il);
cur = build_ffn(cur, layer.ff_up_w, layer.ff_up_b, layer.ff_gate_w, layer.ff_gate_b,
layer.ff_down_w, layer.ff_down_b, hparams.ffn_op, il);
cb(cur, "ffn_out", il);
if (layer.ls_2_w) {
cur = ggml_mul(ctx0, cur, layer.ls_2_w);
cb(cur, "ffn_out_scaled", il);
}
cur = ggml_add(ctx0, inpL, cur);
cb(cur, "layer_out", il);
inpL = cur;
}
// ViT merger: window self-attention
// Tokens are reordered to window-major (4 tokens per window are contiguous),
// and a block-diagonal mask restricts attention to within each window. This
// mirrors the qwen2vl windowed-attention pattern so build_attn() can pick the
// flash-attention path when available.
{
ggml_tensor * residual = inpL;
ggml_tensor * cur = build_norm(inpL,
model.vit_merger_ln1_w, model.vit_merger_ln1_b,
NORM_TYPE_NORMAL, eps, -1);
cb(cur, "vit_merger_attn_inp_normed", -1);
cur = ggml_get_rows(ctx0, cur, vit_merger_window_idx);
cb(cur, "vit_merger_window_reorder", -1);
ggml_tensor * Qcur = build_mm(model.vit_merger_attn_q_w, cur);
if (model.vit_merger_attn_q_b) {
Qcur = ggml_add(ctx0, Qcur, model.vit_merger_attn_q_b);
}
ggml_tensor * Kcur = build_mm(model.vit_merger_attn_k_w, cur);
if (model.vit_merger_attn_k_b) {
Kcur = ggml_add(ctx0, Kcur, model.vit_merger_attn_k_b);
}
ggml_tensor * Vcur = build_mm(model.vit_merger_attn_v_w, cur);
if (model.vit_merger_attn_v_b) {
Vcur = ggml_add(ctx0, Vcur, model.vit_merger_attn_v_b);
}
Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos);
Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos);
cb(Qcur, "vit_merger_Qcur", -1);
cb(Kcur, "vit_merger_Kcur", -1);
cb(Vcur, "vit_merger_Vcur", -1);
cur = build_attn(model.vit_merger_attn_o_w, model.vit_merger_attn_o_b,
Qcur, Kcur, Vcur, vit_merger_window_mask, kq_scale, -1);
cb(cur, "vit_merger_attn_out", -1);
cur = ggml_get_rows(ctx0, cur, vit_merger_inv_window_idx);
inpL = ggml_add(ctx0, cur, residual);
cb(inpL, "vit_merger_attn_residual", -1);
}
// ViT merger: 2x2 spatial downsample + MLP (4 tokens -> 1)
{
ggml_tensor * p0 = ggml_get_rows(ctx0, inpL, vit_merger_ds_idx_0);
ggml_tensor * p1 = ggml_get_rows(ctx0, inpL, vit_merger_ds_idx_1);
ggml_tensor * p2 = ggml_get_rows(ctx0, inpL, vit_merger_ds_idx_2);
ggml_tensor * p3 = ggml_get_rows(ctx0, inpL, vit_merger_ds_idx_3);
ggml_tensor * mean_res = ggml_add(ctx0, p0, p1);
mean_res = ggml_add(ctx0, mean_res, p2);
mean_res = ggml_add(ctx0, mean_res, p3);
mean_res = ggml_scale(ctx0, mean_res, 0.25f);
cb(mean_res, "vit_merger_ds_mean_res", -1);
ggml_tensor * cat = ggml_concat(ctx0, p0, p1, 0);
cat = ggml_concat(ctx0, cat, p2, 0);
cat = ggml_concat(ctx0, cat, p3, 0);
ggml_tensor * cur = build_norm(cat,
model.vit_merger_ds_ln_w, model.vit_merger_ds_ln_b,
NORM_TYPE_NORMAL, eps, -1);
cb(cur, "vit_merger_ds_normed", -1);
// ViTWindowAttentionMerger downsample MLP uses gelu_pytorch_tanh (FFN_GELU)
cur = build_ffn(cur,
model.vit_merger_ds_up_w, model.vit_merger_ds_up_b,
nullptr, nullptr,
model.vit_merger_ds_down_w, model.vit_merger_ds_down_b,
FFN_GELU, -1);
cb(cur, "vit_merger_ds_mlp_out", -1);
inpL = ggml_add(ctx0, cur, mean_res);
cb(inpL, "vit_merger_ds_out", -1);
}
// ViT layers (insert_layer_id+1)..n_layer-1, operating on the downsampled tokens
{
const int64_t n_pos_ds = n_ds;
for (int il = insert_lid + 1; il < n_layer; il++) {
auto & layer = model.layers[il];
ggml_tensor * cur = inpL;
cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il);
cb(cur, "layer_inp_normed", il);
{
ggml_tensor * Qcur = build_mm(layer.q_w, cur);
if (layer.q_b) {
Qcur = ggml_add(ctx0, Qcur, layer.q_b);
}
ggml_tensor * Kcur = build_mm(layer.k_w, cur);
if (layer.k_b) {
Kcur = ggml_add(ctx0, Kcur, layer.k_b);
}
ggml_tensor * Vcur = build_mm(layer.v_w, cur);
if (layer.v_b) {
Vcur = ggml_add(ctx0, Vcur, layer.v_b);
}
Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos_ds);
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos_ds);
Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos_ds);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = build_attn(layer.o_w, layer.o_b, Qcur, Kcur, Vcur, nullptr, kq_scale, il);
cb(cur, "attn_out", il);
}
if (layer.ls_1_w) {
cur = ggml_mul(ctx0, cur, layer.ls_1_w);
cb(cur, "attn_out_scaled", il);
}
cur = ggml_add(ctx0, cur, inpL);
inpL = cur;
cb(cur, "ffn_inp", il);
cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il);
cb(cur, "ffn_inp_normed", il);
cur = build_ffn(cur, layer.ff_up_w, layer.ff_up_b, layer.ff_gate_w, layer.ff_gate_b,
layer.ff_down_w, layer.ff_down_b, hparams.ffn_op, il);
cb(cur, "ffn_out", il);
if (layer.ls_2_w) {
cur = ggml_mul(ctx0, cur, layer.ls_2_w);
cb(cur, "ffn_out_scaled", il);
}
cur = ggml_add(ctx0, inpL, cur);
cb(cur, "layer_out", il);
inpL = cur;
}
}
if (model.post_ln_w) {
inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, NORM_TYPE_NORMAL, eps, -1);
cb(inpL, "post_ln", -1);
}
// Final Merger (DownsampleMLP): another 2x2 spatial merge -> projector embedding
{
ggml_tensor * p0 = ggml_get_rows(ctx0, inpL, merger_ds_idx_0);
ggml_tensor * p1 = ggml_get_rows(ctx0, inpL, merger_ds_idx_1);
ggml_tensor * p2 = ggml_get_rows(ctx0, inpL, merger_ds_idx_2);
ggml_tensor * p3 = ggml_get_rows(ctx0, inpL, merger_ds_idx_3);
ggml_tensor * cat = ggml_concat(ctx0, p0, p1, 0);
cat = ggml_concat(ctx0, cat, p2, 0);
cat = ggml_concat(ctx0, cat, p3, 0);
ggml_tensor * cur = build_norm(cat,
model.mm_input_norm_w, model.mm_input_norm_b,
NORM_TYPE_NORMAL, eps, -1);
cb(cur, "merger_normed", -1);
// MiniCPMV4_6DownsampleMLP uses nn.GELU() (erf-based, FFN_GELU_ERF)
cur = build_ffn(cur,
model.mm_ffn_up_w, model.mm_ffn_up_b,
nullptr, nullptr,
model.mm_ffn_down_w, model.mm_ffn_down_b,
FFN_GELU_ERF, -1);
cb(cur, "merger_out", -1);
inpL = cur;
}
ggml_build_forward_expand(gf, inpL);
return gf;
}

View File

@@ -56,6 +56,11 @@ struct clip_graph_minicpmv : clip_graph {
ggml_cgraph * build() override;
};
struct clip_graph_minicpmv4_6 : clip_graph {
clip_graph_minicpmv4_6(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
ggml_cgraph * build() override;
};
struct clip_graph_internvl : clip_graph {
clip_graph_internvl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
ggml_cgraph * build() override;

View File

@@ -584,7 +584,9 @@ bool mtmd_image_preprocessor_llava_uhd::preprocess(const clip_image_u8 & img, cl
mtmd_image_preprocessor_llava_uhd::slice_instructions mtmd_image_preprocessor_llava_uhd::get_slice_instructions(const clip_image_size & original_size) {
mtmd_image_preprocessor_llava_uhd::slice_instructions res;
const int patch_size = hparams.patch_size;
// align slices by patch_size * n_merge so an integer number of merger output tokens fits per slice
const int n_merge = hparams.n_merge > 0 ? hparams.n_merge : 1;
const int patch_size = hparams.patch_size * n_merge;
const int slice_size = hparams.image_size;
const int original_width = original_size.width;
const int original_height = original_size.height;

View File

@@ -310,6 +310,18 @@ struct mtmd_context {
}
image_preproc = std::make_unique<mtmd_image_preprocessor_llava_uhd>(ctx_v);
} break;
case PROJECTOR_TYPE_MINICPMV4_6:
{
slice_tmpl = MTMD_SLICE_TMPL_MINICPMV_2_6;
tok_ov_img_start = {lookup_token("<image>")};
tok_ov_img_end = {lookup_token("</image>")};
tok_sli_img_start = {lookup_token("<slice>")};
tok_sli_img_end = {lookup_token("</slice>")};
tok_row_end = {lookup_token("\n")};
tok_row_end_trail = false; // no trailing end-of-row token
ov_img_first = true;
image_preproc = std::make_unique<mtmd_image_preprocessor_llava_uhd>(ctx_v);
} break;
case PROJECTOR_TYPE_QWEN2VL:
case PROJECTOR_TYPE_QWEN25VL:
case PROJECTOR_TYPE_QWEN3VL: