mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-30 11:13:02 +02:00
Compare commits
13 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6c5de1cc83 | |||
| 86b94708f2 | |||
| 6f4f53f2b7 | |||
| 25a1d63f43 | |||
| 8c146a8366 | |||
| 6cb18b2f2e | |||
| 277a105dc8 | |||
| b3fed31b99 | |||
| dbdaece23d | |||
| 7cb8576e7c | |||
| fa72bc6826 | |||
| c818263f2a | |||
| f68a788b0b |
@@ -94,10 +94,8 @@ add_library(${TARGET}
|
||||
peg-parser.h
|
||||
preset.cpp
|
||||
preset.h
|
||||
regex-partial.cpp
|
||||
reasoning-budget.cpp
|
||||
reasoning-budget.h
|
||||
regex-partial.h
|
||||
sampling.cpp
|
||||
sampling.h
|
||||
speculative.cpp
|
||||
|
||||
@@ -3296,6 +3296,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.sampling.reasoning_budget_message = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_MESSAGE"));
|
||||
add_opt(common_arg(
|
||||
{"--reasoning-preserve"},
|
||||
{"--no-reasoning-preserve"},
|
||||
"preserve reasoning trace in the full history, not just the last assistant message (default: template default)\n"
|
||||
"compatible with certain templates having 'supports_preserve_reasoning' capability\n"
|
||||
"example: https://docs.z.ai/guides/capabilities/thinking-mode#preserved-thinking",
|
||||
[](common_params & params, bool value) {
|
||||
if (value) {
|
||||
params.default_template_kwargs["preserve_reasoning"] = "true";
|
||||
} else {
|
||||
params.default_template_kwargs["preserve_reasoning"] = "false";
|
||||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_REASONING_PRESERVE"));
|
||||
add_opt(common_arg(
|
||||
{"--chat-template"}, "JINJA_TEMPLATE",
|
||||
string_format(
|
||||
|
||||
+155
@@ -912,6 +912,10 @@ static std::string common_chat_template_direct_apply_impl(
|
||||
if (inputs.add_generation_prompt) {
|
||||
inp["add_generation_prompt"] = true;
|
||||
}
|
||||
if (inp.contains("preserve_reasoning") && inp["preserve_reasoning"].is_boolean()) {
|
||||
bool enabled = inp["preserve_reasoning"].get<bool>();
|
||||
jinja::caps_apply_preserve_reasoning(ctx, enabled);
|
||||
}
|
||||
|
||||
jinja::global_from_json(ctx, inp, inputs.mark_input);
|
||||
|
||||
@@ -2376,6 +2380,149 @@ static void func_args_not_string(json & messages) {
|
||||
|
||||
}
|
||||
|
||||
// MiniCPM5 format:
|
||||
// - Reasoning: <think>{reasoning}</think> (optional)
|
||||
// - Tool calls: <function name="foo"><param name="bar">value</param></function>
|
||||
static common_chat_params common_chat_params_init_minicpm5(const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.preserved_tokens = {
|
||||
"<function",
|
||||
"<param",
|
||||
"</function>",
|
||||
"</param>",
|
||||
"<think>",
|
||||
"</think>",
|
||||
};
|
||||
|
||||
data.thinking_start_tag = "<think>";
|
||||
data.thinking_end_tag = "</think>";
|
||||
|
||||
data.message_delimiters = {
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, "<|im_start|>assistant" },
|
||||
{ COMMON_CHAT_ROLE_TOOL, "<|im_start|>user\n<tool_response>" },
|
||||
{ COMMON_CHAT_ROLE_USER, "<|im_start|>user" },
|
||||
{ COMMON_CHAT_ROLE_SYSTEM, "<|im_start|>system" },
|
||||
};
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto has_response_format = inputs.json_schema.is_object() && !inputs.json_schema.empty();
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE);
|
||||
|
||||
if (inputs.has_continuation()) {
|
||||
const auto & msg = inputs.continue_msg;
|
||||
|
||||
data.generation_prompt = "<|im_start|>assistant\n<think>\n" + msg.reasoning_content;
|
||||
if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) {
|
||||
data.generation_prompt += "\n</think>\n\n" + msg.render_content();
|
||||
}
|
||||
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto generation_prompt = p.literal("<|im_start|>assistant\n");
|
||||
|
||||
auto reasoning = p.eps();
|
||||
if (extract_reasoning) {
|
||||
reasoning = ("<think>" << p.reasoning(p.until("</think>")) << "</think>") + p.space();
|
||||
}
|
||||
|
||||
// Response format parser
|
||||
if (has_response_format) {
|
||||
return generation_prompt + reasoning + p.content(p.schema(p.json(), "response-format", inputs.json_schema));
|
||||
}
|
||||
|
||||
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
// CDATA lets a value carry characters that would otherwise close the tag (e.g.
|
||||
// </param>); capture the inner text only, excluding the CDATA markers.
|
||||
auto string_value = p.choice({
|
||||
p.literal("<![CDATA[") + p.ac(p.tool_arg_string_value(p.until("]]>")) + p.literal("]]>"), "]]>") + p.tool_arg_close(p.literal("</param>")),
|
||||
p.negate(p.literal("< {
|
||||
const auto & function = tool.at("function");
|
||||
const std::string name = function.at("name");
|
||||
auto params = function.contains("parameters") ? function.at("parameters") : json::object();
|
||||
|
||||
auto args = p.eps();
|
||||
if (params.contains("properties") && params.at("properties").is_object() && !params.at("properties").empty()) {
|
||||
auto schema_info = common_schema_info();
|
||||
schema_info.resolve_refs(params);
|
||||
|
||||
auto arg_choice = p.choice();
|
||||
for (const auto & [prop_name, prop_schema] : params.at("properties").items()) {
|
||||
auto value_parser = p.eps();
|
||||
if (schema_info.resolves_to_string(prop_schema)) {
|
||||
value_parser = string_value;
|
||||
} else {
|
||||
value_parser = p.tool_arg_json_value(
|
||||
p.schema(p.json(), "tool-" + name + "-arg-" + prop_name + "-schema", prop_schema, false)
|
||||
) + p.tool_arg_close(p.literal("</param>"));
|
||||
}
|
||||
|
||||
auto arg_rule = p.tool_arg(
|
||||
p.tool_arg_open(p.literal("<param name=\"") + p.tool_arg_name(p.literal(prop_name)) + p.literal("\">")) +
|
||||
value_parser
|
||||
);
|
||||
|
||||
arg_choice |= arg_rule;
|
||||
}
|
||||
args = p.zero_or_more(arg_choice + p.space());
|
||||
}
|
||||
|
||||
auto tool_parser = p.tool(
|
||||
p.tool_open(p.literal("<function name=\"") + p.tool_name(p.literal(name)) + p.literal("\">"))
|
||||
<< p.tool_args(args)
|
||||
<< p.tool_close(p.literal("</function>")));
|
||||
|
||||
tool_choice |= p.rule("tool-" + name, tool_parser);
|
||||
});
|
||||
|
||||
auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
|
||||
auto tool_calls = p.trigger_rule("tool-call", p.repeat(tool_choice + p.space(), 1, max_calls));
|
||||
|
||||
auto content = p.content(p.until("<function"));
|
||||
|
||||
return generation_prompt + reasoning + content + tool_calls + p.end();
|
||||
}
|
||||
|
||||
return generation_prompt + reasoning + p.content(p.rest()) + p.end();
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
||||
if (include_grammar) {
|
||||
data.grammar_lazy = !(has_response_format || (has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED));
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
auto schema = function.contains("parameters") ? function.at("parameters") : json::object();
|
||||
builder.resolve_refs(schema);
|
||||
});
|
||||
if (has_response_format) {
|
||||
auto schema = inputs.json_schema;
|
||||
builder.resolve_refs(schema);
|
||||
}
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
|
||||
data.grammar_triggers = {
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function" },
|
||||
};
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
static json common_chat_extra_context() {
|
||||
json ctx = json::object();
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
@@ -2468,6 +2615,14 @@ std::optional<common_chat_params> common_chat_try_specialized_template(
|
||||
return common_chat_params_init_gemma4(tmpl, params);
|
||||
}
|
||||
|
||||
// MiniCPM5 - XML tool calls with <function name="..."><param name="...">...</param></function>
|
||||
if (src.find("Tool usage guidelines:") != std::string::npos &&
|
||||
src.find("<function name=\"") != std::string::npos &&
|
||||
src.find("<param name=\"") != std::string::npos) {
|
||||
LOG_DBG("Using specialized template: MiniCPM5\n");
|
||||
return common_chat_params_init_minicpm5(tmpl, params);
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
|
||||
+44
-23
@@ -16,22 +16,34 @@ using json = nlohmann::ordered_json;
|
||||
namespace jinja {
|
||||
|
||||
using caps_json_fn = std::function<json()>;
|
||||
using caps_analyze_fn = std::function<void(bool, value &, value &)>;
|
||||
using caps_ctx_fn = std::function<void(context &)>;
|
||||
using caps_analyze_fn = std::function<void(bool, value &, value &, const std::string &)>;
|
||||
|
||||
void caps_apply_preserve_reasoning(jinja::context & ctx, bool enabled) {
|
||||
ctx.set_val("preserve_thinking", mk_val<value_bool>(enabled));
|
||||
ctx.set_val("clear_thinking", mk_val<value_bool>(!enabled));
|
||||
ctx.set_val("truncate_history_thinking", mk_val<value_bool>(!enabled));
|
||||
}
|
||||
|
||||
static void caps_try_execute(jinja::program & prog,
|
||||
const caps_json_fn & messages_fn,
|
||||
const caps_ctx_fn & ctx_fn,
|
||||
const caps_json_fn & tools_fn,
|
||||
const caps_analyze_fn & analyze_fn) {
|
||||
context ctx;
|
||||
ctx.is_get_stats = true;
|
||||
jinja::global_from_json(ctx, json{
|
||||
{"messages", messages_fn()},
|
||||
{"tools", tools_fn()},
|
||||
{"tools", tools_fn ? tools_fn() : json::array()},
|
||||
{"bos_token", ""},
|
||||
{"eos_token", ""},
|
||||
{"add_generation_prompt", true}
|
||||
}, true);
|
||||
|
||||
if (ctx_fn) {
|
||||
ctx_fn(ctx);
|
||||
}
|
||||
|
||||
auto messages = ctx.get_val("messages");
|
||||
auto tools = ctx.get_val("tools");
|
||||
|
||||
@@ -49,7 +61,7 @@ static void caps_try_execute(jinja::program & prog,
|
||||
// ignore exceptions during capability analysis
|
||||
}
|
||||
|
||||
analyze_fn(success, messages, tools);
|
||||
analyze_fn(success, messages, tools, result);
|
||||
}
|
||||
|
||||
// for debugging only
|
||||
@@ -109,11 +121,9 @@ caps caps_get(jinja::program & prog) {
|
||||
}
|
||||
});
|
||||
},
|
||||
[&]() {
|
||||
// tools
|
||||
return json{nullptr};
|
||||
},
|
||||
[&](bool success, value & messages, value &) {
|
||||
nullptr, // ctx_fn
|
||||
nullptr, // tools_fn
|
||||
[&](bool success, value & messages, value &, const std::string &) {
|
||||
auto & content = messages->at(0)->at("content");
|
||||
caps_print_stats(content, "messages[0].content");
|
||||
if (has_op(content, "selectattr") || has_op(content, "array_access")) {
|
||||
@@ -145,11 +155,9 @@ caps caps_get(jinja::program & prog) {
|
||||
},
|
||||
});
|
||||
},
|
||||
[&]() {
|
||||
// tools
|
||||
return json::array();
|
||||
},
|
||||
[&](bool, value & messages, value &) {
|
||||
nullptr, // ctx_fn
|
||||
nullptr, // tools_fn
|
||||
[&](bool, value & messages, value &, const std::string &) {
|
||||
auto & content = messages->at(0)->at("content");
|
||||
caps_print_stats(content, "messages[0].content");
|
||||
if (!content->stats.used) {
|
||||
@@ -201,6 +209,7 @@ caps caps_get(jinja::program & prog) {
|
||||
},
|
||||
});
|
||||
},
|
||||
nullptr, // ctx_fn
|
||||
[&]() {
|
||||
// tools
|
||||
return json::array({
|
||||
@@ -224,7 +233,7 @@ caps caps_get(jinja::program & prog) {
|
||||
},
|
||||
});
|
||||
},
|
||||
[&](bool success, value & messages, value & tools) {
|
||||
[&](bool success, value & messages, value & tools, const std::string &) {
|
||||
if (!success) {
|
||||
return; // Nothing can be inferred
|
||||
}
|
||||
@@ -293,6 +302,7 @@ caps caps_get(jinja::program & prog) {
|
||||
},
|
||||
});
|
||||
},
|
||||
nullptr, // ctx_fn
|
||||
[&]() {
|
||||
// tools
|
||||
return json::array({
|
||||
@@ -316,7 +326,7 @@ caps caps_get(jinja::program & prog) {
|
||||
},
|
||||
});
|
||||
},
|
||||
[&](bool success, value & messages, value & tools) {
|
||||
[&](bool success, value & messages, value & tools, const std::string &) {
|
||||
if (!success) {
|
||||
result.supports_tool_calls = false;
|
||||
result.supports_tools = false;
|
||||
@@ -394,6 +404,7 @@ caps caps_get(jinja::program & prog) {
|
||||
},
|
||||
});
|
||||
},
|
||||
nullptr, // ctx_fn
|
||||
[&]() {
|
||||
// tools
|
||||
return json::array({
|
||||
@@ -417,7 +428,7 @@ caps caps_get(jinja::program & prog) {
|
||||
},
|
||||
});
|
||||
},
|
||||
[&](bool success, value & messages, value & /*tools*/) {
|
||||
[&](bool success, value & messages, value &, const std::string &) {
|
||||
if (!success) {
|
||||
result.supports_parallel_tool_calls = false;
|
||||
return;
|
||||
@@ -438,11 +449,22 @@ caps caps_get(jinja::program & prog) {
|
||||
JJ_DEBUG("%s\n", ">>> Running capability check: preserve reasoning");
|
||||
|
||||
// case: preserve reasoning content in chat history
|
||||
const std::string reasoning_placeholder = "<REASONING_CONTENT_PLACEHOLDER>";
|
||||
caps_try_execute(
|
||||
prog,
|
||||
[&]() {
|
||||
// messages
|
||||
return json::array({
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", "User message"}
|
||||
},
|
||||
{
|
||||
{"role", "assistant"},
|
||||
{"content", "Assistant message"},
|
||||
// check of reasoning_content deeper in the history, not just the last assistant message
|
||||
{"reasoning_content", reasoning_placeholder}
|
||||
},
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", "User message"}
|
||||
@@ -458,14 +480,13 @@ caps caps_get(jinja::program & prog) {
|
||||
},
|
||||
});
|
||||
},
|
||||
[&]() {
|
||||
// tools
|
||||
return json::array();
|
||||
[&](context & ctx) {
|
||||
caps_apply_preserve_reasoning(ctx, true);
|
||||
},
|
||||
[&](bool, value & messages, value &) {
|
||||
auto & content = messages->at(1)->at("reasoning_content");
|
||||
caps_print_stats(content, "messages[1].reasoning_content");
|
||||
if (content->stats.used) {
|
||||
nullptr, // tools_fn
|
||||
[&](bool, value &, value &, const std::string & output) {
|
||||
// note: we cannot use stats here because the reasoning_content may be used for "if" condition test, but not actually outputted in the final result
|
||||
if (output.find(reasoning_placeholder) != std::string::npos) {
|
||||
result.supports_preserve_reasoning = true;
|
||||
}
|
||||
}
|
||||
|
||||
+5
-1
@@ -12,7 +12,9 @@ struct caps {
|
||||
bool supports_tool_calls = true;
|
||||
bool supports_system_role = true;
|
||||
bool supports_parallel_tool_calls = true;
|
||||
bool supports_preserve_reasoning = false; // support assistant message with reasoning_content
|
||||
|
||||
// supports preserve reasoning trace in the full history, not just the last assistant message
|
||||
bool supports_preserve_reasoning = false;
|
||||
|
||||
// one of the 2 content capabilities must be true
|
||||
bool supports_string_content = true;
|
||||
@@ -29,4 +31,6 @@ struct caps {
|
||||
|
||||
caps caps_get(jinja::program & prog);
|
||||
|
||||
void caps_apply_preserve_reasoning(jinja::context & ctx, bool enabled);
|
||||
|
||||
} // namespace jinja
|
||||
|
||||
@@ -954,4 +954,50 @@ value keyword_argument_expression::execute_impl(context & ctx) {
|
||||
return mk_val<value_kwarg>(k, v);
|
||||
}
|
||||
|
||||
std::string runtime::debug_dump_program(const program & prog, const std::string & src) {
|
||||
std::ostringstream oss;
|
||||
size_t lvl = 0;
|
||||
context ctx;
|
||||
ctx.src.reset(new std::string(src));
|
||||
|
||||
auto indent = [](size_t lvl) -> std::string {
|
||||
return std::string(lvl * 2, ' ');
|
||||
};
|
||||
|
||||
ctx.visitor = [&](bool is_leaf, statement * node, std::vector<visitor_pair> children) {
|
||||
oss << indent(lvl) << node->type() << ":\n";
|
||||
lvl++;
|
||||
if (is_leaf) {
|
||||
const auto & pos = node->pos;
|
||||
oss << indent(lvl) << "(leaf) at " << get_line_col(src, pos) << " in source:\n";
|
||||
std::string snippet = peak_source(src, pos);
|
||||
string_replace_all(snippet, "\n", "\n" + indent(lvl));
|
||||
oss << indent(lvl) << snippet << "\n";
|
||||
} else {
|
||||
for (auto & [label, children_vec] : children) {
|
||||
oss << indent(lvl) << label << ":\n";
|
||||
lvl++;
|
||||
if (children_vec.empty()) {
|
||||
oss << indent(lvl) << "<empty>\n\n";
|
||||
} else {
|
||||
for (auto * child : children_vec) {
|
||||
if (!child) {
|
||||
continue;
|
||||
}
|
||||
child->visit(ctx);
|
||||
}
|
||||
}
|
||||
lvl--;
|
||||
}
|
||||
}
|
||||
lvl--;
|
||||
};
|
||||
|
||||
for (const auto & stmt : prog.body) {
|
||||
stmt->visit(ctx);
|
||||
}
|
||||
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
} // namespace jinja
|
||||
|
||||
@@ -47,12 +47,19 @@ const T * cast_stmt(const statement_ptr & ptr) {
|
||||
// not thread-safe
|
||||
void enable_debug(bool enable);
|
||||
|
||||
// for visiting AST nodes
|
||||
// function signature: void(bool is_leaf, statement * node, pair of <label, children>)
|
||||
using visitor_pair = std::pair<std::string, std::vector<statement *>>;
|
||||
using visitor_fn = std::function<void(bool, statement *, std::vector<visitor_pair>)>;
|
||||
|
||||
struct context {
|
||||
std::shared_ptr<std::string> src; // for debugging; use shared_ptr to avoid copying on scope creation
|
||||
std::time_t current_time; // for functions that need current time
|
||||
|
||||
bool is_get_stats = false; // whether to collect stats
|
||||
|
||||
visitor_fn visitor;
|
||||
|
||||
// src is optional, used for error reporting
|
||||
context(std::string src = "") : src(std::make_shared<std::string>(std::move(src))) {
|
||||
env = mk_val<value_object>();
|
||||
@@ -99,6 +106,15 @@ private:
|
||||
value_object env;
|
||||
};
|
||||
|
||||
// utils for visiting AST nodes
|
||||
static std::vector<statement *> stmts_to_ptr(const statements & stmts) {
|
||||
std::vector<statement *> children;
|
||||
for (const auto & stmt : stmts) {
|
||||
children.push_back(stmt.get());
|
||||
}
|
||||
return children;
|
||||
}
|
||||
|
||||
/**
|
||||
* Base class for all nodes in the AST.
|
||||
*/
|
||||
@@ -106,6 +122,7 @@ struct statement {
|
||||
size_t pos; // position in source, for debugging
|
||||
virtual ~statement() = default;
|
||||
virtual std::string type() const { return "Statement"; }
|
||||
virtual void visit(context & ctx) { ctx.visitor(true, this, {}); }
|
||||
|
||||
// execute_impl must be overridden by derived classes
|
||||
virtual value execute_impl(context &) { throw_exec_error(); }
|
||||
@@ -166,6 +183,13 @@ struct if_statement : public statement {
|
||||
|
||||
std::string type() const override { return "If"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"test", {test.get()}},
|
||||
{"body", stmts_to_ptr(body)},
|
||||
{"alternate", stmts_to_ptr(alternate)}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct identifier;
|
||||
@@ -190,6 +214,14 @@ struct for_statement : public statement {
|
||||
|
||||
std::string type() const override { return "For"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"loopvar", {loopvar.get()}},
|
||||
{"iterable", {iterable.get()}},
|
||||
{"body", stmts_to_ptr(body)},
|
||||
{"default_block", stmts_to_ptr(default_block)}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct break_statement : public statement {
|
||||
@@ -241,6 +273,13 @@ struct set_statement : public statement {
|
||||
|
||||
std::string type() const override { return "Set"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"assignee", {assignee.get()}},
|
||||
{"value", {val.get()}},
|
||||
{"body", stmts_to_ptr(body)}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct macro_statement : public statement {
|
||||
@@ -256,6 +295,13 @@ struct macro_statement : public statement {
|
||||
|
||||
std::string type() const override { return "Macro"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"name", {name.get()}},
|
||||
{"args", stmts_to_ptr(args)},
|
||||
{"body", stmts_to_ptr(body)}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct comment_statement : public statement {
|
||||
@@ -289,6 +335,12 @@ struct member_expression : public expression {
|
||||
}
|
||||
std::string type() const override { return "MemberExpression"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"object", {object.get()}},
|
||||
{"property", {property.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct call_expression : public expression {
|
||||
@@ -302,6 +354,12 @@ struct call_expression : public expression {
|
||||
}
|
||||
std::string type() const override { return "CallExpression"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"callee", {callee.get()}},
|
||||
{"args", stmts_to_ptr(args)}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -405,6 +463,12 @@ struct binary_expression : public expression {
|
||||
}
|
||||
std::string type() const override { return "BinaryExpression"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"left", {left.get()}},
|
||||
{"right", {right.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -431,6 +495,12 @@ struct filter_expression : public expression {
|
||||
|
||||
std::string type() const override { return "FilterExpression"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"operand", {operand.get()}},
|
||||
{"filter", {filter.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct filter_statement : public statement {
|
||||
@@ -443,6 +513,12 @@ struct filter_statement : public statement {
|
||||
}
|
||||
std::string type() const override { return "FilterStatement"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"filter", {filter.get()}},
|
||||
{"body", stmts_to_ptr(body)}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -468,6 +544,12 @@ struct select_expression : public expression {
|
||||
}
|
||||
return lhs->execute_impl(ctx);
|
||||
}
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"lhs", {lhs.get()}},
|
||||
{"test", {test.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -486,6 +568,12 @@ struct test_expression : public expression {
|
||||
}
|
||||
std::string type() const override { return "TestExpression"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"operand", {operand.get()}},
|
||||
{"test", {test.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -501,6 +589,11 @@ struct unary_expression : public expression {
|
||||
}
|
||||
std::string type() const override { return "UnaryExpression"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"argument", {argument.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct slice_expression : public expression {
|
||||
@@ -518,6 +611,13 @@ struct slice_expression : public expression {
|
||||
[[noreturn]] value execute_impl(context &) override {
|
||||
throw std::runtime_error("must be handled by MemberExpression");
|
||||
}
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"start_expr", {start_expr.get()}},
|
||||
{"stop_expr", {stop_expr.get()}},
|
||||
{"step_expr", {step_expr.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct keyword_argument_expression : public expression {
|
||||
@@ -531,6 +631,12 @@ struct keyword_argument_expression : public expression {
|
||||
}
|
||||
std::string type() const override { return "KeywordArgumentExpression"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"key", {key.get()}},
|
||||
{"val", {val.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct spread_expression : public expression {
|
||||
@@ -539,6 +645,11 @@ struct spread_expression : public expression {
|
||||
chk_type<expression>(this->argument);
|
||||
}
|
||||
std::string type() const override { return "SpreadExpression"; }
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"argument", {argument.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct call_statement : public statement {
|
||||
@@ -553,6 +664,13 @@ struct call_statement : public statement {
|
||||
}
|
||||
std::string type() const override { return "CallStatement"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"call", {call.get()}},
|
||||
{"caller_args", stmts_to_ptr(caller_args)},
|
||||
{"body", stmts_to_ptr(body)}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct ternary_expression : public expression {
|
||||
@@ -575,6 +693,13 @@ struct ternary_expression : public expression {
|
||||
return false_expr->execute(ctx);
|
||||
}
|
||||
}
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"condition", {condition.get()}},
|
||||
{"true_expr", {true_expr.get()}},
|
||||
{"false_expr", {false_expr.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct raised_exception : public std::exception {
|
||||
@@ -648,6 +773,8 @@ struct runtime {
|
||||
}
|
||||
return parts;
|
||||
}
|
||||
|
||||
static std::string debug_dump_program(const program & prog, const std::string & src);
|
||||
};
|
||||
|
||||
} // namespace jinja
|
||||
|
||||
@@ -1108,6 +1108,50 @@ const func_builtins & value_array_t::get_builtins() const {
|
||||
std::reverse(arr.begin(), arr.end());
|
||||
return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(arr)) : mk_val<value_array>(std::move(arr));
|
||||
}},
|
||||
{"min", [](const func_args & args) -> value {
|
||||
args.ensure_count(1, 4);
|
||||
args.ensure_vals<value_array>();
|
||||
value val_case = args.get_kwarg_or_pos("case_sensitive", 1);
|
||||
value attribute = args.get_kwarg_or_pos("attribute", 2);
|
||||
if (!attribute->is_undefined()) {
|
||||
throw not_implemented_exception("min: attribute not implemented");
|
||||
}
|
||||
// FIXME: min is currently always case sensitive
|
||||
(void) val_case;
|
||||
const auto & arr = args.get_pos(0)->as_array();
|
||||
if (arr.empty()) {
|
||||
return mk_val<value_undefined>();
|
||||
}
|
||||
value result = arr[0];
|
||||
for (size_t i = 1; i < arr.size(); ++i) {
|
||||
if (value_compare(arr[i], result, value_compare_op::lt)) {
|
||||
result = arr[i];
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}},
|
||||
{"max", [](const func_args & args) -> value {
|
||||
args.ensure_count(1, 4);
|
||||
args.ensure_vals<value_array>();
|
||||
value val_case = args.get_kwarg_or_pos("case_sensitive", 1);
|
||||
value attribute = args.get_kwarg_or_pos("attribute", 2);
|
||||
if (!attribute->is_undefined()) {
|
||||
throw not_implemented_exception("max: attribute not implemented");
|
||||
}
|
||||
// FIXME: max is currently always case sensitive
|
||||
(void) val_case;
|
||||
const auto & arr = args.get_pos(0)->as_array();
|
||||
if (arr.empty()) {
|
||||
return mk_val<value_undefined>();
|
||||
}
|
||||
value result = arr[0];
|
||||
for (size_t i = 1; i < arr.size(); ++i) {
|
||||
if (value_compare(arr[i], result, value_compare_op::gt)) {
|
||||
result = arr[i];
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}},
|
||||
{"unique", array_unique_not_implemented},
|
||||
};
|
||||
return builtins;
|
||||
|
||||
+29
-4
@@ -7,6 +7,7 @@
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <filesystem>
|
||||
#include <regex>
|
||||
|
||||
static std::string rm_leading_dashes(const std::string & str) {
|
||||
size_t pos = 0;
|
||||
@@ -16,6 +17,23 @@ static std::string rm_leading_dashes(const std::string & str) {
|
||||
return str.substr(pos);
|
||||
}
|
||||
|
||||
static std::string canonical_tag(const std::string & tag) {
|
||||
static const std::regex re_tag("[-.]([A-Z0-9_]+)$", std::regex::icase);
|
||||
std::smatch m;
|
||||
if (std::regex_search(tag, m, re_tag)) {
|
||||
std::string canon = m[1].str();
|
||||
for (char & c : canon) {
|
||||
c = (char) std::toupper((unsigned char) c);
|
||||
}
|
||||
return canon;
|
||||
}
|
||||
std::string upper = tag;
|
||||
for (char & c : upper) {
|
||||
c = (char) std::toupper((unsigned char) c);
|
||||
}
|
||||
return upper;
|
||||
}
|
||||
|
||||
std::vector<std::string> common_preset::to_args(const std::string & bin_path) const {
|
||||
std::vector<std::string> args;
|
||||
|
||||
@@ -270,11 +288,18 @@ common_presets common_preset_context::load_from_ini(const std::string & path, co
|
||||
|
||||
for (auto section : ini_data) {
|
||||
common_preset preset;
|
||||
if (section.first.empty()) {
|
||||
preset.name = COMMON_PRESET_DEFAULT_NAME;
|
||||
} else {
|
||||
preset.name = section.first;
|
||||
std::string section_name = section.first.empty() ? std::string(COMMON_PRESET_DEFAULT_NAME) : section.first;
|
||||
if (section_name != "*" && section_name != COMMON_PRESET_DEFAULT_NAME) {
|
||||
auto colon_idx = section_name.rfind(':');
|
||||
if (colon_idx != std::string::npos) {
|
||||
std::string tag = section_name.substr(colon_idx + 1);
|
||||
std::string canon_tag = canonical_tag(tag);
|
||||
if (canon_tag != tag) {
|
||||
section_name = section_name.substr(0, colon_idx + 1) + canon_tag;
|
||||
}
|
||||
}
|
||||
}
|
||||
preset.name = section_name;
|
||||
LOG_DBG("loading preset: %s\n", preset.name.c_str());
|
||||
for (const auto & [key, value] : section.second) {
|
||||
if (key == "version") {
|
||||
|
||||
@@ -1,204 +0,0 @@
|
||||
#include "regex-partial.h"
|
||||
#include "common.h"
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
|
||||
common_regex::common_regex(const std::string & pattern) :
|
||||
pattern(pattern),
|
||||
rx(pattern),
|
||||
rx_reversed_partial(regex_to_reversed_partial_regex(pattern)) {}
|
||||
|
||||
common_regex_match common_regex::search(const std::string & input, size_t pos, bool as_match) const {
|
||||
std::smatch match;
|
||||
if (pos > input.size()) {
|
||||
throw std::runtime_error("Position out of bounds");
|
||||
}
|
||||
auto start = input.begin() + pos;
|
||||
auto found = as_match
|
||||
? std::regex_match(start, input.end(), match, rx)
|
||||
: std::regex_search(start, input.end(), match, rx);
|
||||
if (found) {
|
||||
common_regex_match res;
|
||||
res.type = COMMON_REGEX_MATCH_TYPE_FULL;
|
||||
for (size_t i = 0; i < match.size(); ++i) {
|
||||
auto begin = pos + match.position(i);
|
||||
res.groups.emplace_back(begin, begin + match.length(i));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
std::match_results<std::string::const_reverse_iterator> srmatch;
|
||||
if (std::regex_search(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial, std::regex_constants::match_continuous)) {
|
||||
auto group = srmatch[1].str();
|
||||
if (group.length() != 0) {
|
||||
auto it = srmatch[1].second.base();
|
||||
// auto position = static_cast<size_t>(std::distance(input.begin(), it));
|
||||
if ((!as_match) || it == input.begin()) {
|
||||
common_regex_match res;
|
||||
res.type = COMMON_REGEX_MATCH_TYPE_PARTIAL;
|
||||
const size_t begin = std::distance(input.begin(), it);
|
||||
const size_t end = input.size();
|
||||
if (begin == std::string::npos || end == std::string::npos || begin > end) {
|
||||
throw std::runtime_error("Invalid range");
|
||||
}
|
||||
res.groups.push_back({begin, end});
|
||||
return res;
|
||||
}
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
/*
|
||||
Transforms a regex pattern to a partial match pattern that operates on a reversed input string to find partial final matches of the original pattern.
|
||||
|
||||
Ideally we'd like to use boost::match_partial (https://beta.boost.org/doc/libs/1_59_0/libs/regex/doc/html/boost_regex/partial_matches.html)
|
||||
to see if a string ends with a partial regex match, but but it's not in std::regex yet.
|
||||
Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input.
|
||||
|
||||
- /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:(?:d)?c)?b)?a)
|
||||
- /a|b/ -> ^(a|b)
|
||||
- /a*?/ -> error, could match ""
|
||||
- /a*b/ -> ^((?:b)?a*+) (final repetitions become eager)
|
||||
- /.*?ab/ -> ^((?:b)?a) (omit .*)
|
||||
- /a.*?b/ -> ^((?:b)?.*?a) (keep reluctant matches)
|
||||
- /a(bc)d/ -> ^((?:(?:d)?(?:(?:c)?b))?a)
|
||||
- /a(bc|de)/ -> ^((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a)
|
||||
- /ab{2,4}c/ -> ^cbbb?b?a -> ^((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a)
|
||||
|
||||
The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern.
|
||||
All other groups are turned into non-capturing groups, and reluctant quantifiers are ignored.
|
||||
*/
|
||||
std::string regex_to_reversed_partial_regex(const std::string & pattern) {
|
||||
auto it = pattern.begin();
|
||||
const auto end = pattern.end();
|
||||
|
||||
std::function<std::string()> process = [&]() {
|
||||
std::vector<std::vector<std::string>> alternatives(1);
|
||||
std::vector<std::string> * sequence = &alternatives.back();
|
||||
|
||||
while (it != end) {
|
||||
if (*it == '[') {
|
||||
auto start = it;
|
||||
++it;
|
||||
while (it != end) {
|
||||
if ((*it == '\\') && (++it != end)) {
|
||||
++it;
|
||||
} else if ((it != end) && (*it == ']')) {
|
||||
break;
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
if (it == end) {
|
||||
throw std::runtime_error("Unmatched '[' in pattern");
|
||||
}
|
||||
++it;
|
||||
sequence->push_back(std::string(start, it));
|
||||
} else if (*it == '*' || *it == '?' || *it == '+') {
|
||||
if (sequence->empty()) {
|
||||
throw std::runtime_error("Quantifier without preceding element");
|
||||
}
|
||||
sequence->back() += *it;
|
||||
auto is_star = *it == '*';
|
||||
++it;
|
||||
if (is_star) {
|
||||
if (it != end && *it == '?') {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
} else if (*it == '{') {
|
||||
if (sequence->empty()) {
|
||||
throw std::runtime_error("Repetition without preceding element");
|
||||
}
|
||||
++it;
|
||||
auto start = it;
|
||||
while (it != end && *it != '}') {
|
||||
++it;
|
||||
}
|
||||
if (it == end) {
|
||||
throw std::runtime_error("Unmatched '{' in pattern");
|
||||
}
|
||||
auto parts = string_split(std::string(start, it), ",");
|
||||
++it;
|
||||
if (parts.size() > 2) {
|
||||
throw std::runtime_error("Invalid repetition range in pattern");
|
||||
}
|
||||
|
||||
auto parseOptInt = [&](const std::string & s, const std::optional<int> & def = std::nullopt) -> std::optional<int> {
|
||||
if (s.empty()) {
|
||||
return def;
|
||||
}
|
||||
return std::stoi(s);
|
||||
};
|
||||
auto min = parseOptInt(parts[0], 0);
|
||||
auto max = parts.size() == 1 ? min : parseOptInt(parts[1]);
|
||||
if (min && max && *max < *min) {
|
||||
throw std::runtime_error("Invalid repetition range in pattern");
|
||||
}
|
||||
// Brutal but... let's repeat at least min times, then ? for the delta between min & max (or * for unbounded)
|
||||
auto part = sequence->back();
|
||||
sequence->pop_back();
|
||||
for (int i = 0; i < *min; i++) {
|
||||
sequence->push_back(part);
|
||||
}
|
||||
if (max) {
|
||||
for (int i = *min; i < *max; i++) {
|
||||
sequence->push_back(part + "?");
|
||||
}
|
||||
} else {
|
||||
sequence->push_back(part + "*");
|
||||
}
|
||||
} else if (*it == '(') {
|
||||
++it;
|
||||
if (it != end && *it == '?' && (it + 1 != end) && *(it + 1) == ':') {
|
||||
it += 2;
|
||||
}
|
||||
auto sub = process();
|
||||
if (*it != ')') {
|
||||
throw std::runtime_error("Unmatched '(' in pattern");
|
||||
}
|
||||
++it;
|
||||
auto & part = sequence->emplace_back("(?:");
|
||||
part += sub;
|
||||
part += ")";
|
||||
} else if (*it == ')') {
|
||||
break;
|
||||
} else if (*it == '|') {
|
||||
++it;
|
||||
alternatives.emplace_back();
|
||||
sequence = &alternatives.back();
|
||||
} else if (*it == '\\' && (++it != end)) {
|
||||
auto str = std::string("\\") + *it;
|
||||
sequence->push_back(str);
|
||||
++it;
|
||||
} else if (it != end) {
|
||||
sequence->push_back(std::string(1, *it));
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
// /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:d)?c)?b)?a)
|
||||
// if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group
|
||||
// We'll do the outermost capturing group and final .* in the enclosing function.
|
||||
std::vector<std::string> res_alts;
|
||||
for (const auto & parts : alternatives) {
|
||||
auto & res = res_alts.emplace_back();
|
||||
for (size_t i = 0; i < parts.size() - 1; i++) {
|
||||
res += "(?:";
|
||||
}
|
||||
for (auto it = parts.rbegin(); it != parts.rend(); ++it) {
|
||||
res += *it;
|
||||
if (it != parts.rend() - 1) {
|
||||
res += ")?";
|
||||
}
|
||||
}
|
||||
}
|
||||
return string_join(res_alts, "|");
|
||||
};
|
||||
auto res = process();
|
||||
if (it != end) {
|
||||
throw std::runtime_error("Unmatched '(' in pattern");
|
||||
}
|
||||
|
||||
return "^(" + res + ")";
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <regex>
|
||||
#include <string>
|
||||
|
||||
enum common_regex_match_type {
|
||||
COMMON_REGEX_MATCH_TYPE_NONE,
|
||||
COMMON_REGEX_MATCH_TYPE_PARTIAL,
|
||||
COMMON_REGEX_MATCH_TYPE_FULL,
|
||||
};
|
||||
|
||||
struct common_string_range {
|
||||
size_t begin;
|
||||
size_t end;
|
||||
common_string_range(size_t begin, size_t end) : begin(begin), end(end) {
|
||||
if (begin > end) {
|
||||
throw std::runtime_error("Invalid range");
|
||||
}
|
||||
}
|
||||
// prevent default ctor
|
||||
common_string_range() = delete;
|
||||
bool empty() const {
|
||||
return begin == end;
|
||||
}
|
||||
bool operator==(const common_string_range & other) const {
|
||||
return begin == other.begin && end == other.end;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_regex_match {
|
||||
common_regex_match_type type = COMMON_REGEX_MATCH_TYPE_NONE;
|
||||
std::vector<common_string_range> groups;
|
||||
|
||||
bool operator==(const common_regex_match & other) const {
|
||||
return type == other.type && groups == other.groups;
|
||||
}
|
||||
bool operator!=(const common_regex_match & other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
};
|
||||
|
||||
class common_regex {
|
||||
std::string pattern;
|
||||
std::regex rx;
|
||||
std::regex rx_reversed_partial;
|
||||
|
||||
public:
|
||||
explicit common_regex(const std::string & pattern);
|
||||
|
||||
common_regex_match search(const std::string & input, size_t pos, bool as_match = false) const;
|
||||
|
||||
const std::string & str() const { return pattern; }
|
||||
};
|
||||
|
||||
// For testing only (pretty print of failures).
|
||||
std::string regex_to_reversed_partial_regex(const std::string & pattern);
|
||||
@@ -51,6 +51,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
||||
"DeepseekV3ForCausalLM": "deepseek",
|
||||
"DeepseekV32ForCausalLM": "deepseek",
|
||||
"DFlashDraftModel": "qwen",
|
||||
"DeepseekV4ForCausalLM": "deepseek",
|
||||
"DistilBertForMaskedLM": "bert",
|
||||
"DistilBertForSequenceClassification": "bert",
|
||||
"DistilBertModel": "bert",
|
||||
|
||||
+14
-1
@@ -1273,7 +1273,7 @@ class TextModel(ModelBase):
|
||||
if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None:
|
||||
self.gguf_writer.add_layer_norm_eps(f_norm_eps)
|
||||
logger.info(f"gguf: layer norm epsilon = {f_norm_eps}")
|
||||
if (n_experts := self.find_hparam(["num_local_experts", "num_experts"], optional=True)) is not None:
|
||||
if (n_experts := self.find_hparam(["num_local_experts", "num_experts", "n_routed_experts"], optional=True)) is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
logger.info(f"gguf: expert count = {n_experts}")
|
||||
if (n_experts_used := self.find_hparam(["num_experts_per_tok", "num_experts_per_token", "top_k_experts"], optional=True)) is not None:
|
||||
@@ -1291,6 +1291,8 @@ class TextModel(ModelBase):
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
elif score_func == "softmax":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
|
||||
elif score_func == "sqrtsoftplus":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SQRTSOFTPLUS)
|
||||
else:
|
||||
raise ValueError(f"Unsupported expert score gating function value: {score_func}")
|
||||
logger.info(f"gguf: expert score gating function = {score_func}")
|
||||
@@ -2600,6 +2602,17 @@ class LazyTorchTensor(gguf.LazyBase):
|
||||
return cls._wrap_fn(func)(*args, **kwargs)
|
||||
|
||||
|
||||
if hasattr(torch, "float8_e8m0fnu"):
|
||||
_torch_float8_e8m0 = torch.float8_e8m0fnu
|
||||
LazyTorchTensor._dtype_map[_torch_float8_e8m0] = np.uint8
|
||||
LazyTorchTensor._dtype_byteswap_map[_torch_float8_e8m0] = np.uint8
|
||||
LazyTorchTensor._dtype_str_map["F8_E8M0"] = _torch_float8_e8m0
|
||||
else:
|
||||
# Older torch builds do not expose F8_E8M0. Keep the raw bytes so callers
|
||||
# that know the format can decode them explicitly.
|
||||
LazyTorchTensor._dtype_str_map["F8_E8M0"] = torch.uint8
|
||||
|
||||
|
||||
def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> str:
|
||||
# TODO @ngxson : this won't work correctly if the model has both audio & vision encoders
|
||||
# maybe we should fallback to text model's arch in that case, since not many models have both
|
||||
|
||||
+308
-1
@@ -1,15 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from typing import Any, Callable, Iterable, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch import Tensor
|
||||
|
||||
from .base import MmprojModel, ModelBase, TextModel, gguf, logger
|
||||
from .base import LazyTorchTensor, MmprojModel, ModelBase, TextModel, gguf, logger
|
||||
|
||||
from .qwen import QwenModel
|
||||
|
||||
@@ -467,3 +470,307 @@ class DeepseekV32Model(DeepseekV2Model):
|
||||
self.gguf_writer.add_indexer_head_count(self.hparams["index_n_heads"])
|
||||
self.gguf_writer.add_indexer_key_length(self.hparams["index_head_dim"])
|
||||
self.gguf_writer.add_indexer_top_k(self.hparams["index_topk"])
|
||||
|
||||
|
||||
@ModelBase.register("DeepseekV4ForCausalLM")
|
||||
class DeepseekV4Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.DEEPSEEK4
|
||||
_skipped_mtp_tensors = 0
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
type(self)._skipped_mtp_tensors = 0
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
with open(self.dir_model / "config.json", "r", encoding="utf-8") as f:
|
||||
raw_hparams = json.load(f)
|
||||
for key, value in raw_hparams.items():
|
||||
self.hparams.setdefault(key, value)
|
||||
|
||||
self.block_count = self.hparams["num_hidden_layers"]
|
||||
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||
|
||||
self._dsv4_fp8_dequantized: set[str] = set()
|
||||
self._dsv4_bf16_tensors: set[str] = set()
|
||||
self._dsv4_f32_tensors: set[str] = set()
|
||||
self._dsv4_mxfp4_generated = False
|
||||
self._collect_source_dtypes()
|
||||
|
||||
if type(self)._skipped_mtp_tensors:
|
||||
logger.info("Skipping %d DeepSeek-V4 MTP tensor(s) for conversion v0", type(self)._skipped_mtp_tensors)
|
||||
|
||||
# add a default chat template; if the model has a built-in template, it will be overridden later
|
||||
template_path = Path(__file__).parent.parent / "models" / "templates" / "deepseek-ai-DeepSeek-V4.jinja"
|
||||
if template_path.is_file():
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
self.gguf_writer.add_chat_template(f.read())
|
||||
|
||||
@classmethod
|
||||
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
|
||||
name, _ = item
|
||||
if name.startswith("mtp."):
|
||||
cls._skipped_mtp_tensors += 1
|
||||
return None
|
||||
return super().filter_tensors(item)
|
||||
|
||||
@staticmethod
|
||||
def _float8_dtypes() -> tuple[torch.dtype, ...]:
|
||||
return tuple(
|
||||
dtype for dtype in (
|
||||
getattr(torch, "float8_e4m3fn", None),
|
||||
getattr(torch, "float8_e5m2", None),
|
||||
) if dtype is not None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _e8m0_to_float(scale: Tensor) -> Tensor:
|
||||
torch_float8_e8m0 = getattr(torch, "float8_e8m0fnu", None)
|
||||
if torch_float8_e8m0 is not None and scale.dtype == torch_float8_e8m0:
|
||||
return scale.float()
|
||||
|
||||
bits = scale.view(torch.uint8).float()
|
||||
return torch.exp2(bits - 127.0)
|
||||
|
||||
def _collect_source_dtypes(self) -> None:
|
||||
for name, gen in self.model_tensors.items():
|
||||
dtype = gen().dtype
|
||||
if dtype == torch.bfloat16:
|
||||
self._dsv4_bf16_tensors.add(name)
|
||||
elif dtype == torch.float32:
|
||||
self._dsv4_f32_tensors.add(name)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
|
||||
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
||||
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
|
||||
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
|
||||
|
||||
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
|
||||
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
|
||||
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
|
||||
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
|
||||
self.gguf_writer.add_swiglu_clamp_exp([hparams["swiglu_limit"]] * self.block_count)
|
||||
self.gguf_writer.add_swiglu_clamp_shexp([hparams["swiglu_limit"]] * self.block_count)
|
||||
|
||||
self.gguf_writer.add_indexer_head_count(hparams["index_n_heads"])
|
||||
self.gguf_writer.add_indexer_key_length(hparams["index_head_dim"])
|
||||
self.gguf_writer.add_indexer_top_k(hparams["index_topk"])
|
||||
|
||||
self.gguf_writer.add_attention_output_group_count(hparams["o_groups"])
|
||||
self.gguf_writer.add_attention_output_lora_rank(hparams["o_lora_rank"])
|
||||
self.gguf_writer.add_attention_compress_ratios(hparams["compress_ratios"])
|
||||
self.gguf_writer.add_attention_compress_rope_freq_base(hparams["compress_rope_theta"])
|
||||
self.gguf_writer.add_hyper_connection_count(hparams["hc_mult"])
|
||||
self.gguf_writer.add_hyper_connection_sinkhorn_iterations(hparams["hc_sinkhorn_iters"])
|
||||
self.gguf_writer.add_hyper_connection_epsilon(hparams["hc_eps"])
|
||||
self.gguf_writer.add_hash_layer_count(hparams["num_hash_layers"])
|
||||
|
||||
def dequant_model(self):
|
||||
fp8_dtypes = self._float8_dtypes()
|
||||
tensors_to_remove: list[str] = []
|
||||
|
||||
def dequant_fp8_weight(weight: Tensor, scale: Tensor) -> Tensor:
|
||||
out_features, in_features = weight.shape
|
||||
scale_f = self._e8m0_to_float(scale)
|
||||
scale_f = scale_f.repeat_interleave(128, 0)[:out_features]
|
||||
scale_f = scale_f.repeat_interleave(128, 1)[:, :in_features]
|
||||
return weight.float() * scale_f
|
||||
|
||||
for name in list(self.model_tensors.keys()):
|
||||
if not name.endswith(".scale"):
|
||||
continue
|
||||
weight_name = name.removesuffix(".scale") + ".weight"
|
||||
if weight_name not in self.model_tensors:
|
||||
continue
|
||||
|
||||
weight = self.model_tensors[weight_name]
|
||||
scale = self.model_tensors[name]
|
||||
if weight().dtype not in fp8_dtypes:
|
||||
continue
|
||||
|
||||
self.model_tensors[weight_name] = lambda w=weight, s=scale: dequant_fp8_weight(w(), s())
|
||||
self._dsv4_fp8_dequantized.add(weight_name)
|
||||
tensors_to_remove.append(name)
|
||||
|
||||
for name in tensors_to_remove:
|
||||
del self.model_tensors[name]
|
||||
|
||||
@staticmethod
|
||||
def _pack_mxfp4_blocks(weight: Tensor, scale: Tensor) -> np.ndarray:
|
||||
packed = weight.contiguous().view(torch.uint8)
|
||||
scale_u8 = scale.contiguous().view(torch.uint8)
|
||||
|
||||
out_features, packed_cols = packed.shape
|
||||
logical_cols = packed_cols * 2
|
||||
if logical_cols % 32 != 0:
|
||||
raise ValueError(f"MXFP4 source row has {logical_cols} values, expected a multiple of 32")
|
||||
|
||||
n_blocks = logical_cols // 32
|
||||
if tuple(scale_u8.shape) != (out_features, n_blocks):
|
||||
raise ValueError(f"MXFP4 scale shape {tuple(scale_u8.shape)} does not match {(out_features, n_blocks)}")
|
||||
|
||||
src = packed.reshape(out_features, n_blocks, 16)
|
||||
low = src & 0x0F
|
||||
high = (src >> 4) & 0x0F
|
||||
|
||||
# The safetensors bytes store adjacent values as low/high nibbles.
|
||||
# ggml MXFP4 blocks store values 0..15 in low nibbles and 16..31 in high nibbles.
|
||||
vals = torch.stack((low, high), dim=-1).reshape(out_features, n_blocks, 32)
|
||||
qs = vals[:, :, :16] | (vals[:, :, 16:] << 4)
|
||||
raw = torch.cat((scale_u8.unsqueeze(-1), qs.to(torch.uint8)), dim=-1)
|
||||
return raw.reshape(out_features, n_blocks * 17).cpu().numpy()
|
||||
|
||||
def _write_mxfp4_expert_tensor(self, bid: int, proj: str, tensor_key: gguf.MODEL_TENSOR) -> list[str]:
|
||||
n_experts = self.hparams["n_routed_experts"]
|
||||
data: np.ndarray | None = None
|
||||
consumed: list[str] = []
|
||||
|
||||
for eid in range(n_experts):
|
||||
weight_name = f"layers.{bid}.ffn.experts.{eid}.{proj}.weight"
|
||||
scale_name = f"layers.{bid}.ffn.experts.{eid}.{proj}.scale"
|
||||
if weight_name not in self.model_tensors or scale_name not in self.model_tensors:
|
||||
raise KeyError(f"Missing routed expert tensors for {weight_name}")
|
||||
|
||||
weight = LazyTorchTensor.to_eager(self.model_tensors[weight_name]())
|
||||
scale = LazyTorchTensor.to_eager(self.model_tensors[scale_name]())
|
||||
packed = self._pack_mxfp4_blocks(weight, scale)
|
||||
if data is None:
|
||||
data = np.empty((n_experts, *packed.shape), dtype=packed.dtype)
|
||||
data[eid] = packed
|
||||
consumed.extend((weight_name, scale_name))
|
||||
|
||||
assert data is not None
|
||||
new_name = self.format_tensor_name(tensor_key, bid)
|
||||
shape = gguf.quant_shape_from_byte_shape(data.shape, gguf.GGMLQuantizationType.MXFP4)
|
||||
logger.info(f"{new_name}: repacked routed experts to MXFP4, shape = {{{', '.join(str(n) for n in reversed(shape))}}}")
|
||||
self.gguf_writer.add_tensor(new_name, data, raw_dtype=gguf.GGMLQuantizationType.MXFP4)
|
||||
|
||||
return consumed
|
||||
|
||||
def _write_hash_routing_tensors(self) -> list[str]:
|
||||
consumed: list[str] = []
|
||||
|
||||
for bid in range(self.hparams["num_hash_layers"]):
|
||||
name = f"layers.{bid}.ffn.gate.tid2eid"
|
||||
if name not in self.model_tensors:
|
||||
raise KeyError(f"Missing hash routing tensor {name}")
|
||||
|
||||
data_torch = LazyTorchTensor.to_eager(self.model_tensors[name]())
|
||||
data = data_torch.to(torch.int32).cpu().numpy()
|
||||
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_TID2EID, bid, ".weight")
|
||||
logger.info(f"{new_name}: converted hash routing table to I32, shape = {{{', '.join(str(n) for n in reversed(data.shape))}}}")
|
||||
self.gguf_writer.add_tensor(new_name, data)
|
||||
consumed.append(name)
|
||||
|
||||
return consumed
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
if self._dsv4_mxfp4_generated:
|
||||
return ()
|
||||
|
||||
consumed: list[str] = self._write_hash_routing_tensors()
|
||||
for bid in range(self.block_count):
|
||||
consumed.extend(self._write_mxfp4_expert_tensor(bid, "w1", gguf.MODEL_TENSOR.FFN_GATE_EXP))
|
||||
consumed.extend(self._write_mxfp4_expert_tensor(bid, "w2", gguf.MODEL_TENSOR.FFN_DOWN_EXP))
|
||||
consumed.extend(self._write_mxfp4_expert_tensor(bid, "w3", gguf.MODEL_TENSOR.FFN_UP_EXP))
|
||||
|
||||
for name in consumed:
|
||||
del self.model_tensors[name]
|
||||
|
||||
self._dsv4_mxfp4_generated = True
|
||||
return ()
|
||||
|
||||
def _format_dsv4_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None, suffix: str = ".weight") -> str:
|
||||
return self.format_tensor_name(key, bid, suffix)
|
||||
|
||||
def _map_dsv4_tensor_name(self, name: str, bid: int | None) -> tuple[gguf.MODEL_TENSOR, str]:
|
||||
root_map: dict[str, tuple[gguf.MODEL_TENSOR, str]] = {
|
||||
"embed.weight": (gguf.MODEL_TENSOR.TOKEN_EMBD, ".weight"),
|
||||
"norm.weight": (gguf.MODEL_TENSOR.OUTPUT_NORM, ".weight"),
|
||||
"head.weight": (gguf.MODEL_TENSOR.OUTPUT, ".weight"),
|
||||
"hc_head_fn": (gguf.MODEL_TENSOR.HC_HEAD_FN, ".weight"),
|
||||
"hc_head_base": (gguf.MODEL_TENSOR.HC_HEAD_BASE, ".weight"),
|
||||
"hc_head_scale": (gguf.MODEL_TENSOR.HC_HEAD_SCALE, ".weight"),
|
||||
}
|
||||
if name in root_map:
|
||||
return root_map[name]
|
||||
|
||||
match = re.match(r"layers\.(\d+)\.(.+)$", name)
|
||||
if match is None:
|
||||
raise ValueError(f"Unsupported DeepSeek-V4 tensor {name!r}")
|
||||
|
||||
layer = int(match.group(1))
|
||||
if bid != layer:
|
||||
raise ValueError(f"Tensor {name!r} parsed bid {bid} but layer name has {layer}")
|
||||
|
||||
layer_map: dict[str, tuple[gguf.MODEL_TENSOR, str]] = {
|
||||
"hc_attn_fn": (gguf.MODEL_TENSOR.HC_ATTN_FN, ".weight"),
|
||||
"hc_attn_base": (gguf.MODEL_TENSOR.HC_ATTN_BASE, ".weight"),
|
||||
"hc_attn_scale": (gguf.MODEL_TENSOR.HC_ATTN_SCALE, ".weight"),
|
||||
"hc_ffn_fn": (gguf.MODEL_TENSOR.HC_FFN_FN, ".weight"),
|
||||
"hc_ffn_base": (gguf.MODEL_TENSOR.HC_FFN_BASE, ".weight"),
|
||||
"hc_ffn_scale": (gguf.MODEL_TENSOR.HC_FFN_SCALE, ".weight"),
|
||||
"attn.attn_sink": (gguf.MODEL_TENSOR.ATTN_SINKS, ".weight"),
|
||||
"attn.wq_a.weight": (gguf.MODEL_TENSOR.ATTN_Q_A, ".weight"),
|
||||
"attn.wq_b.weight": (gguf.MODEL_TENSOR.ATTN_Q_B, ".weight"),
|
||||
"attn.q_norm.weight": (gguf.MODEL_TENSOR.ATTN_Q_A_NORM, ".weight"),
|
||||
"attn.wkv.weight": (gguf.MODEL_TENSOR.ATTN_KV, ".weight"),
|
||||
"attn.kv_norm.weight": (gguf.MODEL_TENSOR.ATTN_KV_NORM, ".weight"),
|
||||
"attn.wo_a.weight": (gguf.MODEL_TENSOR.ATTN_OUT_A, ".weight"),
|
||||
"attn.wo_b.weight": (gguf.MODEL_TENSOR.ATTN_OUT_B, ".weight"),
|
||||
"attn.compressor.ape": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_APE, ".weight"),
|
||||
"attn.compressor.wkv.weight": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_WKV, ".weight"),
|
||||
"attn.compressor.wgate.weight": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_WGATE, ".weight"),
|
||||
"attn.compressor.norm.weight": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_NORM, ".weight"),
|
||||
"attn.indexer.wq_b.weight": (gguf.MODEL_TENSOR.INDEXER_ATTN_Q_B, ".weight"),
|
||||
"attn.indexer.weights_proj.weight": (gguf.MODEL_TENSOR.INDEXER_PROJ, ".weight"),
|
||||
"attn.indexer.compressor.ape": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_APE, ".weight"),
|
||||
"attn.indexer.compressor.wkv.weight": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_WKV, ".weight"),
|
||||
"attn.indexer.compressor.wgate.weight": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_WGATE, ".weight"),
|
||||
"attn.indexer.compressor.norm.weight": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_NORM, ".weight"),
|
||||
"attn_norm.weight": (gguf.MODEL_TENSOR.ATTN_NORM, ".weight"),
|
||||
"ffn_norm.weight": (gguf.MODEL_TENSOR.FFN_NORM, ".weight"),
|
||||
"ffn.gate.weight": (gguf.MODEL_TENSOR.FFN_GATE_INP, ".weight"),
|
||||
"ffn.gate.bias": (gguf.MODEL_TENSOR.FFN_EXP_PROBS_B, ".bias"),
|
||||
"ffn.gate.tid2eid": (gguf.MODEL_TENSOR.FFN_GATE_TID2EID, ".weight"),
|
||||
"ffn.shared_experts.w1.weight": (gguf.MODEL_TENSOR.FFN_GATE_SHEXP, ".weight"),
|
||||
"ffn.shared_experts.w2.weight": (gguf.MODEL_TENSOR.FFN_DOWN_SHEXP, ".weight"),
|
||||
"ffn.shared_experts.w3.weight": (gguf.MODEL_TENSOR.FFN_UP_SHEXP, ".weight"),
|
||||
}
|
||||
|
||||
tensor_name = match.group(2)
|
||||
if tensor_name in layer_map:
|
||||
return layer_map[tensor_name]
|
||||
|
||||
if re.match(r"ffn\.experts\.\d+\.w[123]\.(weight|scale)$", tensor_name):
|
||||
return gguf.MODEL_TENSOR.FFN_GATE_EXP, ".weight"
|
||||
|
||||
raise ValueError(f"Unsupported DeepSeek-V4 tensor {name!r}")
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if re.match(r"layers\.\d+\.ffn\.experts\.\d+\.w[123]\.(weight|scale)$", name):
|
||||
return []
|
||||
|
||||
tensor_key, suffix = self._map_dsv4_tensor_name(name, bid)
|
||||
if tensor_key == gguf.MODEL_TENSOR.FFN_GATE_TID2EID:
|
||||
return []
|
||||
|
||||
return [(self._format_dsv4_tensor_name(tensor_key, bid, suffix), data_torch)]
|
||||
|
||||
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
|
||||
del new_name, bid # unused
|
||||
|
||||
if name in self._dsv4_fp8_dequantized and n_dims >= 2:
|
||||
return gguf.GGMLQuantizationType.Q8_0
|
||||
if name in self._dsv4_f32_tensors:
|
||||
return gguf.GGMLQuantizationType.F32
|
||||
if name in self._dsv4_bf16_tensors and n_dims >= 2:
|
||||
return gguf.GGMLQuantizationType.BF16
|
||||
|
||||
return False
|
||||
|
||||
def prepare_tensors(self):
|
||||
super().prepare_tensors()
|
||||
self._is_mxfp4 = True
|
||||
self.ftype = gguf.LlamaFileType.MOSTLY_MXFP4_MOE
|
||||
|
||||
+3
-3
@@ -73,7 +73,7 @@ class LlamaModel(TextModel):
|
||||
target_num_layers = target_config["num_hidden_layers"]
|
||||
target_layers = [2, target_num_layers // 2, target_num_layers - 3]
|
||||
logger.info(f"EAGLE-3: target_layers = {target_layers} (target model has {target_num_layers} layers)")
|
||||
self.gguf_writer.add_array(f"{self.gguf_writer.arch}.target_layers", target_layers)
|
||||
self.gguf_writer.add_target_layers(target_layers)
|
||||
|
||||
# target_hidden_size: prefer eagle3 config, fallback to target config
|
||||
if eagle3_raw_config.get("target_hidden_size") is not None:
|
||||
@@ -83,12 +83,12 @@ class LlamaModel(TextModel):
|
||||
target_hidden_size = target_config["hidden_size"]
|
||||
src = "target model config"
|
||||
logger.info(f"EAGLE-3: target_hidden_size = {target_hidden_size} (from {src})")
|
||||
self.gguf_writer.add_uint32(f"{self.gguf_writer.arch}.target_hidden_size", target_hidden_size)
|
||||
self.gguf_writer.add_target_hidden_size(target_hidden_size)
|
||||
|
||||
# norm_before_residual (RedHat-style eagle3 specific)
|
||||
norm_before_residual = eagle3_raw_config.get("norm_before_residual", False)
|
||||
logger.info(f"EAGLE-3: norm_before_residual = {norm_before_residual}")
|
||||
self.gguf_writer.add_bool(f"{self.gguf_writer.arch}.norm_before_residual", norm_before_residual)
|
||||
self.gguf_writer.add_norm_before_residual(norm_before_residual)
|
||||
|
||||
def set_vocab(self):
|
||||
# eagle3: use tokenizer from target model if provided
|
||||
|
||||
+10
-14
@@ -643,21 +643,21 @@ class DFlashModel(Qwen3Model):
|
||||
super().set_vocab()
|
||||
self.dir_model = original_dir
|
||||
|
||||
mask_token_id = self.hparams.get("dflash_config", {}).get("mask_token_id")
|
||||
if mask_token_id is not None:
|
||||
self.gguf_writer.add_mask_token_id(mask_token_id)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
block_size = self.hparams.get("block_size", 16)
|
||||
self.gguf_writer.add_uint32(f"{self.gguf_writer.arch}.block_size", block_size)
|
||||
self.gguf_writer.add_block_size(block_size)
|
||||
dflash_config = self.hparams.get("dflash_config", {})
|
||||
|
||||
target_layer_ids = dflash_config.get("target_layer_ids", [])
|
||||
if target_layer_ids:
|
||||
extract_layer_ids = [i + 1 for i in target_layer_ids]
|
||||
self.gguf_writer.add_array(f"{self.gguf_writer.arch}.target_layers", extract_layer_ids)
|
||||
|
||||
mask_token_id = dflash_config.get("mask_token_id", None)
|
||||
if mask_token_id is not None:
|
||||
self.gguf_writer.add_mask_token_id(mask_token_id)
|
||||
self.gguf_writer.add_target_layers(extract_layer_ids)
|
||||
|
||||
use_sliding_window = self.hparams.get("use_sliding_window", False)
|
||||
sliding_window = self.hparams.get("sliding_window")
|
||||
@@ -667,13 +667,9 @@ class DFlashModel(Qwen3Model):
|
||||
self.gguf_writer.add_sliding_window(sliding_window)
|
||||
self.gguf_writer.add_sliding_window_pattern(is_swa)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if name == "fc.weight":
|
||||
yield (name, data_torch)
|
||||
return
|
||||
if name == "hidden_norm.weight":
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ENC_OUTPUT_NORM), data_torch)
|
||||
return
|
||||
@classmethod
|
||||
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
|
||||
name, gen = item
|
||||
if not name.startswith("model."):
|
||||
name = "model." + name
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
return super().filter_tensors((name, gen))
|
||||
|
||||
@@ -1551,8 +1551,6 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
|
||||
int split_backend_id = split->backend_id;
|
||||
ggml_backend_t split_backend = sched->backends[split_backend_id];
|
||||
|
||||
ggml_backend_synchronize(split_backend);
|
||||
|
||||
// copy the input tensors to the split backend
|
||||
for (int input_id = 0; input_id < split->n_inputs; input_id++) {
|
||||
ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[input_id]);
|
||||
@@ -1563,15 +1561,15 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
|
||||
// inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done
|
||||
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
|
||||
ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);
|
||||
} else if (!split_backend->iface.cpy_tensor_async) {
|
||||
} else {
|
||||
ggml_backend_synchronize(split_backend);
|
||||
}
|
||||
ggml_backend_tensor_copy_async(input_backend, split_backend, input, input_cpy);
|
||||
ggml_backend_tensor_copy(input, input_cpy);
|
||||
} else {
|
||||
// wait for the split backend to finish using the input before overwriting it
|
||||
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
|
||||
ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]);
|
||||
} else if (!split_backend->iface.cpy_tensor_async) {
|
||||
} else {
|
||||
ggml_backend_synchronize(split_backend);
|
||||
}
|
||||
|
||||
@@ -1676,8 +1674,6 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
|
||||
}
|
||||
}
|
||||
|
||||
ggml_backend_synchronize(split_backend);
|
||||
|
||||
if (!sched->callback_eval) {
|
||||
enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph);
|
||||
if (ec != GGML_STATUS_SUCCESS) {
|
||||
|
||||
@@ -3192,24 +3192,11 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
|
||||
ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;
|
||||
ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;
|
||||
|
||||
// Enables async copies from CPU to CUDA, instead of only CUDA-to-CUDA
|
||||
// Excluding this path for HIP and MUSA as a precaution.
|
||||
// According to the summary in https://github.com/ggml-org/llama.cpp/pull/20793#issuecomment-4275794315, this change is not beneficial for hip anyways.
|
||||
// Additionally, there is a lot of anectodal evidence that hip/musa stream behavior might not always 1:1 match CUDA behavior.
|
||||
// e.g. https://github.com/ROCm/rocm-systems/issues/5109
|
||||
// It thus makes sense to exclude this path for HIP and MUSA. This PR was not aimed these backends, the majority of testing happened on CUDA.
|
||||
// This can be revisited in the future if enabling copy_from_host benefits hip/MUSA, and if the PR author can extensively test on these backends.
|
||||
#if defined(GGML_USE_HIP) || defined(GGML_USE_MUSA)
|
||||
const bool copy_from_host = false;
|
||||
#else
|
||||
const bool copy_from_host = ggml_backend_buffer_is_host(buf_src) && ggml_backend_dev_type(backend_src->device) == GGML_BACKEND_DEVICE_TYPE_CPU;
|
||||
#endif
|
||||
|
||||
if (!(copy_from_host || ggml_backend_is_cuda(backend_src)) || !ggml_backend_is_cuda(backend_dst)) {
|
||||
if (!ggml_backend_is_cuda(backend_src) || !ggml_backend_is_cuda(backend_dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!(copy_from_host || ggml_backend_buffer_is_cuda(buf_src)) || !ggml_backend_buffer_is_cuda(buf_dst)) {
|
||||
if (!ggml_backend_buffer_is_cuda(buf_src) || !ggml_backend_buffer_is_cuda(buf_dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -3220,17 +3207,14 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
|
||||
ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *) buf_src->context;
|
||||
ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *) buf_dst->context;
|
||||
|
||||
if ((copy_from_host && cuda_ctx_dst->device != buf_ctx_dst->device) ||
|
||||
!copy_from_host && (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device)) {
|
||||
if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) {
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__);
|
||||
#endif // NDEBUG
|
||||
return false;
|
||||
}
|
||||
|
||||
if (copy_from_host) {
|
||||
CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyHostToDevice, cuda_ctx_dst->stream()));
|
||||
} else if (backend_src != backend_dst) {
|
||||
if (backend_src != backend_dst) {
|
||||
// copy on src stream
|
||||
if (cuda_ctx_src->device == cuda_ctx_dst->device) {
|
||||
CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));
|
||||
|
||||
@@ -1907,6 +1907,38 @@ static bool vk_enable_sync_logger = false;
|
||||
static uint32_t vk_perf_logger_frequency = 1;
|
||||
static std::string vk_pipeline_stats_filter;
|
||||
|
||||
static uint64_t ggml_vk_get_node_flops(const ggml_tensor * node) {
|
||||
if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
|
||||
const uint64_t m = node->ne[0];
|
||||
const uint64_t n = node->ne[1];
|
||||
const uint64_t k = node->src[1]->ne[0];
|
||||
const uint64_t batch = node->ne[2] * node->ne[3];
|
||||
return m * n * (k + (k - 1)) * batch;
|
||||
}
|
||||
if (node->op == GGML_OP_CONV_2D || node->op == GGML_OP_CONV_TRANSPOSE_2D) {
|
||||
const ggml_tensor * knl = node->src[0];
|
||||
const uint64_t Cout = node->ne[2];
|
||||
const uint64_t size_K = node->src[1]->ne[2] * knl->ne[0] * knl->ne[1];
|
||||
const uint64_t size_N = node->ne[3] * node->ne[0] * node->ne[1];
|
||||
return Cout * size_N * (size_K + (size_K - 1));
|
||||
}
|
||||
if (node->op == GGML_OP_CONV_3D) {
|
||||
const ggml_tensor * knl = node->src[0];
|
||||
const uint64_t OC = ggml_get_op_params_i32(node, 11);
|
||||
const uint64_t IC = ggml_get_op_params_i32(node, 9);
|
||||
const uint64_t size_K = IC * knl->ne[0] * knl->ne[1] * knl->ne[2];
|
||||
const uint64_t size_N = node->ne[3] / OC * node->ne[0] * node->ne[1] * node->ne[2];
|
||||
return OC * size_N * (size_K + (size_K - 1));
|
||||
}
|
||||
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
|
||||
const ggml_tensor * q = node->src[0];
|
||||
const ggml_tensor * k = node->src[1];
|
||||
const ggml_tensor * v = node->src[2];
|
||||
return 2ull * q->ne[1] * q->ne[2] * (k->ne[0] + v->ne[0]) * k->ne[1] * q->ne[3];
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
class vk_perf_logger {
|
||||
public:
|
||||
void print_timings(bool force = false) {
|
||||
@@ -1955,7 +1987,7 @@ class vk_perf_logger {
|
||||
}
|
||||
|
||||
std::string get_node_fusion_name(const ggml_tensor * node, const char *fusion_name, uint64_t *n_flops) {
|
||||
*n_flops = 0;
|
||||
*n_flops = ggml_vk_get_node_flops(node);
|
||||
std::string fusion_str;
|
||||
if (fusion_name) {
|
||||
fusion_str = fusion_name + std::string(" ");
|
||||
@@ -1982,35 +2014,22 @@ class vk_perf_logger {
|
||||
if (batch > 1) {
|
||||
name += " batch=" + std::to_string(batch);
|
||||
}
|
||||
name = fusion_str + name;
|
||||
*n_flops = m * n * (k + (k - 1)) * batch;
|
||||
return name;
|
||||
return fusion_str + name;
|
||||
}
|
||||
if (node->op == GGML_OP_CONV_2D || node->op == GGML_OP_CONV_TRANSPOSE_2D) {
|
||||
std::string name = ggml_op_name(node->op);
|
||||
ggml_tensor * knl = node->src[0];
|
||||
uint64_t OW = node->ne[0];
|
||||
uint64_t OH = node->ne[1];
|
||||
uint64_t N = node->ne[3];
|
||||
const ggml_tensor * knl = node->src[0];
|
||||
uint64_t Cout = node->ne[2];
|
||||
uint64_t KW = knl->ne[0];
|
||||
uint64_t KH = knl->ne[1];
|
||||
uint64_t Cin = node->src[1]->ne[2];
|
||||
// KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ
|
||||
uint64_t size_M = Cout;
|
||||
uint64_t size_K = Cin * KW * KH;
|
||||
uint64_t size_N = N * OW * OH;
|
||||
*n_flops = size_M * size_N * (size_K + (size_K - 1));
|
||||
name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) +
|
||||
uint64_t size_K = node->src[1]->ne[2] * knl->ne[0] * knl->ne[1];
|
||||
uint64_t size_N = node->ne[3] * node->ne[0] * node->ne[1];
|
||||
name += " M=Cout=" + std::to_string(Cout) + ", K=Cin*KW*KH=" + std::to_string(size_K) +
|
||||
", N=N*OW*OH=" + std::to_string(size_N);
|
||||
name = fusion_str + name;
|
||||
return name;
|
||||
return fusion_str + name;
|
||||
}
|
||||
if (node->op == GGML_OP_RMS_NORM) {
|
||||
std::string name = ggml_op_name(node->op);
|
||||
name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")";
|
||||
name = fusion_str + name;
|
||||
return name;
|
||||
return fusion_str + name;
|
||||
}
|
||||
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
|
||||
const ggml_tensor * dst = node;
|
||||
@@ -2026,7 +2045,6 @@ class vk_perf_logger {
|
||||
" k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " <<
|
||||
" v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " <<
|
||||
" m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")";
|
||||
*n_flops = 2ull * q->ne[1] * q->ne[2] * (k->ne[0] + v->ne[0]) * k->ne[1] * q->ne[3];
|
||||
return name.str();
|
||||
}
|
||||
if (node->op == GGML_OP_TOP_K) {
|
||||
@@ -2090,7 +2108,7 @@ struct ggml_backend_vk_context {
|
||||
bool do_add_rms_partials_offset_calculation;
|
||||
bool do_add_rms_partials;
|
||||
|
||||
uint64_t last_total_mul_mat_bytes {};
|
||||
uint64_t last_total_flops {UINT64_MAX};
|
||||
|
||||
// Cache most recent tensor that was converted into prealloc_y, and what pipeline it used to convert.
|
||||
vk_pipeline_struct * prealloc_y_last_pipeline_used {};
|
||||
@@ -16188,22 +16206,23 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
}
|
||||
|
||||
// Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
|
||||
// Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
|
||||
// (and scaled down based on model size, so smaller models submit earlier).
|
||||
int submitted_nodes = 0;
|
||||
int submit_count = 0;
|
||||
uint64_t mul_mat_bytes = 0;
|
||||
uint64_t total_mul_mat_bytes = 0;
|
||||
uint64_t mul_mat_bytes_per_submit = std::min(uint64_t(100*1000*1000), ctx->last_total_mul_mat_bytes / 40u);
|
||||
// Estimate the amount of compute work using flops, and submit every 200 GFLOP
|
||||
// (and scaled down based on total graph flops, so smaller models submit earlier).
|
||||
// Also submit at least every 100 nodes, in case there are workloads without heavy compute.
|
||||
uint32_t submitted_nodes = 0;
|
||||
uint32_t submit_count = 0;
|
||||
uint64_t batch_flops = 0;
|
||||
uint64_t total_flops = 0;
|
||||
uint64_t flops_per_submit = std::min(uint64_t(200'000'000'000), ctx->last_total_flops / 40u);
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
if (first_node_in_batch) {
|
||||
submit_node_idx = i;
|
||||
}
|
||||
|
||||
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
|
||||
auto bytes = ggml_nbytes(cgraph->nodes[i]->src[0]);
|
||||
mul_mat_bytes += bytes;
|
||||
total_mul_mat_bytes += bytes;
|
||||
{
|
||||
auto node_flops = ggml_vk_get_node_flops(cgraph->nodes[i]);
|
||||
batch_flops += node_flops;
|
||||
total_flops += node_flops;
|
||||
}
|
||||
|
||||
// op_srcs_fused_elementwise indicates whether an op's srcs all contribute to
|
||||
@@ -16415,8 +16434,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
|
||||
// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
|
||||
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
|
||||
bool submit = ((uint32_t)submitted_nodes >= ctx->device->max_nodes_per_submit) ||
|
||||
(mul_mat_bytes_per_submit != 0 && mul_mat_bytes >= mul_mat_bytes_per_submit) ||
|
||||
bool submit = (submitted_nodes >= ctx->device->max_nodes_per_submit) ||
|
||||
(flops_per_submit != 0 && batch_flops >= flops_per_submit) ||
|
||||
(i + ctx->num_additional_fused_ops >= last_node) ||
|
||||
(almost_ready && !ctx->almost_ready_fence_pending);
|
||||
|
||||
@@ -16450,9 +16469,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
if (submit && enqueued) {
|
||||
first_node_in_batch = true;
|
||||
submitted_nodes = 0;
|
||||
mul_mat_bytes = 0;
|
||||
batch_flops = 0;
|
||||
if (submit_count < 3) {
|
||||
mul_mat_bytes_per_submit *= 2;
|
||||
flops_per_submit *= 2;
|
||||
}
|
||||
submit_count++;
|
||||
}
|
||||
@@ -16461,7 +16480,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
ctx->fused_ops_write_mask = 0;
|
||||
}
|
||||
|
||||
ctx->last_total_mul_mat_bytes = total_mul_mat_bytes;
|
||||
ctx->last_total_flops = total_flops;
|
||||
|
||||
if (vk_perf_logger_enabled) {
|
||||
// End the command buffer and submit/wait
|
||||
|
||||
@@ -1563,6 +1563,7 @@ class ggml_webgpu_shader_lib {
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
{
|
||||
// Quantized types using u32 buffers for portability.
|
||||
defines.push_back("SRC_TYPE=u32");
|
||||
@@ -1593,6 +1594,8 @@ class ggml_webgpu_shader_lib {
|
||||
} else if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
|
||||
key.src_type == GGML_TYPE_IQ4_NL || key.src_type == GGML_TYPE_MXFP4) {
|
||||
defines.push_back("BLOCK_SIZE=32u");
|
||||
} else if (key.src_type == GGML_TYPE_NVFP4) {
|
||||
defines.push_back("BLOCK_SIZE=64u");
|
||||
} else if (key.src_type >= GGML_TYPE_Q2_K) {
|
||||
defines.push_back("BLOCK_SIZE=256u");
|
||||
} else {
|
||||
@@ -1960,6 +1963,7 @@ class ggml_webgpu_shader_lib {
|
||||
defines.push_back(type_upper + "_TABLES");
|
||||
break;
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
defines.push_back(type_upper + "_LUT");
|
||||
break;
|
||||
default:
|
||||
@@ -2103,6 +2107,7 @@ class ggml_webgpu_shader_lib {
|
||||
defines.push_back(type_upper + "_TABLES");
|
||||
break;
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
defines.push_back(type_upper + "_LUT");
|
||||
break;
|
||||
default:
|
||||
@@ -2274,6 +2279,7 @@ class ggml_webgpu_shader_lib {
|
||||
defines.push_back(type_upper + "_TABLES");
|
||||
break;
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
defines.push_back(type_upper + "_LUT");
|
||||
break;
|
||||
default:
|
||||
@@ -2394,6 +2400,7 @@ class ggml_webgpu_shader_lib {
|
||||
defines.push_back(type_upper + "_TABLES");
|
||||
break;
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
defines.push_back(type_upper + "_LUT");
|
||||
break;
|
||||
default:
|
||||
|
||||
@@ -4056,6 +4056,7 @@ static bool ggml_webgpu_supported_qtype(ggml_type type) {
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
@@ -4156,6 +4157,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
supports_op = true;
|
||||
break;
|
||||
default:
|
||||
@@ -4196,6 +4198,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
supports_op = true;
|
||||
break;
|
||||
default:
|
||||
|
||||
@@ -896,9 +896,23 @@ const kvalues_iq4nl = array<i32, 16>(
|
||||
|
||||
#endif
|
||||
|
||||
#ifdef MXFP4_LUT
|
||||
#if defined(MXFP4_LUT) || defined(NVFP4_LUT)
|
||||
const kvalues_mxfp4 = array<i32, 16>(
|
||||
0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12
|
||||
);
|
||||
#endif
|
||||
#endif // MXFP4_LUT || NVFP4_LUT
|
||||
|
||||
#ifdef NVFP4_LUT
|
||||
fn ue4m3_to_fp32(u: u32) -> f32 {
|
||||
if (u == 0u || u == 127u) {
|
||||
return 0.0;
|
||||
}
|
||||
let exp = (u >> 3u) & 15u;
|
||||
let man = u & 7u;
|
||||
if (exp == 0u) {
|
||||
return f32(man) * (1.0 / 512.0);
|
||||
}
|
||||
let bits = ((exp + 120u) << 23u) | (man << 20u);
|
||||
return bitcast<f32>(bits);
|
||||
}
|
||||
#endif // NVFP4_LUT
|
||||
|
||||
@@ -672,6 +672,27 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef NVFP4
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_byte_base = (src_base + offset) * 36;
|
||||
let d_word = load_u32_at_src(block_byte_base);
|
||||
for (var sub: u32 = 0u; sub < 4; sub++) {
|
||||
let d = ue4m3_to_fp32(get_byte(d_word, sub)) * 0.5;
|
||||
for (var j: u32 = 0u; j < 2; j++) {
|
||||
let q_packed = load_u32_at_src(block_byte_base + 4 + sub * 8 + j * 4);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_lo = f32(kvalues_mxfp4[q_byte & 0xFu]) * d;
|
||||
let q_hi = f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * d;
|
||||
let dst_offset = dst_base + offset * 64 + sub * 16 + j * 4 + k;
|
||||
dst[dst_offset] = q_lo;
|
||||
dst[dst_offset + 8u] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<SRC_TYPE>;
|
||||
|
||||
@@ -241,7 +241,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
#endif // INIT_SRC0_SHMEM_Q8_1
|
||||
|
||||
#if defined(INIT_SRC0_SHMEM_MXFP4)
|
||||
let block_byte_base = src0_idx * 17u;
|
||||
let block_byte_base = src0_idx * 17u; // BLOCK_SIZE_BYTES = 17u;
|
||||
let eu8 = get_byte(load_u32_at_src0_aligned(block_byte_base), block_byte_base & 3u);
|
||||
let e = ldexp(1.0, i32(eu8) - 128);
|
||||
|
||||
@@ -263,6 +263,47 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
}
|
||||
#endif // legacy-quants
|
||||
|
||||
#if defined(INIT_SRC0_SHMEM_NVFP4)
|
||||
const BLOCK_SIZE = 64u;
|
||||
const BLOCK_SIZE_BYTES = 36u;
|
||||
const SUB_BLOCK_SIZE = 16u; // elements sharing one UE4M3 scale
|
||||
const NQ = 16u;
|
||||
const BYTES_PER_THREAD = 8u;
|
||||
const BYTES_PER_INNER_LOOP = 4u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let tile_m = i / TILE_K;
|
||||
let tile_k_start = i % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k_start = k_outer + tile_k_start;
|
||||
|
||||
if (global_m >= params.m) {
|
||||
break;
|
||||
}
|
||||
|
||||
let block_k = global_k_start / BLOCK_SIZE;
|
||||
let sub_block = (global_k_start % BLOCK_SIZE) / SUB_BLOCK_SIZE;
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d_byte_base = block_byte_base;
|
||||
let qs_byte_base = block_byte_base + 4u;
|
||||
|
||||
let d = ue4m3_to_fp32(get_byte(load_u32_at_src0_aligned(d_byte_base), sub_block)) * 0.5;
|
||||
|
||||
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j++) {
|
||||
let q_packed = load_u32_at_src0_aligned(qs_byte_base + sub_block * 8u + j * 4u);
|
||||
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
shmem[i + j * BYTES_PER_INNER_LOOP + k] = f16(f32(kvalues_mxfp4[q_byte & 0xF]) * d);
|
||||
shmem[i + j * BYTES_PER_INNER_LOOP + k + 8u] = f16(f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_NVFP4
|
||||
|
||||
// k-quants
|
||||
#if defined(INIT_SRC0_SHMEM_Q2_K) || defined(INIT_SRC0_SHMEM_Q3_K) || defined(INIT_SRC0_SHMEM_Q4_K) || defined(INIT_SRC0_SHMEM_Q5_K) || defined(INIT_SRC0_SHMEM_Q6_K)
|
||||
const BLOCK_SIZE = 256u;
|
||||
|
||||
@@ -1505,3 +1505,49 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src
|
||||
return acc;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_NVFP4
|
||||
#define BLOCK_SIZE 64
|
||||
#define BLOCK_SIZE_BYTES 36
|
||||
#define THREADS_PER_BLOCK 4
|
||||
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
|
||||
fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<array<f32, OUTPUTS_PER_WG>, NUM_COLS> {
|
||||
var acc: array<array<f32, OUTPUTS_PER_WG>, NUM_COLS>;
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
let sub = thread_id % THREADS_PER_BLOCK;
|
||||
for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) {
|
||||
let x_base = src1_idx_base + block * BLOCK_SIZE + sub * ELEMS_PER_THREAD;
|
||||
var x_block: array<array<f32, ELEMS_PER_THREAD>, NUM_COLS>;
|
||||
for (var col = 0u; col < NUM_COLS;col += 1) {
|
||||
for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) {
|
||||
x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]);
|
||||
x_block[col][i + 8] = f32(src1[x_base + col * params.stride_11 + i + 8]);
|
||||
}
|
||||
}
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
let d = ue4m3_to_fp32(get_byte(load_u32_at_src0_aligned(block_byte_base), sub)) * 0.5;
|
||||
let q_w0 = load_u32_at_src0_aligned(block_byte_base + 4u + 8u * sub);
|
||||
let q_w1 = load_u32_at_src0_aligned(block_byte_base + 8u + 8u * sub);
|
||||
for (var col = 0u;col < NUM_COLS;col += 1) {
|
||||
var row_sum = 0.0;
|
||||
for (var l = 0u; l < 8u; l++) {
|
||||
let q_word = select(q_w0, q_w1, l >= 4u);
|
||||
let q_byte = get_byte(q_word, l % 4u);
|
||||
let q_lo = f32(kvalues_mxfp4[q_byte & 0xFu]) * d;
|
||||
let q_hi = f32(kvalues_mxfp4[(q_byte >> 4u) & 0xFu]) * d;
|
||||
row_sum += q_lo * x_block[col][l];
|
||||
row_sum += q_hi * x_block[col][l + 8u];
|
||||
}
|
||||
acc[col][row] += row_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return acc;
|
||||
}
|
||||
#endif
|
||||
|
||||
+103
-2
@@ -145,6 +145,7 @@ class Keys:
|
||||
TOKEN_SHIFT_COUNT = "{arch}.token_shift_count"
|
||||
INTERLEAVE_MOE_LAYER_STEP = "{arch}.interleave_moe_layer_step"
|
||||
FULL_ATTENTION_INTERVAL = "{arch}.full_attention_interval"
|
||||
HASH_LAYER_COUNT = "{arch}.hash_layer_count"
|
||||
ACTIVATION_SPARSITY_SCALE = "{arch}.activation_sparsity_scale"
|
||||
ALTUP_ACTIVE_IDX = "{arch}.altup.active_idx"
|
||||
ALTUP_NUM_INPUTS = "{arch}.altup.num_inputs"
|
||||
@@ -156,6 +157,7 @@ class Keys:
|
||||
DENSE_FEAT_OUT_SIZE = "{arch}.{dense}_feat_out"
|
||||
TARGET_LAYERS = "{arch}.target_layers"
|
||||
TARGET_HIDDEN_SIZE = "{arch}.target_hidden_size"
|
||||
BLOCK_SIZE = "{arch}.block_size"
|
||||
NORM_BEFORE_RESIDUAL = "{arch}.norm_before_residual"
|
||||
|
||||
class Attention:
|
||||
@@ -179,8 +181,12 @@ class Keys:
|
||||
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
|
||||
SLIDING_WINDOW = "{arch}.attention.sliding_window"
|
||||
SCALE = "{arch}.attention.scale"
|
||||
OUTPUT_GROUP_COUNT = "{arch}.attention.output_group_count"
|
||||
OUTPUT_LORA_RANK = "{arch}.attention.output_lora_rank"
|
||||
OUTPUT_SCALE = "{arch}.attention.output_scale"
|
||||
VALUE_SCALE = "{arch}.attention.value_scale"
|
||||
COMPRESS_RATIOS = "{arch}.attention.compress_ratios"
|
||||
COMPRESS_ROPE_FREQ_BASE = "{arch}.attention.compress_rope_freq_base"
|
||||
TEMPERATURE_LENGTH = "{arch}.attention.temperature_length"
|
||||
KEY_LENGTH_MLA = "{arch}.attention.key_length_mla"
|
||||
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
|
||||
@@ -195,6 +201,11 @@ class Keys:
|
||||
KEY_LENGTH = "{arch}.attention.indexer.key_length"
|
||||
TOP_K = "{arch}.attention.indexer.top_k"
|
||||
|
||||
class HyperConnection:
|
||||
COUNT = "{arch}.hyper_connection.count"
|
||||
SINKHORN_ITERATIONS = "{arch}.hyper_connection.sinkhorn_iterations"
|
||||
EPSILON = "{arch}.hyper_connection.epsilon"
|
||||
|
||||
class Rope:
|
||||
DIMENSION_COUNT = "{arch}.rope.dimension_count"
|
||||
DIMENSION_COUNT_SWA = "{arch}.rope.dimension_count_swa"
|
||||
@@ -469,6 +480,7 @@ class MODEL_ARCH(IntEnum):
|
||||
DEEPSEEK2 = auto()
|
||||
DEEPSEEK2OCR = auto()
|
||||
DEEPSEEK32 = auto()
|
||||
DEEPSEEK4 = auto()
|
||||
CHATGLM = auto()
|
||||
GLM4 = auto()
|
||||
GLM4_MOE = auto()
|
||||
@@ -554,6 +566,9 @@ class MODEL_TENSOR(IntEnum):
|
||||
DENSE_2_OUT = auto() # embeddinggemma 2_Dense
|
||||
DENSE_3_OUT = auto() # embeddinggemma 3_Dense
|
||||
OUTPUT_NORM = auto()
|
||||
HC_HEAD_FN = auto()
|
||||
HC_HEAD_BASE = auto()
|
||||
HC_HEAD_SCALE = auto()
|
||||
ROPE_FREQS = auto()
|
||||
ROPE_FACTORS_LONG = auto()
|
||||
ROPE_FACTORS_SHORT = auto()
|
||||
@@ -593,6 +608,7 @@ class MODEL_TENSOR(IntEnum):
|
||||
FFN_DOWN_CHEXP = auto()
|
||||
FFN_UP_CHEXP = auto()
|
||||
FFN_EXP_PROBS_B = auto()
|
||||
FFN_GATE_TID2EID = auto()
|
||||
MOE_LATENT_DOWN = auto() # nemotron 3 super
|
||||
MOE_LATENT_UP = auto() # nemotron 3 super
|
||||
ATTN_Q_NORM = auto()
|
||||
@@ -680,6 +696,20 @@ class MODEL_TENSOR(IntEnum):
|
||||
ATTN_V_B = auto()
|
||||
ATTN_Q_A_NORM = auto()
|
||||
ATTN_KV_A_NORM = auto()
|
||||
ATTN_KV = auto()
|
||||
ATTN_KV_NORM = auto()
|
||||
ATTN_OUT_A = auto()
|
||||
ATTN_OUT_B = auto()
|
||||
HC_ATTN_FN = auto()
|
||||
HC_ATTN_BASE = auto()
|
||||
HC_ATTN_SCALE = auto()
|
||||
HC_FFN_FN = auto()
|
||||
HC_FFN_BASE = auto()
|
||||
HC_FFN_SCALE = auto()
|
||||
ATTN_COMPRESSOR_WKV = auto()
|
||||
ATTN_COMPRESSOR_WGATE = auto()
|
||||
ATTN_COMPRESSOR_APE = auto()
|
||||
ATTN_COMPRESSOR_NORM = auto()
|
||||
FFN_SUB_NORM = auto()
|
||||
ATTN_SUB_NORM = auto()
|
||||
DEC_ATTN_NORM = auto()
|
||||
@@ -741,6 +771,10 @@ class MODEL_TENSOR(IntEnum):
|
||||
INDEXER_PROJ = auto()
|
||||
INDEXER_ATTN_K = auto()
|
||||
INDEXER_ATTN_Q_B = auto()
|
||||
INDEXER_COMPRESSOR_WKV = auto()
|
||||
INDEXER_COMPRESSOR_WGATE = auto()
|
||||
INDEXER_COMPRESSOR_APE = auto()
|
||||
INDEXER_COMPRESSOR_NORM = auto()
|
||||
# vision
|
||||
V_MMPROJ = auto()
|
||||
V_MMPROJ_FC = auto()
|
||||
@@ -1026,6 +1060,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.DEEPSEEK2: "deepseek2",
|
||||
MODEL_ARCH.DEEPSEEK2OCR: "deepseek2-ocr",
|
||||
MODEL_ARCH.DEEPSEEK32: "deepseek32",
|
||||
MODEL_ARCH.DEEPSEEK4: "deepseek4",
|
||||
MODEL_ARCH.CHATGLM: "chatglm",
|
||||
MODEL_ARCH.GLM4: "glm4",
|
||||
MODEL_ARCH.GLM4_MOE: "glm4moe",
|
||||
@@ -1110,6 +1145,9 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.OUTPUT: "output",
|
||||
MODEL_TENSOR.DENSE_2_OUT: "dense_2", # embeddinggemma 2_Dense
|
||||
MODEL_TENSOR.DENSE_3_OUT: "dense_3", # embeddinggemma 2_Dense
|
||||
MODEL_TENSOR.HC_HEAD_FN: "output_hc_fn",
|
||||
MODEL_TENSOR.HC_HEAD_BASE: "output_hc_base",
|
||||
MODEL_TENSOR.HC_HEAD_SCALE: "output_hc_scale",
|
||||
MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
|
||||
MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long",
|
||||
MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short",
|
||||
@@ -1151,6 +1189,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
|
||||
MODEL_TENSOR.FFN_GATE_UP_EXP: "blk.{bid}.ffn_gate_up_exps",
|
||||
MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b",
|
||||
MODEL_TENSOR.FFN_GATE_TID2EID: "blk.{bid}.ffn_gate_tid2eid",
|
||||
MODEL_TENSOR.MOE_LATENT_DOWN: "blk.{bid}.ffn_latent_down", # nemotron 3 super
|
||||
MODEL_TENSOR.MOE_LATENT_UP: "blk.{bid}.ffn_latent_up", # nemotron 3 super
|
||||
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
|
||||
@@ -1236,6 +1275,20 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b",
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
|
||||
MODEL_TENSOR.ATTN_KV: "blk.{bid}.attn_kv",
|
||||
MODEL_TENSOR.ATTN_KV_NORM: "blk.{bid}.attn_kv_a_norm",
|
||||
MODEL_TENSOR.ATTN_OUT_A: "blk.{bid}.attn_output_a",
|
||||
MODEL_TENSOR.ATTN_OUT_B: "blk.{bid}.attn_output_b",
|
||||
MODEL_TENSOR.HC_ATTN_FN: "blk.{bid}.hc_attn_fn",
|
||||
MODEL_TENSOR.HC_ATTN_BASE: "blk.{bid}.hc_attn_base",
|
||||
MODEL_TENSOR.HC_ATTN_SCALE: "blk.{bid}.hc_attn_scale",
|
||||
MODEL_TENSOR.HC_FFN_FN: "blk.{bid}.hc_ffn_fn",
|
||||
MODEL_TENSOR.HC_FFN_BASE: "blk.{bid}.hc_ffn_base",
|
||||
MODEL_TENSOR.HC_FFN_SCALE: "blk.{bid}.hc_ffn_scale",
|
||||
MODEL_TENSOR.ATTN_COMPRESSOR_WKV: "blk.{bid}.attn_compressor_kv",
|
||||
MODEL_TENSOR.ATTN_COMPRESSOR_WGATE: "blk.{bid}.attn_compressor_gate",
|
||||
MODEL_TENSOR.ATTN_COMPRESSOR_APE: "blk.{bid}.attn_compressor_ape",
|
||||
MODEL_TENSOR.ATTN_COMPRESSOR_NORM: "blk.{bid}.attn_compressor_norm",
|
||||
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
|
||||
MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm",
|
||||
MODEL_TENSOR.DEC_ATTN_NORM: "dec.blk.{bid}.attn_norm",
|
||||
@@ -1297,6 +1350,10 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.INDEXER_PROJ: "blk.{bid}.indexer.proj",
|
||||
MODEL_TENSOR.INDEXER_ATTN_K: "blk.{bid}.indexer.attn_k",
|
||||
MODEL_TENSOR.INDEXER_ATTN_Q_B: "blk.{bid}.indexer.attn_q_b",
|
||||
MODEL_TENSOR.INDEXER_COMPRESSOR_WKV: "blk.{bid}.indexer_compressor_kv",
|
||||
MODEL_TENSOR.INDEXER_COMPRESSOR_WGATE: "blk.{bid}.indexer_compressor_gate",
|
||||
MODEL_TENSOR.INDEXER_COMPRESSOR_APE: "blk.{bid}.indexer_compressor_ape",
|
||||
MODEL_TENSOR.INDEXER_COMPRESSOR_NORM: "blk.{bid}.indexer_compressor_norm",
|
||||
# vision
|
||||
MODEL_TENSOR.V_MMPROJ: "mm.{bid}",
|
||||
MODEL_TENSOR.V_MMPROJ_FC: "mm.model.fc",
|
||||
@@ -3137,6 +3194,49 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
|
||||
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
|
||||
],
|
||||
MODEL_ARCH.DEEPSEEK4: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.HC_HEAD_FN,
|
||||
MODEL_TENSOR.HC_HEAD_BASE,
|
||||
MODEL_TENSOR.HC_HEAD_SCALE,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_SINKS,
|
||||
MODEL_TENSOR.ATTN_Q_A,
|
||||
MODEL_TENSOR.ATTN_Q_B,
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM,
|
||||
MODEL_TENSOR.ATTN_KV,
|
||||
MODEL_TENSOR.ATTN_KV_NORM,
|
||||
MODEL_TENSOR.ATTN_OUT_A,
|
||||
MODEL_TENSOR.ATTN_OUT_B,
|
||||
MODEL_TENSOR.HC_ATTN_FN,
|
||||
MODEL_TENSOR.HC_ATTN_BASE,
|
||||
MODEL_TENSOR.HC_ATTN_SCALE,
|
||||
MODEL_TENSOR.HC_FFN_FN,
|
||||
MODEL_TENSOR.HC_FFN_BASE,
|
||||
MODEL_TENSOR.HC_FFN_SCALE,
|
||||
MODEL_TENSOR.ATTN_COMPRESSOR_WKV,
|
||||
MODEL_TENSOR.ATTN_COMPRESSOR_WGATE,
|
||||
MODEL_TENSOR.ATTN_COMPRESSOR_APE,
|
||||
MODEL_TENSOR.ATTN_COMPRESSOR_NORM,
|
||||
MODEL_TENSOR.INDEXER_PROJ,
|
||||
MODEL_TENSOR.INDEXER_ATTN_Q_B,
|
||||
MODEL_TENSOR.INDEXER_COMPRESSOR_WKV,
|
||||
MODEL_TENSOR.INDEXER_COMPRESSOR_WGATE,
|
||||
MODEL_TENSOR.INDEXER_COMPRESSOR_APE,
|
||||
MODEL_TENSOR.INDEXER_COMPRESSOR_NORM,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_GATE_TID2EID,
|
||||
MODEL_TENSOR.FFN_EXP_PROBS_B,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
],
|
||||
MODEL_ARCH.ERNIE4_5_MOE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
@@ -4436,8 +4536,9 @@ class GGMLQuantizationType(IntEnum):
|
||||
|
||||
|
||||
class ExpertGatingFuncType(IntEnum):
|
||||
SOFTMAX = 1
|
||||
SIGMOID = 2
|
||||
SOFTMAX = 1
|
||||
SIGMOID = 2
|
||||
SQRTSOFTPLUS = 4
|
||||
|
||||
|
||||
# TODO: add GGMLFileType from ggml_ftype in ggml.h
|
||||
|
||||
@@ -715,6 +715,9 @@ class GGUFWriter:
|
||||
def add_full_attention_interval(self, interval: int) -> None:
|
||||
self.add_uint32(Keys.LLM.FULL_ATTENTION_INTERVAL.format(arch=self.arch), interval)
|
||||
|
||||
def add_hash_layer_count(self, count: int) -> None:
|
||||
self.add_uint32(Keys.LLM.HASH_LAYER_COUNT.format(arch=self.arch), count)
|
||||
|
||||
def add_feed_forward_length(self, length: int | Sequence[int]) -> None:
|
||||
if isinstance(length, int):
|
||||
self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
|
||||
@@ -940,6 +943,39 @@ class GGUFWriter:
|
||||
def add_sliding_window(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value)
|
||||
|
||||
def add_block_size(self, value: int) -> None:
|
||||
self.add_uint32(Keys.LLM.BLOCK_SIZE.format(arch=self.arch), value)
|
||||
|
||||
def add_target_layers(self, value: Sequence[int]) -> None:
|
||||
self.add_array(Keys.LLM.TARGET_LAYERS.format(arch=self.arch), value)
|
||||
|
||||
def add_target_hidden_size(self, value: int) -> None:
|
||||
self.add_uint32(Keys.LLM.TARGET_HIDDEN_SIZE.format(arch=self.arch), value)
|
||||
|
||||
def add_norm_before_residual(self, value: bool) -> None:
|
||||
self.add_bool(Keys.LLM.NORM_BEFORE_RESIDUAL.format(arch=self.arch), value)
|
||||
|
||||
def add_attention_output_group_count(self, count: int) -> None:
|
||||
self.add_uint32(Keys.Attention.OUTPUT_GROUP_COUNT.format(arch=self.arch), count)
|
||||
|
||||
def add_attention_output_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.OUTPUT_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_attention_compress_ratios(self, values: Sequence[int]) -> None:
|
||||
self.add_array(Keys.Attention.COMPRESS_RATIOS.format(arch=self.arch), values)
|
||||
|
||||
def add_attention_compress_rope_freq_base(self, value: float) -> None:
|
||||
self.add_float32(Keys.Attention.COMPRESS_ROPE_FREQ_BASE.format(arch=self.arch), value)
|
||||
|
||||
def add_hyper_connection_count(self, count: int) -> None:
|
||||
self.add_uint32(Keys.HyperConnection.COUNT.format(arch=self.arch), count)
|
||||
|
||||
def add_hyper_connection_sinkhorn_iterations(self, count: int) -> None:
|
||||
self.add_uint32(Keys.HyperConnection.SINKHORN_ITERATIONS.format(arch=self.arch), count)
|
||||
|
||||
def add_hyper_connection_epsilon(self, value: float) -> None:
|
||||
self.add_float32(Keys.HyperConnection.EPSILON.format(arch=self.arch), value)
|
||||
|
||||
def add_attention_scale(self, value: float) -> None:
|
||||
self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value)
|
||||
|
||||
|
||||
@@ -1283,6 +1283,11 @@ class TensorNameMap:
|
||||
MODEL_TENSOR.ENC_OUTPUT_NORM: (
|
||||
"encoder.final_layer_norm", # t5
|
||||
"layer_norm", # neobert
|
||||
"model.hidden_norm", # dflash
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FC: (
|
||||
"model.fc", # dflash
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CLS: (
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
{%- if not add_generation_prompt is defined -%}
|
||||
{%- set add_generation_prompt = false -%}
|
||||
{%- endif -%}
|
||||
{%- if not thinking is defined -%}
|
||||
{%- if enable_thinking is defined -%}
|
||||
{%- set thinking = enable_thinking -%}
|
||||
{%- else -%}
|
||||
{%- set thinking = false -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- set dsml_token = '|DSML|' -%}
|
||||
{%- set thinking_start_token = '<think>' -%}
|
||||
{%- set thinking_end_token = '</think>' -%}
|
||||
{%- set tools_header = '## Tools\n\nYou have access to a set of tools to help answer the user\'s question. You can invoke tools by writing a "<' + dsml_token + 'tool_calls>" block like the following:\n\n<' + dsml_token + 'tool_calls>\n<' + dsml_token + 'invoke name="$TOOL_NAME">\n<' + dsml_token + 'parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</' + dsml_token + 'parameter>\n...\n</' + dsml_token + 'invoke>\n<' + dsml_token + 'invoke name="$TOOL_NAME2">\n...\n</' + dsml_token + 'invoke>\n</' + dsml_token + 'tool_calls>\n\nString parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.\n\nIf thinking_mode is enabled (triggered by ' + thinking_start_token + '), you MUST output your complete reasoning inside ' + thinking_start_token + '...' + thinking_end_token + ' BEFORE any tool calls or final response.\n\nOtherwise, output directly after ' + thinking_end_token + ' with tool calls or final response.\n\n### Available Tool Schemas\n\n' -%}
|
||||
{%- set tools_footer = '\nYou MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.\n' -%}
|
||||
{%- set ns = namespace(system_prompt='', is_first_sp=true) -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'system' -%}
|
||||
{%- if ns.is_first_sp -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + (message['content'] or '') -%}
|
||||
{%- set ns.is_first_sp = false -%}
|
||||
{%- else -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + '\n\n' + (message['content'] or '') -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if tools is defined and tools -%}
|
||||
{%- set ts = namespace(schemas='') -%}
|
||||
{%- for tool in tools -%}
|
||||
{%- if tool['type'] == 'function' -%}
|
||||
{%- set ts.schemas = ts.schemas + (tool['function'] | tojson) + '\n' -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if ns.system_prompt -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + '\n\n' + tools_header + ts.schemas + tools_footer -%}
|
||||
{%- else -%}
|
||||
{%- set ns.system_prompt = tools_header + ts.schemas + tools_footer -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{{- bos_token -}}
|
||||
{{- ns.system_prompt -}}
|
||||
{%- set last_user_idx = namespace(value=-1) -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'user' or message['role'] == 'developer' or message['role'] == 'tool' -%}
|
||||
{%- set last_user_idx.value = loop.index0 -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- set state = namespace(in_user=false) -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'user' or message['role'] == 'developer' -%}
|
||||
{%- if state.in_user -%}
|
||||
{{- '\n\n' -}}
|
||||
{%- else -%}
|
||||
{{- '<|User|>' -}}
|
||||
{%- set state.in_user = true -%}
|
||||
{%- endif -%}
|
||||
{{- message['content'] or '' -}}
|
||||
{%- elif message['role'] == 'tool' -%}
|
||||
{%- if state.in_user -%}
|
||||
{{- '\n\n' -}}
|
||||
{%- else -%}
|
||||
{{- '<|User|>' -}}
|
||||
{%- set state.in_user = true -%}
|
||||
{%- endif -%}
|
||||
{{- '<tool_result>' + (message['content'] or '') + '</tool_result>' -}}
|
||||
{%- elif message['role'] == 'assistant' -%}
|
||||
{%- set state.in_user = false -%}
|
||||
{{- '<|Assistant|>' -}}
|
||||
{%- set is_after_last_user = loop.index0 > last_user_idx.value -%}
|
||||
{%- if is_after_last_user and thinking -%}
|
||||
{{- thinking_start_token -}}
|
||||
{%- if message['reasoning_content'] is defined and message['reasoning_content'] -%}
|
||||
{{- message['reasoning_content'] -}}
|
||||
{%- endif -%}
|
||||
{{- thinking_end_token -}}
|
||||
{%- else -%}
|
||||
{{- thinking_end_token -}}
|
||||
{%- endif -%}
|
||||
{%- if message['content'] is defined and message['content'] -%}
|
||||
{{- message['content'] -}}
|
||||
{%- endif -%}
|
||||
{%- if message['tool_calls'] -%}
|
||||
{{- '\n\n<' + dsml_token + 'tool_calls>\n' -}}
|
||||
{%- for tool in message['tool_calls'] -%}
|
||||
{%- set func = tool['function'] -%}
|
||||
{{- '<' + dsml_token + 'invoke name="' + func['name'] + '">\n' -}}
|
||||
{%- set args = func['arguments'] -%}
|
||||
{%- if args is string -%}
|
||||
{%- set args = args | from_json -%}
|
||||
{%- endif -%}
|
||||
{%- for key, val in args.items() -%}
|
||||
{%- if val is string -%}
|
||||
{{- '<' + dsml_token + 'parameter name="' + key + '" string="true">' + val + '</' + dsml_token + 'parameter>\n' -}}
|
||||
{%- else -%}
|
||||
{{- '<' + dsml_token + 'parameter name="' + key + '" string="false">' + (val | tojson) + '</' + dsml_token + 'parameter>\n' -}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{- '</' + dsml_token + 'invoke>\n' -}}
|
||||
{%- endfor -%}
|
||||
{{- '</' + dsml_token + 'tool_calls>' -}}
|
||||
{%- endif -%}
|
||||
{{- '<|end▁of▁sentence|>' -}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{{- '<|Assistant|>' -}}
|
||||
{%- if thinking -%}
|
||||
{{- thinking_start_token -}}
|
||||
{%- else -%}
|
||||
{{- thinking_end_token -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
@@ -0,0 +1,179 @@
|
||||
{{- bos_token }}{%- if tools %}
|
||||
{%- set tool_definitions %}
|
||||
{{- "# Tools\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
||||
{%- for tool in tools %}
|
||||
{{- "\n" }}
|
||||
{{- tool | tojson(ensure_ascii=False) }}
|
||||
{%- endfor %}
|
||||
{{- '\n</tools>\n\nTool usage guidelines:\n- You may call zero or more functions. If no function calls are needed, just answer normally and do not include any <function ... </function>.\n- When calling a function, return an XML object within <function ... </function> using:\n<function name="function-name"><param name="param-name">param-value</param></function>\n- param-value may be multi-line. If it contains <, & or newline characters, wrap it in a CDATA block: <param name="param-name"><![CDATA[...multi-line value...]]></param>' }}
|
||||
{%- endset %}
|
||||
|
||||
{{- '<|im_start|>system\n' }}
|
||||
{%- if messages[0].role == 'system' %}
|
||||
{%- if '<tool_def_sep>' in messages[0].content %}
|
||||
{{- messages[0].content.replace('<tool_def_sep>', tool_definitions) }}
|
||||
{%- else %}
|
||||
{{- messages[0].content + '\n\n' + tool_definitions }}
|
||||
{%- endif %}
|
||||
{%- else %}
|
||||
{{- tool_definitions.lstrip() }}
|
||||
{%- endif %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- else %}
|
||||
{%- if messages[0].role == 'system' %}
|
||||
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
||||
{%- for message in messages[::-1] %}
|
||||
{%- set index = (messages|length - 1) - loop.index0 %}
|
||||
{%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
|
||||
{%- set ns.multi_step_tool = false %}
|
||||
{%- set ns.last_query_index = index %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- for message in messages %}
|
||||
{%- if message.content is string %}
|
||||
{%- set content = message.content %}
|
||||
{%- else %}
|
||||
{%- set content = '' %}
|
||||
{%- endif %}
|
||||
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
|
||||
{%- elif message.role == "assistant" %}
|
||||
{%- set reasoning_content = '' %}
|
||||
{%- if message.reasoning_content is string %}
|
||||
{%- set reasoning_content = message.reasoning_content %}
|
||||
{%- else %}
|
||||
{%- if '</think>' in content %}
|
||||
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
||||
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
|
||||
{%- if message.tool_calls %}
|
||||
{%- set content_parts = content.split('<tool_sep>') %}
|
||||
{%- set processed_content = content_parts[0] %}
|
||||
{%- set tool_calls_count = message.tool_calls|length %}
|
||||
{%- set tool_sep_count = content_parts|length - 1 %}
|
||||
{%- set min_count = [tool_calls_count, tool_sep_count]|min %}
|
||||
|
||||
{%- for i in range(1, content_parts|length) %}
|
||||
{%- set tool_index = i - 1 %}
|
||||
{%- if tool_index < tool_calls_count %}
|
||||
{%- set tool_call = message.tool_calls[tool_index] %}
|
||||
{%- if tool_call.function %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{%- set single_tool_xml %}
|
||||
{{- '<function name="' ~ tool_call.name ~ '">' }}
|
||||
{%- if tool_call.arguments %}
|
||||
{%- set args_dict = tool_call.arguments %}
|
||||
{%- for param_name, param_value in args_dict.items() %}
|
||||
{{- '<param name="' ~ param_name ~ '">' }}
|
||||
{%- if param_value is string and ('<' in param_value or '&' in param_value or '\n' in param_value) %}
|
||||
{{- '<![CDATA[' + param_value + ']]>' }}
|
||||
{%- else %}
|
||||
{{- param_value }}
|
||||
{%- endif %}
|
||||
{{- '</param>' }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- '</function>' }}
|
||||
{%- endset %}
|
||||
{%- set processed_content = processed_content + single_tool_xml + content_parts[i] %}
|
||||
{%- else %}
|
||||
{%- set processed_content = processed_content + content_parts[i] %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
|
||||
{%- if tool_calls_count > tool_sep_count %}
|
||||
{%- for remaining_index in range(tool_sep_count, tool_calls_count) %}
|
||||
{%- set tool_call = message.tool_calls[remaining_index] %}
|
||||
{%- if tool_call.function %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{%- set remaining_tool_xml %}
|
||||
{{- '<function name="' ~ tool_call.name ~ '">' }}
|
||||
{%- if tool_call.arguments %}
|
||||
{%- set args_dict = tool_call.arguments %}
|
||||
{%- for param_name, param_value in args_dict.items() %}
|
||||
{{- '<param name="' ~ param_name ~ '">' }}
|
||||
{%- if param_value is string and ('<' in param_value or '&' in param_value or '\n' in param_value) %}
|
||||
{{- '<![CDATA[' + param_value + ']]>' }}
|
||||
{%- else %}
|
||||
{{- param_value }}
|
||||
{%- endif %}
|
||||
{{- '</param>' }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- '</function>' }}
|
||||
{%- endset %}
|
||||
{%- set processed_content = processed_content + remaining_tool_xml %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
|
||||
{%- set content = processed_content %}
|
||||
{%- endif %}
|
||||
|
||||
{%- if loop.index0 > ns.last_query_index %}
|
||||
{%- if reasoning_content %}
|
||||
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
|
||||
{%- else %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + content }}
|
||||
{%- endif %}
|
||||
{%- else %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + content }}
|
||||
{%- endif %}
|
||||
|
||||
{%- if message.tool_calls and not has_tool_sep %}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{%- if (loop.first and content) or (not loop.first) %}
|
||||
{{- '\n' }}
|
||||
{%- endif %}
|
||||
{%- if tool_call.function %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{{- '<function name="' ~ tool_call.name ~ '">' }}
|
||||
{%- if tool_call.arguments %}
|
||||
{%- set args_dict = tool_call.arguments %}
|
||||
{%- for param_name, param_value in args_dict.items() %}
|
||||
{{- '<param name="' ~ param_name ~ '">' }}
|
||||
{%- if param_value is string and ('<' in param_value or '&' in param_value or '\n' in param_value) %}
|
||||
{{- '<![CDATA[' + param_value + ']]>' }}
|
||||
{%- else %}
|
||||
{{- param_value }}
|
||||
{%- endif %}
|
||||
{{- '</param>' }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- '</function>' }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- elif message.role == "tool" %}
|
||||
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
||||
{{- '<|im_start|>user' }}
|
||||
{%- endif %}
|
||||
{{- '\n<tool_response>\n' }}
|
||||
{%- if message.content is string %}
|
||||
{{- content }}
|
||||
{%- else %}
|
||||
{{- message.content | tojson(ensure_ascii=False) }}
|
||||
{%- endif %}
|
||||
{{- '\n</tool_response>' }}
|
||||
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant\n' }}
|
||||
{%- if enable_thinking is defined %}
|
||||
{%- if enable_thinking is false %}
|
||||
{{- '<think>\n\n</think>\n\n' }}
|
||||
{%- elif enable_thinking is true %}
|
||||
{{- '<think>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
@@ -25,6 +25,7 @@ add_library(llama
|
||||
llama-kv-cache.cpp
|
||||
llama-kv-cache-iswa.cpp
|
||||
llama-kv-cache-dsa.cpp
|
||||
llama-kv-cache-dsv4.cpp
|
||||
llama-memory.cpp
|
||||
llama-memory-hybrid.cpp
|
||||
llama-memory-hybrid-iswa.cpp
|
||||
|
||||
@@ -77,6 +77,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
|
||||
{ LLM_ARCH_DEEPSEEK2OCR, "deepseek2-ocr" },
|
||||
{ LLM_ARCH_DEEPSEEK32, "deepseek32" },
|
||||
{ LLM_ARCH_DEEPSEEK4, "deepseek4" },
|
||||
{ LLM_ARCH_CHATGLM, "chatglm" },
|
||||
{ LLM_ARCH_GLM4, "glm4" },
|
||||
{ LLM_ARCH_GLM4_MOE, "glm4moe" },
|
||||
@@ -250,9 +251,19 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, "%s.attention.indexer.head_count" },
|
||||
{ LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" },
|
||||
{ LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" },
|
||||
{ LLM_KV_ATTENTION_OUTPUT_GROUP_COUNT, "%s.attention.output_group_count" },
|
||||
{ LLM_KV_ATTENTION_OUTPUT_LORA_RANK, "%s.attention.output_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_COMPRESS_ROPE_FREQ_BASE, "%s.attention.compress_rope_freq_base" },
|
||||
{ LLM_KV_ATTENTION_COMPRESS_RATIOS, "%s.attention.compress_ratios" },
|
||||
{ LLM_KV_ATTENTION_SHARED_KV_LAYERS, "%s.attention.shared_kv_layers" },
|
||||
{ LLM_KV_ATTENTION_RECURRENT_LAYERS, "%s.attention.recurrent_layers" },
|
||||
|
||||
{ LLM_KV_HYPER_CONNECTION_COUNT, "%s.hyper_connection.count" },
|
||||
{ LLM_KV_HYPER_CONNECTION_SINKHORN_ITERATIONS, "%s.hyper_connection.sinkhorn_iterations" },
|
||||
{ LLM_KV_HYPER_CONNECTION_EPSILON, "%s.hyper_connection.epsilon" },
|
||||
|
||||
{ LLM_KV_HASH_LAYER_COUNT, "%s.hash_layer_count" },
|
||||
|
||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||
{ LLM_KV_ROPE_DIMENSION_COUNT_SWA, "%s.rope.dimension_count_swa" },
|
||||
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
||||
@@ -440,6 +451,23 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
|
||||
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
|
||||
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
|
||||
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
|
||||
{ LLM_TENSOR_ATTN_KV, "blk.%d.attn_kv" },
|
||||
{ LLM_TENSOR_ATTN_KV_NORM, "blk.%d.attn_kv_a_norm" },
|
||||
{ LLM_TENSOR_ATTN_OUT_A, "blk.%d.attn_output_a" },
|
||||
{ LLM_TENSOR_ATTN_OUT_B, "blk.%d.attn_output_b" },
|
||||
{ LLM_TENSOR_HC_HEAD_FN, "output_hc_fn" },
|
||||
{ LLM_TENSOR_HC_HEAD_BASE, "output_hc_base" },
|
||||
{ LLM_TENSOR_HC_HEAD_SCALE, "output_hc_scale" },
|
||||
{ LLM_TENSOR_HC_ATTN_FN, "blk.%d.hc_attn_fn" },
|
||||
{ LLM_TENSOR_HC_ATTN_BASE, "blk.%d.hc_attn_base" },
|
||||
{ LLM_TENSOR_HC_ATTN_SCALE, "blk.%d.hc_attn_scale" },
|
||||
{ LLM_TENSOR_HC_FFN_FN, "blk.%d.hc_ffn_fn" },
|
||||
{ LLM_TENSOR_HC_FFN_BASE, "blk.%d.hc_ffn_base" },
|
||||
{ LLM_TENSOR_HC_FFN_SCALE, "blk.%d.hc_ffn_scale" },
|
||||
{ LLM_TENSOR_ATTN_COMPRESSOR_WKV, "blk.%d.attn_compressor_kv" },
|
||||
{ LLM_TENSOR_ATTN_COMPRESSOR_WGATE, "blk.%d.attn_compressor_gate" },
|
||||
{ LLM_TENSOR_ATTN_COMPRESSOR_APE, "blk.%d.attn_compressor_ape" },
|
||||
{ LLM_TENSOR_ATTN_COMPRESSOR_NORM, "blk.%d.attn_compressor_norm" },
|
||||
{ LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "per_layer_token_embd" },
|
||||
{ LLM_TENSOR_PER_LAYER_MODEL_PROJ, "per_layer_model_proj" },
|
||||
{ LLM_TENSOR_PER_LAYER_PROJ_NORM, "per_layer_proj_norm" },
|
||||
@@ -566,6 +594,11 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
|
||||
{ LLM_TENSOR_INDEXER_PROJ, "blk.%d.indexer.proj" },
|
||||
{ LLM_TENSOR_INDEXER_ATTN_K, "blk.%d.indexer.attn_k" },
|
||||
{ LLM_TENSOR_INDEXER_ATTN_Q_B, "blk.%d.indexer.attn_q_b" },
|
||||
{ LLM_TENSOR_INDEXER_COMPRESSOR_WKV, "blk.%d.indexer_compressor_kv" },
|
||||
{ LLM_TENSOR_INDEXER_COMPRESSOR_WGATE, "blk.%d.indexer_compressor_gate" },
|
||||
{ LLM_TENSOR_INDEXER_COMPRESSOR_APE, "blk.%d.indexer_compressor_ape" },
|
||||
{ LLM_TENSOR_INDEXER_COMPRESSOR_NORM, "blk.%d.indexer_compressor_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_TID2EID, "blk.%d.ffn_gate_tid2eid" },
|
||||
{ LLM_TENSOR_MASKED_EMBD_CENTROIDS, "masked_embd_centroids" },
|
||||
{ LLM_TENSOR_MASKED_EMBD_ORDERING, "masked_embd_ordering" },
|
||||
{ LLM_TENSOR_FC, "fc" },
|
||||
@@ -616,6 +649,23 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_KV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_KV_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ATTN_OUT_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_OUT_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_HC_HEAD_FN, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_HC_HEAD_BASE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_HC_HEAD_SCALE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_HC_ATTN_FN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_HC_ATTN_BASE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_HC_ATTN_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_HC_FFN_FN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_HC_FFN_BASE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_HC_FFN_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ATTN_COMPRESSOR_WKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_COMPRESSOR_WGATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_COMPRESSOR_APE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_ATTN_COMPRESSOR_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_SINKS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SCALE}},
|
||||
@@ -779,6 +829,11 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_INDEXER_COMPRESSOR_WKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_INDEXER_COMPRESSOR_WGATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_INDEXER_COMPRESSOR_APE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_INDEXER_COMPRESSOR_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_FFN_GATE_TID2EID, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_NEXTN_PROJ_PRE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_NEXTN_PROJ_POST, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
// NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the
|
||||
@@ -933,6 +988,7 @@ bool llm_arch_supports_sm_tensor(const llm_arch & arch) {
|
||||
case LLM_ARCH_OLMOE:
|
||||
case LLM_ARCH_DEEPSEEK2:
|
||||
case LLM_ARCH_DEEPSEEK32:
|
||||
case LLM_ARCH_DEEPSEEK4:
|
||||
case LLM_ARCH_GLM_DSA:
|
||||
case LLM_ARCH_BITNET:
|
||||
case LLM_ARCH_T5:
|
||||
|
||||
@@ -82,6 +82,7 @@ enum llm_arch {
|
||||
LLM_ARCH_DEEPSEEK2,
|
||||
LLM_ARCH_DEEPSEEK2OCR,
|
||||
LLM_ARCH_DEEPSEEK32,
|
||||
LLM_ARCH_DEEPSEEK4,
|
||||
LLM_ARCH_CHATGLM,
|
||||
LLM_ARCH_GLM4,
|
||||
LLM_ARCH_GLM4_MOE,
|
||||
@@ -255,9 +256,19 @@ enum llm_kv {
|
||||
LLM_KV_ATTENTION_INDEXER_HEAD_COUNT,
|
||||
LLM_KV_ATTENTION_INDEXER_KEY_LENGTH,
|
||||
LLM_KV_ATTENTION_INDEXER_TOP_K,
|
||||
LLM_KV_ATTENTION_OUTPUT_GROUP_COUNT,
|
||||
LLM_KV_ATTENTION_OUTPUT_LORA_RANK,
|
||||
LLM_KV_ATTENTION_COMPRESS_ROPE_FREQ_BASE,
|
||||
LLM_KV_ATTENTION_COMPRESS_RATIOS,
|
||||
LLM_KV_ATTENTION_SHARED_KV_LAYERS,
|
||||
LLM_KV_ATTENTION_RECURRENT_LAYERS,
|
||||
|
||||
LLM_KV_HYPER_CONNECTION_COUNT,
|
||||
LLM_KV_HYPER_CONNECTION_SINKHORN_ITERATIONS,
|
||||
LLM_KV_HYPER_CONNECTION_EPSILON,
|
||||
|
||||
LLM_KV_HASH_LAYER_COUNT,
|
||||
|
||||
LLM_KV_ROPE_DIMENSION_COUNT,
|
||||
LLM_KV_ROPE_DIMENSION_COUNT_SWA,
|
||||
LLM_KV_ROPE_DIMENSION_SECTIONS,
|
||||
@@ -501,10 +512,27 @@ enum llm_tensor {
|
||||
LLM_TENSOR_ATTN_Q_B,
|
||||
LLM_TENSOR_ATTN_KV_A_MQA,
|
||||
LLM_TENSOR_ATTN_KV_B,
|
||||
LLM_TENSOR_ATTN_KV,
|
||||
LLM_TENSOR_ATTN_KV_NORM,
|
||||
LLM_TENSOR_ATTN_OUT_A,
|
||||
LLM_TENSOR_ATTN_OUT_B,
|
||||
LLM_TENSOR_ATTN_K_B,
|
||||
LLM_TENSOR_ATTN_V_B,
|
||||
LLM_TENSOR_ATTN_Q_A_NORM,
|
||||
LLM_TENSOR_ATTN_KV_A_NORM,
|
||||
LLM_TENSOR_HC_HEAD_FN,
|
||||
LLM_TENSOR_HC_HEAD_BASE,
|
||||
LLM_TENSOR_HC_HEAD_SCALE,
|
||||
LLM_TENSOR_HC_ATTN_FN,
|
||||
LLM_TENSOR_HC_ATTN_BASE,
|
||||
LLM_TENSOR_HC_ATTN_SCALE,
|
||||
LLM_TENSOR_HC_FFN_FN,
|
||||
LLM_TENSOR_HC_FFN_BASE,
|
||||
LLM_TENSOR_HC_FFN_SCALE,
|
||||
LLM_TENSOR_ATTN_COMPRESSOR_WKV,
|
||||
LLM_TENSOR_ATTN_COMPRESSOR_WGATE,
|
||||
LLM_TENSOR_ATTN_COMPRESSOR_APE,
|
||||
LLM_TENSOR_ATTN_COMPRESSOR_NORM,
|
||||
LLM_TENSOR_ATTN_SUB_NORM,
|
||||
LLM_TENSOR_FFN_SUB_NORM,
|
||||
LLM_TENSOR_DEC_ATTN_NORM,
|
||||
@@ -566,6 +594,11 @@ enum llm_tensor {
|
||||
LLM_TENSOR_INDEXER_PROJ,
|
||||
LLM_TENSOR_INDEXER_ATTN_K,
|
||||
LLM_TENSOR_INDEXER_ATTN_Q_B,
|
||||
LLM_TENSOR_INDEXER_COMPRESSOR_WKV,
|
||||
LLM_TENSOR_INDEXER_COMPRESSOR_WGATE,
|
||||
LLM_TENSOR_INDEXER_COMPRESSOR_APE,
|
||||
LLM_TENSOR_INDEXER_COMPRESSOR_NORM,
|
||||
LLM_TENSOR_FFN_GATE_TID2EID,
|
||||
LLM_TENSOR_NEXTN_PROJ_PRE,
|
||||
LLM_TENSOR_NEXTN_PROJ_POST,
|
||||
LLM_TENSOR_NEXTN_EH_PROJ,
|
||||
|
||||
@@ -2321,7 +2321,11 @@ void llama_context::output_reorder() {
|
||||
//
|
||||
|
||||
uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
|
||||
if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI_LINEAR || model.arch == LLM_ARCH_QWEN35 || model.arch == LLM_ARCH_QWEN35MOE) {
|
||||
if (model.arch == LLM_ARCH_QWEN3NEXT ||
|
||||
model.arch == LLM_ARCH_KIMI_LINEAR ||
|
||||
model.arch == LLM_ARCH_QWEN35 ||
|
||||
model.arch == LLM_ARCH_QWEN35MOE ||
|
||||
model.arch == LLM_ARCH_DEEPSEEK4) {
|
||||
return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors());
|
||||
}
|
||||
uint32_t res = std::max<uint32_t>(1024u, 8u*model.n_tensors());
|
||||
|
||||
+352
-23
@@ -8,6 +8,7 @@
|
||||
#include "llama-kv-cache.h"
|
||||
#include "llama-kv-cache-iswa.h"
|
||||
#include "llama-kv-cache-dsa.h"
|
||||
#include "llama-kv-cache-dsv4.h"
|
||||
#include "llama-memory-hybrid.h"
|
||||
#include "llama-memory-hybrid-iswa.h"
|
||||
#include "llama-memory-recurrent.h"
|
||||
@@ -17,6 +18,7 @@
|
||||
#include <cstring>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
// dedup helpers
|
||||
@@ -568,7 +570,9 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
// base tensors may not be allocated if there are no non-SWA attention layers
|
||||
if (self_k_idxs && self_k_idxs->buffer) {
|
||||
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
|
||||
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
|
||||
if (self_v_idxs) {
|
||||
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
|
||||
}
|
||||
}
|
||||
|
||||
// the kq mask guards on its own buffer: shared cells leave idxs unbacked while the mask stays live
|
||||
@@ -579,7 +583,9 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
// swa tensors may not be allocated if there are no SWA attention layers
|
||||
if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
|
||||
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
|
||||
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
|
||||
if (self_v_idxs_swa) {
|
||||
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
|
||||
}
|
||||
}
|
||||
|
||||
if (self_kq_mask_swa && self_kq_mask_swa->buffer) {
|
||||
@@ -633,6 +639,283 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
|
||||
return res;
|
||||
}
|
||||
|
||||
static void dsv4_set_i64(ggml_tensor * dst, const std::vector<int64_t> & src) {
|
||||
if (!dst || !dst->buffer) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(dst->ne[0] == (int64_t) src.size());
|
||||
ggml_backend_tensor_set(dst, src.data(), 0, src.size()*ggml_element_size(dst));
|
||||
}
|
||||
|
||||
static void dsv4_set_i32(ggml_tensor * dst, const std::vector<int32_t> & src) {
|
||||
if (!dst || !dst->buffer) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(dst->ne[0] == (int64_t) src.size());
|
||||
ggml_backend_tensor_set(dst, src.data(), 0, src.size()*ggml_element_size(dst));
|
||||
}
|
||||
|
||||
static void dsv4_set_kq_mask(
|
||||
ggml_tensor * dst,
|
||||
const llama_kv_cache_dsv4_context::comp_plan & plan,
|
||||
uint32_t n_tokens,
|
||||
int64_t n_stream) {
|
||||
if (!dst || !dst->buffer) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(n_stream > 0);
|
||||
GGML_ASSERT(n_tokens%n_stream == 0);
|
||||
GGML_ASSERT(dst->ne[0] == plan.n_kv);
|
||||
GGML_ASSERT(dst->ne[1] == (int64_t) n_tokens/n_stream);
|
||||
GGML_ASSERT(dst->ne[2] == 1);
|
||||
GGML_ASSERT(dst->ne[3] == n_stream);
|
||||
GGML_ASSERT((int64_t) plan.n_visible.size() == (int64_t) n_tokens);
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||
|
||||
float * data = (float *) dst->data;
|
||||
|
||||
for (int64_t i = 0; i < (int64_t) n_tokens; ++i) {
|
||||
const int32_t n_visible = plan.n_visible[i];
|
||||
|
||||
for (int64_t j = 0; j < dst->ne[0]; ++j) {
|
||||
data[i*dst->ne[0] + j] = j < n_visible ? 0.0f : -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static ggml_tensor * dsv4_build_raw_kq_mask(
|
||||
ggml_context * ctx,
|
||||
const llama_kv_cache_dsv4_raw_context * mctx,
|
||||
const llama_ubatch & ubatch,
|
||||
const llama_cparams & cparams,
|
||||
int64_t n_stream) {
|
||||
const auto n_kv = mctx->get_n_kv();
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
|
||||
GGML_ASSERT(n_stream > 0);
|
||||
GGML_ASSERT(n_tokens%n_stream == 0);
|
||||
|
||||
const bool use_fattn = cparams.flash_attn && (!cparams.kv_unified || n_stream == 1);
|
||||
const auto type = use_fattn ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
|
||||
ggml_tensor * res = ggml_new_tensor_4d(ctx, type, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
ggml_set_input(res);
|
||||
ggml_set_name(res, "attn_inp_kq_mask");
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
static bool dsv4_can_reuse_raw_kq_mask(
|
||||
ggml_tensor * kq_mask,
|
||||
const llama_kv_cache_dsv4_raw_context * mctx,
|
||||
const llama_ubatch & ubatch,
|
||||
int64_t n_stream) {
|
||||
const auto n_kv = mctx->get_n_kv();
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
|
||||
GGML_ASSERT(n_stream > 0);
|
||||
|
||||
bool res = true;
|
||||
|
||||
res &= (kq_mask->ne[0] == n_kv);
|
||||
res &= (kq_mask->ne[1] == n_tokens/n_stream);
|
||||
res &= (kq_mask->ne[2] == 1);
|
||||
res &= (kq_mask->ne[3] == n_stream);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
static std::string dsv4_plan_positions(const std::vector<int32_t> & values) {
|
||||
std::ostringstream ss;
|
||||
ss << "[";
|
||||
for (size_t i = 0; i < values.size(); ++i) {
|
||||
if (i > 0) {
|
||||
ss << ", ";
|
||||
}
|
||||
ss << values[i];
|
||||
}
|
||||
ss << "]";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
static bool dsv4_compress_debug() {
|
||||
static const bool debug = []() {
|
||||
const char * env = getenv("LLAMA_DSV4_COMPRESS_DEBUG");
|
||||
return env && atoi(env) > 0;
|
||||
}();
|
||||
|
||||
return debug;
|
||||
}
|
||||
|
||||
static void dsv4_set_comp_inputs(
|
||||
const llm_graph_input_dsv4::comp_input & inp,
|
||||
const llama_kv_cache_dsv4_context::comp_plan & plan,
|
||||
const char * name,
|
||||
bool debug,
|
||||
uint32_t n_tokens,
|
||||
int64_t n_stream) {
|
||||
dsv4_set_i32(inp.state_pos, plan.state_pos);
|
||||
dsv4_set_i32(inp.state_persist_src_idxs, plan.state_persist_src_idxs);
|
||||
dsv4_set_i32(inp.state_persist_dst_idxs, plan.state_persist_dst_idxs);
|
||||
dsv4_set_i32(inp.state_read_idxs, plan.state_read_idxs);
|
||||
dsv4_set_i64(inp.state_write_idxs, plan.state_write_idxs);
|
||||
dsv4_set_i32(inp.state_write_pos, plan.state_write_pos);
|
||||
dsv4_set_kq_mask(inp.kq_mask, plan, n_tokens, n_stream);
|
||||
|
||||
if (debug || dsv4_compress_debug()) {
|
||||
LLAMA_LOG_INFO("%s: %s n_tokens=%u, n_stream=%d, state_persist_dst=%s, state_write_pos=%s\n",
|
||||
__func__, name, n_tokens, (int) n_stream,
|
||||
dsv4_plan_positions(plan.state_persist_dst_idxs).c_str(),
|
||||
dsv4_plan_positions(plan.state_write_pos).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
static bool dsv4_can_reuse_tensor_1d(ggml_tensor * t, int64_t ne0) {
|
||||
return (t == nullptr && ne0 == 0) || (t != nullptr && t->ne[0] == ne0);
|
||||
}
|
||||
|
||||
static bool dsv4_can_reuse_kq_mask(
|
||||
ggml_tensor * t,
|
||||
const llama_kv_cache_dsv4_context::comp_plan & plan,
|
||||
uint32_t n_tokens,
|
||||
int64_t n_stream) {
|
||||
if (plan.n_kv == 0) {
|
||||
return t == nullptr;
|
||||
}
|
||||
|
||||
GGML_ASSERT(n_stream > 0);
|
||||
|
||||
return t != nullptr &&
|
||||
t->ne[0] == plan.n_kv &&
|
||||
t->ne[1] == (int64_t) n_tokens/n_stream &&
|
||||
t->ne[2] == 1 &&
|
||||
t->ne[3] == n_stream;
|
||||
}
|
||||
|
||||
static bool dsv4_can_reuse_comp_input(
|
||||
const llm_graph_input_dsv4::comp_input & inp,
|
||||
const llama_kv_cache_dsv4_context::comp_plan & plan,
|
||||
uint32_t n_tokens,
|
||||
int64_t n_stream) {
|
||||
bool res = true;
|
||||
res &= dsv4_can_reuse_tensor_1d(inp.state_pos, plan.state_pos.size());
|
||||
res &= dsv4_can_reuse_tensor_1d(inp.state_persist_src_idxs, plan.state_persist_src_idxs.size());
|
||||
res &= dsv4_can_reuse_tensor_1d(inp.state_persist_dst_idxs, plan.state_persist_dst_idxs.size());
|
||||
res &= dsv4_can_reuse_tensor_1d(inp.state_read_idxs, plan.state_read_idxs.size());
|
||||
res &= dsv4_can_reuse_tensor_1d(inp.state_write_idxs, plan.state_write_idxs.size());
|
||||
res &= dsv4_can_reuse_tensor_1d(inp.state_write_pos, plan.state_write_pos.size());
|
||||
res &= dsv4_can_reuse_kq_mask(inp.kq_mask, plan, n_tokens, n_stream);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
static ggml_tensor * dsv4_build_input_1d(
|
||||
ggml_context * ctx,
|
||||
ggml_type type,
|
||||
int64_t ne0,
|
||||
const std::string & name) {
|
||||
if (ne0 == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ggml_tensor * res = ggml_new_tensor_1d(ctx, type, ne0);
|
||||
ggml_set_input(res);
|
||||
ggml_set_name(res, name.c_str());
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
static void dsv4_build_comp_inputs(
|
||||
ggml_context * ctx,
|
||||
llm_graph_input_dsv4::comp_input & inp,
|
||||
const llama_kv_cache_dsv4_context::comp_plan & plan,
|
||||
const char * name,
|
||||
int64_t n_stream) {
|
||||
inp.state_pos = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_pos.size(), std::string("dsv4_") + name + "_state_pos");
|
||||
inp.state_persist_src_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_persist_src_idxs.size(), std::string("dsv4_") + name + "_state_persist_src_idxs");
|
||||
inp.state_persist_dst_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_persist_dst_idxs.size(), std::string("dsv4_") + name + "_state_persist_dst_idxs");
|
||||
inp.state_read_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_read_idxs.size(), std::string("dsv4_") + name + "_state_read_idxs");
|
||||
inp.state_write_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I64, plan.state_write_idxs.size(), std::string("dsv4_") + name + "_state_write_idxs");
|
||||
inp.state_write_pos = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_write_pos.size(), std::string("dsv4_") + name + "_state_write_pos");
|
||||
|
||||
if (plan.n_kv > 0) {
|
||||
const int64_t n_tokens = (int64_t) plan.n_visible.size();
|
||||
|
||||
GGML_ASSERT(n_stream > 0);
|
||||
GGML_ASSERT(n_tokens%n_stream == 0);
|
||||
|
||||
inp.kq_mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, plan.n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
ggml_set_input(inp.kq_mask);
|
||||
ggml_set_name(inp.kq_mask, (std::string("dsv4_") + name + "_kq_mask").c_str());
|
||||
}
|
||||
}
|
||||
|
||||
void llm_graph_input_dsv4_raw::set_input(const llama_ubatch * ubatch) {
|
||||
if (self_k_idxs && self_k_idxs->buffer) {
|
||||
mctx->set_input_k_idxs(self_k_idxs);
|
||||
}
|
||||
|
||||
if (self_kq_mask && self_kq_mask->buffer) {
|
||||
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
if (self_k_rot) {
|
||||
mctx->set_input_k_rot(self_k_rot);
|
||||
}
|
||||
}
|
||||
|
||||
void llm_graph_input_dsv4::set_input(const llama_ubatch * ubatch) {
|
||||
const auto & plan_csa = mctx->get_csa_plan(*ubatch);
|
||||
const auto & plan_hca = mctx->get_hca_plan(*ubatch);
|
||||
const auto & plan_lid = mctx->get_lid_plan(*ubatch);
|
||||
const int64_t n_stream = plan_csa.n_stream;
|
||||
|
||||
inp_raw->mctx = mctx->get_raw();
|
||||
inp_raw->set_input(ubatch);
|
||||
|
||||
dsv4_set_comp_inputs(inp_csa, plan_csa, "csa", debug > 0, ubatch->n_tokens, n_stream);
|
||||
dsv4_set_comp_inputs(inp_hca, plan_hca, "hca", debug > 0, ubatch->n_tokens, n_stream);
|
||||
dsv4_set_comp_inputs(inp_lid, plan_lid, "lid", debug > 0, ubatch->n_tokens, n_stream);
|
||||
|
||||
if (inp_lid.k_rot && inp_lid.k_rot->buffer) {
|
||||
mctx->get_lid()->set_input_k_rot(inp_lid.k_rot);
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_graph_input_dsv4::can_reuse(const llm_graph_params & params) {
|
||||
const auto * mctx = static_cast<const llama_kv_cache_dsv4_context *>(params.mctx);
|
||||
|
||||
this->mctx = mctx;
|
||||
inp_raw->mctx = mctx->get_raw();
|
||||
|
||||
bool res = true;
|
||||
|
||||
const auto & plan_csa = mctx->get_csa_plan(params.ubatch);
|
||||
const auto & plan_hca = mctx->get_hca_plan(params.ubatch);
|
||||
const auto & plan_lid = mctx->get_lid_plan(params.ubatch);
|
||||
const int64_t n_stream = plan_csa.n_stream;
|
||||
|
||||
const auto * raw_ctx = mctx->get_raw();
|
||||
inp_raw->mctx = raw_ctx;
|
||||
|
||||
if (inp_raw->self_k_idxs && inp_raw->self_k_idxs->buffer) {
|
||||
res &= inp_raw->self_k_idxs->ne[0] == raw_ctx->get_n_write();
|
||||
}
|
||||
if (inp_raw->self_kq_mask && inp_raw->self_kq_mask->buffer) {
|
||||
res &= dsv4_can_reuse_raw_kq_mask(inp_raw->self_kq_mask, raw_ctx, params.ubatch, n_stream);
|
||||
}
|
||||
|
||||
res &= dsv4_can_reuse_comp_input(inp_csa, plan_csa, params.ubatch.n_tokens, n_stream);
|
||||
res &= dsv4_can_reuse_comp_input(inp_hca, plan_hca, params.ubatch.n_tokens, n_stream);
|
||||
res &= dsv4_can_reuse_comp_input(inp_lid, plan_lid, params.ubatch.n_tokens, n_stream);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
||||
GGML_ASSERT(cross_kq_mask);
|
||||
|
||||
@@ -1351,20 +1634,24 @@ ggml_tensor * llm_graph_context::build_ffn(
|
||||
switch (type_op) {
|
||||
case LLM_FFN_SILU:
|
||||
if (gate && type_gate == LLM_FFN_PAR) {
|
||||
// Step35: HF clamps gate (after SiLU) and up before multiplication
|
||||
if (arch == LLM_ARCH_STEP35 && il >= 0) {
|
||||
if (il >= 0) {
|
||||
const float limit = hparams.swiglu_clamp_shexp[il];
|
||||
constexpr float eps = 1e-6f;
|
||||
if (limit > eps) {
|
||||
ggml_tensor * gate_act = ggml_silu(ctx0, cur);
|
||||
cb(gate_act, "ffn_silu", il);
|
||||
gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
|
||||
cb(gate_act, "ffn_silu_clamped", il);
|
||||
|
||||
tmp = ggml_clamp(ctx0, tmp, -limit, limit);
|
||||
cb(tmp, "ffn_up_clamped", il);
|
||||
|
||||
cur = ggml_mul(ctx0, gate_act, tmp);
|
||||
if (arch == LLM_ARCH_DEEPSEEK4) {
|
||||
cur = ggml_clamp(ctx0, cur, -INFINITY, limit);
|
||||
cb(cur, "ffn_gate_clamped", il);
|
||||
cur = ggml_swiglu_split(ctx0, cur, tmp);
|
||||
} else {
|
||||
ggml_tensor * gate_act = ggml_silu(ctx0, cur);
|
||||
cb(gate_act, "ffn_silu", il);
|
||||
gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
|
||||
cb(gate_act, "ffn_silu_clamped", il);
|
||||
cur = ggml_mul(ctx0, gate_act, tmp);
|
||||
}
|
||||
cb(cur, "ffn_swiglu_limited", il);
|
||||
type_gate = LLM_FFN_SEQ;
|
||||
break;
|
||||
@@ -1474,7 +1761,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
ggml_tensor * gate_up_exps,
|
||||
ggml_tensor * up_exps_s,
|
||||
ggml_tensor * gate_exps_s,
|
||||
ggml_tensor * down_exps_s) const {
|
||||
ggml_tensor * down_exps_s,
|
||||
ggml_tensor * selected_experts_in) const {
|
||||
return build_moe_ffn(
|
||||
cur,
|
||||
gate_inp, /* gate_inp_b */ nullptr,
|
||||
@@ -1494,7 +1782,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
/* gate_up_exps_b */ nullptr,
|
||||
up_exps_s,
|
||||
gate_exps_s,
|
||||
down_exps_s
|
||||
down_exps_s,
|
||||
selected_experts_in
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1521,7 +1810,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
ggml_tensor * gate_up_exps_b,
|
||||
ggml_tensor * up_exps_s,
|
||||
ggml_tensor * gate_exps_s,
|
||||
ggml_tensor * down_exps_s) const {
|
||||
ggml_tensor * down_exps_s,
|
||||
ggml_tensor * selected_experts_in) const {
|
||||
const int64_t n_embd = cur->ne[0];
|
||||
const int64_t n_tokens = cur->ne[1];
|
||||
const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
|
||||
@@ -1530,6 +1820,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
|
||||
if (probs_in == nullptr) {
|
||||
logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
|
||||
if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SQRT_SOFTPLUS) {
|
||||
ggml_mul_mat_set_prec(logits, GGML_PREC_F32);
|
||||
}
|
||||
cb(logits, "ffn_moe_logits", il);
|
||||
} else {
|
||||
logits = probs_in;
|
||||
@@ -1554,6 +1847,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
{
|
||||
probs = logits; // [n_expert, n_tokens]
|
||||
} break;
|
||||
case LLAMA_EXPERT_GATING_FUNC_TYPE_SQRT_SOFTPLUS:
|
||||
{
|
||||
probs = ggml_sqrt(ctx0, ggml_softplus(ctx0, logits)); // [n_expert, n_tokens]
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
@@ -1604,8 +1901,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
}
|
||||
|
||||
// select experts
|
||||
ggml_tensor * selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
|
||||
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
||||
ggml_tensor * selected_experts = selected_experts_in;
|
||||
if (selected_experts == nullptr) {
|
||||
selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
|
||||
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
||||
}
|
||||
cb(selected_experts, "ffn_moe_topk", il);
|
||||
|
||||
if (arch == LLM_ARCH_GROVEMOE && n_expert != hparams.n_expert) {
|
||||
@@ -1718,20 +2018,24 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
switch (type_op) {
|
||||
case LLM_FFN_SILU:
|
||||
if (gate_exps) {
|
||||
// Step35: per-layer clamp for routed experts
|
||||
if (arch == LLM_ARCH_STEP35 && il >= 0) {
|
||||
if (il >= 0) {
|
||||
const float limit = hparams.swiglu_clamp_exp[il];
|
||||
constexpr float eps = 1e-6f;
|
||||
if (limit > eps) {
|
||||
ggml_tensor * gate_act = ggml_silu(ctx0, cur);
|
||||
cb(gate_act, "ffn_moe_silu", il);
|
||||
gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
|
||||
cb(gate_act, "ffn_moe_silu_clamped", il);
|
||||
|
||||
up = ggml_clamp(ctx0, up, -limit, limit);
|
||||
cb(up, "ffn_moe_up_clamped", il);
|
||||
|
||||
cur = ggml_mul(ctx0, gate_act, up);
|
||||
if (arch == LLM_ARCH_DEEPSEEK4) {
|
||||
cur = ggml_clamp(ctx0, cur, -INFINITY, limit);
|
||||
cb(cur, "ffn_moe_gate_clamped", il);
|
||||
cur = ggml_swiglu_split(ctx0, cur, up);
|
||||
} else {
|
||||
ggml_tensor * gate_act = ggml_silu(ctx0, cur);
|
||||
cb(gate_act, "ffn_moe_silu", il);
|
||||
gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
|
||||
cb(gate_act, "ffn_moe_silu_clamped", il);
|
||||
cur = ggml_mul(ctx0, gate_act, up);
|
||||
}
|
||||
cb(cur, "ffn_moe_swiglu_limited", il);
|
||||
break;
|
||||
}
|
||||
@@ -2760,6 +3064,31 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
|
||||
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
llm_graph_input_dsv4 * llm_graph_context::build_inp_dsv4() const {
|
||||
const auto * mctx_cur = static_cast<const llama_kv_cache_dsv4_context *>(mctx);
|
||||
const auto * raw_ctx = mctx_cur->get_raw();
|
||||
|
||||
auto inp_raw = std::make_unique<llm_graph_input_dsv4_raw>(cparams, raw_ctx);
|
||||
|
||||
const int64_t n_stream = mctx_cur->get_csa_plan(ubatch).n_stream;
|
||||
|
||||
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "DSV4 expects SWA raw cache");
|
||||
|
||||
inp_raw->self_k_idxs = raw_ctx->build_input_k_idxs(ctx0, ubatch);
|
||||
inp_raw->self_kq_mask = dsv4_build_raw_kq_mask(ctx0, raw_ctx, ubatch, cparams, n_stream);
|
||||
inp_raw->self_kq_mask_cnv = inp_raw->self_kq_mask;
|
||||
|
||||
inp_raw->self_k_rot = raw_ctx->build_input_k_rot(ctx0);
|
||||
auto inp = std::make_unique<llm_graph_input_dsv4>(cparams, std::move(inp_raw), mctx_cur);
|
||||
|
||||
dsv4_build_comp_inputs(ctx0, inp->inp_csa, mctx_cur->get_csa_plan(ubatch), "csa", n_stream);
|
||||
dsv4_build_comp_inputs(ctx0, inp->inp_hca, mctx_cur->get_hca_plan(ubatch), "hca", n_stream);
|
||||
dsv4_build_comp_inputs(ctx0, inp->inp_lid, mctx_cur->get_lid_plan(ubatch), "lid", n_stream);
|
||||
inp->inp_lid.k_rot = mctx_cur->get_lid()->build_input_k_rot(ctx0);
|
||||
|
||||
return (llm_graph_input_dsv4 *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_rs(
|
||||
ggml_tensor * s,
|
||||
ggml_tensor * state_copy_main,
|
||||
|
||||
+81
-2
@@ -23,6 +23,8 @@ struct llama_memory_context_i;
|
||||
|
||||
class llama_kv_cache_context;
|
||||
class llama_kv_cache_dsa_context;
|
||||
class llama_kv_cache_dsv4_raw_context;
|
||||
class llama_kv_cache_dsv4_context;
|
||||
class llama_kv_cache_iswa_context;
|
||||
class llama_memory_recurrent_context;
|
||||
class llama_memory_hybrid_context;
|
||||
@@ -459,6 +461,79 @@ public:
|
||||
const llama_kv_cache_iswa_context * mctx;
|
||||
};
|
||||
|
||||
// DSV4 raw graph inputs are SWA-only, but their mask may be stream-shaped
|
||||
// so raw K can be concatenated with DSV4 compressed K in one attention op.
|
||||
class llm_graph_input_dsv4_raw {
|
||||
public:
|
||||
llm_graph_input_dsv4_raw(
|
||||
const llama_cparams & cparams,
|
||||
const llama_kv_cache_dsv4_raw_context * mctx) :
|
||||
cparams(cparams),
|
||||
mctx(mctx) {
|
||||
}
|
||||
|
||||
void set_input(const llama_ubatch * ubatch);
|
||||
|
||||
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||
|
||||
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
|
||||
ggml_tensor * self_k_rot = nullptr;
|
||||
|
||||
const llama_cparams cparams;
|
||||
|
||||
const llama_kv_cache_dsv4_raw_context * mctx;
|
||||
};
|
||||
|
||||
class llm_graph_input_dsv4 : public llm_graph_input_i {
|
||||
public:
|
||||
struct comp_input {
|
||||
ggml_tensor * state_pos = nullptr; // I32 [n_state]
|
||||
ggml_tensor * state_persist_src_idxs = nullptr; // I32 [n_state_persist]
|
||||
ggml_tensor * state_persist_dst_idxs = nullptr; // I32 [n_state_persist]
|
||||
ggml_tensor * state_read_idxs = nullptr; // I32 [ratio*n_state_write]
|
||||
ggml_tensor * state_write_idxs = nullptr; // I64 [n_state_write]
|
||||
ggml_tensor * state_write_pos = nullptr; // I32 [n_state_write]
|
||||
|
||||
ggml_tensor * kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
|
||||
ggml_tensor * k_rot = nullptr;
|
||||
};
|
||||
|
||||
llm_graph_input_dsv4(
|
||||
const llama_cparams & cparams,
|
||||
std::unique_ptr<llm_graph_input_dsv4_raw> inp_raw,
|
||||
const llama_kv_cache_dsv4_context * mctx) :
|
||||
inp_raw(std::move(inp_raw)),
|
||||
cparams(cparams),
|
||||
mctx(mctx) {
|
||||
}
|
||||
~llm_graph_input_dsv4() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
bool can_reuse(const llm_graph_params & params) override;
|
||||
|
||||
llm_graph_input_dsv4_raw * get_raw() const { return inp_raw.get(); }
|
||||
const comp_input & get_csa() const { return inp_csa; }
|
||||
const comp_input & get_hca() const { return inp_hca; }
|
||||
const comp_input & get_lid() const { return inp_lid; }
|
||||
|
||||
std::unique_ptr<llm_graph_input_dsv4_raw> inp_raw;
|
||||
|
||||
comp_input inp_csa;
|
||||
comp_input inp_hca;
|
||||
comp_input inp_lid;
|
||||
|
||||
const llama_cparams cparams;
|
||||
|
||||
const llama_kv_cache_dsv4_context * mctx;
|
||||
};
|
||||
|
||||
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
|
||||
@@ -920,7 +995,8 @@ struct llm_graph_context {
|
||||
ggml_tensor * gate_up_exps = nullptr,
|
||||
ggml_tensor * up_exps_s = nullptr,
|
||||
ggml_tensor * gate_exps_s = nullptr,
|
||||
ggml_tensor * down_exps_s = nullptr) const;
|
||||
ggml_tensor * down_exps_s = nullptr,
|
||||
ggml_tensor * selected_experts_in = nullptr) const;
|
||||
|
||||
ggml_tensor * build_moe_ffn(
|
||||
ggml_tensor * cur,
|
||||
@@ -945,7 +1021,8 @@ struct llm_graph_context {
|
||||
ggml_tensor * gate_up_exps_b = nullptr,
|
||||
ggml_tensor * up_exps_s = nullptr,
|
||||
ggml_tensor * gate_exps_s = nullptr,
|
||||
ggml_tensor * down_exps_s = nullptr) const;
|
||||
ggml_tensor * down_exps_s = nullptr,
|
||||
ggml_tensor * selected_experts_in = nullptr) const;
|
||||
|
||||
//
|
||||
// inputs
|
||||
@@ -1045,6 +1122,8 @@ struct llm_graph_context {
|
||||
|
||||
llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
|
||||
|
||||
llm_graph_input_dsv4 * build_inp_dsv4() const;
|
||||
|
||||
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
|
||||
ggml_tensor * build_attn(
|
||||
llm_graph_input_attn_kv_iswa * inp,
|
||||
|
||||
@@ -14,6 +14,7 @@ enum llama_expert_gating_func_type {
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX = 1,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT = 3, // applied to the router weights instead of the logits
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SQRT_SOFTPLUS = 4,
|
||||
};
|
||||
|
||||
enum llama_swa_type {
|
||||
@@ -226,6 +227,16 @@ struct llama_hparams {
|
||||
uint32_t indexer_head_size = 0;
|
||||
uint32_t indexer_top_k = 0;
|
||||
|
||||
// DeepSeek-V4
|
||||
uint32_t dsv4_o_group_count = 0;
|
||||
uint32_t dsv4_o_lora_rank = 0;
|
||||
uint32_t dsv4_hc_mult = 0;
|
||||
uint32_t dsv4_hc_sinkhorn_iters = 0;
|
||||
uint32_t dsv4_hash_layer_count = 0;
|
||||
float dsv4_compress_rope_base = 0.0f;
|
||||
float dsv4_hc_eps = 0.0f;
|
||||
std::array<uint32_t, LLAMA_MAX_LAYERS> dsv4_compress_ratios;
|
||||
|
||||
// qwen3vl deepstack
|
||||
// When parsed from GGUF, this implies the first N layers consume the first
|
||||
// N deepstack embeddings. Use deepstack_mapping_arr if you need a more
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,362 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama-kv-cache.h"
|
||||
#include "llama-kv-cache-iswa.h"
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
class llama_dsv4_comp_state {
|
||||
public:
|
||||
llama_dsv4_comp_state(
|
||||
const llama_model & model,
|
||||
bool offload,
|
||||
bool unified,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t ratio,
|
||||
uint32_t state_size,
|
||||
uint32_t n_embd_state,
|
||||
const char * name,
|
||||
const llama_memory_i::layer_filter_cb & filter);
|
||||
|
||||
void clear(bool data);
|
||||
|
||||
uint32_t get_ratio() const;
|
||||
uint32_t get_state_size() const;
|
||||
uint32_t get_n_stream() const;
|
||||
|
||||
std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const;
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);
|
||||
|
||||
ggml_tensor * get_kv (ggml_context * ctx, int32_t il) const;
|
||||
ggml_tensor * get_score(ggml_context * ctx, int32_t il) const;
|
||||
|
||||
ggml_tensor * cpy_kv (ggml_context * ctx, ggml_tensor * cur, ggml_tensor * idxs, int32_t il) const;
|
||||
ggml_tensor * cpy_score(ggml_context * ctx, ggml_tensor * cur, ggml_tensor * idxs, int32_t il) const;
|
||||
|
||||
private:
|
||||
struct layer {
|
||||
uint32_t il;
|
||||
|
||||
ggml_tensor * kv;
|
||||
ggml_tensor * score;
|
||||
};
|
||||
|
||||
const uint32_t ratio;
|
||||
const uint32_t state_size;
|
||||
const uint32_t n_embd_state;
|
||||
const uint32_t n_stream;
|
||||
|
||||
std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
|
||||
|
||||
std::vector<layer> layers;
|
||||
|
||||
std::unordered_map<int32_t, int32_t> map_layer_ids;
|
||||
|
||||
size_t total_size() const;
|
||||
};
|
||||
|
||||
//
|
||||
// llama_kv_cache_dsv4
|
||||
//
|
||||
|
||||
// DSV4 uses a normal raw/SWA token cache plus compressed K-only block caches.
|
||||
// The compressed caches are storage only; DSV4-specific visibility and block
|
||||
// planning are handled by llama_kv_cache_dsv4_context / llm_graph_input_dsv4.
|
||||
|
||||
class llama_kv_cache_dsv4 : public llama_memory_i {
|
||||
public:
|
||||
llama_kv_cache_dsv4(
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool swa_full,
|
||||
bool unified,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad,
|
||||
const layer_filter_cb & filter,
|
||||
const layer_reuse_cb & reuse);
|
||||
|
||||
~llama_kv_cache_dsv4() = default;
|
||||
|
||||
//
|
||||
// llama_memory_i
|
||||
//
|
||||
|
||||
llama_memory_context_ptr init_batch(
|
||||
llama_batch_allocr & balloc,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_all) override;
|
||||
|
||||
llama_memory_context_ptr init_full() override;
|
||||
|
||||
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
void clear(bool data) override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_dsv4 specific API
|
||||
//
|
||||
|
||||
llama_kv_cache_iswa * get_raw() const;
|
||||
llama_kv_cache * get_csa() const;
|
||||
llama_kv_cache * get_hca() const;
|
||||
llama_kv_cache * get_lid() const;
|
||||
llama_dsv4_comp_state * get_csa_state() const;
|
||||
llama_dsv4_comp_state * get_hca_state() const;
|
||||
llama_dsv4_comp_state * get_lid_state() const;
|
||||
|
||||
private:
|
||||
llama_hparams hparams_raw;
|
||||
llama_hparams hparams_csa;
|
||||
llama_hparams hparams_hca;
|
||||
llama_hparams hparams_lid;
|
||||
|
||||
const uint32_t n_seq_max;
|
||||
|
||||
std::unique_ptr<llama_kv_cache_iswa> kv_raw;
|
||||
std::unique_ptr<llama_kv_cache> kv_csa;
|
||||
std::unique_ptr<llama_kv_cache> kv_hca;
|
||||
std::unique_ptr<llama_kv_cache> kv_lid;
|
||||
std::unique_ptr<llama_dsv4_comp_state> csa_state;
|
||||
std::unique_ptr<llama_dsv4_comp_state> hca_state;
|
||||
std::unique_ptr<llama_dsv4_comp_state> lid_state;
|
||||
|
||||
void clear_compressed(bool data);
|
||||
};
|
||||
|
||||
// DSV4 raw attention only uses the SWA half of kv_raw. The base half is kept
|
||||
// for generic ISWA bookkeeping, but it has no DSV4 layers to expose here.
|
||||
class llama_kv_cache_dsv4_raw_context : public llama_memory_context_i {
|
||||
public:
|
||||
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
|
||||
|
||||
llama_kv_cache_dsv4_raw_context(llama_kv_cache_iswa * kv);
|
||||
|
||||
llama_kv_cache_dsv4_raw_context(
|
||||
llama_kv_cache_iswa * kv,
|
||||
llama_context * lctx,
|
||||
bool optimize);
|
||||
|
||||
llama_kv_cache_dsv4_raw_context(
|
||||
llama_kv_cache_iswa * kv,
|
||||
slot_info_vec_t sinfos_base_write,
|
||||
slot_info_vec_t sinfos_swa_write,
|
||||
slot_info_vec_t sinfos_swa_read,
|
||||
std::vector<llama_ubatch> ubatches,
|
||||
std::vector<llama_ubatch> ubatches_write);
|
||||
|
||||
bool next() override;
|
||||
bool apply() override;
|
||||
|
||||
llama_memory_status get_status() const override;
|
||||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
uint32_t get_n_kv() const;
|
||||
uint32_t get_n_write() const;
|
||||
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
|
||||
|
||||
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||
ggml_tensor * build_input_k_rot(ggml_context * ctx) const;
|
||||
|
||||
void set_input_k_idxs(ggml_tensor * dst) const;
|
||||
void set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||
void set_input_k_rot(ggml_tensor * dst) const;
|
||||
|
||||
private:
|
||||
size_t i_next = 0;
|
||||
|
||||
llama_kv_cache * kv_swa = nullptr;
|
||||
|
||||
slot_info_vec_t sinfos_write;
|
||||
slot_info_vec_t sinfos_read;
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
std::vector<llama_ubatch> ubatches_write;
|
||||
|
||||
const llama_memory_context_ptr ctx_base_mem;
|
||||
const llama_memory_context_ptr ctx_swa_mem;
|
||||
|
||||
uint32_t n_kv = 0;
|
||||
|
||||
const llama_memory_status status;
|
||||
};
|
||||
|
||||
// DSV4 compressed KV rows are graph outputs, not normal token KV writes.
|
||||
// Keep a small context that exposes K tensors without generic apply() semantics.
|
||||
class llama_kv_cache_dsv4_comp_context {
|
||||
public:
|
||||
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
|
||||
|
||||
llama_kv_cache_dsv4_comp_context(llama_kv_cache * kv);
|
||||
|
||||
llama_kv_cache_dsv4_comp_context(
|
||||
llama_kv_cache * kv,
|
||||
slot_info_vec_t sinfos,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
bool next();
|
||||
|
||||
uint32_t get_n_kv() const;
|
||||
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
|
||||
|
||||
ggml_tensor * build_input_k_rot(ggml_context * ctx) const;
|
||||
void set_input_k_rot(ggml_tensor * dst) const;
|
||||
|
||||
private:
|
||||
llama_kv_cache * kv;
|
||||
|
||||
size_t i_cur = 0;
|
||||
slot_info_vec_t sinfos;
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
uint32_t n_kv;
|
||||
};
|
||||
|
||||
class llama_kv_cache_dsv4_context : public llama_memory_context_i {
|
||||
public:
|
||||
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
|
||||
|
||||
struct comp_plan {
|
||||
// Per-ubatch recipe for updating compressor state, committing completed
|
||||
// compressed rows, and masking the compressed attention source.
|
||||
|
||||
// APE row ids, i.e. pos % ratio, for the compressor-state updates.
|
||||
std::vector<int32_t> state_pos;
|
||||
|
||||
// Current-ubatch source row ids and unique persistent-state
|
||||
// destination row ids for deterministic ring-state updates.
|
||||
std::vector<int32_t> state_persist_src_idxs;
|
||||
std::vector<int32_t> state_persist_dst_idxs;
|
||||
|
||||
// Flattened source row ids used for state-backed commits. Source rows
|
||||
// index the graph-local [persistent_state | current_ubatch_scratch]
|
||||
// tensor. For overlapped compression the first half is previous rows
|
||||
// and the second half is current rows; a final synthetic zero/-inf row
|
||||
// may be addressed for the first block's previous half.
|
||||
std::vector<int32_t> state_read_idxs;
|
||||
|
||||
// Final compressed-cache row ids written by state-backed commits.
|
||||
// A non-boundary CSA/LID decode step can target a masked scratch row.
|
||||
std::vector<int64_t> state_write_idxs;
|
||||
|
||||
// RoPE positions for state-backed commits.
|
||||
std::vector<int32_t> state_write_pos;
|
||||
|
||||
// Number of completed compressed rows visible for each query token.
|
||||
std::vector<int32_t> n_visible;
|
||||
|
||||
// Number of streams used by the attention graph for this ubatch.
|
||||
int64_t n_stream = 1;
|
||||
|
||||
// Graph-width for compressed rows. This can be larger than n_visible
|
||||
// so masked padding rows do not force a new graph at every CSA block.
|
||||
int64_t n_kv = 0;
|
||||
};
|
||||
|
||||
llama_kv_cache_dsv4_context(llama_memory_status status);
|
||||
|
||||
llama_kv_cache_dsv4_context(
|
||||
llama_kv_cache_dsv4 * kv);
|
||||
|
||||
llama_kv_cache_dsv4_context(
|
||||
llama_kv_cache_dsv4 * kv,
|
||||
llama_context * lctx,
|
||||
bool optimize);
|
||||
|
||||
llama_kv_cache_dsv4_context(
|
||||
llama_kv_cache_dsv4 * kv,
|
||||
slot_info_vec_t sinfos_raw_base_write,
|
||||
slot_info_vec_t sinfos_raw_swa_write,
|
||||
slot_info_vec_t sinfos_raw_swa_read,
|
||||
std::vector<llama_ubatch> ubatches,
|
||||
std::vector<llama_ubatch> ubatches_raw);
|
||||
|
||||
virtual ~llama_kv_cache_dsv4_context();
|
||||
|
||||
//
|
||||
// llama_memory_context_i
|
||||
//
|
||||
|
||||
bool next() override;
|
||||
bool apply() override;
|
||||
|
||||
llama_memory_status get_status() const override;
|
||||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_dsv4_context specific API
|
||||
//
|
||||
|
||||
const llama_kv_cache_dsv4_raw_context * get_raw() const;
|
||||
const llama_kv_cache_dsv4_comp_context * get_csa() const;
|
||||
const llama_kv_cache_dsv4_comp_context * get_hca() const;
|
||||
const llama_kv_cache_dsv4_comp_context * get_lid() const;
|
||||
const llama_dsv4_comp_state * get_csa_state() const;
|
||||
const llama_dsv4_comp_state * get_hca_state() const;
|
||||
const llama_dsv4_comp_state * get_lid_state() const;
|
||||
|
||||
const comp_plan & get_csa_plan() const;
|
||||
const comp_plan & get_hca_plan() const;
|
||||
const comp_plan & get_lid_plan() const;
|
||||
|
||||
const comp_plan & get_csa_plan(const llama_ubatch & ubatch) const;
|
||||
const comp_plan & get_hca_plan(const llama_ubatch & ubatch) const;
|
||||
const comp_plan & get_lid_plan(const llama_ubatch & ubatch) const;
|
||||
|
||||
private:
|
||||
size_t i_next = 0;
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
std::vector<comp_plan> plans_csa;
|
||||
std::vector<comp_plan> plans_hca;
|
||||
std::vector<comp_plan> plans_lid;
|
||||
|
||||
const std::unique_ptr<llama_kv_cache_dsv4_raw_context> ctx_raw;
|
||||
const llama_memory_context_ptr ctx_csa_mem;
|
||||
const llama_memory_context_ptr ctx_hca_mem;
|
||||
const llama_memory_context_ptr ctx_lid_mem;
|
||||
|
||||
const std::unique_ptr<llama_kv_cache_dsv4_comp_context> ctx_csa;
|
||||
const std::unique_ptr<llama_kv_cache_dsv4_comp_context> ctx_hca;
|
||||
const std::unique_ptr<llama_kv_cache_dsv4_comp_context> ctx_lid;
|
||||
|
||||
const llama_dsv4_comp_state * csa_state = nullptr;
|
||||
const llama_dsv4_comp_state * hca_state = nullptr;
|
||||
const llama_dsv4_comp_state * lid_state = nullptr;
|
||||
|
||||
bool reserve_plans = false;
|
||||
mutable comp_plan reserve_plan_csa;
|
||||
mutable comp_plan reserve_plan_hca;
|
||||
mutable comp_plan reserve_plan_lid;
|
||||
|
||||
const llama_memory_status status;
|
||||
};
|
||||
@@ -26,7 +26,28 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
|
||||
llama_memory_t mem_other,
|
||||
const layer_filter_cb & filter,
|
||||
const layer_reuse_cb & reuse,
|
||||
const layer_share_cb & share) : hparams(model.hparams), unified(unified) {
|
||||
const layer_share_cb & share) :
|
||||
llama_kv_cache_iswa(model, model.hparams, type_k, type_v, v_trans, offload, swa_full, unified,
|
||||
kv_size, n_seq_max, n_ubatch, n_pad, mem_other, filter, reuse, share) {
|
||||
}
|
||||
|
||||
llama_kv_cache_iswa::llama_kv_cache_iswa(
|
||||
const llama_model & model,
|
||||
const llama_hparams & hparams,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool swa_full,
|
||||
bool unified,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad,
|
||||
llama_memory_t mem_other,
|
||||
const layer_filter_cb & filter,
|
||||
const layer_reuse_cb & reuse,
|
||||
const layer_share_cb & share) : unified(unified) {
|
||||
|
||||
// chain filters
|
||||
const layer_filter_cb filter_base = [&](int32_t il) {
|
||||
|
||||
@@ -30,6 +30,24 @@ public:
|
||||
const layer_reuse_cb & reuse,
|
||||
const layer_share_cb & share);
|
||||
|
||||
llama_kv_cache_iswa(
|
||||
const llama_model & model,
|
||||
const llama_hparams & hparams,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool swa_full,
|
||||
bool unified,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad,
|
||||
llama_memory_t mem_other,
|
||||
const layer_filter_cb & filter,
|
||||
const layer_reuse_cb & reuse,
|
||||
const layer_share_cb & share);
|
||||
|
||||
~llama_kv_cache_iswa() = default;
|
||||
|
||||
//
|
||||
@@ -73,8 +91,6 @@ public:
|
||||
llama_kv_cache * get_swa () const;
|
||||
|
||||
private:
|
||||
const llama_hparams & hparams;
|
||||
|
||||
const bool unified;
|
||||
|
||||
std::unique_ptr<llama_kv_cache> kv_base;
|
||||
|
||||
+26
-6
@@ -211,10 +211,12 @@ llama_kv_cache::llama_kv_cache(
|
||||
n_embd_head_k_all = -1;
|
||||
}
|
||||
|
||||
if (n_embd_head_v_all == 0) {
|
||||
n_embd_head_v_all = (int32_t) hparams.n_embd_head_v(il);
|
||||
} else if (n_embd_head_v_all > 0 && n_embd_head_v_all != (int32_t) hparams.n_embd_head_v(il)) {
|
||||
n_embd_head_v_all = -1;
|
||||
if (!is_mla) {
|
||||
if (n_embd_head_v_all == 0) {
|
||||
n_embd_head_v_all = (int32_t) hparams.n_embd_head_v(il);
|
||||
} else if (n_embd_head_v_all > 0 && n_embd_head_v_all != (int32_t) hparams.n_embd_head_v(il)) {
|
||||
n_embd_head_v_all = -1;
|
||||
}
|
||||
}
|
||||
|
||||
// [TAG_V_CACHE_VARIABLE]
|
||||
@@ -336,8 +338,9 @@ llama_kv_cache::llama_kv_cache(
|
||||
ggml_is_quantized(type_k) &&
|
||||
hparams.n_embd_head_k() % 64 == 0;
|
||||
|
||||
// always create Hadamard rotation tensors for DeepSeek V3.2 DSA lightning indexer
|
||||
if (model.arch == LLM_ARCH_DEEPSEEK32 && hparams.n_embd_head_k_full == hparams.indexer_head_size) {
|
||||
// always create Hadamard rotation tensors for DeepSeek lightning indexers
|
||||
if ((model.arch == LLM_ARCH_DEEPSEEK32 || model.arch == LLM_ARCH_DEEPSEEK4) &&
|
||||
hparams.n_embd_head_k_full == hparams.indexer_head_size) {
|
||||
attn_rot_k = true;
|
||||
}
|
||||
|
||||
@@ -1220,6 +1223,23 @@ ggml_type llama_kv_cache::type_v() const {
|
||||
return layers[0].v->type;
|
||||
}
|
||||
|
||||
std::vector<uint32_t> llama_kv_cache::get_layer_ids() const {
|
||||
std::vector<uint32_t> res;
|
||||
res.reserve(layers.size());
|
||||
|
||||
for (const auto & layer : layers) {
|
||||
res.push_back(layer.il);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache::get_k_storage(int32_t il) const {
|
||||
const int32_t ikv = map_layer_ids.at(il);
|
||||
|
||||
return layers[ikv].k;
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
|
||||
uint32_t result = 0;
|
||||
|
||||
|
||||
@@ -161,6 +161,9 @@ public:
|
||||
ggml_type type_k() const;
|
||||
ggml_type type_v() const;
|
||||
|
||||
std::vector<uint32_t> get_layer_ids() const;
|
||||
ggml_tensor * get_k_storage(int32_t il) const;
|
||||
|
||||
//
|
||||
// graph_build API
|
||||
//
|
||||
|
||||
@@ -294,6 +294,8 @@ namespace GGUFMeta {
|
||||
}
|
||||
|
||||
template bool llama_model_loader::get_arr_n(enum llm_kv kid, uint32_t & result, bool required);
|
||||
template std::enable_if<std::is_integral<uint32_t>::value, bool>::type
|
||||
llama_model_loader::get_arr_n<uint32_t>(const std::string & key, uint32_t & result, bool required);
|
||||
|
||||
template<typename T>
|
||||
bool llama_model_loader::get_arr(const std::string & key, std::vector<T> & result, bool required) {
|
||||
@@ -395,6 +397,7 @@ namespace GGUFMeta {
|
||||
template bool llama_model_loader::get_arr<std::vector<std::string>>(enum llm_kv kid, std::vector<std::string> & result, bool required);
|
||||
template bool llama_model_loader::get_arr<std::array<int32_t, 512>>(enum llm_kv kid, std::array<int32_t, 512> & result, bool required);
|
||||
template bool llama_model_loader::get_arr<std::vector<int32_t>>(enum llm_kv kid, std::vector<int32_t> & result, bool required);
|
||||
template bool llama_model_loader::get_arr<std::array<uint32_t, LLAMA_MAX_LAYERS>>(enum llm_kv kid, std::array<uint32_t, LLAMA_MAX_LAYERS> & result, bool required);
|
||||
|
||||
template<typename T>
|
||||
bool llama_model_loader::get_key(const std::string & key, T & result, bool required) {
|
||||
|
||||
+28
-1
@@ -11,6 +11,7 @@
|
||||
#include "llama-kv-cache.h"
|
||||
#include "llama-kv-cache-iswa.h"
|
||||
#include "llama-kv-cache-dsa.h"
|
||||
#include "llama-kv-cache-dsv4.h"
|
||||
#include "llama-memory-hybrid.h"
|
||||
#include "llama-memory-hybrid-iswa.h"
|
||||
#include "llama-memory-recurrent.h"
|
||||
@@ -181,6 +182,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params
|
||||
return new llama_model_deepseek2ocr(params);
|
||||
case LLM_ARCH_DEEPSEEK32:
|
||||
return new llama_model_deepseek32(params);
|
||||
case LLM_ARCH_DEEPSEEK4:
|
||||
return new llama_model_deepseek4(params);
|
||||
case LLM_ARCH_GLM_DSA:
|
||||
return new llama_model_glm_dsa(params);
|
||||
case LLM_ARCH_MISTRAL4:
|
||||
@@ -817,6 +820,7 @@ static const char * llama_expert_gating_func_name(llama_expert_gating_func_type
|
||||
switch (type) {
|
||||
case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX: return "softmax";
|
||||
case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID: return "sigmoid";
|
||||
case LLAMA_EXPERT_GATING_FUNC_TYPE_SQRT_SOFTPLUS: return "sqrtsoftplus";
|
||||
default: return "unknown";
|
||||
}
|
||||
}
|
||||
@@ -2156,7 +2160,24 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
}
|
||||
}
|
||||
|
||||
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||
if (arch == LLM_ARCH_DEEPSEEK4) {
|
||||
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE);
|
||||
|
||||
res = new llama_kv_cache_dsv4(
|
||||
*this,
|
||||
params.type_k,
|
||||
params.type_v,
|
||||
!cparams.flash_attn,
|
||||
cparams.offload_kqv,
|
||||
params.swa_full,
|
||||
cparams.kv_unified,
|
||||
cparams.n_ctx_seq,
|
||||
cparams.n_seq_max,
|
||||
cparams.n_ubatch,
|
||||
1,
|
||||
filter,
|
||||
reuse);
|
||||
} else if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||
GGML_ASSERT(hparams.is_swa_any());
|
||||
|
||||
if (arch == LLM_ARCH_GEMMA4_ASSISTANT) {
|
||||
@@ -2328,6 +2349,11 @@ int32_t llama_model_n_head_kv(const llama_model * model) {
|
||||
}
|
||||
|
||||
int32_t llama_model_n_swa(const llama_model * model) {
|
||||
// dsv4 kv-cache has SWA but it cannot be used as a rollback because of
|
||||
// other compression ratios, so we return 0 here
|
||||
if (model->arch == LLM_ARCH_DEEPSEEK4) {
|
||||
return 0;
|
||||
}
|
||||
return model->hparams.n_swa;
|
||||
}
|
||||
|
||||
@@ -2409,6 +2435,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
case LLM_ARCH_DEEPSEEK2:
|
||||
case LLM_ARCH_DEEPSEEK2OCR:
|
||||
case LLM_ARCH_DEEPSEEK32:
|
||||
case LLM_ARCH_DEEPSEEK4:
|
||||
case LLM_ARCH_PLM:
|
||||
case LLM_ARCH_CHATGLM:
|
||||
case LLM_ARCH_GRANITE:
|
||||
|
||||
@@ -255,9 +255,11 @@ struct llama_layer {
|
||||
struct ggml_tensor * wq_b = nullptr;
|
||||
struct ggml_tensor * wkv_a_mqa = nullptr;
|
||||
struct ggml_tensor * wkv_b = nullptr;
|
||||
struct ggml_tensor * wkv = nullptr;
|
||||
struct ggml_tensor * wk_b = nullptr;
|
||||
struct ggml_tensor * wv_b = nullptr;
|
||||
struct ggml_tensor * wqkv_b = nullptr;
|
||||
struct ggml_tensor * wo_a = nullptr;
|
||||
struct ggml_tensor * wo_b = nullptr;
|
||||
struct ggml_tensor * wq_cross = nullptr;
|
||||
struct ggml_tensor * wk_cross = nullptr;
|
||||
@@ -333,6 +335,7 @@ struct llama_layer {
|
||||
struct ggml_tensor * ffn_up_b = nullptr; // b3
|
||||
struct ggml_tensor * ffn_act = nullptr;
|
||||
struct ggml_tensor * ffn_exp_probs_b = nullptr;
|
||||
struct ggml_tensor * ffn_gate_tid2eid = nullptr;
|
||||
|
||||
// mamba proj
|
||||
struct ggml_tensor * ssm_in = nullptr;
|
||||
@@ -463,6 +466,23 @@ struct llama_layer {
|
||||
// openai-moe
|
||||
struct ggml_tensor * attn_sinks = nullptr;
|
||||
|
||||
// DeepSeek-V4
|
||||
struct ggml_tensor * attn_kv_norm = nullptr;
|
||||
struct ggml_tensor * hc_attn_fn = nullptr;
|
||||
struct ggml_tensor * hc_attn_base = nullptr;
|
||||
struct ggml_tensor * hc_attn_scale = nullptr;
|
||||
struct ggml_tensor * hc_ffn_fn = nullptr;
|
||||
struct ggml_tensor * hc_ffn_base = nullptr;
|
||||
struct ggml_tensor * hc_ffn_scale = nullptr;
|
||||
struct ggml_tensor * attn_comp_wkv = nullptr;
|
||||
struct ggml_tensor * attn_comp_wgate = nullptr;
|
||||
struct ggml_tensor * attn_comp_ape = nullptr;
|
||||
struct ggml_tensor * attn_comp_norm = nullptr;
|
||||
struct ggml_tensor * indexer_comp_wkv = nullptr;
|
||||
struct ggml_tensor * indexer_comp_wgate = nullptr;
|
||||
struct ggml_tensor * indexer_comp_ape = nullptr;
|
||||
struct ggml_tensor * indexer_comp_norm = nullptr;
|
||||
|
||||
// cogvlm
|
||||
struct ggml_tensor * visexp_attn_wqkv = nullptr;
|
||||
struct ggml_tensor * visexp_attn_wo = nullptr;
|
||||
@@ -553,6 +573,11 @@ struct llama_model {
|
||||
struct ggml_tensor * nextn_proj_pre = nullptr;
|
||||
struct ggml_tensor * nextn_proj_post = nullptr;
|
||||
|
||||
// DeepSeek-V4
|
||||
struct ggml_tensor * hc_head_fn = nullptr;
|
||||
struct ggml_tensor * hc_head_base = nullptr;
|
||||
struct ggml_tensor * hc_head_scale = nullptr;
|
||||
|
||||
// classifier
|
||||
struct ggml_tensor * cls = nullptr;
|
||||
struct ggml_tensor * cls_b = nullptr;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1085,6 +1085,121 @@ struct llama_model_deepseek32 : public llama_model_base {
|
||||
};
|
||||
|
||||
|
||||
struct llama_model_deepseek4 : public llama_model_base {
|
||||
llama_model_deepseek4(const struct llama_model_params & params) : llama_model_base(params) {}
|
||||
void load_arch_hparams(llama_model_loader & ml) override;
|
||||
void load_arch_tensors(llama_model_loader & ml) override;
|
||||
|
||||
struct graph : public llm_graph_context {
|
||||
graph(const llama_model & model, const llm_graph_params & params);
|
||||
|
||||
ggml_tensor * build_hc_pre(
|
||||
ggml_tensor * x,
|
||||
ggml_tensor * hc_fn,
|
||||
ggml_tensor * hc_scale,
|
||||
ggml_tensor * hc_base,
|
||||
ggml_tensor ** post,
|
||||
ggml_tensor ** comb,
|
||||
int il) const;
|
||||
|
||||
ggml_tensor * build_hc_post(
|
||||
ggml_tensor * x,
|
||||
ggml_tensor * residual,
|
||||
ggml_tensor * post,
|
||||
ggml_tensor * comb,
|
||||
int il) const;
|
||||
|
||||
ggml_tensor * build_hc_head(
|
||||
ggml_tensor * x,
|
||||
ggml_tensor * hc_fn,
|
||||
ggml_tensor * hc_scale,
|
||||
ggml_tensor * hc_base) const;
|
||||
|
||||
ggml_tensor * build_attention(
|
||||
const llama_model & model,
|
||||
llm_graph_input_dsv4 * inp_dsv4,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * inp_pos,
|
||||
int il) const;
|
||||
|
||||
ggml_tensor * build_hca_compressed_kv_from_state(
|
||||
ggml_tensor * kv_state,
|
||||
ggml_tensor * score_state,
|
||||
ggml_tensor * state_read_idxs,
|
||||
ggml_tensor * comp_pos,
|
||||
ggml_tensor * norm,
|
||||
int64_t n_embd_head,
|
||||
const char * name,
|
||||
int il) const;
|
||||
|
||||
ggml_tensor * build_overlap_compressed_kv_from_state(
|
||||
ggml_tensor * kv_state,
|
||||
ggml_tensor * score_state,
|
||||
ggml_tensor * state_read_idxs,
|
||||
ggml_tensor * comp_pos,
|
||||
ggml_tensor * norm,
|
||||
int64_t ratio,
|
||||
int64_t n_embd_head,
|
||||
const char * name,
|
||||
int il) const;
|
||||
|
||||
ggml_tensor * build_lid_top_k(
|
||||
const llama_model & model,
|
||||
llm_graph_input_dsv4 * inp_dsv4,
|
||||
ggml_tensor * qr,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * inp_pos,
|
||||
int il) const;
|
||||
|
||||
ggml_tensor * build_top_k_mask(
|
||||
ggml_tensor * kq_mask,
|
||||
ggml_tensor * top_k,
|
||||
const char * name,
|
||||
int il) const;
|
||||
|
||||
ggml_tensor * build_csa_lid_attention(
|
||||
const llama_model & model,
|
||||
llm_graph_input_dsv4 * inp_dsv4,
|
||||
llm_graph_input_dsv4_raw * inp_attn,
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * kv,
|
||||
ggml_tensor * qr,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * inp_pos,
|
||||
ggml_tensor * sinks,
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
ggml_tensor * build_hca_attention(
|
||||
llm_graph_input_dsv4 * inp_dsv4,
|
||||
llm_graph_input_dsv4_raw * inp_attn,
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * kv,
|
||||
ggml_tensor * sinks,
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
ggml_tensor * build_raw_attention(
|
||||
llm_graph_input_dsv4_raw * inp_attn,
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * kv,
|
||||
ggml_tensor * sinks,
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
ggml_tensor * build_hc_weighted_sum(
|
||||
ggml_tensor * x,
|
||||
ggml_tensor * weights) const;
|
||||
|
||||
ggml_tensor * build_hc_sinkhorn(
|
||||
ggml_tensor * comb,
|
||||
int il) const;
|
||||
};
|
||||
|
||||
std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override;
|
||||
};
|
||||
|
||||
|
||||
struct llama_model_deepseek2ocr : public llama_model_base {
|
||||
llama_model_deepseek2ocr(const struct llama_model_params & params) : llama_model_base(params) {}
|
||||
void load_arch_hparams(llama_model_loader & ml) override;
|
||||
|
||||
@@ -211,7 +211,6 @@ llama_build_and_test(
|
||||
peg-parser/test-unicode.cpp
|
||||
peg-parser/tests.h
|
||||
)
|
||||
llama_build_and_test(test-regex-partial.cpp)
|
||||
|
||||
if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x")
|
||||
set(MODEL_NAME "tinyllamas/stories15M-q4_0.gguf")
|
||||
|
||||
@@ -25,7 +25,7 @@ using json = nlohmann::ordered_json;
|
||||
static int main_automated_tests(void);
|
||||
|
||||
static void run_multiple(const std::string& dir_path, bool stop_on_first_failure, const json& input, bool use_common = false);
|
||||
static void run_single(const std::string& contents, json input, bool use_common = false, const std::string & output_path = "");
|
||||
static void run_single(const std::string& contents, json input, bool use_common = false, bool dump_prog = false, const std::string & output_path = "");
|
||||
|
||||
static std::string HELP = R"(
|
||||
Usage: test-chat-template [OPTIONS] PATH_TO_TEMPLATE
|
||||
@@ -35,6 +35,7 @@ Options:
|
||||
--json <path> Path to the JSON input file.
|
||||
--stop-on-first-fail Stop testing on the first failure (default: false).
|
||||
--no-common Use direct Jinja engine instead of common chat templates (default: use common).
|
||||
--dump-prog Dump the parsed program for debugging (only for single template runs).
|
||||
--output <path> Path to output results (only for single template runs).
|
||||
If PATH_TO_TEMPLATE is a file, runs that single template.
|
||||
If PATH_TO_TEMPLATE is a directory, runs all .jinja files in that directory.
|
||||
@@ -118,6 +119,7 @@ int main(int argc, char ** argv) {
|
||||
std::string & json_to_use = DEFAULT_JSON;
|
||||
bool stop_on_first_fail = false;
|
||||
bool use_common = true;
|
||||
bool dump_prog = false;
|
||||
|
||||
for (size_t i = 1; i < args.size(); i++) {
|
||||
if (args[i] == "--help" || args[i] == "-h") {
|
||||
@@ -136,6 +138,8 @@ int main(int argc, char ** argv) {
|
||||
i++;
|
||||
} else if (args[i] == "--no-common") {
|
||||
use_common = false;
|
||||
} else if (args[i] == "--dump-prog") {
|
||||
dump_prog = true;
|
||||
} else if (tmpl_path.empty()) {
|
||||
tmpl_path = args[i];
|
||||
} else {
|
||||
@@ -172,7 +176,7 @@ int main(int argc, char ** argv) {
|
||||
std::string contents = std::string(
|
||||
std::istreambuf_iterator<char>(infile),
|
||||
std::istreambuf_iterator<char>());
|
||||
run_single(contents, input_json, use_common, output_path);
|
||||
run_single(contents, input_json, use_common, dump_prog, output_path);
|
||||
} else {
|
||||
std::cerr << "Error: PATH_TO_TEMPLATE is not a valid file or directory: " << tmpl_path << "\n";
|
||||
return 1;
|
||||
@@ -276,11 +280,21 @@ static jinja::value_string format_using_direct_engine(
|
||||
}
|
||||
|
||||
|
||||
void run_single(const std::string& contents, json input, bool use_common, const std::string & output_path) {
|
||||
void run_single(const std::string& contents, json input, bool use_common, bool dump_prog, const std::string & output_path) {
|
||||
jinja::enable_debug(true);
|
||||
|
||||
jinja::value_string output_parts;
|
||||
|
||||
if (dump_prog) {
|
||||
jinja::lexer lexer;
|
||||
auto lexer_res = lexer.tokenize(contents);
|
||||
jinja::program ast = jinja::parse_from_tokens(lexer_res);
|
||||
std::string prog_dump = jinja::runtime::debug_dump_program(ast, contents);
|
||||
std::cout << "\n=== DUMPED PROGRAM ===\n";
|
||||
std::cout << prog_dump << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
if (use_common) {
|
||||
std::string bos_token = "<s>";
|
||||
std::string eos_token = "</s>";
|
||||
|
||||
@@ -5593,6 +5593,77 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.expect_content("Hello, world!\nWhat's up?")
|
||||
.run();
|
||||
}
|
||||
|
||||
// MiniCPM5 - XML tool calls with <function name="..."><param name="...">...</param></function>
|
||||
{
|
||||
auto tst = peg_tester("models/templates/openbmb-MiniCPM5-1B.jinja", detailed_debug);
|
||||
|
||||
tst.test("Hello, world!\nWhat's up?")
|
||||
.enable_thinking(false)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.expect(message_assist)
|
||||
.run();
|
||||
|
||||
tst.test(R"(<function name="python"><param name="code">print('Hello, World!')</param></function>)")
|
||||
.enable_thinking(false)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.tools({ python_tool })
|
||||
.expect_tool_calls({ { "python", R"#({"code": "print('Hello, World!')"})#", {} } })
|
||||
.run();
|
||||
|
||||
tst.test(R"(<function name="empty_args"></function>)")
|
||||
.enable_thinking(false)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.tools({ empty_args_tool })
|
||||
.expect(simple_assist_msg("", "", "empty_args", "{}"))
|
||||
.run();
|
||||
|
||||
tst.test(R"(<function name="python"><param name="code">print('x')</param></function>)")
|
||||
.enable_thinking(false)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.parallel_tool_calls(true)
|
||||
.tools({ python_tool })
|
||||
.expect_tool_calls({ { "python", R"#({"code": "print('x')"})#", {} } })
|
||||
.run();
|
||||
|
||||
// CDATA lets a string value carry characters that would otherwise close the tag.
|
||||
tst.test(R"(<function name="html"><param name="markup"><![CDATA[<a href="/x">hi</a> </param>]]></param></function>)")
|
||||
.enable_thinking(false)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.tools({ html_tool })
|
||||
.expect_tool_calls({ { "html", R"#({"markup": "<a href=\"/x\">hi</a> </param>"})#", {} } })
|
||||
.run();
|
||||
|
||||
tst.test(R"(I'm thinking</think><function name="python"><param name="code">print('hey')</param></function>)")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.tools({ python_tool })
|
||||
.expect_reasoning("I'm thinking")
|
||||
.expect_tool_calls({ { "python", R"#({"code": "print('hey')"})#", {} } })
|
||||
.run();
|
||||
|
||||
tst.test(R"(<function name="python"><param name="code">print('x')</param></function>
|
||||
<function name="python"><param name="code">print('y')</param></function>)")
|
||||
.enable_thinking(false)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.parallel_tool_calls(true)
|
||||
.tools({ python_tool })
|
||||
.expect_tool_calls({
|
||||
{ "python", R"#({"code": "print('x')"})#", {} },
|
||||
{ "python", R"#({"code": "print('y')"})#", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
tst.test(" thinking</think>Hello, world!\nWhat's up?")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.messages({ message_user, message_assist_prefill_reasoning })
|
||||
.add_generation_prompt(false)
|
||||
.continue_final_message(COMMON_CHAT_CONTINUATION_REASONING)
|
||||
.expect_reasoning("I'm thinking")
|
||||
.expect_content("Hello, world!\nWhat's up?")
|
||||
.run();
|
||||
}
|
||||
}
|
||||
|
||||
static void test_template_generation_prompt() {
|
||||
@@ -5740,6 +5811,13 @@ static void test_template_generation_prompt() {
|
||||
check(tmpls, continuation_content(), "<|Assistant|><think>I'm thinking</think>Hello, ");
|
||||
check(tmpls, continuation_reasoning(), "<|Assistant|><think>I'm");
|
||||
}
|
||||
|
||||
{
|
||||
auto tmpls = read_templates("models/templates/openbmb-MiniCPM5-1B.jinja");
|
||||
check(tmpls, basic(), "<|im_start|>assistant\n<think>\n");
|
||||
check(tmpls, continuation_content(), "<|im_start|>assistant\n<think>\nI'm thinking\n</think>\n\nHello, ");
|
||||
check(tmpls, continuation_reasoning(), "<|im_start|>assistant\n<think>\nI'm");
|
||||
}
|
||||
}
|
||||
|
||||
// Test the developer role to system workaround with a simple mock template
|
||||
|
||||
@@ -1584,6 +1584,36 @@ static void test_array_methods(testing & t) {
|
||||
"6"
|
||||
);
|
||||
|
||||
test_template(t, "array|min",
|
||||
"{{ [tool_calls_count, tool_sep_count]|min }}",
|
||||
{{"tool_calls_count", 2}, {"tool_sep_count", 1}},
|
||||
"1"
|
||||
);
|
||||
|
||||
test_template(t, "array|max",
|
||||
"{{ [tool_calls_count, tool_sep_count]|max }}",
|
||||
{{"tool_calls_count", 2}, {"tool_sep_count", 1}},
|
||||
"2"
|
||||
);
|
||||
|
||||
test_template(t, "array|min attribute",
|
||||
"{{ items|min(attribute='x') }}",
|
||||
{{"items", json::array({
|
||||
json({{"x", 2}}),
|
||||
json({{"x", 1}}),
|
||||
})}},
|
||||
"{'x': 1}"
|
||||
);
|
||||
|
||||
test_template(t, "array|max attribute",
|
||||
"{{ items|max(attribute='x') }}",
|
||||
{{"items", json::array({
|
||||
json({{"x", 2}}),
|
||||
json({{"x", 1}}),
|
||||
})}},
|
||||
"{'x': 2}"
|
||||
);
|
||||
|
||||
// not used by any chat templates
|
||||
// test_template(t, "array.insert()",
|
||||
// "{% set _ = arr.insert(1, 'x') %}{{ arr|join(',') }}",
|
||||
|
||||
@@ -412,6 +412,9 @@ static bool arch_supported(const llm_arch arch) {
|
||||
if (arch == LLM_ARCH_DEEPSEEK2OCR) {
|
||||
return false;
|
||||
}
|
||||
if (arch == LLM_ARCH_DEEPSEEK4) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// FIXME some models are segfaulting with WebGPU:
|
||||
#ifdef GGML_USE_WEBGPU
|
||||
|
||||
@@ -1,288 +0,0 @@
|
||||
// Tests common_regex (esp. its partial final matches support).
|
||||
|
||||
#include "common.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <iostream>
|
||||
#include <optional>
|
||||
|
||||
template <class T> static void assert_equals(const T & expected, const T & actual) {
|
||||
if (expected != actual) {
|
||||
std::cerr << "Expected: " << expected << std::endl;
|
||||
std::cerr << " Actual: " << actual << std::endl;
|
||||
std::cerr << std::flush;
|
||||
throw std::runtime_error("Test failed");
|
||||
}
|
||||
}
|
||||
|
||||
struct test_case {
|
||||
std::string pattern;
|
||||
struct input_output {
|
||||
std::string input;
|
||||
common_regex_match output;
|
||||
};
|
||||
std::vector<input_output> inputs_outputs;
|
||||
};
|
||||
|
||||
static std::string common_regex_match_type_name(common_regex_match_type type) {
|
||||
switch (type) {
|
||||
case COMMON_REGEX_MATCH_TYPE_NONE:
|
||||
return "COMMON_REGEX_MATCH_TYPE_NONE";
|
||||
case COMMON_REGEX_MATCH_TYPE_PARTIAL:
|
||||
return "COMMON_REGEX_MATCH_TYPE_PARTIAL";
|
||||
case COMMON_REGEX_MATCH_TYPE_FULL:
|
||||
return "COMMON_REGEX_MATCH_TYPE_FULL";
|
||||
}
|
||||
return "?";
|
||||
}
|
||||
|
||||
static void test_regex() {
|
||||
printf("[%s]\n", __func__);
|
||||
auto test = [](const test_case & test_case) {
|
||||
common_regex cr(test_case.pattern);
|
||||
std::cout << "Testing pattern: /" << test_case.pattern << "/\n";
|
||||
// std::cout << " partial rev: " << cr.reversed_partial_pattern.str() << '\n';
|
||||
for (const auto & input_output : test_case.inputs_outputs) {
|
||||
std::cout << " Input: " << input_output.input << '\n';
|
||||
auto m = cr.search(input_output.input, 0);
|
||||
if (m != input_output.output) {
|
||||
auto match_to_str = [&](const std::optional<common_regex_match> & m) {
|
||||
std::ostringstream ss;
|
||||
if (m->type == COMMON_REGEX_MATCH_TYPE_NONE) {
|
||||
ss << "<no match>";
|
||||
} else {
|
||||
GGML_ASSERT(!input_output.output.groups.empty());
|
||||
std::vector<std::string> parts;
|
||||
for (const auto & g : m->groups) {
|
||||
parts.push_back("{" + std::to_string(g.begin) + ", " + std::to_string(g.end) + "}");
|
||||
}
|
||||
ss << "{" << common_regex_match_type_name(m->type) << ", {" << string_join(parts, ", ") << "}}";
|
||||
}
|
||||
return ss.str();
|
||||
};
|
||||
std::cout << " Expected: " << match_to_str(input_output.output) << '\n';
|
||||
std::cout << " Got: " << match_to_str(m) << '\n';
|
||||
std::cout << " Inverted pattern: /" << regex_to_reversed_partial_regex(test_case.pattern) << "/\n";
|
||||
|
||||
throw std::runtime_error("Test failed");
|
||||
}
|
||||
}
|
||||
};
|
||||
test({
|
||||
"a",
|
||||
{
|
||||
{"a", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
|
||||
{"b", {COMMON_REGEX_MATCH_TYPE_NONE, {}}},
|
||||
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
|
||||
{"ba", {COMMON_REGEX_MATCH_TYPE_FULL, {{1, 2}}}},
|
||||
}
|
||||
});
|
||||
test({
|
||||
"abcd",
|
||||
{
|
||||
{"abcd", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
|
||||
{"abcde", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
|
||||
{"abc", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
|
||||
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
|
||||
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
|
||||
{"d", {}},
|
||||
{"bcd", {}},
|
||||
{"cde", {}},
|
||||
{"cd", {}},
|
||||
{"yeah ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{5, 7}}}},
|
||||
{"abbie", {}},
|
||||
{"", {}},
|
||||
}
|
||||
});
|
||||
test({
|
||||
".*?ab",
|
||||
{
|
||||
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
|
||||
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
|
||||
{"dab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
|
||||
{"dabc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
|
||||
{"da", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
|
||||
{"d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
|
||||
}
|
||||
});
|
||||
test({
|
||||
"a.*?b",
|
||||
{
|
||||
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
|
||||
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
|
||||
{"a b", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
|
||||
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
|
||||
{"argh", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
|
||||
{"d", {}},
|
||||
{"b", {}},
|
||||
}
|
||||
});
|
||||
test({
|
||||
"ab(?:cd){2,4}ef",
|
||||
{
|
||||
// {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, 0, {}}},
|
||||
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
|
||||
{"abcd", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
|
||||
{"abcde", {}},
|
||||
{"abcdef", {}},
|
||||
{"abcdcd", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
|
||||
{"abcdcde", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 7}}}},
|
||||
{"abcdcdef", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}}}},
|
||||
{"abcdcdcdcdef", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 12}}}},
|
||||
{"abcdcdcdcdcdef", {}},
|
||||
{"abcde", {}},
|
||||
{"yea", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{2, 3}}}},
|
||||
}
|
||||
});
|
||||
test({
|
||||
"a(?:rte| pure )fact",
|
||||
{
|
||||
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
|
||||
{"art", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
|
||||
{"artefa", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
|
||||
{"fact", {}},
|
||||
{"an arte", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{3, 7}}}},
|
||||
{"artefact", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}}}},
|
||||
{"an artefact", {COMMON_REGEX_MATCH_TYPE_FULL, {{3, 11}}}},
|
||||
{"a pure", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
|
||||
{"a pure fact", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 11}}}},
|
||||
{"it's a pure fact", {COMMON_REGEX_MATCH_TYPE_FULL, {{5, 16}}}},
|
||||
{"" , {}},
|
||||
{"pure", {}},
|
||||
{"pure fact", {}},
|
||||
}
|
||||
});
|
||||
test({
|
||||
"abc",
|
||||
{
|
||||
{" abcc", {COMMON_REGEX_MATCH_TYPE_FULL, {{1, 4}}}},
|
||||
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
|
||||
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
|
||||
{" ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{1, 3}}}},
|
||||
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
|
||||
{"b", {}},
|
||||
{"c", {}},
|
||||
{"", {}},
|
||||
}
|
||||
});
|
||||
|
||||
test({
|
||||
"(?:abc)?\\s*def",
|
||||
{
|
||||
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
|
||||
{"abc", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
|
||||
{"abc ", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
|
||||
{"abc d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 5}}}},
|
||||
{"abc de", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
|
||||
{"abc def", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
|
||||
{"abc defg", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
|
||||
{"abc defgh", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
|
||||
{"abcde", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 5}}}},
|
||||
{"abcdefgh", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 6}}}},
|
||||
{" d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
|
||||
{"def", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
|
||||
}
|
||||
});
|
||||
|
||||
test({
|
||||
"a+b",
|
||||
{
|
||||
{"aaab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
|
||||
{"aaa", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
|
||||
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
|
||||
}
|
||||
});
|
||||
|
||||
test({
|
||||
"(?:"
|
||||
"(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
|
||||
"(" // match 2 (open_tag)
|
||||
"<tool_call>"
|
||||
"|<function_call>"
|
||||
"|<tool>"
|
||||
"|<tools>"
|
||||
"|<response>"
|
||||
"|<json>"
|
||||
"|<xml>"
|
||||
"|<JSON>"
|
||||
")?"
|
||||
"(\\s*\\{\\s*\"name\"\\s*:)" // match 3 (named tool call)
|
||||
")"
|
||||
"|<function=([^>]+)>" // match 4 (function name)
|
||||
"|<function name=\"([^\"]+)\">", // match 5 (function name again)
|
||||
{
|
||||
{"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}, {54, 54}, {54, 54}, {0, 8}, {54, 54}, {54, 54}}}},
|
||||
{"<tool_call> {\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 18}}}},
|
||||
{"<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 17}}}},
|
||||
{"Let's call something\n<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{21, 38}}}},
|
||||
{"Ok then<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{7, 24}}}},
|
||||
{"{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
|
||||
{"Ok then{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{7, 13}}}},
|
||||
{"<tool_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 20}, {66, 66}, {0, 11}, {11, 20}, {66, 66}, {66, 66}}}},
|
||||
{"<function_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 24}, {70, 70}, {0, 15}, {15, 24}, {70, 70}, {70, 70}}}},
|
||||
{"<function name=\"special_function\"> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 34}, {89, 89}, {89, 89}, {89, 89}, {89, 89}, {16, 32}}}},
|
||||
{"<function=all>", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 14}, {14, 14}, {14, 14}, {14, 14}, {10, 13}, {14, 14}}}},
|
||||
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
static void test_regex_to_reversed_partial_regex() {
|
||||
printf("[%s]\n", __func__);
|
||||
|
||||
assert_equals<std::string>(
|
||||
"^((?:(?:c)?b)?a)",
|
||||
regex_to_reversed_partial_regex("abc"));
|
||||
|
||||
assert_equals<std::string>(
|
||||
"^(a+)",
|
||||
regex_to_reversed_partial_regex("a+"));
|
||||
|
||||
assert_equals<std::string>(
|
||||
"^(a*)",
|
||||
regex_to_reversed_partial_regex("a*"));
|
||||
|
||||
assert_equals<std::string>(
|
||||
"^(a?)",
|
||||
regex_to_reversed_partial_regex("a?"));
|
||||
|
||||
assert_equals<std::string>(
|
||||
"^([a-z])",
|
||||
regex_to_reversed_partial_regex("[a-z]"));
|
||||
|
||||
assert_equals<std::string>(
|
||||
"^((?:\\w+)?[a-z])",
|
||||
regex_to_reversed_partial_regex("[a-z]\\w+"));
|
||||
|
||||
assert_equals<std::string>(
|
||||
"^((?:a|b))",
|
||||
regex_to_reversed_partial_regex("(?:a|b)"));
|
||||
assert_equals<std::string>(
|
||||
"^((?:(?:(?:d)?c)?b)?a)",
|
||||
regex_to_reversed_partial_regex("abcd"));
|
||||
assert_equals<std::string>(
|
||||
"^((?:b)?a*)", // TODO: ((?:b)?a*+).* ??
|
||||
regex_to_reversed_partial_regex("a*b"));
|
||||
assert_equals<std::string>(
|
||||
"^((?:(?:b)?a)?.*)",
|
||||
regex_to_reversed_partial_regex(".*?ab"));
|
||||
assert_equals<std::string>(
|
||||
"^((?:(?:b)?.*)?a)",
|
||||
regex_to_reversed_partial_regex("a.*?b"));
|
||||
assert_equals<std::string>(
|
||||
"^((?:(?:d)?(?:(?:c)?b))?a)",
|
||||
regex_to_reversed_partial_regex("a(bc)d"));
|
||||
assert_equals<std::string>(
|
||||
"^((?:(?:(?:c)?b|(?:e)?d))?a)",
|
||||
regex_to_reversed_partial_regex("a(bc|de)"));
|
||||
assert_equals<std::string>(
|
||||
"^((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)",
|
||||
regex_to_reversed_partial_regex("ab{2,4}c"));
|
||||
}
|
||||
|
||||
int main() {
|
||||
test_regex_to_reversed_partial_regex();
|
||||
test_regex();
|
||||
std::cout << "All tests passed.\n";
|
||||
}
|
||||
@@ -1538,6 +1538,19 @@ private:
|
||||
/* media_path */ params_base.media_path,
|
||||
/* force_pure_content */ params_base.force_pure_content_parser
|
||||
};
|
||||
|
||||
{
|
||||
auto caps = common_chat_templates_get_caps(chat_params.tmpls.get());
|
||||
auto it = params_base.default_template_kwargs.find("preserve_reasoning");
|
||||
bool supported = caps.at("supports_preserve_reasoning");
|
||||
bool enabled = it != params_base.default_template_kwargs.end();
|
||||
if (supported && !enabled) {
|
||||
SRV_INF("%s", "chat template supports preserving reasoning, consider enabling it via --reasoning-preserve\n");
|
||||
}
|
||||
if (!supported && enabled) {
|
||||
SRV_WRN("%s", "chat template does NOT support preserving reasoning, --reasoning-preserve has no effect\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
@@ -2450,6 +2463,8 @@ private:
|
||||
|
||||
server_slot * slot = get_slot_by_cmpl_id(task.params.control_cmpl_id);
|
||||
if (slot == nullptr) {
|
||||
SRV_WRN("control %s on unknown completion id=%s, no live slot\n",
|
||||
task.params.control_action.c_str(), task.params.control_cmpl_id.c_str());
|
||||
res->success = false;
|
||||
res->message = "no active completion for this id";
|
||||
queue_results.send(std::move(res));
|
||||
|
||||
@@ -1983,7 +1983,10 @@ void server_models_routes::init_routes() {
|
||||
cli.set_read_timeout(0, STREAM_LOOKUP_TIMEOUT_MS * 1000);
|
||||
cli.set_write_timeout(0, STREAM_LOOKUP_TIMEOUT_MS * 1000);
|
||||
auto resp = cli.Delete(child_path.c_str());
|
||||
(void) resp; // best effort, 404 and network errors are equivalent to no op
|
||||
(void) resp; // the child logs its own miss when the session is unknown there
|
||||
} else {
|
||||
SRV_WRN("router stop for unknown conv_id=%s, no owning child in the conv map\n",
|
||||
conv_id.c_str());
|
||||
}
|
||||
// drop the tracking entry, the session is being torn down
|
||||
models.conv_models.forget(conv_id);
|
||||
|
||||
@@ -218,6 +218,13 @@ void stream_session_manager::evict_and_cancel(const std::string & conversation_i
|
||||
std::unique_lock<std::shared_mutex> lock(map_mu);
|
||||
auto it = sessions.find(conversation_id);
|
||||
if (it == sessions.end()) {
|
||||
std::string live;
|
||||
for (const auto & kv : sessions) {
|
||||
if (!live.empty()) live += ", ";
|
||||
live += kv.first;
|
||||
}
|
||||
SRV_WRN("stop on unknown stream session, conv_id=%s matched nothing, %zu live: [%s]\n",
|
||||
conversation_id.c_str(), sessions.size(), live.c_str());
|
||||
return;
|
||||
}
|
||||
s = it->second;
|
||||
|
||||
@@ -8,6 +8,7 @@ set(UI_SOURCE_GLOBS
|
||||
set(UI_SOURCE_FILES
|
||||
package.json
|
||||
package-lock.json
|
||||
src/.gitignore
|
||||
vite.config.ts
|
||||
svelte.config.js
|
||||
tsconfig.json
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
!*
|
||||
+1
-1
@@ -33,7 +33,7 @@
|
||||
|
||||
{#if !readonly && onRemove}
|
||||
<div
|
||||
class="absolute top-10 right-2 flex items-center justify-center opacity-0 transition-opacity group-focus-within:opacity-100 group-hover:opacity-100"
|
||||
class="absolute top-10 right-2 flex items-center justify-center opacity-0 transition-opacity group-hover:opacity-100"
|
||||
>
|
||||
<ActionIcon icon={X} tooltip="Remove" stopPropagationOnClick onclick={() => onRemove?.()} />
|
||||
</div>
|
||||
|
||||
+1
-1
@@ -56,7 +56,7 @@
|
||||
<div class="relative flex h-6 items-center justify-between">
|
||||
<div class="right-0 flex items-center gap-2 opacity-100 transition-opacity">
|
||||
<div
|
||||
class="pointer-events-auto inset-0 flex items-center gap-1 opacity-0 transition-all duration-150 group-focus-within:opacity-100 group-hover:opacity-100"
|
||||
class="pointer-events-auto inset-0 flex items-center gap-1 opacity-0 transition-all duration-150 group-hover:opacity-100"
|
||||
>
|
||||
<ActionIcon icon={Edit} tooltip="Edit" onclick={editCtx.handleEdit} />
|
||||
<ActionIcon icon={Trash2} tooltip="Delete" onclick={onDelete} />
|
||||
|
||||
+81
-56
@@ -39,6 +39,7 @@
|
||||
depth = 0
|
||||
}: Props = $props();
|
||||
|
||||
let renderActionsDropdown = $state(false);
|
||||
let dropdownOpen = $state(false);
|
||||
|
||||
let isLoading = $derived(getAllLoadingChats().includes(conversation.id));
|
||||
@@ -70,10 +71,26 @@
|
||||
}
|
||||
}
|
||||
|
||||
function handleMouseLeave() {
|
||||
if (!dropdownOpen) {
|
||||
renderActionsDropdown = false;
|
||||
}
|
||||
}
|
||||
|
||||
function handleMouseOver() {
|
||||
renderActionsDropdown = true;
|
||||
}
|
||||
|
||||
function handleSelect() {
|
||||
onSelect?.(conversation.id);
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (!dropdownOpen) {
|
||||
renderActionsDropdown = false;
|
||||
}
|
||||
});
|
||||
|
||||
onMount(() => {
|
||||
document.addEventListener('edit-active-conversation', handleGlobalEditEvent as EventListener);
|
||||
|
||||
@@ -86,19 +103,23 @@
|
||||
});
|
||||
</script>
|
||||
|
||||
<div
|
||||
class="conversation-item group relative flex min-h-9 w-full items-center justify-between space-x-3 rounded-lg py-1.5 transition-colors hover:bg-foreground/10 {isActive
|
||||
<!-- svelte-ignore a11y_mouse_events_have_key_events -->
|
||||
<button
|
||||
class="group flex min-h-9 w-full cursor-pointer items-center justify-between space-x-3 rounded-lg py-1.5 text-left transition-colors hover:bg-foreground/10 {isActive
|
||||
? 'bg-foreground/5 text-accent-foreground'
|
||||
: ''} px-3"
|
||||
onclick={handleSelect}
|
||||
onmouseover={handleMouseOver}
|
||||
onmouseleave={handleMouseLeave}
|
||||
onfocusin={handleMouseOver}
|
||||
onfocusout={(e) => {
|
||||
if (!e.currentTarget.contains(e.relatedTarget as Node | null)) {
|
||||
handleMouseLeave();
|
||||
}
|
||||
}}
|
||||
>
|
||||
<button
|
||||
class="absolute inset-0 z-0 cursor-pointer rounded-lg focus:outline-none focus-visible:ring-2 focus-visible:ring-ring"
|
||||
onclick={handleSelect}
|
||||
aria-label={conversation.name}
|
||||
>
|
||||
</button>
|
||||
<div
|
||||
class="pointer-events-none relative z-10 flex min-w-0 flex-1 items-center gap-2"
|
||||
class="flex min-w-0 flex-1 items-center gap-2"
|
||||
style:padding-left="{depth * FORK_TREE_DEPTH_PADDING}px"
|
||||
>
|
||||
{#if depth > 0}
|
||||
@@ -109,7 +130,7 @@
|
||||
<a
|
||||
{...props}
|
||||
href={RouterService.chat(conversation.forkedFromConversationId)}
|
||||
class="pointer-events-auto flex shrink-0 items-center text-muted-foreground transition-colors hover:text-foreground"
|
||||
class="flex shrink-0 items-center text-muted-foreground transition-colors hover:text-foreground"
|
||||
>
|
||||
<GitBranch class="h-3.5 w-3.5" />
|
||||
</a>
|
||||
@@ -125,15 +146,18 @@
|
||||
{#if isLoading}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<button
|
||||
class="stop-button pointer-events-auto flex h-4 w-4 shrink-0 cursor-pointer items-center justify-center rounded text-muted-foreground transition-colors hover:text-foreground"
|
||||
<div
|
||||
class="stop-button flex h-4 w-4 shrink-0 cursor-pointer items-center justify-center rounded text-muted-foreground transition-colors hover:text-foreground"
|
||||
onclick={handleStop}
|
||||
onkeydown={(e) => e.key === 'Enter' && handleStop(e)}
|
||||
role="button"
|
||||
tabindex="0"
|
||||
aria-label="Stop generation"
|
||||
>
|
||||
<Loader2 class="loading-icon h-3.5 w-3.5 animate-spin" />
|
||||
|
||||
<Square class="stop-icon hidden h-3 w-3 fill-current text-destructive" />
|
||||
</button>
|
||||
</div>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content>
|
||||
@@ -145,50 +169,52 @@
|
||||
<TruncatedText text={conversation.name} class="text-sm font-medium" showTooltip={false} />
|
||||
</div>
|
||||
|
||||
<div class="actions pointer-events-auto relative z-20 flex items-center">
|
||||
<DropdownMenuActions
|
||||
triggerIcon={MoreHorizontal}
|
||||
triggerTooltip="More actions"
|
||||
bind:open={dropdownOpen}
|
||||
actions={[
|
||||
{
|
||||
icon: conversation.pinned ? PinOff : Pin,
|
||||
label: conversation.pinned ? 'Unpin' : 'Pin',
|
||||
onclick: (e: Event) => {
|
||||
e.stopPropagation();
|
||||
handleTogglePin();
|
||||
}
|
||||
},
|
||||
{
|
||||
icon: Pencil,
|
||||
label: 'Edit',
|
||||
onclick: handleEdit,
|
||||
shortcut: ['shift', 'cmd', 'e']
|
||||
},
|
||||
{
|
||||
icon: Download,
|
||||
label: 'Export',
|
||||
onclick: (e: Event) => {
|
||||
e.stopPropagation();
|
||||
conversationsStore.downloadConversation(conversation.id);
|
||||
{#if renderActionsDropdown}
|
||||
<div class="actions flex items-center">
|
||||
<DropdownMenuActions
|
||||
triggerIcon={MoreHorizontal}
|
||||
triggerTooltip="More actions"
|
||||
bind:open={dropdownOpen}
|
||||
actions={[
|
||||
{
|
||||
icon: conversation.pinned ? PinOff : Pin,
|
||||
label: conversation.pinned ? 'Unpin' : 'Pin',
|
||||
onclick: (e: Event) => {
|
||||
e.stopPropagation();
|
||||
handleTogglePin();
|
||||
}
|
||||
},
|
||||
shortcut: ['shift', 'cmd', 's']
|
||||
},
|
||||
{
|
||||
icon: Trash2,
|
||||
label: 'Delete',
|
||||
onclick: handleDelete,
|
||||
variant: 'destructive',
|
||||
shortcut: ['shift', 'cmd', 'd'],
|
||||
separator: true
|
||||
}
|
||||
]}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
{
|
||||
icon: Pencil,
|
||||
label: 'Edit',
|
||||
onclick: handleEdit,
|
||||
shortcut: ['shift', 'cmd', 'e']
|
||||
},
|
||||
{
|
||||
icon: Download,
|
||||
label: 'Export',
|
||||
onclick: (e: Event) => {
|
||||
e.stopPropagation();
|
||||
conversationsStore.downloadConversation(conversation.id);
|
||||
},
|
||||
shortcut: ['shift', 'cmd', 's']
|
||||
},
|
||||
{
|
||||
icon: Trash2,
|
||||
label: 'Delete',
|
||||
onclick: handleDelete,
|
||||
variant: 'destructive',
|
||||
shortcut: ['shift', 'cmd', 'd'],
|
||||
separator: true
|
||||
}
|
||||
]}
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
</button>
|
||||
|
||||
<style>
|
||||
.conversation-item {
|
||||
button {
|
||||
:global([data-slot='dropdown-menu-trigger']:not([data-state='open'])) {
|
||||
opacity: 0;
|
||||
}
|
||||
@@ -213,8 +239,7 @@
|
||||
}
|
||||
}
|
||||
|
||||
&:is(:hover) .stop-button,
|
||||
&:focus-within .stop-button {
|
||||
&:is(:hover) .stop-button {
|
||||
:global(.stop-icon) {
|
||||
display: block;
|
||||
}
|
||||
|
||||
@@ -154,7 +154,13 @@ class ChatStore {
|
||||
});
|
||||
if (convId === conversationsStore.activeConversation?.id) this.currentResponse = response;
|
||||
}
|
||||
private clearChatStreaming(convId: string): void {
|
||||
private clearChatStreaming(convId: string, messageId?: string): void {
|
||||
// session aware: a stale generation must not wipe a newer one's streaming state on the
|
||||
// same conversation, that would drop the frozen stop identity and stop the wrong session
|
||||
if (messageId !== undefined) {
|
||||
const cur = this.chatStreamingStates.get(convId);
|
||||
if (cur && cur.messageId !== messageId) return;
|
||||
}
|
||||
this.chatStreamingStates.delete(convId);
|
||||
if (convId === conversationsStore.activeConversation?.id) this.currentResponse = '';
|
||||
}
|
||||
@@ -1055,11 +1061,14 @@ class ChatStore {
|
||||
modelOverride?: string | null,
|
||||
firstUserMessageContent?: string
|
||||
): Promise<void> {
|
||||
let effectiveModel = modelOverride;
|
||||
// the ::model suffix in the stream identity is only for router mode, where it routes to the
|
||||
// owning child. in single-model mode the identity stays the bare conv id so that attach, stop
|
||||
// and reattach all agree, regardless of fresh send vs regenerate passing a resolved model
|
||||
let effectiveModel: string | null | undefined = undefined;
|
||||
|
||||
if (isRouterMode() && !effectiveModel) {
|
||||
if (isRouterMode()) {
|
||||
const conversationModel = this.getConversationModel(allMessages);
|
||||
effectiveModel = selectedModelName() || conversationModel;
|
||||
effectiveModel = modelOverride || selectedModelName() || conversationModel;
|
||||
}
|
||||
|
||||
if (isRouterMode() && effectiveModel) {
|
||||
@@ -1074,6 +1083,9 @@ class ChatStore {
|
||||
let resolvedModel: string | null = null;
|
||||
let modelPersisted = false;
|
||||
const convId = assistantMessage.convId;
|
||||
// freeze the POST identity from t0 so a stop cancels with the exact session key,
|
||||
// never a stale or empty model resolved later
|
||||
this.setChatStreaming(convId, streamedContent, currentMessageId, effectiveModel);
|
||||
|
||||
const recordModel = (modelName: string | null | undefined, persistImmediately = true): void => {
|
||||
if (!modelName) return;
|
||||
@@ -1103,7 +1115,7 @@ class ChatStore {
|
||||
};
|
||||
|
||||
const updateStreamingUI = () => {
|
||||
this.setChatStreaming(convId, streamedContent, currentMessageId);
|
||||
this.setChatStreaming(convId, streamedContent, currentMessageId, effectiveModel);
|
||||
const idx = conversationsStore.findMessageIndex(currentMessageId);
|
||||
conversationsStore.updateMessageAtIndex(idx, { content: streamedContent });
|
||||
};
|
||||
@@ -1111,7 +1123,7 @@ class ChatStore {
|
||||
const cleanupStreamingState = () => {
|
||||
this.setStreamingActive(false);
|
||||
this.setChatLoading(convId, false);
|
||||
this.clearChatStreaming(convId);
|
||||
this.clearChatStreaming(convId, currentMessageId);
|
||||
this.setProcessingState(convId, null);
|
||||
};
|
||||
|
||||
@@ -1128,7 +1140,7 @@ class ChatStore {
|
||||
onReasoningChunk: (chunk: string) => {
|
||||
streamedReasoningContent += chunk;
|
||||
// mark streaming state so a stop mid-thinking can persist the partial reasoning
|
||||
this.setChatStreaming(convId, streamedContent, currentMessageId);
|
||||
this.setChatStreaming(convId, streamedContent, currentMessageId, effectiveModel);
|
||||
const idx = conversationsStore.findMessageIndex(currentMessageId);
|
||||
conversationsStore.updateMessageAtIndex(idx, {
|
||||
reasoningContent: streamedReasoningContent
|
||||
@@ -1405,7 +1417,7 @@ class ChatStore {
|
||||
// detached drain keeps producing tokens until eos or max_tokens. use the frozen identity
|
||||
// captured when the session started, not the live dropdown
|
||||
const streamStateForStop = this.chatStreamingStates.get(convId);
|
||||
const modelForStop = streamStateForStop?.model ?? selectedModelName();
|
||||
const modelForStop = streamStateForStop?.model;
|
||||
void ChatService.cancelServerStream(convId, modelForStop);
|
||||
this.abortRequest(convId);
|
||||
this.setChatLoading(convId, false);
|
||||
@@ -1846,6 +1858,14 @@ class ChatStore {
|
||||
updateStreamingContent(originalContent + appendedContent);
|
||||
this.setChatReasoning(msg.convId, false);
|
||||
},
|
||||
onCompletionId: (id: string) => {
|
||||
if (!id) return;
|
||||
// refresh the message id so a later skip targets the live slot after a continue
|
||||
conversationsStore.updateMessageAtIndex(conversationsStore.findMessageIndex(msg.id), {
|
||||
completionId: id
|
||||
});
|
||||
DatabaseService.updateMessage(msg.id, { completionId: id }).catch(() => {});
|
||||
},
|
||||
onReasoningChunk: (chunk: string) => {
|
||||
appendedReasoning += chunk;
|
||||
hasReceivedContent = true;
|
||||
|
||||
Reference in New Issue
Block a user