Compare commits

...

5 Commits
b9544 ... b9549

Author SHA1 Message Date
Aman Gupta
04eb4c446d llama : add Gemma4 MTP (#23398) 2026-06-07 20:50:54 +08:00
Sigbjørn Skjæret
8a091c47ab spec : fix vocab compatibility check (#24256) 2026-06-07 14:43:52 +03:00
konradmb
465b1f0e75 arg: Skip mmproj download when user supplied mmproj (#24239) 2026-06-07 11:18:44 +02:00
Sigbjørn Skjæret
f71af352a5 convert : fix Gemma4 with no audio encoder (#24242) 2026-06-07 08:43:05 +02:00
Sigbjørn Skjæret
3f7c79d7b5 docker : bump cuda13 to 13.3.0 (#24228) 2026-06-07 08:31:58 +02:00
33 changed files with 654 additions and 151 deletions

View File

@@ -82,8 +82,8 @@ jobs:
{ "tag": "cpu", "dockerfile": ".devops/s390x.Dockerfile", "platforms": "linux/s390x", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-s390x" },
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.8.1", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.8.1", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
{ "tag": "cuda13", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "13.1.1", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
{ "tag": "cuda13", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "13.1.1", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
{ "tag": "cuda13", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "13.3.0", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
{ "tag": "cuda13", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "13.3.0", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
{ "tag": "musa", "dockerfile": ".devops/musa.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
{ "tag": "intel", "dockerfile": ".devops/intel.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
{ "tag": "vulkan", "dockerfile": ".devops/vulkan.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04" },

View File

@@ -444,7 +444,7 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex)
opts.offline = params.offline;
opts.skip_download = params.skip_download;
opts.download_mtp = spec_type_draft_mtp;
opts.download_mmproj = !params.no_mmproj;
opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty();
// sub-models (draft, mmproj, vocoder) are explicitly specified by the user,
// so we should not auto-discover mtp/mmproj siblings for them

View File

@@ -3,13 +3,14 @@
#include "common.h"
#include "ggml.h"
#include "llama.h"
#include "../src/llama-ext.h" // staging API: llama_set_embeddings_nextn / llama_get_embeddings_nextn_ith (used by MTP)
#include "log.h"
#include "ngram-cache.h"
#include "ngram-map.h"
#include "ngram-mod.h"
#include "sampling.h"
#include "../src/llama-ext.h" // staging API: llama_set_embeddings_nextn / llama_get_embeddings_nextn_ith (used by MTP)
#include <algorithm>
#include <cassert>
#include <cstring>
@@ -58,10 +59,10 @@ static bool common_speculative_are_compatible(
const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
const bool vocab_type_tgt = llama_vocab_type(vocab_tgt);
const auto vocab_type_tgt = llama_vocab_type(vocab_tgt);
LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
const bool vocab_type_dft = llama_vocab_type(vocab_dft);
const auto vocab_type_dft = llama_vocab_type(vocab_dft);
LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
if (vocab_type_tgt != vocab_type_dft) {
@@ -418,6 +419,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
int32_t n_embd = 0;
bool is_mem_shared = false;
// Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
// The last h-row of one process() call needs the first token of the NEXT
// call to pair with, so it's stashed here until that next call fires.
@@ -444,7 +447,9 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
auto * ctx_dft = this->params.ctx_dft;
GGML_ASSERT(ctx_tgt && ctx_dft && "MTP requires ctx_tgt and ctx_dft to be set");
n_embd = llama_model_n_embd(llama_get_model(ctx_dft));
n_embd = llama_model_n_embd_out(llama_get_model(ctx_dft));
GGML_ASSERT(n_embd == llama_model_n_embd(llama_get_model(ctx_tgt)) &&
"MTP input row width must match the target h_nextn width");
LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
@@ -490,6 +495,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
llama_set_embeddings_nextn(ctx_tgt, true, /*masked*/ false);
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
is_mem_shared = llama_get_ctx_other(ctx_dft) == ctx_tgt;
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
i_batch_beg.assign(n_seq, -1);
@@ -526,9 +533,11 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
if (N <= 0) {
return;
}
auto * ctx_dft = this->params.ctx_dft;
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
if (pos_max < N - 1) {
if (pos_max < N - 1 && !is_mem_shared) {
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - "
"process() hook may not have run on every prefill ubatch "
"(need_embd / logits=1 on every prompt position?). "
@@ -571,48 +580,42 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
const size_t row_bytes = (size_t) n_embd * sizeof(float);
common_batch_clear(batch);
// if kv is shared with target (e.g Gemma4), then we can skip this catch-up decode
if (!is_mem_shared) {
common_batch_clear(batch);
for (int k = 0; k < n_tokens; ++k) {
common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0);
}
// shift the tgt embeddings to the right by one position
// assumes that the tokens in the batch are sequential for each sequence
// i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1]
// ^--- this is a problem
// TODO:this is generally true, but would be nice to assert it
{
const float * h_tgt = llama_get_embeddings_nextn(ctx_tgt);
std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1));
//{
// // string with seq_ids in the batch
// std::stringstream ss;
// for (int i = 0; i < n_tokens; ++i) {
// ss << batch_in.seq_id[i][0] << ",";
// }
// LOG_WRN("%s: batch_in.seq_id = %s\n", __func__, ss.str().c_str());
//}
}
// fill the pending embeddings from a previous run
auto set_h = [&](int idx, const float * h_row) {
std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes);
};
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_batch_beg[seq_id] < 0) {
continue;
for (int k = 0; k < n_tokens; ++k) {
common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0);
}
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
}
// shift the tgt embeddings to the right by one position
// assumes that the tokens in the batch are sequential for each sequence
// i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1]
// ^--- this is a problem
// TODO:this is generally true, but would be nice to assert it
{
const float * h_tgt = llama_get_embeddings_nextn(ctx_tgt);
std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1));
}
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
return false;
// fill the pending embeddings from a previous run
auto set_h = [&](int idx, const float * h_row) {
std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes);
};
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_batch_beg[seq_id] < 0) {
continue;
}
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
}
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
return false;
}
}
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
@@ -721,7 +724,13 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
continue;
}
common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
if (is_mem_shared) {
// note: with shared memory (e.g. Gemma4 assistants) we use the same position for all draft tokens
// ref: https://github.com/huggingface/transformers/blob/effde20942e3f82a1b97449f60b3a48c5ff96145/docs/source/en/model_doc/gemma4_assistant.md?plain=1#L36-L37
common_batch_add(batch, id, dp.n_past, { seq_id }, true);
} else {
common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
}
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
}

View File

@@ -75,9 +75,11 @@ TEXT_MODEL_MAP: dict[str, str] = {
"Gemma3TextModel": "gemma",
"Gemma3nForCausalLM": "gemma",
"Gemma3nForConditionalGeneration": "gemma",
"Gemma4AssistantForCausalLM": "gemma",
"Gemma4ForConditionalGeneration": "gemma",
"Gemma4ForCausalLM": "gemma",
"Gemma4UnifiedForConditionalGeneration": "gemma",
"Gemma4UnifiedAssistantForCausalLM": "gemma",
"GemmaForCausalLM": "gemma",
"Glm4ForCausalLM": "glm",
"Glm4MoeForCausalLM": "glm",

View File

@@ -785,6 +785,16 @@ class Gemma4UnifiedModel(Gemma4Model):
self.gguf_writer.add_suppress_tokens(suppress_tokens)
@ModelBase.register("Gemma4AssistantForCausalLM", "Gemma4UnifiedAssistantForCausalLM")
class Gemma4AssistantModel(Gemma4Model):
model_arch = gguf.MODEL_ARCH.GEMMA4_ASSISTANT
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_embedding_length_out(self.hparams["backbone_hidden_size"])
self.gguf_writer.add_nextn_predict_layers(self.block_count)
@ModelBase.register("Gemma4ForConditionalGeneration")
class Gemma4VisionAudioModel(MmprojModel):
has_audio_encoder = True
@@ -812,10 +822,11 @@ class Gemma4VisionAudioModel(MmprojModel):
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("layer_norm_eps", 1e-6))
# audio params
assert self.hparams_audio is not None
self.gguf_writer.add_clip_audio_projector_type(gguf.VisionProjectorType.GEMMA4A)
self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["feat_in"])
self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams_audio.get("layer_norm_eps", 1e-6))
if self.has_audio_encoder:
assert self.hparams_audio is not None
self.gguf_writer.add_clip_audio_projector_type(gguf.VisionProjectorType.GEMMA4A)
self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["feat_in"])
self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams_audio.get("layer_norm_eps", 1e-6))
def is_audio_tensor(self, name: str) -> bool:
return "audio_tower" in name or "embed_audio" in name

View File

@@ -440,6 +440,7 @@ class MODEL_ARCH(IntEnum):
GEMMA3 = auto()
GEMMA3N = auto()
GEMMA4 = auto()
GEMMA4_ASSISTANT = auto()
GEMMA_EMBEDDING = auto()
STARCODER2 = auto()
RWKV6 = auto()
@@ -897,6 +898,8 @@ class MODEL_TENSOR(IntEnum):
A_PER_DIM_K_SCALE = auto() # gemma4
A_PER_DIM_SCALE = auto() # gemma4
# nextn/mtp
NEXTN_PROJ_PRE = auto()
NEXTN_PROJ_POST = auto()
NEXTN_EH_PROJ = auto()
NEXTN_EMBED_TOKENS = auto()
NEXTN_ENORM = auto()
@@ -986,6 +989,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.GEMMA3: "gemma3",
MODEL_ARCH.GEMMA3N: "gemma3n",
MODEL_ARCH.GEMMA4: "gemma4",
MODEL_ARCH.GEMMA4_ASSISTANT: "gemma4-assistant",
MODEL_ARCH.GEMMA_EMBEDDING: "gemma-embedding",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.RWKV6: "rwkv6",
@@ -1471,6 +1475,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.A_QF_FFN_DOWN: "a.proj_blk.{bid}.ffn_down",
MODEL_TENSOR.A_QF_FFN_NORM: "a.proj_blk.{bid}.ffn_norm",
# NextN/MTP
MODEL_TENSOR.NEXTN_PROJ_PRE: "nextn.pre_projection",
MODEL_TENSOR.NEXTN_PROJ_POST: "nextn.post_projection",
MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.nextn.eh_proj",
MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.nextn.embed_tokens",
MODEL_TENSOR.NEXTN_ENORM: "blk.{bid}.nextn.enorm",
@@ -2577,6 +2583,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.PER_LAYER_PROJ_NORM,
MODEL_TENSOR.PER_LAYER_POST_NORM,
],
MODEL_ARCH.GEMMA4_ASSISTANT: [
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.NEXTN_PROJ_PRE,
MODEL_TENSOR.NEXTN_PROJ_POST,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.FFN_PRE_NORM,
MODEL_TENSOR.FFN_POST_NORM,
MODEL_TENSOR.LAYER_OUT_SCALE,
],
MODEL_ARCH.GEMMA_EMBEDDING: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT,

View File

@@ -2367,6 +2367,14 @@ class TensorNameMap:
),
# NextN/MTP tensors
MODEL_TENSOR.NEXTN_PROJ_PRE: (
"pre_projection",
),
MODEL_TENSOR.NEXTN_PROJ_POST: (
"post_projection",
),
MODEL_TENSOR.NEXTN_EH_PROJ: (
"model.layers.{bid}.eh_proj",
),

View File

@@ -388,6 +388,10 @@ extern "C" {
// note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init)
struct llama_sampler_seq_config * samplers;
size_t n_samplers;
// a source/target/parent context
// can be utilized in various ways, for example by sharing results or llama_memory between 2 contexts
struct llama_context * ctx_other;
};
struct llama_model_tensor_override {

View File

@@ -57,6 +57,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_GEMMA3N, "gemma3n" },
{ LLM_ARCH_GEMMA4, "gemma4" },
{ LLM_ARCH_GEMMA4_ASSISTANT, "gemma4-assistant" },
{ LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
@@ -453,6 +454,8 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
{ LLM_TENSOR_FFN_NORM_EXPS, "blk.%d.ffn_norm_exps" },
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
{ LLM_TENSOR_NEXTN_PROJ_PRE, "nextn.pre_projection" },
{ LLM_TENSOR_NEXTN_PROJ_POST, "nextn.post_projection" },
{ LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" },
{ LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" },
{ LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" },
@@ -765,6 +768,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_PROJ_PRE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_PROJ_POST, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
// NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the
// last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so
// the model loader doesn't fault on the block index.

View File

@@ -61,6 +61,7 @@ enum llm_arch {
LLM_ARCH_GEMMA3,
LLM_ARCH_GEMMA3N,
LLM_ARCH_GEMMA4,
LLM_ARCH_GEMMA4_ASSISTANT,
LLM_ARCH_GEMMA_EMBEDDING,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
@@ -557,6 +558,8 @@ enum llm_tensor {
LLM_TENSOR_INDEXER_PROJ,
LLM_TENSOR_INDEXER_ATTN_K,
LLM_TENSOR_INDEXER_ATTN_Q_B,
LLM_TENSOR_NEXTN_PROJ_PRE,
LLM_TENSOR_NEXTN_PROJ_POST,
LLM_TENSOR_NEXTN_EH_PROJ,
LLM_TENSOR_NEXTN_EMBED_TOKENS,
LLM_TENSOR_NEXTN_ENORM,

View File

@@ -69,9 +69,10 @@ llama_context::llama_context(
cparams.embeddings_nextn_masked = false;
cparams.offload_kqv = params.offload_kqv;
cparams.no_perf = params.no_perf;
cparams.pooling_type = params.pooling_type;
cparams.warmup = false;
cparams.ctx_type = params.ctx_type;
cparams.pooling_type = params.pooling_type;
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
@@ -84,7 +85,17 @@ llama_context::llama_context(
cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.ctx_type = params.ctx_type;
cparams.ctx_other = nullptr;
// TODO: more generic
if (model.arch == LLM_ARCH_GEMMA4_ASSISTANT) {
if (params.ctx_other == nullptr) {
// TODO: change from runtime_error to llama_exception to avoid printing error message
throw std::runtime_error("Gemma4Assistant requires ctx_other to be set (this is normal during memory fitting)");
}
cparams.ctx_other = params.ctx_other;
}
// Initialize backend samplers here so they are part of the sampling graph
// before the reserve passes run later in this function. This avoids a later
@@ -300,10 +311,11 @@ llama_context::llama_context(
// init the memory module
if (!hparams.vocab_only) {
llama_memory_params params_mem = {
/*.type_k =*/ params.type_k,
/*.type_v =*/ params.type_v,
/*.swa_full =*/ params.swa_full,
/*.ctx_type= */ cparams.ctx_type,
/*.type_k =*/ params.type_k,
/*.type_v =*/ params.type_v,
/*.swa_full =*/ params.swa_full,
/*.ctx_type =*/ cparams.ctx_type,
/*.mem_other =*/ llama_get_memory(cparams.ctx_other),
};
memory.reset(model.create_memory(params_mem, cparams));
@@ -904,7 +916,7 @@ float * llama_context::get_embeddings_nextn_ith(int32_t i) {
throw std::runtime_error("no nextn embeddings");
}
const uint32_t n_embd = model.hparams.n_embd;
const uint32_t n_embd = model.hparams.n_embd_out();
if (!cparams.embeddings_nextn_masked) {
// unmasked: nextn rows are stored densely, indexed by raw token position.
@@ -1473,7 +1485,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn);
GGML_ASSERT(backend_h != nullptr);
const uint32_t n_embd = hparams.n_embd;
const uint32_t n_embd = hparams.n_embd_out();
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_nextn.size);
ggml_backend_tensor_get_async(backend_h, t_h_nextn, embd_nextn.data, 0, n_tokens*n_embd*sizeof(float));
}
@@ -1924,7 +1936,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn);
GGML_ASSERT(backend_h != nullptr);
const uint32_t n_embd = hparams.n_embd;
const uint32_t n_embd = hparams.n_embd_out();
float * embd_nextn_out = embd_nextn.data + offset*n_embd;
GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_nextn.size);
@@ -2017,7 +2029,6 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
const auto n_batch = cparams.n_batch;
const auto n_vocab = vocab.n_tokens();
const auto n_embd = hparams.n_embd;
const auto n_embd_out = hparams.n_embd_out();
bool has_logits = true;
@@ -2036,12 +2047,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
logits.size = has_logits ? n_vocab*n_outputs_max : 0;
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
embd_nextn.size = has_embd_nextn ? n_embd*n_outputs_max : 0;
embd_nextn.size = has_embd_nextn ? n_embd_out*n_outputs_max : 0;
if (has_embd_nextn && !cparams.embeddings_nextn_masked) {
// unmasked: nextn row exists for every token in the batch, not just
// those flagged via batch.logits[i] -> size by token count instead.
embd_nextn.size = (size_t) n_embd * n_batch;
embd_nextn.size = (size_t) n_embd_out * n_batch;
}
// Allocate backend sampling output buffers if there are backend samplers configured.
@@ -3375,6 +3386,7 @@ llama_context_params llama_context_default_params() {
/*.kv_unified =*/ false,
/*.sampler =*/ nullptr,
/*.n_sampler =*/ 0,
/*.ctx_other =*/ nullptr,
};
return result;
@@ -3454,7 +3466,6 @@ llama_context * llama_init_from_model(
return nullptr;
}
try {
auto * ctx = new llama_context(*model, params);
return ctx;
@@ -3593,6 +3604,14 @@ void llama_set_embeddings_nextn(llama_context * ctx, bool value, bool masked) {
ctx->set_embeddings_nextn(value, masked);
}
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
if (!ctx) {
return nullptr;
}
return ctx->get_memory();
}
float * llama_get_embeddings_nextn(llama_context * ctx) {
ctx->synchronize();
@@ -3656,7 +3675,7 @@ struct ggml_cgraph * llama_graph_reserve(
uint32_t n_tokens,
uint32_t n_seqs,
uint32_t n_outputs) {
auto * memory = ctx->get_memory();
auto memory = ctx->get_memory();
llama_memory_context_ptr mctx;
if (memory) {
mctx = memory->init_full();
@@ -3696,10 +3715,6 @@ int32_t llama_set_adapter_cvec(
// memory
//
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
return ctx->get_memory();
}
void llama_memory_clear(llama_memory_t mem, bool data) {
if (!mem) {
return;
@@ -4010,3 +4025,7 @@ void llama_opt_epoch(
llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx) {
return ctx->memory_breakdown();
}
llama_context * llama_get_ctx_other(struct llama_context * ctx) {
return ctx->get_cparams().ctx_other;
}

View File

@@ -6,6 +6,7 @@
#include "llama-graph.h"
#include "llama-adapter.h"
#include "llama-impl.h"
#include "llama-memory.h"
#include "ggml-cpp.h"
#include "ggml-opt.h"
@@ -273,7 +274,7 @@ private:
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
std::unique_ptr<llama_memory_i> memory;
llama_memory_ptr memory;
// decode output (2-dimensional array: [n_outputs][n_vocab])
buffer_view<float> logits = {nullptr, 0};

View File

@@ -49,4 +49,6 @@ struct llama_cparams {
ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data;
llama_context * ctx_other;
};

View File

@@ -100,3 +100,5 @@ LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx);
// LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
LLAMA_API float * llama_get_embeddings_nextn_ith(struct llama_context * ctx, int32_t i);
LLAMA_API llama_context * llama_get_ctx_other(struct llama_context * ctx);

View File

@@ -397,7 +397,7 @@ static void print_mask(const T * data, int64_t n_tokens, int64_t n_kv, int64_t n
case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
};
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swa_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
@@ -565,18 +565,18 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
if (self_k_idxs && self_k_idxs->buffer) {
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
// swa tensors may not be allocated if there are no SWA attention layers
if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
}
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
if (self_k_rot) {
mctx->get_base()->set_input_k_rot(self_k_rot);
}
@@ -605,18 +605,18 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
if (self_k_idxs && self_k_idxs->buffer) {
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
}
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
// swa tensors may not be allocated if there are no SWA attention layers
if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
}
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
return res;
}
@@ -756,7 +756,9 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
}
if (inp_attn->self_kq_mask && inp_attn->self_kq_mask->buffer) {
attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
}
@@ -764,7 +766,9 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch);
attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch);
}
if (inp_attn->self_kq_mask_swa && inp_attn->self_kq_mask_swa->buffer) {
attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
}
@@ -810,18 +814,18 @@ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params)
if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams);
}
res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams);
// swa tensors may not be allocated if there are no SWA attention layers
if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
//res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams);
}
res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams);
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
@@ -1006,6 +1010,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
ubatch (params.ubatch),
n_embd (hparams.n_embd),
n_layer (hparams.n_layer()),
n_layer_nextn (hparams.n_layer_nextn),
n_rot (hparams.n_rot()),
n_ctx (cparams.n_ctx),
n_head (hparams.n_head()),

View File

@@ -784,6 +784,7 @@ struct llm_graph_context {
const int64_t n_embd;
const int64_t n_layer;
const int64_t n_layer_nextn;
const int64_t n_rot;
const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
const int64_t n_head;

View File

@@ -91,6 +91,10 @@ uint32_t llama_hparams::n_rot(uint32_t il) const {
}
uint32_t llama_hparams::n_embd_inp() const {
if (n_embd_inp_impl > 0) {
return n_embd_inp_impl;
}
uint32_t n_embd_inp = n_embd;
if (n_deepstack_layers > 0) {

View File

@@ -185,6 +185,9 @@ struct llama_hparams {
// for Classifiers
uint32_t n_cls_out = 1;
// input embedding dimension (0 = use n_embd)
uint32_t n_embd_inp_impl = 0;
// output embedding dimension (0 = use n_embd)
uint32_t n_embd_out_impl = 0;
@@ -224,6 +227,7 @@ struct llama_hparams {
// complex mapping. If using deepstack_mapping_arr, also make sure to set
// n_deepstack_layers to the number of unique deepstack layers so that
// n_embd_imp is accurate (see granite.cpp).
// TODO: can be expressed via the `new n_embd_inp_impl` and remove this param
uint32_t n_deepstack_layers = 0;
// deepstack layer array (Granite4 Vision)

View File

@@ -32,7 +32,7 @@ llama_kv_cache_dsa::llama_kv_cache_dsa(
kv_mla = std::make_unique<llama_kv_cache>(
model, model.hparams, type_k, type_v,
v_trans, offload, unified, kv_size, n_seq_max, n_pad,
n_swa, swa_type, filter, reuse);
n_swa, swa_type, nullptr, filter, reuse, nullptr);
// we use llama_kv_cache for caching indexer keys
// by hand-tweaking some hparams we fool it to create
@@ -49,7 +49,7 @@ llama_kv_cache_dsa::llama_kv_cache_dsa(
kv_lid = std::make_unique<llama_kv_cache>(
model, hparams_lid, type_k, type_v,
v_trans, offload, unified, kv_size, n_seq_max, n_pad,
n_swa, swa_type, filter, reuse);
n_swa, swa_type, nullptr, filter, reuse, nullptr);
}
void llama_kv_cache_dsa::clear(bool data) {

View File

@@ -23,8 +23,10 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad,
llama_memory_t mem_other,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {
const layer_reuse_cb & reuse,
const layer_share_cb & share) : hparams(model.hparams), unified(unified) {
// chain filters
const layer_filter_cb filter_base = [&](int32_t il) {
@@ -59,17 +61,27 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
llama_memory_t mem_other_base = nullptr;
if (mem_other) {
mem_other_base = static_cast<llama_kv_cache_iswa *>(mem_other)->get_base();
}
llama_memory_t mem_other_swa = nullptr;
if (mem_other) {
mem_other_swa = static_cast<llama_kv_cache_iswa *>(mem_other)->get_swa();
}
kv_base = std::make_unique<llama_kv_cache>(
model, hparams, type_k, type_v,
v_trans, offload, unified, size_base, n_seq_max, n_pad,
0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
0, LLAMA_SWA_TYPE_NONE, mem_other_base, filter_base, reuse, share);
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
kv_swa = std::make_unique<llama_kv_cache>(
model, hparams, type_k, type_v,
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
hparams.n_swa, hparams.swa_type, filter_swa, reuse);
hparams.n_swa, hparams.swa_type, mem_other_swa, filter_swa, reuse, share);
}
void llama_kv_cache_iswa::clear(bool data) {

View File

@@ -25,8 +25,10 @@ public:
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad,
llama_memory_t mem_other,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse);
const layer_reuse_cb & reuse,
const layer_share_cb & share);
~llama_kv_cache_iswa() = default;

View File

@@ -90,8 +90,10 @@ llama_kv_cache::llama_kv_cache(
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
llama_memory_t mem_other,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse) :
const layer_reuse_cb & reuse,
const layer_share_cb & share) :
model(model), hparams(hparams), v_trans(v_trans),
n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
@@ -160,6 +162,8 @@ llama_kv_cache::llama_kv_cache(
const bool is_mla = hparams.is_mla();
other = static_cast<llama_kv_cache *>(mem_other);
for (uint32_t il = 0; il < n_layer; il++) {
if (!hparams.has_kv(il)) {
LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
@@ -171,6 +175,24 @@ llama_kv_cache::llama_kv_cache(
continue;
}
if (share && other) {
const int32_t il_share = share(il);
if (il_share >= 0) {
const auto & layer_share = other->layers[other->map_layer_ids[il_share]];
LLAMA_LOG_WARN("%s: layer %3d: sharing with layer %d. k = %p, v = %p\n", __func__, il, il_share,
layer_share.k->data, layer_share.v->data);
map_layer_ids[il] = layers.size();
layers.push_back(layer_share);
layers.back().il = il;
continue;
}
}
if (n_embd_head_k_all == 0) {
n_embd_head_k_all = (int32_t) hparams.n_embd_head_k(il);
} else if (n_embd_head_k_all > 0 && n_embd_head_k_all != (int32_t) hparams.n_embd_head_k(il)) {
@@ -282,29 +304,38 @@ llama_kv_cache::llama_kv_cache(
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
}
const char * LLAMA_ATTN_ROT_DISABLE = getenv("LLAMA_ATTN_ROT_DISABLE");
const bool attn_rot_disable = LLAMA_ATTN_ROT_DISABLE ? atoi(LLAMA_ATTN_ROT_DISABLE) : false;
if (attn_rot_disable) {
LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__);
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
n_embd_head_k_all = other->n_embd_head_k_all;
n_embd_head_v_all = other->n_embd_head_v_all;
attn_rot_k = other->attn_rot_k;
attn_rot_v = other->attn_rot_v;
} else {
const char * LLAMA_ATTN_ROT_DISABLE = getenv("LLAMA_ATTN_ROT_DISABLE");
const bool attn_rot_disable = LLAMA_ATTN_ROT_DISABLE ? atoi(LLAMA_ATTN_ROT_DISABLE) : false;
if (attn_rot_disable) {
LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__);
}
attn_rot_k =
!attn_rot_disable &&
n_embd_head_k_all > 0 &&
ggml_is_quantized(type_k) &&
hparams.n_embd_head_k() % 64 == 0;
// always create Hadamard rotation tensors for DeepSeek V3.2 DSA lightning indexer
if (model.arch == LLM_ARCH_DEEPSEEK32 && hparams.n_embd_head_k_full == hparams.indexer_head_size) {
attn_rot_k = true;
}
attn_rot_v =
!attn_rot_disable &&
n_embd_head_v_all > 0 &&
ggml_is_quantized(type_v) &&
hparams.n_embd_head_v() % 64 == 0;
}
attn_rot_k =
!attn_rot_disable &&
n_embd_head_k_all > 0 &&
ggml_is_quantized(type_k) &&
hparams.n_embd_head_k() % 64 == 0;
// always create Hadamard rotation tensors for DeepSeek V3.2 DSA lightning indexer
if (model.arch == LLM_ARCH_DEEPSEEK32 && hparams.n_embd_head_k_full == hparams.indexer_head_size) {
attn_rot_k = true;
}
attn_rot_v =
!attn_rot_disable &&
n_embd_head_v_all > 0 &&
ggml_is_quantized(type_v) &&
hparams.n_embd_head_v() % 64 == 0;
LLAMA_LOG_INFO("%s: attn_rot_k = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_k, n_embd_head_k_all);
LLAMA_LOG_INFO("%s: attn_rot_v = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_v, n_embd_head_v_all);
@@ -347,6 +378,11 @@ void llama_kv_cache::clear(bool data) {
}
bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return true;
}
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
if (p0 < 0) {
@@ -410,6 +446,11 @@ bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
}
void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return;
}
GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size());
GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size());
@@ -497,6 +538,11 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
}
void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return;
}
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
auto & cells = v_cells[seq_to_stream[seq_id]];
@@ -519,6 +565,11 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
}
void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return;
}
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1");
@@ -564,6 +615,11 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll
}
void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return;
}
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1");
@@ -598,6 +654,11 @@ void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, in
}
llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return other->seq_pos_min(seq_id);
}
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
const auto & cells = v_cells[seq_to_stream[seq_id]];
@@ -606,6 +667,11 @@ llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const {
}
llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return other->seq_pos_max(seq_id);
}
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
const auto & cells = v_cells[seq_to_stream[seq_id]];
@@ -746,6 +812,11 @@ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_
}
bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return true;
}
bool updated = false;
auto * sched = lctx->get_sched();
@@ -1021,6 +1092,12 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch,
}
void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
v_cells = other->v_cells;
return;
}
// keep track of the max sequence position that we would overwrite with this ubatch
// for non-SWA cache, this would be always empty
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
@@ -1815,6 +1892,9 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
}
ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
GGML_ASSERT(!other);
auto * ctx = res->get_ctx();
auto * gf = res->get_gf();
@@ -1860,6 +1940,11 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
}
void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return;
}
GGML_UNUSED(flags);
io.write(&n_stream, sizeof(n_stream));
@@ -1925,6 +2010,11 @@ void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, lla
}
void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return;
}
GGML_UNUSED(flags);
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));

View File

@@ -98,7 +98,7 @@ public:
// likely through `struct llama_memory_params`
llama_kv_cache(
const llama_model & model,
const llama_hparams & hparams,
const llama_hparams & hparams,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
@@ -109,8 +109,10 @@ public:
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
llama_memory_t mem_other,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse);
const layer_reuse_cb & reuse,
const layer_share_cb & share);
~llama_kv_cache() = default;
@@ -264,6 +266,9 @@ private:
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
std::vector<uint32_t> v_heads;
// TODO: temporary until we refactor to be able to share the same cells between 2 kv caches [TAG_KV_CACHE_SHARE_CELLS]
llama_kv_cache * other;
std::vector<llama_kv_cells> v_cells;
// maps from a sequence id to a stream id

View File

@@ -43,9 +43,11 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa(
n_seq_max,
n_ubatch,
n_pad,
nullptr,
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recr(il); }
: filter_attn,
nullptr,
nullptr
)),
mem_recr(new llama_memory_recurrent(

View File

@@ -44,9 +44,11 @@ llama_memory_hybrid::llama_memory_hybrid(
n_pad,
n_swa,
swa_type,
nullptr,
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recr(il); }
: filter_attn,
nullptr,
nullptr
)),
mem_recr(new llama_memory_recurrent(

View File

@@ -23,6 +23,8 @@ struct llama_memory_params {
bool swa_full;
llama_context_type ctx_type;
llama_memory_t mem_other;
};
enum llama_memory_status {
@@ -76,6 +78,8 @@ struct llama_memory_i {
// return negative value to indicate that the layer il should not reuse memory
using layer_reuse_cb = std::function<int32_t(int32_t il)>;
using layer_share_cb = std::function<int32_t(int32_t il)>;
virtual ~llama_memory_i() = default;
// split the input batch into a set of ubatches and verify that they can fit into the cache

View File

@@ -139,6 +139,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params
return new llama_model_gemma3n(params);
case LLM_ARCH_GEMMA4:
return new llama_model_gemma4(params);
case LLM_ARCH_GEMMA4_ASSISTANT:
return new llama_model_gemma4_assistant(params);
case LLM_ARCH_GEMMA_EMBEDDING:
return new llama_model_gemma_embedding(params);
case LLM_ARCH_STARCODER2:
@@ -1717,19 +1719,21 @@ void llama_model::print_info() const {
if (!hparams.vocab_only) {
LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train);
LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd);
LLAMA_LOG_INFO("%s: n_embd_inp = %u\n", __func__, hparams.n_embd_inp());
LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd);
LLAMA_LOG_INFO("%s: n_embd_out = %u\n", __func__, hparams.n_embd_out());
LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer());
LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer()).c_str());
LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer()).c_str());
LLAMA_LOG_INFO("%s: n_layer_all = %u\n", __func__, hparams.n_layer_all);
LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer_all).c_str());
LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer_all).c_str());
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot_full);
LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any());
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k_full);
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v_full);
LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer()).c_str());
LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer()).c_str());
LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer()).c_str());
LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer_all).c_str());
LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer_all).c_str());
LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer_all).c_str());
LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps);
LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps);
LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv);
@@ -1737,7 +1741,7 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale);
LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale);
LLAMA_LOG_INFO("%s: f_attn_value_scale = %.4f\n", __func__, hparams.f_attn_value_scale);
LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer()).c_str());
LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer_all).c_str());
LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups);
@@ -1764,7 +1768,7 @@ void llama_model::print_info() const {
[](const auto & entry) { return entry >= 0; })) {
LLAMA_LOG_INFO("%s: deepstack_mapping_arr = %s\n", __func__,
print_f([&](uint32_t il) { return hparams.deepstack_mapping_arr[il]; },
hparams.n_layer()).c_str());
hparams.n_layer_all).c_str());
}
// MRoPE (Multi-axis Rotary Position Embedding) sections
if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) {
@@ -2113,8 +2117,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* filter_recr */ std::move(filter_recr));
}
} else {
llama_memory_i::layer_reuse_cb reuse = nullptr;
llama_kv_cache::layer_filter_cb filter = nullptr;
llama_memory_i::layer_reuse_cb reuse = nullptr;
llama_kv_cache::layer_share_cb share = nullptr;
if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) {
reuse = [&](uint32_t il) {
@@ -2143,20 +2148,53 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
GGML_ASSERT(hparams.is_swa_any());
res = new llama_kv_cache_iswa(
*this,
params.type_k,
params.type_v,
!cparams.flash_attn,
cparams.offload_kqv,
params.swa_full,
cparams.kv_unified,
cparams.n_ctx_seq,
cparams.n_seq_max,
cparams.n_ubatch,
1,
filter,
reuse);
if (arch == LLM_ARCH_GEMMA4_ASSISTANT) {
llama_memory_t mem_other = llama_get_memory(cparams.ctx_other);
share = [&](int32_t il) {
const llama_model * model_other = llama_get_model(cparams.ctx_other);
if (hparams.is_swa(il)) {
return llama_model_n_layer(model_other) - 2;
}
return llama_model_n_layer(model_other) - 1;
};
res = new llama_kv_cache_iswa(
*this,
params.type_k,
params.type_v,
!cparams.flash_attn,
cparams.offload_kqv,
params.swa_full,
cparams.kv_unified,
cparams.n_ctx_seq,
cparams.n_seq_max,
cparams.n_ubatch,
1,
mem_other,
filter,
reuse,
share);
} else {
res = new llama_kv_cache_iswa(
*this,
params.type_k,
params.type_v,
!cparams.flash_attn,
cparams.offload_kqv,
params.swa_full,
cparams.kv_unified,
cparams.n_ctx_seq,
cparams.n_seq_max,
cparams.n_ubatch,
1,
nullptr,
filter,
reuse,
share);
}
} else {
GGML_ASSERT(!hparams.is_swa_any());
@@ -2173,7 +2211,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1,
hparams.n_swa,
hparams.swa_type,
nullptr,
filter,
nullptr,
nullptr);
}
}
@@ -2406,6 +2446,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_GEMMA3:
case LLM_ARCH_GEMMA3N:
case LLM_ARCH_GEMMA4:
case LLM_ARCH_GEMMA4_ASSISTANT:
case LLM_ARCH_GEMMA_EMBEDDING:
case LLM_ARCH_STARCODER2:
case LLM_ARCH_OPENELM:

View File

@@ -548,6 +548,10 @@ struct llama_model {
struct ggml_tensor * output_s = nullptr;
struct ggml_tensor * output_in_s = nullptr;
// NextN/MTP model-level projections
struct ggml_tensor * nextn_proj_pre = nullptr;
struct ggml_tensor * nextn_proj_post = nullptr;
// classifier
struct ggml_tensor * cls = nullptr;
struct ggml_tensor * cls_b = nullptr;
@@ -702,6 +706,7 @@ const char * llm_type_name(llm_type type);
#define LLAMA_LOAD_LOCALS \
const int n_layer = hparams.n_layer(); GGML_UNUSED(n_layer); \
const int n_layer_all = hparams.n_layer_all; GGML_UNUSED(n_layer_all); \
const int n_layer_nextn = hparams.n_layer_nextn; GGML_UNUSED(n_layer_nextn); \
const int64_t n_head = hparams.n_head(); GGML_UNUSED(n_head); \
const int64_t n_head_kv = hparams.n_head_kv(); GGML_UNUSED(n_head_kv); \
const int64_t n_embd = hparams.n_embd; GGML_UNUSED(n_embd); \

View File

@@ -0,0 +1,200 @@
#include "models.h"
void llama_model_gemma4_assistant::load_arch_hparams(llama_model_loader & ml) {
hparams.n_embd_inp_impl = hparams.n_embd_out();
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer());
uint32_t n_kv_shared_layers = 0;
ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false);
hparams.f_attention_scale = 1.0f;
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false);
GGML_ASSERT(hparams.n_layer_nextn == hparams.n_layer_all && "n_layer_nextn must be == n_layer_impl");
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa);
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa);
}
void llama_model_gemma4_assistant::load_arch_tensors(llama_model_loader &) {
LLAMA_LOAD_LOCALS;
if (n_embd_head_k != n_embd_head_v) {
throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k == n_embd_head_v");
}
if (hparams.n_embd_head_k_swa != hparams.n_embd_head_v_swa) {
throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k_swa == n_embd_head_v_swa");
}
if (hparams.n_embd_out() == n_embd) {
throw std::runtime_error("Gemma 4 assistant requires embedding_length_out to carry the target hidden size");
}
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
const int64_t n_embd_backbone = hparams.n_embd_inp();
nextn_proj_post = create_tensor(tn(LLM_TENSOR_NEXTN_PROJ_POST, "weight"), { n_embd, n_embd_backbone }, 0);
int rope_freqs_flag = 0;
for (int i = 0; i < n_layer_nextn; ++i) {
auto & layer = layers[i];
const int64_t n_head = hparams.n_head(i);
const int64_t n_embd_head = hparams.n_embd_head_k(i);
const int64_t n_ff = hparams.n_ff(i);
if (i == 0) {
nextn_proj_pre = create_tensor(tn(LLM_TENSOR_NEXTN_PROJ_PRE, "weight", i), { 2*n_embd_backbone, n_embd }, 0);
}
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head*n_head }, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head*n_head, n_embd }, 0);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head }, 0);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), { 1u }, 0);
if (!hparams.is_swa(i)) {
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_embd_head/2 }, rope_freqs_flag);
rope_freqs_flag = TENSOR_DUPLICATED;
}
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), { n_embd }, 0);
}
}
std::unique_ptr<llm_graph_context> llama_model_gemma4_assistant::build_arch_graph(const llm_graph_params & params) const {
return std::make_unique<graph>(*this, params);
}
llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_backbone = hparams.n_embd_inp();
ggml_tensor * inp_tokens;
ggml_tensor * inp_h;
{
auto inp = std::make_unique<llm_graph_input_embd>(n_embd_backbone);
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
cb(inp->tokens, "inp_tokens", -1);
ggml_set_input(inp->tokens);
inp_tokens = inp->tokens;
res->t_inp_tokens = inp->tokens;
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_backbone, ubatch.n_tokens);
cb(inp->embd, "inp_h", -1);
ggml_set_input(inp->embd);
inp_h = inp->embd;
res->t_inp_embd = inp->embd;
res->add_input(std::move(inp));
}
GGML_ASSERT(cparams.ctx_other != nullptr);
const auto * model_other = llama_get_model(cparams.ctx_other);
ggml_tensor * x = ggml_get_rows(ctx0, model_other->tok_embd, inp_tokens);
x = ggml_scale(ctx0, x, sqrtf((float) n_embd_backbone));
cb(x, "inp_embd_target", -1);
ggml_tensor * xh = ggml_concat(ctx0, x, inp_h, 0);
cb(xh, "inp_xh", -1);
ggml_tensor * cur = ggml_mul_mat(ctx0, model.nextn_proj_pre, xh);
cb(cur, "pre_proj", -1);
auto * inp_attn = build_attn_inp_kv_iswa();
ggml_tensor * inp_pos = build_inp_pos();
ggml_tensor * inp_out_ids = build_inp_out_ids();
ggml_tensor * inpL = cur;
for (int il = 0; il < n_layer_nextn; ++il) {
const bool is_swa = hparams.is_swa(il);
const int64_t n_embd_head = hparams.n_embd_head_k(il);
const int64_t n_head = hparams.n_head(il);
const float freq_base_l = model.get_rope_freq_base(cparams, il);
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
const int n_rot_l = hparams.n_rot(il);
ggml_tensor * cur_norm = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
cb(cur_norm, "attn_norm", il);
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur_norm);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
cb(Qcur, "Qcur_normed", il);
ggml_tensor * freq_factors = is_swa ? nullptr : model.layers[il].rope_freqs;
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig,
freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur_pos", il);
cur = build_attn(inp_attn, model.layers[il].wo, nullptr, nullptr,
Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
if (il == n_layer_nextn - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
cur = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "attn_post_norm", il);
ggml_tensor * attn_out = ggml_add(ctx0, cur, inpL);
cb(attn_out, "attn_out", il);
cur = build_norm(attn_out, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_ffn(cur,
model.layers[il].ffn_up, nullptr, nullptr,
model.layers[il].ffn_gate, nullptr, nullptr,
model.layers[il].ffn_down, nullptr, nullptr,
nullptr,
LLM_FFN_GELU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
cur = build_norm(cur, model.layers[il].ffn_post_norm, nullptr, LLM_NORM_RMS, -1);
cb(cur, "ffn_post_norm", il);
cur = ggml_add(ctx0, cur, attn_out);
cur = ggml_mul(ctx0, cur, model.layers[il].out_scale);
cb(cur, "out_scaled", il);
inpL = cur;
}
cur = inpL;
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
ggml_tensor * logits = build_lora_mm(model.output, cur);
cb(logits, "result_output", -1);
res->t_logits = logits;
ggml_tensor * h_next = ggml_mul_mat(ctx0, model.nextn_proj_post, cur);
cb(h_next, "h_nextn", -1);
res->t_h_nextn = h_next;
ggml_build_forward_expand(gf, logits);
ggml_build_forward_expand(gf, h_next);
}

View File

@@ -155,12 +155,14 @@ public:
}
virtual ~llm_graph_input_logits_bias() = default;
void set_input(const llama_ubatch *) override {
void set_input(const llama_ubatch * /*ubatch*/) override {
const int64_t n_vocab = arr.size();
ggml_backend_tensor_set(logits_bias, arr.data(), 0, n_vocab*ggml_element_size(logits_bias));
}
// bool can_reuse(const llm_graph_params & params) override;
bool can_reuse(const llm_graph_params & /*params*/) override {
return true;
}
ggml_tensor * logits_bias = nullptr; // F32 [n_vocab]
@@ -270,7 +272,8 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para
}
// TODO @ngxson : strip unused token right after the last KV layer to speed up prompt processing
if (il == n_layer - 1 && inp_out_ids) {
// keep all rows when extracting unmasked nextn embeddings (MTP target needs the hidden state for every token)
if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
@@ -370,7 +373,7 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para
ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_per_layer, n_tokens]
// TODO @ngxson : improve this
if (il == n_layer - 1 && inp_out_ids) {
if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) {
inp_this_layer = ggml_get_rows(ctx0, inp_this_layer, inp_out_ids);
}
@@ -401,6 +404,17 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para
model.output_norm, nullptr,
LLM_NORM_RMS, -1);
// Expose the post-output-norm hidden state (the LM-head input feature) so that
// MTP draft contexts can read it via llama_get_embeddings_nextn_ith() as the
// recurrent h input. This matches the reference (transformers/vLLM/SGLang),
// which feeds the drafter the target's post-final-norm hidden state.
cb(cur, "h_nextn", -1);
res->t_h_nextn = cur;
if (!cparams.embeddings_nextn_masked && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
}
cb(cur, "result_norm", -1);
res->t_embd = cur;

View File

@@ -822,6 +822,19 @@ struct llama_model_gemma4 : public llama_model_base {
};
struct llama_model_gemma4_assistant : public llama_model_base {
llama_model_gemma4_assistant(const struct llama_model_params & params) : llama_model_base(params) {}
void load_arch_hparams(llama_model_loader & ml) override;
void load_arch_tensors(llama_model_loader & ml) override;
struct graph : public llm_graph_context {
graph(const llama_model & model, const llm_graph_params & params);
};
std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override;
};
struct llama_model_gemma_embedding : public llama_model_base {
llama_model_gemma_embedding(const struct llama_model_params & params) : llama_model_base(params) {}
void load_arch_hparams(llama_model_loader & ml) override;

View File

@@ -392,7 +392,7 @@ static bool arch_supported(const llm_arch arch) {
if (arch == LLM_ARCH_WAVTOKENIZER_DEC) {
return false; // FIXME CUDA backend crashes.
}
if (arch == LLM_ARCH_GEMMA4) {
if (arch == LLM_ARCH_GEMMA4 || arch == LLM_ARCH_GEMMA4_ASSISTANT) {
return false; // FIXME @ngxson
}
if (arch == LLM_ARCH_LLAMA_EMBED || arch == LLM_ARCH_GEMMA_EMBEDDING || arch == LLM_ARCH_T5ENCODER) {
@@ -447,7 +447,7 @@ static int save_models(const llm_arch target_arch, const size_t seed, const ggml
if (target_arch != LLM_ARCH_UNKNOWN && arch != target_arch) {
continue;
}
if (arch == LLM_ARCH_GEMMA4) {
if (arch == LLM_ARCH_GEMMA4 || arch == LLM_ARCH_GEMMA4_ASSISTANT) {
continue; // FIXME: ISWA KV cache initialization needs more fixture params
}
for (bool moe : {false, true}) {
@@ -550,7 +550,7 @@ static int test_backends(const llm_arch target_arch, const size_t seed, const gg
if (target_arch != LLM_ARCH_UNKNOWN && arch != target_arch) {
continue;
}
if (arch == LLM_ARCH_GEMMA4) {
if (arch == LLM_ARCH_GEMMA4 || arch == LLM_ARCH_GEMMA4_ASSISTANT) {
continue; // FIXME: ISWA KV cache initialization needs more fixture params
}

View File

@@ -1,4 +1,3 @@
#include "server-context.h"
#include "server-chat.h"
#include "server-common.h"
@@ -16,6 +15,11 @@
#include "mtmd.h"
#include "mtmd-helper.h"
#include "ggml-cpp.h"
// TODO: tmp until the mtmd draft processing is refactored [TAG_MTMD_DRAFT_PROCESSING]
#include "../../src/llama-ext.h"
#include <algorithm>
#include <cstddef>
#include <cinttypes>
@@ -884,7 +888,7 @@ private:
has_draft ? "draft model" : "MTP context",
total / (1024.0 * 1024.0));
} catch (const std::exception & e) {
SRV_ERR("[spec] failed to measure %s memory: %s\n",
SRV_WRN("[spec] failed to measure %s memory: %s\n",
has_draft ? "draft model" : "MTP context", e.what());
}
}
@@ -940,16 +944,17 @@ private:
const bool spec_mtp = std::find(params_base.speculative.types.begin(),
params_base.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end();
if (spec_mtp) {
cparams.ctx_type = LLAMA_CONTEXT_TYPE_MTP;
}
// note: for small models maybe we can set this to the maximum possible draft from all speculative types
// the extra memory for small models is likely negligible?
cparams.n_rs_seq = 0;
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
cparams.n_rs_seq = 0;
cparams.ctx_other = ctx_tgt;
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
params_base.speculative.draft.ctx_tgt = ctx_tgt;
params_base.speculative.draft.ctx_dft = ctx_dft.get();
@@ -964,6 +969,7 @@ private:
cparams_mtp.type_v = params_base.speculative.draft.cache_type_v;
cparams_mtp.n_rs_seq = 0;
cparams_mtp.n_outputs_max = params_base.n_parallel;
cparams_mtp.ctx_other = ctx_tgt;
ctx_dft.reset(llama_init_from_model(model_tgt, cparams_mtp));
if (ctx_dft == nullptr) {
@@ -971,8 +977,6 @@ private:
return false;
}
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
params_base.speculative.draft.ctx_tgt = ctx_tgt;
params_base.speculative.draft.ctx_dft = ctx_dft.get();
}
@@ -1060,6 +1064,10 @@ private:
}
}
if (ctx_dft) {
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
}
if (spec) {
SRV_INF("%s", "speculative decoding context initialized\n");
} else {
@@ -2974,10 +2982,11 @@ private:
continue;
}
if (ctx_dft) {
if (ctx_dft && llama_get_ctx_other(ctx_dft.get()) != ctx_tgt) {
// TODO: in the future, figure out how to infuse target embeddings to the images
// for now, we skip this for simplicity
// maybe we simply need to call `common_speculative_process()` on the mtmd batches in the `process_chunk` above?
// [TAG_MTMD_DRAFT_PROCESSING]
res = input_tokens.process_chunk(ctx_dft.get(), mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
if (res != 0) {
GGML_ABORT("failed to process multi-modal data on draft context\n");