mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-08 20:12:57 +02:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
04eb4c446d |
@@ -3,13 +3,14 @@
|
||||
#include "common.h"
|
||||
#include "ggml.h"
|
||||
#include "llama.h"
|
||||
#include "../src/llama-ext.h" // staging API: llama_set_embeddings_nextn / llama_get_embeddings_nextn_ith (used by MTP)
|
||||
#include "log.h"
|
||||
#include "ngram-cache.h"
|
||||
#include "ngram-map.h"
|
||||
#include "ngram-mod.h"
|
||||
#include "sampling.h"
|
||||
|
||||
#include "../src/llama-ext.h" // staging API: llama_set_embeddings_nextn / llama_get_embeddings_nextn_ith (used by MTP)
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cstring>
|
||||
@@ -418,6 +419,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
|
||||
int32_t n_embd = 0;
|
||||
|
||||
bool is_mem_shared = false;
|
||||
|
||||
// Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
|
||||
// The last h-row of one process() call needs the first token of the NEXT
|
||||
// call to pair with, so it's stashed here until that next call fires.
|
||||
@@ -444,7 +447,9 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
auto * ctx_dft = this->params.ctx_dft;
|
||||
GGML_ASSERT(ctx_tgt && ctx_dft && "MTP requires ctx_tgt and ctx_dft to be set");
|
||||
|
||||
n_embd = llama_model_n_embd(llama_get_model(ctx_dft));
|
||||
n_embd = llama_model_n_embd_out(llama_get_model(ctx_dft));
|
||||
GGML_ASSERT(n_embd == llama_model_n_embd(llama_get_model(ctx_tgt)) &&
|
||||
"MTP input row width must match the target h_nextn width");
|
||||
|
||||
LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
|
||||
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
|
||||
@@ -490,6 +495,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
llama_set_embeddings_nextn(ctx_tgt, true, /*masked*/ false);
|
||||
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
|
||||
|
||||
is_mem_shared = llama_get_ctx_other(ctx_dft) == ctx_tgt;
|
||||
|
||||
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
|
||||
|
||||
i_batch_beg.assign(n_seq, -1);
|
||||
@@ -526,9 +533,11 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
if (N <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto * ctx_dft = this->params.ctx_dft;
|
||||
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
|
||||
if (pos_max < N - 1) {
|
||||
|
||||
if (pos_max < N - 1 && !is_mem_shared) {
|
||||
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - "
|
||||
"process() hook may not have run on every prefill ubatch "
|
||||
"(need_embd / logits=1 on every prompt position?). "
|
||||
@@ -571,48 +580,42 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
|
||||
const size_t row_bytes = (size_t) n_embd * sizeof(float);
|
||||
|
||||
common_batch_clear(batch);
|
||||
// if kv is shared with target (e.g Gemma4), then we can skip this catch-up decode
|
||||
if (!is_mem_shared) {
|
||||
common_batch_clear(batch);
|
||||
|
||||
for (int k = 0; k < n_tokens; ++k) {
|
||||
common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0);
|
||||
}
|
||||
|
||||
// shift the tgt embeddings to the right by one position
|
||||
// assumes that the tokens in the batch are sequential for each sequence
|
||||
// i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1]
|
||||
// ^--- this is a problem
|
||||
// TODO:this is generally true, but would be nice to assert it
|
||||
{
|
||||
const float * h_tgt = llama_get_embeddings_nextn(ctx_tgt);
|
||||
std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1));
|
||||
|
||||
//{
|
||||
// // string with seq_ids in the batch
|
||||
// std::stringstream ss;
|
||||
// for (int i = 0; i < n_tokens; ++i) {
|
||||
// ss << batch_in.seq_id[i][0] << ",";
|
||||
// }
|
||||
// LOG_WRN("%s: batch_in.seq_id = %s\n", __func__, ss.str().c_str());
|
||||
//}
|
||||
}
|
||||
|
||||
// fill the pending embeddings from a previous run
|
||||
auto set_h = [&](int idx, const float * h_row) {
|
||||
std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes);
|
||||
};
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
if (i_batch_beg[seq_id] < 0) {
|
||||
continue;
|
||||
for (int k = 0; k < n_tokens; ++k) {
|
||||
common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0);
|
||||
}
|
||||
|
||||
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
|
||||
}
|
||||
// shift the tgt embeddings to the right by one position
|
||||
// assumes that the tokens in the batch are sequential for each sequence
|
||||
// i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1]
|
||||
// ^--- this is a problem
|
||||
// TODO:this is generally true, but would be nice to assert it
|
||||
{
|
||||
const float * h_tgt = llama_get_embeddings_nextn(ctx_tgt);
|
||||
std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1));
|
||||
}
|
||||
|
||||
const int32_t rc = llama_decode(ctx_dft, batch);
|
||||
if (rc != 0) {
|
||||
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
|
||||
return false;
|
||||
// fill the pending embeddings from a previous run
|
||||
auto set_h = [&](int idx, const float * h_row) {
|
||||
std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes);
|
||||
};
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
if (i_batch_beg[seq_id] < 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
|
||||
}
|
||||
|
||||
const int32_t rc = llama_decode(ctx_dft, batch);
|
||||
if (rc != 0) {
|
||||
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
@@ -721,7 +724,13 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
continue;
|
||||
}
|
||||
|
||||
common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
|
||||
if (is_mem_shared) {
|
||||
// note: with shared memory (e.g. Gemma4 assistants) we use the same position for all draft tokens
|
||||
// ref: https://github.com/huggingface/transformers/blob/effde20942e3f82a1b97449f60b3a48c5ff96145/docs/source/en/model_doc/gemma4_assistant.md?plain=1#L36-L37
|
||||
common_batch_add(batch, id, dp.n_past, { seq_id }, true);
|
||||
} else {
|
||||
common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
|
||||
}
|
||||
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
|
||||
}
|
||||
|
||||
|
||||
@@ -75,9 +75,11 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
||||
"Gemma3TextModel": "gemma",
|
||||
"Gemma3nForCausalLM": "gemma",
|
||||
"Gemma3nForConditionalGeneration": "gemma",
|
||||
"Gemma4AssistantForCausalLM": "gemma",
|
||||
"Gemma4ForConditionalGeneration": "gemma",
|
||||
"Gemma4ForCausalLM": "gemma",
|
||||
"Gemma4UnifiedForConditionalGeneration": "gemma",
|
||||
"Gemma4UnifiedAssistantForCausalLM": "gemma",
|
||||
"GemmaForCausalLM": "gemma",
|
||||
"Glm4ForCausalLM": "glm",
|
||||
"Glm4MoeForCausalLM": "glm",
|
||||
|
||||
@@ -785,6 +785,16 @@ class Gemma4UnifiedModel(Gemma4Model):
|
||||
self.gguf_writer.add_suppress_tokens(suppress_tokens)
|
||||
|
||||
|
||||
@ModelBase.register("Gemma4AssistantForCausalLM", "Gemma4UnifiedAssistantForCausalLM")
|
||||
class Gemma4AssistantModel(Gemma4Model):
|
||||
model_arch = gguf.MODEL_ARCH.GEMMA4_ASSISTANT
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_embedding_length_out(self.hparams["backbone_hidden_size"])
|
||||
self.gguf_writer.add_nextn_predict_layers(self.block_count)
|
||||
|
||||
|
||||
@ModelBase.register("Gemma4ForConditionalGeneration")
|
||||
class Gemma4VisionAudioModel(MmprojModel):
|
||||
has_audio_encoder = True
|
||||
|
||||
@@ -440,6 +440,7 @@ class MODEL_ARCH(IntEnum):
|
||||
GEMMA3 = auto()
|
||||
GEMMA3N = auto()
|
||||
GEMMA4 = auto()
|
||||
GEMMA4_ASSISTANT = auto()
|
||||
GEMMA_EMBEDDING = auto()
|
||||
STARCODER2 = auto()
|
||||
RWKV6 = auto()
|
||||
@@ -897,6 +898,8 @@ class MODEL_TENSOR(IntEnum):
|
||||
A_PER_DIM_K_SCALE = auto() # gemma4
|
||||
A_PER_DIM_SCALE = auto() # gemma4
|
||||
# nextn/mtp
|
||||
NEXTN_PROJ_PRE = auto()
|
||||
NEXTN_PROJ_POST = auto()
|
||||
NEXTN_EH_PROJ = auto()
|
||||
NEXTN_EMBED_TOKENS = auto()
|
||||
NEXTN_ENORM = auto()
|
||||
@@ -986,6 +989,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.GEMMA3: "gemma3",
|
||||
MODEL_ARCH.GEMMA3N: "gemma3n",
|
||||
MODEL_ARCH.GEMMA4: "gemma4",
|
||||
MODEL_ARCH.GEMMA4_ASSISTANT: "gemma4-assistant",
|
||||
MODEL_ARCH.GEMMA_EMBEDDING: "gemma-embedding",
|
||||
MODEL_ARCH.STARCODER2: "starcoder2",
|
||||
MODEL_ARCH.RWKV6: "rwkv6",
|
||||
@@ -1471,6 +1475,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.A_QF_FFN_DOWN: "a.proj_blk.{bid}.ffn_down",
|
||||
MODEL_TENSOR.A_QF_FFN_NORM: "a.proj_blk.{bid}.ffn_norm",
|
||||
# NextN/MTP
|
||||
MODEL_TENSOR.NEXTN_PROJ_PRE: "nextn.pre_projection",
|
||||
MODEL_TENSOR.NEXTN_PROJ_POST: "nextn.post_projection",
|
||||
MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.nextn.eh_proj",
|
||||
MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.nextn.embed_tokens",
|
||||
MODEL_TENSOR.NEXTN_ENORM: "blk.{bid}.nextn.enorm",
|
||||
@@ -2577,6 +2583,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.PER_LAYER_PROJ_NORM,
|
||||
MODEL_TENSOR.PER_LAYER_POST_NORM,
|
||||
],
|
||||
MODEL_ARCH.GEMMA4_ASSISTANT: [
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.NEXTN_PROJ_PRE,
|
||||
MODEL_TENSOR.NEXTN_PROJ_POST,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_POST_NORM,
|
||||
MODEL_TENSOR.FFN_PRE_NORM,
|
||||
MODEL_TENSOR.FFN_POST_NORM,
|
||||
MODEL_TENSOR.LAYER_OUT_SCALE,
|
||||
],
|
||||
MODEL_ARCH.GEMMA_EMBEDDING: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
|
||||
@@ -2367,6 +2367,14 @@ class TensorNameMap:
|
||||
),
|
||||
|
||||
# NextN/MTP tensors
|
||||
MODEL_TENSOR.NEXTN_PROJ_PRE: (
|
||||
"pre_projection",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.NEXTN_PROJ_POST: (
|
||||
"post_projection",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.NEXTN_EH_PROJ: (
|
||||
"model.layers.{bid}.eh_proj",
|
||||
),
|
||||
|
||||
@@ -388,6 +388,10 @@ extern "C" {
|
||||
// note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init)
|
||||
struct llama_sampler_seq_config * samplers;
|
||||
size_t n_samplers;
|
||||
|
||||
// a source/target/parent context
|
||||
// can be utilized in various ways, for example by sharing results or llama_memory between 2 contexts
|
||||
struct llama_context * ctx_other;
|
||||
};
|
||||
|
||||
struct llama_model_tensor_override {
|
||||
|
||||
@@ -57,6 +57,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_GEMMA3, "gemma3" },
|
||||
{ LLM_ARCH_GEMMA3N, "gemma3n" },
|
||||
{ LLM_ARCH_GEMMA4, "gemma4" },
|
||||
{ LLM_ARCH_GEMMA4_ASSISTANT, "gemma4-assistant" },
|
||||
{ LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" },
|
||||
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
||||
{ LLM_ARCH_MAMBA, "mamba" },
|
||||
@@ -453,6 +454,8 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
|
||||
{ LLM_TENSOR_FFN_NORM_EXPS, "blk.%d.ffn_norm_exps" },
|
||||
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
|
||||
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
|
||||
{ LLM_TENSOR_NEXTN_PROJ_PRE, "nextn.pre_projection" },
|
||||
{ LLM_TENSOR_NEXTN_PROJ_POST, "nextn.post_projection" },
|
||||
{ LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" },
|
||||
{ LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" },
|
||||
{ LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" },
|
||||
@@ -765,6 +768,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_NEXTN_PROJ_PRE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_NEXTN_PROJ_POST, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
// NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the
|
||||
// last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so
|
||||
// the model loader doesn't fault on the block index.
|
||||
|
||||
@@ -61,6 +61,7 @@ enum llm_arch {
|
||||
LLM_ARCH_GEMMA3,
|
||||
LLM_ARCH_GEMMA3N,
|
||||
LLM_ARCH_GEMMA4,
|
||||
LLM_ARCH_GEMMA4_ASSISTANT,
|
||||
LLM_ARCH_GEMMA_EMBEDDING,
|
||||
LLM_ARCH_STARCODER2,
|
||||
LLM_ARCH_MAMBA,
|
||||
@@ -557,6 +558,8 @@ enum llm_tensor {
|
||||
LLM_TENSOR_INDEXER_PROJ,
|
||||
LLM_TENSOR_INDEXER_ATTN_K,
|
||||
LLM_TENSOR_INDEXER_ATTN_Q_B,
|
||||
LLM_TENSOR_NEXTN_PROJ_PRE,
|
||||
LLM_TENSOR_NEXTN_PROJ_POST,
|
||||
LLM_TENSOR_NEXTN_EH_PROJ,
|
||||
LLM_TENSOR_NEXTN_EMBED_TOKENS,
|
||||
LLM_TENSOR_NEXTN_ENORM,
|
||||
|
||||
@@ -69,9 +69,10 @@ llama_context::llama_context(
|
||||
cparams.embeddings_nextn_masked = false;
|
||||
cparams.offload_kqv = params.offload_kqv;
|
||||
cparams.no_perf = params.no_perf;
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
cparams.warmup = false;
|
||||
|
||||
cparams.ctx_type = params.ctx_type;
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
|
||||
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
|
||||
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
|
||||
@@ -84,7 +85,17 @@ llama_context::llama_context(
|
||||
cparams.cb_eval = params.cb_eval;
|
||||
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
||||
|
||||
cparams.ctx_type = params.ctx_type;
|
||||
cparams.ctx_other = nullptr;
|
||||
|
||||
// TODO: more generic
|
||||
if (model.arch == LLM_ARCH_GEMMA4_ASSISTANT) {
|
||||
if (params.ctx_other == nullptr) {
|
||||
// TODO: change from runtime_error to llama_exception to avoid printing error message
|
||||
throw std::runtime_error("Gemma4Assistant requires ctx_other to be set (this is normal during memory fitting)");
|
||||
}
|
||||
|
||||
cparams.ctx_other = params.ctx_other;
|
||||
}
|
||||
|
||||
// Initialize backend samplers here so they are part of the sampling graph
|
||||
// before the reserve passes run later in this function. This avoids a later
|
||||
@@ -300,10 +311,11 @@ llama_context::llama_context(
|
||||
// init the memory module
|
||||
if (!hparams.vocab_only) {
|
||||
llama_memory_params params_mem = {
|
||||
/*.type_k =*/ params.type_k,
|
||||
/*.type_v =*/ params.type_v,
|
||||
/*.swa_full =*/ params.swa_full,
|
||||
/*.ctx_type= */ cparams.ctx_type,
|
||||
/*.type_k =*/ params.type_k,
|
||||
/*.type_v =*/ params.type_v,
|
||||
/*.swa_full =*/ params.swa_full,
|
||||
/*.ctx_type =*/ cparams.ctx_type,
|
||||
/*.mem_other =*/ llama_get_memory(cparams.ctx_other),
|
||||
};
|
||||
|
||||
memory.reset(model.create_memory(params_mem, cparams));
|
||||
@@ -904,7 +916,7 @@ float * llama_context::get_embeddings_nextn_ith(int32_t i) {
|
||||
throw std::runtime_error("no nextn embeddings");
|
||||
}
|
||||
|
||||
const uint32_t n_embd = model.hparams.n_embd;
|
||||
const uint32_t n_embd = model.hparams.n_embd_out();
|
||||
|
||||
if (!cparams.embeddings_nextn_masked) {
|
||||
// unmasked: nextn rows are stored densely, indexed by raw token position.
|
||||
@@ -1473,7 +1485,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
||||
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn);
|
||||
GGML_ASSERT(backend_h != nullptr);
|
||||
|
||||
const uint32_t n_embd = hparams.n_embd;
|
||||
const uint32_t n_embd = hparams.n_embd_out();
|
||||
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_nextn.size);
|
||||
ggml_backend_tensor_get_async(backend_h, t_h_nextn, embd_nextn.data, 0, n_tokens*n_embd*sizeof(float));
|
||||
}
|
||||
@@ -1924,7 +1936,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn);
|
||||
GGML_ASSERT(backend_h != nullptr);
|
||||
|
||||
const uint32_t n_embd = hparams.n_embd;
|
||||
const uint32_t n_embd = hparams.n_embd_out();
|
||||
float * embd_nextn_out = embd_nextn.data + offset*n_embd;
|
||||
|
||||
GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_nextn.size);
|
||||
@@ -2017,7 +2029,6 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
|
||||
const auto n_batch = cparams.n_batch;
|
||||
const auto n_vocab = vocab.n_tokens();
|
||||
const auto n_embd = hparams.n_embd;
|
||||
const auto n_embd_out = hparams.n_embd_out();
|
||||
|
||||
bool has_logits = true;
|
||||
@@ -2036,12 +2047,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
|
||||
logits.size = has_logits ? n_vocab*n_outputs_max : 0;
|
||||
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
|
||||
embd_nextn.size = has_embd_nextn ? n_embd*n_outputs_max : 0;
|
||||
embd_nextn.size = has_embd_nextn ? n_embd_out*n_outputs_max : 0;
|
||||
|
||||
if (has_embd_nextn && !cparams.embeddings_nextn_masked) {
|
||||
// unmasked: nextn row exists for every token in the batch, not just
|
||||
// those flagged via batch.logits[i] -> size by token count instead.
|
||||
embd_nextn.size = (size_t) n_embd * n_batch;
|
||||
embd_nextn.size = (size_t) n_embd_out * n_batch;
|
||||
}
|
||||
|
||||
// Allocate backend sampling output buffers if there are backend samplers configured.
|
||||
@@ -3375,6 +3386,7 @@ llama_context_params llama_context_default_params() {
|
||||
/*.kv_unified =*/ false,
|
||||
/*.sampler =*/ nullptr,
|
||||
/*.n_sampler =*/ 0,
|
||||
/*.ctx_other =*/ nullptr,
|
||||
};
|
||||
|
||||
return result;
|
||||
@@ -3454,7 +3466,6 @@ llama_context * llama_init_from_model(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
try {
|
||||
auto * ctx = new llama_context(*model, params);
|
||||
return ctx;
|
||||
@@ -3593,6 +3604,14 @@ void llama_set_embeddings_nextn(llama_context * ctx, bool value, bool masked) {
|
||||
ctx->set_embeddings_nextn(value, masked);
|
||||
}
|
||||
|
||||
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
|
||||
if (!ctx) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return ctx->get_memory();
|
||||
}
|
||||
|
||||
float * llama_get_embeddings_nextn(llama_context * ctx) {
|
||||
ctx->synchronize();
|
||||
|
||||
@@ -3656,7 +3675,7 @@ struct ggml_cgraph * llama_graph_reserve(
|
||||
uint32_t n_tokens,
|
||||
uint32_t n_seqs,
|
||||
uint32_t n_outputs) {
|
||||
auto * memory = ctx->get_memory();
|
||||
auto memory = ctx->get_memory();
|
||||
llama_memory_context_ptr mctx;
|
||||
if (memory) {
|
||||
mctx = memory->init_full();
|
||||
@@ -3696,10 +3715,6 @@ int32_t llama_set_adapter_cvec(
|
||||
// memory
|
||||
//
|
||||
|
||||
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
|
||||
return ctx->get_memory();
|
||||
}
|
||||
|
||||
void llama_memory_clear(llama_memory_t mem, bool data) {
|
||||
if (!mem) {
|
||||
return;
|
||||
@@ -4010,3 +4025,7 @@ void llama_opt_epoch(
|
||||
llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx) {
|
||||
return ctx->memory_breakdown();
|
||||
}
|
||||
|
||||
llama_context * llama_get_ctx_other(struct llama_context * ctx) {
|
||||
return ctx->get_cparams().ctx_other;
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "llama-graph.h"
|
||||
#include "llama-adapter.h"
|
||||
#include "llama-impl.h"
|
||||
#include "llama-memory.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
#include "ggml-opt.h"
|
||||
@@ -273,7 +274,7 @@ private:
|
||||
|
||||
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
|
||||
|
||||
std::unique_ptr<llama_memory_i> memory;
|
||||
llama_memory_ptr memory;
|
||||
|
||||
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
||||
buffer_view<float> logits = {nullptr, 0};
|
||||
|
||||
@@ -49,4 +49,6 @@ struct llama_cparams {
|
||||
|
||||
ggml_backend_sched_eval_callback cb_eval;
|
||||
void * cb_eval_user_data;
|
||||
|
||||
llama_context * ctx_other;
|
||||
};
|
||||
|
||||
@@ -100,3 +100,5 @@ LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx);
|
||||
|
||||
// LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
|
||||
LLAMA_API float * llama_get_embeddings_nextn_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
LLAMA_API llama_context * llama_get_ctx_other(struct llama_context * ctx);
|
||||
|
||||
@@ -397,7 +397,7 @@ static void print_mask(const T * data, int64_t n_tokens, int64_t n_kv, int64_t n
|
||||
case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
|
||||
};
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
|
||||
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swa_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
|
||||
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
|
||||
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
|
||||
|
||||
@@ -565,18 +565,18 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
if (self_k_idxs && self_k_idxs->buffer) {
|
||||
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
|
||||
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
|
||||
|
||||
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
|
||||
// swa tensors may not be allocated if there are no SWA attention layers
|
||||
if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
|
||||
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
|
||||
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
|
||||
|
||||
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
|
||||
if (self_k_rot) {
|
||||
mctx->get_base()->set_input_k_rot(self_k_rot);
|
||||
}
|
||||
@@ -605,18 +605,18 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
|
||||
if (self_k_idxs && self_k_idxs->buffer) {
|
||||
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
||||
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
|
||||
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
|
||||
}
|
||||
|
||||
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
|
||||
|
||||
// swa tensors may not be allocated if there are no SWA attention layers
|
||||
if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
|
||||
res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
|
||||
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
|
||||
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
|
||||
}
|
||||
|
||||
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -756,7 +756,9 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
|
||||
attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
|
||||
attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
|
||||
}
|
||||
|
||||
if (inp_attn->self_kq_mask && inp_attn->self_kq_mask->buffer) {
|
||||
attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
@@ -764,7 +766,9 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
|
||||
attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch);
|
||||
attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch);
|
||||
}
|
||||
|
||||
if (inp_attn->self_kq_mask_swa && inp_attn->self_kq_mask_swa->buffer) {
|
||||
attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
@@ -810,18 +814,18 @@ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params)
|
||||
if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
|
||||
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
||||
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
|
||||
res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams);
|
||||
}
|
||||
|
||||
res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams);
|
||||
|
||||
// swa tensors may not be allocated if there are no SWA attention layers
|
||||
if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
|
||||
res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
|
||||
//res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
|
||||
res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams);
|
||||
}
|
||||
|
||||
res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams);
|
||||
|
||||
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
|
||||
|
||||
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
|
||||
@@ -1006,6 +1010,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
||||
ubatch (params.ubatch),
|
||||
n_embd (hparams.n_embd),
|
||||
n_layer (hparams.n_layer()),
|
||||
n_layer_nextn (hparams.n_layer_nextn),
|
||||
n_rot (hparams.n_rot()),
|
||||
n_ctx (cparams.n_ctx),
|
||||
n_head (hparams.n_head()),
|
||||
|
||||
@@ -784,6 +784,7 @@ struct llm_graph_context {
|
||||
|
||||
const int64_t n_embd;
|
||||
const int64_t n_layer;
|
||||
const int64_t n_layer_nextn;
|
||||
const int64_t n_rot;
|
||||
const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
|
||||
const int64_t n_head;
|
||||
|
||||
@@ -91,6 +91,10 @@ uint32_t llama_hparams::n_rot(uint32_t il) const {
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_inp() const {
|
||||
if (n_embd_inp_impl > 0) {
|
||||
return n_embd_inp_impl;
|
||||
}
|
||||
|
||||
uint32_t n_embd_inp = n_embd;
|
||||
|
||||
if (n_deepstack_layers > 0) {
|
||||
|
||||
@@ -185,6 +185,9 @@ struct llama_hparams {
|
||||
// for Classifiers
|
||||
uint32_t n_cls_out = 1;
|
||||
|
||||
// input embedding dimension (0 = use n_embd)
|
||||
uint32_t n_embd_inp_impl = 0;
|
||||
|
||||
// output embedding dimension (0 = use n_embd)
|
||||
uint32_t n_embd_out_impl = 0;
|
||||
|
||||
@@ -224,6 +227,7 @@ struct llama_hparams {
|
||||
// complex mapping. If using deepstack_mapping_arr, also make sure to set
|
||||
// n_deepstack_layers to the number of unique deepstack layers so that
|
||||
// n_embd_imp is accurate (see granite.cpp).
|
||||
// TODO: can be expressed via the `new n_embd_inp_impl` and remove this param
|
||||
uint32_t n_deepstack_layers = 0;
|
||||
|
||||
// deepstack layer array (Granite4 Vision)
|
||||
|
||||
@@ -32,7 +32,7 @@ llama_kv_cache_dsa::llama_kv_cache_dsa(
|
||||
kv_mla = std::make_unique<llama_kv_cache>(
|
||||
model, model.hparams, type_k, type_v,
|
||||
v_trans, offload, unified, kv_size, n_seq_max, n_pad,
|
||||
n_swa, swa_type, filter, reuse);
|
||||
n_swa, swa_type, nullptr, filter, reuse, nullptr);
|
||||
|
||||
// we use llama_kv_cache for caching indexer keys
|
||||
// by hand-tweaking some hparams we fool it to create
|
||||
@@ -49,7 +49,7 @@ llama_kv_cache_dsa::llama_kv_cache_dsa(
|
||||
kv_lid = std::make_unique<llama_kv_cache>(
|
||||
model, hparams_lid, type_k, type_v,
|
||||
v_trans, offload, unified, kv_size, n_seq_max, n_pad,
|
||||
n_swa, swa_type, filter, reuse);
|
||||
n_swa, swa_type, nullptr, filter, reuse, nullptr);
|
||||
}
|
||||
|
||||
void llama_kv_cache_dsa::clear(bool data) {
|
||||
|
||||
@@ -23,8 +23,10 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad,
|
||||
llama_memory_t mem_other,
|
||||
const layer_filter_cb & filter,
|
||||
const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {
|
||||
const layer_reuse_cb & reuse,
|
||||
const layer_share_cb & share) : hparams(model.hparams), unified(unified) {
|
||||
|
||||
// chain filters
|
||||
const layer_filter_cb filter_base = [&](int32_t il) {
|
||||
@@ -59,17 +61,27 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
|
||||
|
||||
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
|
||||
|
||||
llama_memory_t mem_other_base = nullptr;
|
||||
if (mem_other) {
|
||||
mem_other_base = static_cast<llama_kv_cache_iswa *>(mem_other)->get_base();
|
||||
}
|
||||
|
||||
llama_memory_t mem_other_swa = nullptr;
|
||||
if (mem_other) {
|
||||
mem_other_swa = static_cast<llama_kv_cache_iswa *>(mem_other)->get_swa();
|
||||
}
|
||||
|
||||
kv_base = std::make_unique<llama_kv_cache>(
|
||||
model, hparams, type_k, type_v,
|
||||
v_trans, offload, unified, size_base, n_seq_max, n_pad,
|
||||
0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
|
||||
0, LLAMA_SWA_TYPE_NONE, mem_other_base, filter_base, reuse, share);
|
||||
|
||||
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
||||
|
||||
kv_swa = std::make_unique<llama_kv_cache>(
|
||||
model, hparams, type_k, type_v,
|
||||
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
|
||||
hparams.n_swa, hparams.swa_type, filter_swa, reuse);
|
||||
hparams.n_swa, hparams.swa_type, mem_other_swa, filter_swa, reuse, share);
|
||||
}
|
||||
|
||||
void llama_kv_cache_iswa::clear(bool data) {
|
||||
|
||||
@@ -25,8 +25,10 @@ public:
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad,
|
||||
llama_memory_t mem_other,
|
||||
const layer_filter_cb & filter,
|
||||
const layer_reuse_cb & reuse);
|
||||
const layer_reuse_cb & reuse,
|
||||
const layer_share_cb & share);
|
||||
|
||||
~llama_kv_cache_iswa() = default;
|
||||
|
||||
|
||||
@@ -90,8 +90,10 @@ llama_kv_cache::llama_kv_cache(
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type,
|
||||
llama_memory_t mem_other,
|
||||
const layer_filter_cb & filter,
|
||||
const layer_reuse_cb & reuse) :
|
||||
const layer_reuse_cb & reuse,
|
||||
const layer_share_cb & share) :
|
||||
model(model), hparams(hparams), v_trans(v_trans),
|
||||
n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
|
||||
|
||||
@@ -160,6 +162,8 @@ llama_kv_cache::llama_kv_cache(
|
||||
|
||||
const bool is_mla = hparams.is_mla();
|
||||
|
||||
other = static_cast<llama_kv_cache *>(mem_other);
|
||||
|
||||
for (uint32_t il = 0; il < n_layer; il++) {
|
||||
if (!hparams.has_kv(il)) {
|
||||
LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
|
||||
@@ -171,6 +175,24 @@ llama_kv_cache::llama_kv_cache(
|
||||
continue;
|
||||
}
|
||||
|
||||
if (share && other) {
|
||||
const int32_t il_share = share(il);
|
||||
|
||||
if (il_share >= 0) {
|
||||
const auto & layer_share = other->layers[other->map_layer_ids[il_share]];
|
||||
|
||||
LLAMA_LOG_WARN("%s: layer %3d: sharing with layer %d. k = %p, v = %p\n", __func__, il, il_share,
|
||||
layer_share.k->data, layer_share.v->data);
|
||||
|
||||
map_layer_ids[il] = layers.size();
|
||||
|
||||
layers.push_back(layer_share);
|
||||
layers.back().il = il;
|
||||
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (n_embd_head_k_all == 0) {
|
||||
n_embd_head_k_all = (int32_t) hparams.n_embd_head_k(il);
|
||||
} else if (n_embd_head_k_all > 0 && n_embd_head_k_all != (int32_t) hparams.n_embd_head_k(il)) {
|
||||
@@ -282,29 +304,38 @@ llama_kv_cache::llama_kv_cache(
|
||||
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
||||
}
|
||||
|
||||
const char * LLAMA_ATTN_ROT_DISABLE = getenv("LLAMA_ATTN_ROT_DISABLE");
|
||||
const bool attn_rot_disable = LLAMA_ATTN_ROT_DISABLE ? atoi(LLAMA_ATTN_ROT_DISABLE) : false;
|
||||
if (attn_rot_disable) {
|
||||
LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__);
|
||||
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
|
||||
if (other) {
|
||||
n_embd_head_k_all = other->n_embd_head_k_all;
|
||||
n_embd_head_v_all = other->n_embd_head_v_all;
|
||||
|
||||
attn_rot_k = other->attn_rot_k;
|
||||
attn_rot_v = other->attn_rot_v;
|
||||
} else {
|
||||
const char * LLAMA_ATTN_ROT_DISABLE = getenv("LLAMA_ATTN_ROT_DISABLE");
|
||||
const bool attn_rot_disable = LLAMA_ATTN_ROT_DISABLE ? atoi(LLAMA_ATTN_ROT_DISABLE) : false;
|
||||
if (attn_rot_disable) {
|
||||
LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__);
|
||||
}
|
||||
|
||||
attn_rot_k =
|
||||
!attn_rot_disable &&
|
||||
n_embd_head_k_all > 0 &&
|
||||
ggml_is_quantized(type_k) &&
|
||||
hparams.n_embd_head_k() % 64 == 0;
|
||||
|
||||
// always create Hadamard rotation tensors for DeepSeek V3.2 DSA lightning indexer
|
||||
if (model.arch == LLM_ARCH_DEEPSEEK32 && hparams.n_embd_head_k_full == hparams.indexer_head_size) {
|
||||
attn_rot_k = true;
|
||||
}
|
||||
|
||||
attn_rot_v =
|
||||
!attn_rot_disable &&
|
||||
n_embd_head_v_all > 0 &&
|
||||
ggml_is_quantized(type_v) &&
|
||||
hparams.n_embd_head_v() % 64 == 0;
|
||||
}
|
||||
|
||||
attn_rot_k =
|
||||
!attn_rot_disable &&
|
||||
n_embd_head_k_all > 0 &&
|
||||
ggml_is_quantized(type_k) &&
|
||||
hparams.n_embd_head_k() % 64 == 0;
|
||||
|
||||
// always create Hadamard rotation tensors for DeepSeek V3.2 DSA lightning indexer
|
||||
if (model.arch == LLM_ARCH_DEEPSEEK32 && hparams.n_embd_head_k_full == hparams.indexer_head_size) {
|
||||
attn_rot_k = true;
|
||||
}
|
||||
|
||||
attn_rot_v =
|
||||
!attn_rot_disable &&
|
||||
n_embd_head_v_all > 0 &&
|
||||
ggml_is_quantized(type_v) &&
|
||||
hparams.n_embd_head_v() % 64 == 0;
|
||||
|
||||
LLAMA_LOG_INFO("%s: attn_rot_k = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_k, n_embd_head_k_all);
|
||||
LLAMA_LOG_INFO("%s: attn_rot_v = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_v, n_embd_head_v_all);
|
||||
|
||||
@@ -347,6 +378,11 @@ void llama_kv_cache::clear(bool data) {
|
||||
}
|
||||
|
||||
bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
|
||||
if (other) {
|
||||
return true;
|
||||
}
|
||||
|
||||
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
|
||||
|
||||
if (p0 < 0) {
|
||||
@@ -410,6 +446,11 @@ bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
}
|
||||
|
||||
void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
|
||||
if (other) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size());
|
||||
GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size());
|
||||
|
||||
@@ -497,6 +538,11 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
|
||||
}
|
||||
|
||||
void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
|
||||
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
|
||||
if (other) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
||||
|
||||
auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||
@@ -519,6 +565,11 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
|
||||
}
|
||||
|
||||
void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
|
||||
if (other) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
||||
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1");
|
||||
|
||||
@@ -564,6 +615,11 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll
|
||||
}
|
||||
|
||||
void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
|
||||
if (other) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
||||
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1");
|
||||
|
||||
@@ -598,6 +654,11 @@ void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, in
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const {
|
||||
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
|
||||
if (other) {
|
||||
return other->seq_pos_min(seq_id);
|
||||
}
|
||||
|
||||
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
||||
|
||||
const auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||
@@ -606,6 +667,11 @@ llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const {
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
|
||||
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
|
||||
if (other) {
|
||||
return other->seq_pos_max(seq_id);
|
||||
}
|
||||
|
||||
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
||||
|
||||
const auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||
@@ -746,6 +812,11 @@ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_
|
||||
}
|
||||
|
||||
bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info) {
|
||||
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
|
||||
if (other) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool updated = false;
|
||||
|
||||
auto * sched = lctx->get_sched();
|
||||
@@ -1021,6 +1092,12 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch,
|
||||
}
|
||||
|
||||
void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
|
||||
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
|
||||
if (other) {
|
||||
v_cells = other->v_cells;
|
||||
return;
|
||||
}
|
||||
|
||||
// keep track of the max sequence position that we would overwrite with this ubatch
|
||||
// for non-SWA cache, this would be always empty
|
||||
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
|
||||
@@ -1815,6 +1892,9 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
|
||||
ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
|
||||
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
|
||||
GGML_ASSERT(!other);
|
||||
|
||||
auto * ctx = res->get_ctx();
|
||||
auto * gf = res->get_gf();
|
||||
|
||||
@@ -1860,6 +1940,11 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
||||
}
|
||||
|
||||
void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
||||
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
|
||||
if (other) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_UNUSED(flags);
|
||||
|
||||
io.write(&n_stream, sizeof(n_stream));
|
||||
@@ -1925,6 +2010,11 @@ void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, lla
|
||||
}
|
||||
|
||||
void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
|
||||
if (other) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_UNUSED(flags);
|
||||
|
||||
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
|
||||
|
||||
@@ -98,7 +98,7 @@ public:
|
||||
// likely through `struct llama_memory_params`
|
||||
llama_kv_cache(
|
||||
const llama_model & model,
|
||||
const llama_hparams & hparams,
|
||||
const llama_hparams & hparams,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
@@ -109,8 +109,10 @@ public:
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type,
|
||||
llama_memory_t mem_other,
|
||||
const layer_filter_cb & filter,
|
||||
const layer_reuse_cb & reuse);
|
||||
const layer_reuse_cb & reuse,
|
||||
const layer_share_cb & share);
|
||||
|
||||
~llama_kv_cache() = default;
|
||||
|
||||
@@ -264,6 +266,9 @@ private:
|
||||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||
std::vector<uint32_t> v_heads;
|
||||
|
||||
// TODO: temporary until we refactor to be able to share the same cells between 2 kv caches [TAG_KV_CACHE_SHARE_CELLS]
|
||||
llama_kv_cache * other;
|
||||
|
||||
std::vector<llama_kv_cells> v_cells;
|
||||
|
||||
// maps from a sequence id to a stream id
|
||||
|
||||
@@ -43,9 +43,11 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa(
|
||||
n_seq_max,
|
||||
n_ubatch,
|
||||
n_pad,
|
||||
nullptr,
|
||||
filter_attn == nullptr ?
|
||||
[&](int32_t il) { return !hparams.is_recr(il); }
|
||||
: filter_attn,
|
||||
nullptr,
|
||||
nullptr
|
||||
)),
|
||||
mem_recr(new llama_memory_recurrent(
|
||||
|
||||
@@ -44,9 +44,11 @@ llama_memory_hybrid::llama_memory_hybrid(
|
||||
n_pad,
|
||||
n_swa,
|
||||
swa_type,
|
||||
nullptr,
|
||||
filter_attn == nullptr ?
|
||||
[&](int32_t il) { return !hparams.is_recr(il); }
|
||||
: filter_attn,
|
||||
nullptr,
|
||||
nullptr
|
||||
)),
|
||||
mem_recr(new llama_memory_recurrent(
|
||||
|
||||
@@ -23,6 +23,8 @@ struct llama_memory_params {
|
||||
bool swa_full;
|
||||
|
||||
llama_context_type ctx_type;
|
||||
|
||||
llama_memory_t mem_other;
|
||||
};
|
||||
|
||||
enum llama_memory_status {
|
||||
@@ -76,6 +78,8 @@ struct llama_memory_i {
|
||||
// return negative value to indicate that the layer il should not reuse memory
|
||||
using layer_reuse_cb = std::function<int32_t(int32_t il)>;
|
||||
|
||||
using layer_share_cb = std::function<int32_t(int32_t il)>;
|
||||
|
||||
virtual ~llama_memory_i() = default;
|
||||
|
||||
// split the input batch into a set of ubatches and verify that they can fit into the cache
|
||||
|
||||
@@ -139,6 +139,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params
|
||||
return new llama_model_gemma3n(params);
|
||||
case LLM_ARCH_GEMMA4:
|
||||
return new llama_model_gemma4(params);
|
||||
case LLM_ARCH_GEMMA4_ASSISTANT:
|
||||
return new llama_model_gemma4_assistant(params);
|
||||
case LLM_ARCH_GEMMA_EMBEDDING:
|
||||
return new llama_model_gemma_embedding(params);
|
||||
case LLM_ARCH_STARCODER2:
|
||||
@@ -1717,19 +1719,21 @@ void llama_model::print_info() const {
|
||||
|
||||
if (!hparams.vocab_only) {
|
||||
LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train);
|
||||
LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd);
|
||||
LLAMA_LOG_INFO("%s: n_embd_inp = %u\n", __func__, hparams.n_embd_inp());
|
||||
LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd);
|
||||
LLAMA_LOG_INFO("%s: n_embd_out = %u\n", __func__, hparams.n_embd_out());
|
||||
LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer());
|
||||
LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer()).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer()).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_layer_all = %u\n", __func__, hparams.n_layer_all);
|
||||
LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer_all).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer_all).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot_full);
|
||||
LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
|
||||
LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any());
|
||||
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k_full);
|
||||
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v_full);
|
||||
LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer()).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer()).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer()).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer_all).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer_all).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer_all).c_str());
|
||||
LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps);
|
||||
LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps);
|
||||
LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv);
|
||||
@@ -1737,7 +1741,7 @@ void llama_model::print_info() const {
|
||||
LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale);
|
||||
LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale);
|
||||
LLAMA_LOG_INFO("%s: f_attn_value_scale = %.4f\n", __func__, hparams.f_attn_value_scale);
|
||||
LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer()).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer_all).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
|
||||
LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
|
||||
LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups);
|
||||
@@ -1764,7 +1768,7 @@ void llama_model::print_info() const {
|
||||
[](const auto & entry) { return entry >= 0; })) {
|
||||
LLAMA_LOG_INFO("%s: deepstack_mapping_arr = %s\n", __func__,
|
||||
print_f([&](uint32_t il) { return hparams.deepstack_mapping_arr[il]; },
|
||||
hparams.n_layer()).c_str());
|
||||
hparams.n_layer_all).c_str());
|
||||
}
|
||||
// MRoPE (Multi-axis Rotary Position Embedding) sections
|
||||
if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) {
|
||||
@@ -2113,8 +2117,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
/* filter_recr */ std::move(filter_recr));
|
||||
}
|
||||
} else {
|
||||
llama_memory_i::layer_reuse_cb reuse = nullptr;
|
||||
llama_kv_cache::layer_filter_cb filter = nullptr;
|
||||
llama_memory_i::layer_reuse_cb reuse = nullptr;
|
||||
llama_kv_cache::layer_share_cb share = nullptr;
|
||||
|
||||
if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) {
|
||||
reuse = [&](uint32_t il) {
|
||||
@@ -2143,20 +2148,53 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||
GGML_ASSERT(hparams.is_swa_any());
|
||||
|
||||
res = new llama_kv_cache_iswa(
|
||||
*this,
|
||||
params.type_k,
|
||||
params.type_v,
|
||||
!cparams.flash_attn,
|
||||
cparams.offload_kqv,
|
||||
params.swa_full,
|
||||
cparams.kv_unified,
|
||||
cparams.n_ctx_seq,
|
||||
cparams.n_seq_max,
|
||||
cparams.n_ubatch,
|
||||
1,
|
||||
filter,
|
||||
reuse);
|
||||
if (arch == LLM_ARCH_GEMMA4_ASSISTANT) {
|
||||
llama_memory_t mem_other = llama_get_memory(cparams.ctx_other);
|
||||
|
||||
share = [&](int32_t il) {
|
||||
const llama_model * model_other = llama_get_model(cparams.ctx_other);
|
||||
|
||||
if (hparams.is_swa(il)) {
|
||||
return llama_model_n_layer(model_other) - 2;
|
||||
}
|
||||
|
||||
return llama_model_n_layer(model_other) - 1;
|
||||
};
|
||||
|
||||
res = new llama_kv_cache_iswa(
|
||||
*this,
|
||||
params.type_k,
|
||||
params.type_v,
|
||||
!cparams.flash_attn,
|
||||
cparams.offload_kqv,
|
||||
params.swa_full,
|
||||
cparams.kv_unified,
|
||||
cparams.n_ctx_seq,
|
||||
cparams.n_seq_max,
|
||||
cparams.n_ubatch,
|
||||
1,
|
||||
mem_other,
|
||||
filter,
|
||||
reuse,
|
||||
share);
|
||||
} else {
|
||||
res = new llama_kv_cache_iswa(
|
||||
*this,
|
||||
params.type_k,
|
||||
params.type_v,
|
||||
!cparams.flash_attn,
|
||||
cparams.offload_kqv,
|
||||
params.swa_full,
|
||||
cparams.kv_unified,
|
||||
cparams.n_ctx_seq,
|
||||
cparams.n_seq_max,
|
||||
cparams.n_ubatch,
|
||||
1,
|
||||
nullptr,
|
||||
filter,
|
||||
reuse,
|
||||
share);
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(!hparams.is_swa_any());
|
||||
|
||||
@@ -2173,7 +2211,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
1,
|
||||
hparams.n_swa,
|
||||
hparams.swa_type,
|
||||
nullptr,
|
||||
filter,
|
||||
nullptr,
|
||||
nullptr);
|
||||
}
|
||||
}
|
||||
@@ -2406,6 +2446,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
case LLM_ARCH_GEMMA3:
|
||||
case LLM_ARCH_GEMMA3N:
|
||||
case LLM_ARCH_GEMMA4:
|
||||
case LLM_ARCH_GEMMA4_ASSISTANT:
|
||||
case LLM_ARCH_GEMMA_EMBEDDING:
|
||||
case LLM_ARCH_STARCODER2:
|
||||
case LLM_ARCH_OPENELM:
|
||||
|
||||
@@ -548,6 +548,10 @@ struct llama_model {
|
||||
struct ggml_tensor * output_s = nullptr;
|
||||
struct ggml_tensor * output_in_s = nullptr;
|
||||
|
||||
// NextN/MTP model-level projections
|
||||
struct ggml_tensor * nextn_proj_pre = nullptr;
|
||||
struct ggml_tensor * nextn_proj_post = nullptr;
|
||||
|
||||
// classifier
|
||||
struct ggml_tensor * cls = nullptr;
|
||||
struct ggml_tensor * cls_b = nullptr;
|
||||
@@ -702,6 +706,7 @@ const char * llm_type_name(llm_type type);
|
||||
#define LLAMA_LOAD_LOCALS \
|
||||
const int n_layer = hparams.n_layer(); GGML_UNUSED(n_layer); \
|
||||
const int n_layer_all = hparams.n_layer_all; GGML_UNUSED(n_layer_all); \
|
||||
const int n_layer_nextn = hparams.n_layer_nextn; GGML_UNUSED(n_layer_nextn); \
|
||||
const int64_t n_head = hparams.n_head(); GGML_UNUSED(n_head); \
|
||||
const int64_t n_head_kv = hparams.n_head_kv(); GGML_UNUSED(n_head_kv); \
|
||||
const int64_t n_embd = hparams.n_embd; GGML_UNUSED(n_embd); \
|
||||
|
||||
200
src/models/gemma4-assistant.cpp
Normal file
200
src/models/gemma4-assistant.cpp
Normal file
@@ -0,0 +1,200 @@
|
||||
#include "models.h"
|
||||
|
||||
void llama_model_gemma4_assistant::load_arch_hparams(llama_model_loader & ml) {
|
||||
hparams.n_embd_inp_impl = hparams.n_embd_out();
|
||||
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer());
|
||||
|
||||
uint32_t n_kv_shared_layers = 0;
|
||||
ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false);
|
||||
|
||||
hparams.f_attention_scale = 1.0f;
|
||||
|
||||
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false);
|
||||
GGML_ASSERT(hparams.n_layer_nextn == hparams.n_layer_all && "n_layer_nextn must be == n_layer_impl");
|
||||
|
||||
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa);
|
||||
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa);
|
||||
}
|
||||
|
||||
void llama_model_gemma4_assistant::load_arch_tensors(llama_model_loader &) {
|
||||
LLAMA_LOAD_LOCALS;
|
||||
|
||||
if (n_embd_head_k != n_embd_head_v) {
|
||||
throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k == n_embd_head_v");
|
||||
}
|
||||
if (hparams.n_embd_head_k_swa != hparams.n_embd_head_v_swa) {
|
||||
throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k_swa == n_embd_head_v_swa");
|
||||
}
|
||||
if (hparams.n_embd_out() == n_embd) {
|
||||
throw std::runtime_error("Gemma 4 assistant requires embedding_length_out to carry the target hidden size");
|
||||
}
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
|
||||
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
|
||||
|
||||
const int64_t n_embd_backbone = hparams.n_embd_inp();
|
||||
nextn_proj_post = create_tensor(tn(LLM_TENSOR_NEXTN_PROJ_POST, "weight"), { n_embd, n_embd_backbone }, 0);
|
||||
|
||||
int rope_freqs_flag = 0;
|
||||
|
||||
for (int i = 0; i < n_layer_nextn; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
const int64_t n_head = hparams.n_head(i);
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k(i);
|
||||
const int64_t n_ff = hparams.n_ff(i);
|
||||
|
||||
if (i == 0) {
|
||||
nextn_proj_pre = create_tensor(tn(LLM_TENSOR_NEXTN_PROJ_PRE, "weight", i), { 2*n_embd_backbone, n_embd }, 0);
|
||||
}
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head*n_head }, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head*n_head, n_embd }, 0);
|
||||
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head }, 0);
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
|
||||
|
||||
layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), { 1u }, 0);
|
||||
|
||||
if (!hparams.is_swa(i)) {
|
||||
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_embd_head/2 }, rope_freqs_flag);
|
||||
rope_freqs_flag = TENSOR_DUPLICATED;
|
||||
}
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
|
||||
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), { n_embd }, 0);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<llm_graph_context> llama_model_gemma4_assistant::build_arch_graph(const llm_graph_params & params) const {
|
||||
return std::make_unique<graph>(*this, params);
|
||||
}
|
||||
|
||||
llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_backbone = hparams.n_embd_inp();
|
||||
|
||||
ggml_tensor * inp_tokens;
|
||||
ggml_tensor * inp_h;
|
||||
{
|
||||
auto inp = std::make_unique<llm_graph_input_embd>(n_embd_backbone);
|
||||
|
||||
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
cb(inp->tokens, "inp_tokens", -1);
|
||||
ggml_set_input(inp->tokens);
|
||||
inp_tokens = inp->tokens;
|
||||
res->t_inp_tokens = inp->tokens;
|
||||
|
||||
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_backbone, ubatch.n_tokens);
|
||||
cb(inp->embd, "inp_h", -1);
|
||||
ggml_set_input(inp->embd);
|
||||
inp_h = inp->embd;
|
||||
res->t_inp_embd = inp->embd;
|
||||
|
||||
res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
GGML_ASSERT(cparams.ctx_other != nullptr);
|
||||
const auto * model_other = llama_get_model(cparams.ctx_other);
|
||||
|
||||
ggml_tensor * x = ggml_get_rows(ctx0, model_other->tok_embd, inp_tokens);
|
||||
x = ggml_scale(ctx0, x, sqrtf((float) n_embd_backbone));
|
||||
cb(x, "inp_embd_target", -1);
|
||||
|
||||
ggml_tensor * xh = ggml_concat(ctx0, x, inp_h, 0);
|
||||
cb(xh, "inp_xh", -1);
|
||||
|
||||
ggml_tensor * cur = ggml_mul_mat(ctx0, model.nextn_proj_pre, xh);
|
||||
cb(cur, "pre_proj", -1);
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_iswa();
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
ggml_tensor * inpL = cur;
|
||||
|
||||
for (int il = 0; il < n_layer_nextn; ++il) {
|
||||
const bool is_swa = hparams.is_swa(il);
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k(il);
|
||||
const int64_t n_head = hparams.n_head(il);
|
||||
|
||||
const float freq_base_l = model.get_rope_freq_base(cparams, il);
|
||||
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
|
||||
const int n_rot_l = hparams.n_rot(il);
|
||||
|
||||
ggml_tensor * cur_norm = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(cur_norm, "attn_norm", il);
|
||||
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur_norm);
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
|
||||
ggml_tensor * freq_factors = is_swa ? nullptr : model.layers[il].rope_freqs;
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig,
|
||||
freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Qcur, "Qcur_pos", il);
|
||||
|
||||
cur = build_attn(inp_attn, model.layers[il].wo, nullptr, nullptr,
|
||||
Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
|
||||
|
||||
if (il == n_layer_nextn - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
||||
}
|
||||
|
||||
cur = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_post_norm", il);
|
||||
|
||||
ggml_tensor * attn_out = ggml_add(ctx0, cur, inpL);
|
||||
cb(attn_out, "attn_out", il);
|
||||
|
||||
cur = build_norm(attn_out, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, nullptr, nullptr,
|
||||
model.layers[il].ffn_gate, nullptr, nullptr,
|
||||
model.layers[il].ffn_down, nullptr, nullptr,
|
||||
nullptr,
|
||||
LLM_FFN_GELU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = build_norm(cur, model.layers[il].ffn_post_norm, nullptr, LLM_NORM_RMS, -1);
|
||||
cb(cur, "ffn_post_norm", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, attn_out);
|
||||
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].out_scale);
|
||||
cb(cur, "out_scaled", il);
|
||||
|
||||
inpL = cur;
|
||||
}
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
ggml_tensor * logits = build_lora_mm(model.output, cur);
|
||||
cb(logits, "result_output", -1);
|
||||
res->t_logits = logits;
|
||||
|
||||
ggml_tensor * h_next = ggml_mul_mat(ctx0, model.nextn_proj_post, cur);
|
||||
cb(h_next, "h_nextn", -1);
|
||||
res->t_h_nextn = h_next;
|
||||
|
||||
ggml_build_forward_expand(gf, logits);
|
||||
ggml_build_forward_expand(gf, h_next);
|
||||
}
|
||||
@@ -155,12 +155,14 @@ public:
|
||||
}
|
||||
virtual ~llm_graph_input_logits_bias() = default;
|
||||
|
||||
void set_input(const llama_ubatch *) override {
|
||||
void set_input(const llama_ubatch * /*ubatch*/) override {
|
||||
const int64_t n_vocab = arr.size();
|
||||
ggml_backend_tensor_set(logits_bias, arr.data(), 0, n_vocab*ggml_element_size(logits_bias));
|
||||
}
|
||||
|
||||
// bool can_reuse(const llm_graph_params & params) override;
|
||||
bool can_reuse(const llm_graph_params & /*params*/) override {
|
||||
return true;
|
||||
}
|
||||
|
||||
ggml_tensor * logits_bias = nullptr; // F32 [n_vocab]
|
||||
|
||||
@@ -270,7 +272,8 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para
|
||||
}
|
||||
|
||||
// TODO @ngxson : strip unused token right after the last KV layer to speed up prompt processing
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
// keep all rows when extracting unmasked nextn embeddings (MTP target needs the hidden state for every token)
|
||||
if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
||||
}
|
||||
@@ -370,7 +373,7 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para
|
||||
ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_per_layer, n_tokens]
|
||||
|
||||
// TODO @ngxson : improve this
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) {
|
||||
inp_this_layer = ggml_get_rows(ctx0, inp_this_layer, inp_out_ids);
|
||||
}
|
||||
|
||||
@@ -401,6 +404,17 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para
|
||||
model.output_norm, nullptr,
|
||||
LLM_NORM_RMS, -1);
|
||||
|
||||
// Expose the post-output-norm hidden state (the LM-head input feature) so that
|
||||
// MTP draft contexts can read it via llama_get_embeddings_nextn_ith() as the
|
||||
// recurrent h input. This matches the reference (transformers/vLLM/SGLang),
|
||||
// which feeds the drafter the target's post-final-norm hidden state.
|
||||
cb(cur, "h_nextn", -1);
|
||||
res->t_h_nextn = cur;
|
||||
|
||||
if (!cparams.embeddings_nextn_masked && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
}
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
|
||||
@@ -822,6 +822,19 @@ struct llama_model_gemma4 : public llama_model_base {
|
||||
};
|
||||
|
||||
|
||||
struct llama_model_gemma4_assistant : public llama_model_base {
|
||||
llama_model_gemma4_assistant(const struct llama_model_params & params) : llama_model_base(params) {}
|
||||
void load_arch_hparams(llama_model_loader & ml) override;
|
||||
void load_arch_tensors(llama_model_loader & ml) override;
|
||||
|
||||
struct graph : public llm_graph_context {
|
||||
graph(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override;
|
||||
};
|
||||
|
||||
|
||||
struct llama_model_gemma_embedding : public llama_model_base {
|
||||
llama_model_gemma_embedding(const struct llama_model_params & params) : llama_model_base(params) {}
|
||||
void load_arch_hparams(llama_model_loader & ml) override;
|
||||
|
||||
@@ -392,7 +392,7 @@ static bool arch_supported(const llm_arch arch) {
|
||||
if (arch == LLM_ARCH_WAVTOKENIZER_DEC) {
|
||||
return false; // FIXME CUDA backend crashes.
|
||||
}
|
||||
if (arch == LLM_ARCH_GEMMA4) {
|
||||
if (arch == LLM_ARCH_GEMMA4 || arch == LLM_ARCH_GEMMA4_ASSISTANT) {
|
||||
return false; // FIXME @ngxson
|
||||
}
|
||||
if (arch == LLM_ARCH_LLAMA_EMBED || arch == LLM_ARCH_GEMMA_EMBEDDING || arch == LLM_ARCH_T5ENCODER) {
|
||||
@@ -447,7 +447,7 @@ static int save_models(const llm_arch target_arch, const size_t seed, const ggml
|
||||
if (target_arch != LLM_ARCH_UNKNOWN && arch != target_arch) {
|
||||
continue;
|
||||
}
|
||||
if (arch == LLM_ARCH_GEMMA4) {
|
||||
if (arch == LLM_ARCH_GEMMA4 || arch == LLM_ARCH_GEMMA4_ASSISTANT) {
|
||||
continue; // FIXME: ISWA KV cache initialization needs more fixture params
|
||||
}
|
||||
for (bool moe : {false, true}) {
|
||||
@@ -550,7 +550,7 @@ static int test_backends(const llm_arch target_arch, const size_t seed, const gg
|
||||
if (target_arch != LLM_ARCH_UNKNOWN && arch != target_arch) {
|
||||
continue;
|
||||
}
|
||||
if (arch == LLM_ARCH_GEMMA4) {
|
||||
if (arch == LLM_ARCH_GEMMA4 || arch == LLM_ARCH_GEMMA4_ASSISTANT) {
|
||||
continue; // FIXME: ISWA KV cache initialization needs more fixture params
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
#include "server-context.h"
|
||||
#include "server-chat.h"
|
||||
#include "server-common.h"
|
||||
@@ -16,6 +15,11 @@
|
||||
#include "mtmd.h"
|
||||
#include "mtmd-helper.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
|
||||
// TODO: tmp until the mtmd draft processing is refactored [TAG_MTMD_DRAFT_PROCESSING]
|
||||
#include "../../src/llama-ext.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <cinttypes>
|
||||
@@ -884,7 +888,7 @@ private:
|
||||
has_draft ? "draft model" : "MTP context",
|
||||
total / (1024.0 * 1024.0));
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("[spec] failed to measure %s memory: %s\n",
|
||||
SRV_WRN("[spec] failed to measure %s memory: %s\n",
|
||||
has_draft ? "draft model" : "MTP context", e.what());
|
||||
}
|
||||
}
|
||||
@@ -940,16 +944,17 @@ private:
|
||||
const bool spec_mtp = std::find(params_base.speculative.types.begin(),
|
||||
params_base.speculative.types.end(),
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end();
|
||||
|
||||
if (spec_mtp) {
|
||||
cparams.ctx_type = LLAMA_CONTEXT_TYPE_MTP;
|
||||
}
|
||||
|
||||
// note: for small models maybe we can set this to the maximum possible draft from all speculative types
|
||||
// the extra memory for small models is likely negligible?
|
||||
cparams.n_rs_seq = 0;
|
||||
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
|
||||
cparams.n_rs_seq = 0;
|
||||
cparams.ctx_other = ctx_tgt;
|
||||
|
||||
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
|
||||
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
|
||||
|
||||
params_base.speculative.draft.ctx_tgt = ctx_tgt;
|
||||
params_base.speculative.draft.ctx_dft = ctx_dft.get();
|
||||
@@ -964,6 +969,7 @@ private:
|
||||
cparams_mtp.type_v = params_base.speculative.draft.cache_type_v;
|
||||
cparams_mtp.n_rs_seq = 0;
|
||||
cparams_mtp.n_outputs_max = params_base.n_parallel;
|
||||
cparams_mtp.ctx_other = ctx_tgt;
|
||||
|
||||
ctx_dft.reset(llama_init_from_model(model_tgt, cparams_mtp));
|
||||
if (ctx_dft == nullptr) {
|
||||
@@ -971,8 +977,6 @@ private:
|
||||
return false;
|
||||
}
|
||||
|
||||
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
|
||||
|
||||
params_base.speculative.draft.ctx_tgt = ctx_tgt;
|
||||
params_base.speculative.draft.ctx_dft = ctx_dft.get();
|
||||
}
|
||||
@@ -1060,6 +1064,10 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
if (ctx_dft) {
|
||||
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
|
||||
}
|
||||
|
||||
if (spec) {
|
||||
SRV_INF("%s", "speculative decoding context initialized\n");
|
||||
} else {
|
||||
@@ -2974,10 +2982,11 @@ private:
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ctx_dft) {
|
||||
if (ctx_dft && llama_get_ctx_other(ctx_dft.get()) != ctx_tgt) {
|
||||
// TODO: in the future, figure out how to infuse target embeddings to the images
|
||||
// for now, we skip this for simplicity
|
||||
// maybe we simply need to call `common_speculative_process()` on the mtmd batches in the `process_chunk` above?
|
||||
// [TAG_MTMD_DRAFT_PROCESSING]
|
||||
res = input_tokens.process_chunk(ctx_dft.get(), mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
|
||||
if (res != 0) {
|
||||
GGML_ABORT("failed to process multi-modal data on draft context\n");
|
||||
|
||||
Reference in New Issue
Block a user