mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-29 02:33:03 +02:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b3fed31b99 | |||
| dbdaece23d | |||
| 7cb8576e7c | |||
| fa72bc6826 | |||
| c818263f2a | |||
| f68a788b0b | |||
| d1b34251bc | |||
| c1a1c8ee94 |
+16
-2
@@ -467,7 +467,7 @@ void common_models_handler_apply(common_models_handler & handler, common_params
|
||||
// the first part is what gets loaded, so point params.model.path at it
|
||||
if (!url_tasks.empty()) {
|
||||
std::string first_path = url_tasks.front().local_path;
|
||||
url_tasks.front().on_done = [&]() { params.model.path = first_path; };
|
||||
url_tasks.front().on_done = [&, first_path]() { params.model.path = first_path; };
|
||||
}
|
||||
for (auto & task : url_tasks) {
|
||||
tasks.push_back(std::move(task));
|
||||
@@ -3296,6 +3296,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.sampling.reasoning_budget_message = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_MESSAGE"));
|
||||
add_opt(common_arg(
|
||||
{"--reasoning-preserve"},
|
||||
{"--no-reasoning-preserve"},
|
||||
"preserve reasoning trace in the full history, not just the last assistant message (default: template default)\n"
|
||||
"compatible with certain templates having 'supports_preserve_reasoning' capability\n"
|
||||
"example: https://docs.z.ai/guides/capabilities/thinking-mode#preserved-thinking",
|
||||
[](common_params & params, bool value) {
|
||||
if (value) {
|
||||
params.default_template_kwargs["preserve_reasoning"] = "true";
|
||||
} else {
|
||||
params.default_template_kwargs["preserve_reasoning"] = "false";
|
||||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_REASONING_PRESERVE"));
|
||||
add_opt(common_arg(
|
||||
{"--chat-template"}, "JINJA_TEMPLATE",
|
||||
string_format(
|
||||
@@ -3471,7 +3485,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params) {
|
||||
params.offline = true;
|
||||
}
|
||||
).set_env("LLAMA_ARG_OFFLINE"));
|
||||
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_OFFLINE"));
|
||||
add_opt(common_arg(
|
||||
{"-lv", "--verbosity", "--log-verbosity"}, "N",
|
||||
string_format("Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:\n"
|
||||
|
||||
+155
@@ -912,6 +912,10 @@ static std::string common_chat_template_direct_apply_impl(
|
||||
if (inputs.add_generation_prompt) {
|
||||
inp["add_generation_prompt"] = true;
|
||||
}
|
||||
if (inp.contains("preserve_reasoning") && inp["preserve_reasoning"].is_boolean()) {
|
||||
bool enabled = inp["preserve_reasoning"].get<bool>();
|
||||
jinja::caps_apply_preserve_reasoning(ctx, enabled);
|
||||
}
|
||||
|
||||
jinja::global_from_json(ctx, inp, inputs.mark_input);
|
||||
|
||||
@@ -2376,6 +2380,149 @@ static void func_args_not_string(json & messages) {
|
||||
|
||||
}
|
||||
|
||||
// MiniCPM5 format:
|
||||
// - Reasoning: <think>{reasoning}</think> (optional)
|
||||
// - Tool calls: <function name="foo"><param name="bar">value</param></function>
|
||||
static common_chat_params common_chat_params_init_minicpm5(const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.preserved_tokens = {
|
||||
"<function",
|
||||
"<param",
|
||||
"</function>",
|
||||
"</param>",
|
||||
"<think>",
|
||||
"</think>",
|
||||
};
|
||||
|
||||
data.thinking_start_tag = "<think>";
|
||||
data.thinking_end_tag = "</think>";
|
||||
|
||||
data.message_delimiters = {
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, "<|im_start|>assistant" },
|
||||
{ COMMON_CHAT_ROLE_TOOL, "<|im_start|>user\n<tool_response>" },
|
||||
{ COMMON_CHAT_ROLE_USER, "<|im_start|>user" },
|
||||
{ COMMON_CHAT_ROLE_SYSTEM, "<|im_start|>system" },
|
||||
};
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto has_response_format = inputs.json_schema.is_object() && !inputs.json_schema.empty();
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE);
|
||||
|
||||
if (inputs.has_continuation()) {
|
||||
const auto & msg = inputs.continue_msg;
|
||||
|
||||
data.generation_prompt = "<|im_start|>assistant\n<think>\n" + msg.reasoning_content;
|
||||
if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) {
|
||||
data.generation_prompt += "\n</think>\n\n" + msg.render_content();
|
||||
}
|
||||
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto generation_prompt = p.literal("<|im_start|>assistant\n");
|
||||
|
||||
auto reasoning = p.eps();
|
||||
if (extract_reasoning) {
|
||||
reasoning = ("<think>" << p.reasoning(p.until("</think>")) << "</think>") + p.space();
|
||||
}
|
||||
|
||||
// Response format parser
|
||||
if (has_response_format) {
|
||||
return generation_prompt + reasoning + p.content(p.schema(p.json(), "response-format", inputs.json_schema));
|
||||
}
|
||||
|
||||
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
// CDATA lets a value carry characters that would otherwise close the tag (e.g.
|
||||
// </param>); capture the inner text only, excluding the CDATA markers.
|
||||
auto string_value = p.choice({
|
||||
p.literal("<![CDATA[") + p.ac(p.tool_arg_string_value(p.until("]]>")) + p.literal("]]>"), "]]>") + p.tool_arg_close(p.literal("</param>")),
|
||||
p.negate(p.literal("< {
|
||||
const auto & function = tool.at("function");
|
||||
const std::string name = function.at("name");
|
||||
auto params = function.contains("parameters") ? function.at("parameters") : json::object();
|
||||
|
||||
auto args = p.eps();
|
||||
if (params.contains("properties") && params.at("properties").is_object() && !params.at("properties").empty()) {
|
||||
auto schema_info = common_schema_info();
|
||||
schema_info.resolve_refs(params);
|
||||
|
||||
auto arg_choice = p.choice();
|
||||
for (const auto & [prop_name, prop_schema] : params.at("properties").items()) {
|
||||
auto value_parser = p.eps();
|
||||
if (schema_info.resolves_to_string(prop_schema)) {
|
||||
value_parser = string_value;
|
||||
} else {
|
||||
value_parser = p.tool_arg_json_value(
|
||||
p.schema(p.json(), "tool-" + name + "-arg-" + prop_name + "-schema", prop_schema, false)
|
||||
) + p.tool_arg_close(p.literal("</param>"));
|
||||
}
|
||||
|
||||
auto arg_rule = p.tool_arg(
|
||||
p.tool_arg_open(p.literal("<param name=\"") + p.tool_arg_name(p.literal(prop_name)) + p.literal("\">")) +
|
||||
value_parser
|
||||
);
|
||||
|
||||
arg_choice |= arg_rule;
|
||||
}
|
||||
args = p.zero_or_more(arg_choice + p.space());
|
||||
}
|
||||
|
||||
auto tool_parser = p.tool(
|
||||
p.tool_open(p.literal("<function name=\"") + p.tool_name(p.literal(name)) + p.literal("\">"))
|
||||
<< p.tool_args(args)
|
||||
<< p.tool_close(p.literal("</function>")));
|
||||
|
||||
tool_choice |= p.rule("tool-" + name, tool_parser);
|
||||
});
|
||||
|
||||
auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
|
||||
auto tool_calls = p.trigger_rule("tool-call", p.repeat(tool_choice + p.space(), 1, max_calls));
|
||||
|
||||
auto content = p.content(p.until("<function"));
|
||||
|
||||
return generation_prompt + reasoning + content + tool_calls + p.end();
|
||||
}
|
||||
|
||||
return generation_prompt + reasoning + p.content(p.rest()) + p.end();
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
||||
if (include_grammar) {
|
||||
data.grammar_lazy = !(has_response_format || (has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED));
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
auto schema = function.contains("parameters") ? function.at("parameters") : json::object();
|
||||
builder.resolve_refs(schema);
|
||||
});
|
||||
if (has_response_format) {
|
||||
auto schema = inputs.json_schema;
|
||||
builder.resolve_refs(schema);
|
||||
}
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
|
||||
data.grammar_triggers = {
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function" },
|
||||
};
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
static json common_chat_extra_context() {
|
||||
json ctx = json::object();
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
@@ -2468,6 +2615,14 @@ std::optional<common_chat_params> common_chat_try_specialized_template(
|
||||
return common_chat_params_init_gemma4(tmpl, params);
|
||||
}
|
||||
|
||||
// MiniCPM5 - XML tool calls with <function name="..."><param name="...">...</param></function>
|
||||
if (src.find("Tool usage guidelines:") != std::string::npos &&
|
||||
src.find("<function name=\"") != std::string::npos &&
|
||||
src.find("<param name=\"") != std::string::npos) {
|
||||
LOG_DBG("Using specialized template: MiniCPM5\n");
|
||||
return common_chat_params_init_minicpm5(tmpl, params);
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
|
||||
+2
-1
@@ -169,6 +169,7 @@ enum common_speculative_type {
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, // standalone draft model speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, // Eagle3 speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_MTP, // Multi-token prediction
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH, // DFlash speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding based on n-grams
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
|
||||
@@ -384,7 +385,7 @@ struct common_params_speculative {
|
||||
|
||||
uint32_t need_n_rs_seq() const {
|
||||
bool needs_rs_seq = std::any_of(types.begin(), types.end(), [&](auto t) {
|
||||
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP || t == COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3;
|
||||
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP || t == COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3 || t == COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH;
|
||||
});
|
||||
|
||||
return needs_rs_seq ? draft.n_max : 0u;
|
||||
|
||||
+44
-23
@@ -16,22 +16,34 @@ using json = nlohmann::ordered_json;
|
||||
namespace jinja {
|
||||
|
||||
using caps_json_fn = std::function<json()>;
|
||||
using caps_analyze_fn = std::function<void(bool, value &, value &)>;
|
||||
using caps_ctx_fn = std::function<void(context &)>;
|
||||
using caps_analyze_fn = std::function<void(bool, value &, value &, const std::string &)>;
|
||||
|
||||
void caps_apply_preserve_reasoning(jinja::context & ctx, bool enabled) {
|
||||
ctx.set_val("preserve_thinking", mk_val<value_bool>(enabled));
|
||||
ctx.set_val("clear_thinking", mk_val<value_bool>(!enabled));
|
||||
ctx.set_val("truncate_history_thinking", mk_val<value_bool>(!enabled));
|
||||
}
|
||||
|
||||
static void caps_try_execute(jinja::program & prog,
|
||||
const caps_json_fn & messages_fn,
|
||||
const caps_ctx_fn & ctx_fn,
|
||||
const caps_json_fn & tools_fn,
|
||||
const caps_analyze_fn & analyze_fn) {
|
||||
context ctx;
|
||||
ctx.is_get_stats = true;
|
||||
jinja::global_from_json(ctx, json{
|
||||
{"messages", messages_fn()},
|
||||
{"tools", tools_fn()},
|
||||
{"tools", tools_fn ? tools_fn() : json::array()},
|
||||
{"bos_token", ""},
|
||||
{"eos_token", ""},
|
||||
{"add_generation_prompt", true}
|
||||
}, true);
|
||||
|
||||
if (ctx_fn) {
|
||||
ctx_fn(ctx);
|
||||
}
|
||||
|
||||
auto messages = ctx.get_val("messages");
|
||||
auto tools = ctx.get_val("tools");
|
||||
|
||||
@@ -49,7 +61,7 @@ static void caps_try_execute(jinja::program & prog,
|
||||
// ignore exceptions during capability analysis
|
||||
}
|
||||
|
||||
analyze_fn(success, messages, tools);
|
||||
analyze_fn(success, messages, tools, result);
|
||||
}
|
||||
|
||||
// for debugging only
|
||||
@@ -109,11 +121,9 @@ caps caps_get(jinja::program & prog) {
|
||||
}
|
||||
});
|
||||
},
|
||||
[&]() {
|
||||
// tools
|
||||
return json{nullptr};
|
||||
},
|
||||
[&](bool success, value & messages, value &) {
|
||||
nullptr, // ctx_fn
|
||||
nullptr, // tools_fn
|
||||
[&](bool success, value & messages, value &, const std::string &) {
|
||||
auto & content = messages->at(0)->at("content");
|
||||
caps_print_stats(content, "messages[0].content");
|
||||
if (has_op(content, "selectattr") || has_op(content, "array_access")) {
|
||||
@@ -145,11 +155,9 @@ caps caps_get(jinja::program & prog) {
|
||||
},
|
||||
});
|
||||
},
|
||||
[&]() {
|
||||
// tools
|
||||
return json::array();
|
||||
},
|
||||
[&](bool, value & messages, value &) {
|
||||
nullptr, // ctx_fn
|
||||
nullptr, // tools_fn
|
||||
[&](bool, value & messages, value &, const std::string &) {
|
||||
auto & content = messages->at(0)->at("content");
|
||||
caps_print_stats(content, "messages[0].content");
|
||||
if (!content->stats.used) {
|
||||
@@ -201,6 +209,7 @@ caps caps_get(jinja::program & prog) {
|
||||
},
|
||||
});
|
||||
},
|
||||
nullptr, // ctx_fn
|
||||
[&]() {
|
||||
// tools
|
||||
return json::array({
|
||||
@@ -224,7 +233,7 @@ caps caps_get(jinja::program & prog) {
|
||||
},
|
||||
});
|
||||
},
|
||||
[&](bool success, value & messages, value & tools) {
|
||||
[&](bool success, value & messages, value & tools, const std::string &) {
|
||||
if (!success) {
|
||||
return; // Nothing can be inferred
|
||||
}
|
||||
@@ -293,6 +302,7 @@ caps caps_get(jinja::program & prog) {
|
||||
},
|
||||
});
|
||||
},
|
||||
nullptr, // ctx_fn
|
||||
[&]() {
|
||||
// tools
|
||||
return json::array({
|
||||
@@ -316,7 +326,7 @@ caps caps_get(jinja::program & prog) {
|
||||
},
|
||||
});
|
||||
},
|
||||
[&](bool success, value & messages, value & tools) {
|
||||
[&](bool success, value & messages, value & tools, const std::string &) {
|
||||
if (!success) {
|
||||
result.supports_tool_calls = false;
|
||||
result.supports_tools = false;
|
||||
@@ -394,6 +404,7 @@ caps caps_get(jinja::program & prog) {
|
||||
},
|
||||
});
|
||||
},
|
||||
nullptr, // ctx_fn
|
||||
[&]() {
|
||||
// tools
|
||||
return json::array({
|
||||
@@ -417,7 +428,7 @@ caps caps_get(jinja::program & prog) {
|
||||
},
|
||||
});
|
||||
},
|
||||
[&](bool success, value & messages, value & /*tools*/) {
|
||||
[&](bool success, value & messages, value &, const std::string &) {
|
||||
if (!success) {
|
||||
result.supports_parallel_tool_calls = false;
|
||||
return;
|
||||
@@ -438,11 +449,22 @@ caps caps_get(jinja::program & prog) {
|
||||
JJ_DEBUG("%s\n", ">>> Running capability check: preserve reasoning");
|
||||
|
||||
// case: preserve reasoning content in chat history
|
||||
const std::string reasoning_placeholder = "<REASONING_CONTENT_PLACEHOLDER>";
|
||||
caps_try_execute(
|
||||
prog,
|
||||
[&]() {
|
||||
// messages
|
||||
return json::array({
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", "User message"}
|
||||
},
|
||||
{
|
||||
{"role", "assistant"},
|
||||
{"content", "Assistant message"},
|
||||
// check of reasoning_content deeper in the history, not just the last assistant message
|
||||
{"reasoning_content", reasoning_placeholder}
|
||||
},
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", "User message"}
|
||||
@@ -458,14 +480,13 @@ caps caps_get(jinja::program & prog) {
|
||||
},
|
||||
});
|
||||
},
|
||||
[&]() {
|
||||
// tools
|
||||
return json::array();
|
||||
[&](context & ctx) {
|
||||
caps_apply_preserve_reasoning(ctx, true);
|
||||
},
|
||||
[&](bool, value & messages, value &) {
|
||||
auto & content = messages->at(1)->at("reasoning_content");
|
||||
caps_print_stats(content, "messages[1].reasoning_content");
|
||||
if (content->stats.used) {
|
||||
nullptr, // tools_fn
|
||||
[&](bool, value &, value &, const std::string & output) {
|
||||
// note: we cannot use stats here because the reasoning_content may be used for "if" condition test, but not actually outputted in the final result
|
||||
if (output.find(reasoning_placeholder) != std::string::npos) {
|
||||
result.supports_preserve_reasoning = true;
|
||||
}
|
||||
}
|
||||
|
||||
+5
-1
@@ -12,7 +12,9 @@ struct caps {
|
||||
bool supports_tool_calls = true;
|
||||
bool supports_system_role = true;
|
||||
bool supports_parallel_tool_calls = true;
|
||||
bool supports_preserve_reasoning = false; // support assistant message with reasoning_content
|
||||
|
||||
// supports preserve reasoning trace in the full history, not just the last assistant message
|
||||
bool supports_preserve_reasoning = false;
|
||||
|
||||
// one of the 2 content capabilities must be true
|
||||
bool supports_string_content = true;
|
||||
@@ -29,4 +31,6 @@ struct caps {
|
||||
|
||||
caps caps_get(jinja::program & prog);
|
||||
|
||||
void caps_apply_preserve_reasoning(jinja::context & ctx, bool enabled);
|
||||
|
||||
} // namespace jinja
|
||||
|
||||
@@ -954,4 +954,50 @@ value keyword_argument_expression::execute_impl(context & ctx) {
|
||||
return mk_val<value_kwarg>(k, v);
|
||||
}
|
||||
|
||||
std::string runtime::debug_dump_program(const program & prog, const std::string & src) {
|
||||
std::ostringstream oss;
|
||||
size_t lvl = 0;
|
||||
context ctx;
|
||||
ctx.src.reset(new std::string(src));
|
||||
|
||||
auto indent = [](size_t lvl) -> std::string {
|
||||
return std::string(lvl * 2, ' ');
|
||||
};
|
||||
|
||||
ctx.visitor = [&](bool is_leaf, statement * node, std::vector<visitor_pair> children) {
|
||||
oss << indent(lvl) << node->type() << ":\n";
|
||||
lvl++;
|
||||
if (is_leaf) {
|
||||
const auto & pos = node->pos;
|
||||
oss << indent(lvl) << "(leaf) at " << get_line_col(src, pos) << " in source:\n";
|
||||
std::string snippet = peak_source(src, pos);
|
||||
string_replace_all(snippet, "\n", "\n" + indent(lvl));
|
||||
oss << indent(lvl) << snippet << "\n";
|
||||
} else {
|
||||
for (auto & [label, children_vec] : children) {
|
||||
oss << indent(lvl) << label << ":\n";
|
||||
lvl++;
|
||||
if (children_vec.empty()) {
|
||||
oss << indent(lvl) << "<empty>\n\n";
|
||||
} else {
|
||||
for (auto * child : children_vec) {
|
||||
if (!child) {
|
||||
continue;
|
||||
}
|
||||
child->visit(ctx);
|
||||
}
|
||||
}
|
||||
lvl--;
|
||||
}
|
||||
}
|
||||
lvl--;
|
||||
};
|
||||
|
||||
for (const auto & stmt : prog.body) {
|
||||
stmt->visit(ctx);
|
||||
}
|
||||
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
} // namespace jinja
|
||||
|
||||
@@ -47,12 +47,19 @@ const T * cast_stmt(const statement_ptr & ptr) {
|
||||
// not thread-safe
|
||||
void enable_debug(bool enable);
|
||||
|
||||
// for visiting AST nodes
|
||||
// function signature: void(bool is_leaf, statement * node, pair of <label, children>)
|
||||
using visitor_pair = std::pair<std::string, std::vector<statement *>>;
|
||||
using visitor_fn = std::function<void(bool, statement *, std::vector<visitor_pair>)>;
|
||||
|
||||
struct context {
|
||||
std::shared_ptr<std::string> src; // for debugging; use shared_ptr to avoid copying on scope creation
|
||||
std::time_t current_time; // for functions that need current time
|
||||
|
||||
bool is_get_stats = false; // whether to collect stats
|
||||
|
||||
visitor_fn visitor;
|
||||
|
||||
// src is optional, used for error reporting
|
||||
context(std::string src = "") : src(std::make_shared<std::string>(std::move(src))) {
|
||||
env = mk_val<value_object>();
|
||||
@@ -99,6 +106,15 @@ private:
|
||||
value_object env;
|
||||
};
|
||||
|
||||
// utils for visiting AST nodes
|
||||
static std::vector<statement *> stmts_to_ptr(const statements & stmts) {
|
||||
std::vector<statement *> children;
|
||||
for (const auto & stmt : stmts) {
|
||||
children.push_back(stmt.get());
|
||||
}
|
||||
return children;
|
||||
}
|
||||
|
||||
/**
|
||||
* Base class for all nodes in the AST.
|
||||
*/
|
||||
@@ -106,6 +122,7 @@ struct statement {
|
||||
size_t pos; // position in source, for debugging
|
||||
virtual ~statement() = default;
|
||||
virtual std::string type() const { return "Statement"; }
|
||||
virtual void visit(context & ctx) { ctx.visitor(true, this, {}); }
|
||||
|
||||
// execute_impl must be overridden by derived classes
|
||||
virtual value execute_impl(context &) { throw_exec_error(); }
|
||||
@@ -166,6 +183,13 @@ struct if_statement : public statement {
|
||||
|
||||
std::string type() const override { return "If"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"test", {test.get()}},
|
||||
{"body", stmts_to_ptr(body)},
|
||||
{"alternate", stmts_to_ptr(alternate)}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct identifier;
|
||||
@@ -190,6 +214,14 @@ struct for_statement : public statement {
|
||||
|
||||
std::string type() const override { return "For"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"loopvar", {loopvar.get()}},
|
||||
{"iterable", {iterable.get()}},
|
||||
{"body", stmts_to_ptr(body)},
|
||||
{"default_block", stmts_to_ptr(default_block)}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct break_statement : public statement {
|
||||
@@ -241,6 +273,13 @@ struct set_statement : public statement {
|
||||
|
||||
std::string type() const override { return "Set"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"assignee", {assignee.get()}},
|
||||
{"value", {val.get()}},
|
||||
{"body", stmts_to_ptr(body)}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct macro_statement : public statement {
|
||||
@@ -256,6 +295,13 @@ struct macro_statement : public statement {
|
||||
|
||||
std::string type() const override { return "Macro"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"name", {name.get()}},
|
||||
{"args", stmts_to_ptr(args)},
|
||||
{"body", stmts_to_ptr(body)}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct comment_statement : public statement {
|
||||
@@ -289,6 +335,12 @@ struct member_expression : public expression {
|
||||
}
|
||||
std::string type() const override { return "MemberExpression"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"object", {object.get()}},
|
||||
{"property", {property.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct call_expression : public expression {
|
||||
@@ -302,6 +354,12 @@ struct call_expression : public expression {
|
||||
}
|
||||
std::string type() const override { return "CallExpression"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"callee", {callee.get()}},
|
||||
{"args", stmts_to_ptr(args)}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -405,6 +463,12 @@ struct binary_expression : public expression {
|
||||
}
|
||||
std::string type() const override { return "BinaryExpression"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"left", {left.get()}},
|
||||
{"right", {right.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -431,6 +495,12 @@ struct filter_expression : public expression {
|
||||
|
||||
std::string type() const override { return "FilterExpression"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"operand", {operand.get()}},
|
||||
{"filter", {filter.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct filter_statement : public statement {
|
||||
@@ -443,6 +513,12 @@ struct filter_statement : public statement {
|
||||
}
|
||||
std::string type() const override { return "FilterStatement"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"filter", {filter.get()}},
|
||||
{"body", stmts_to_ptr(body)}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -468,6 +544,12 @@ struct select_expression : public expression {
|
||||
}
|
||||
return lhs->execute_impl(ctx);
|
||||
}
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"lhs", {lhs.get()}},
|
||||
{"test", {test.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -486,6 +568,12 @@ struct test_expression : public expression {
|
||||
}
|
||||
std::string type() const override { return "TestExpression"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"operand", {operand.get()}},
|
||||
{"test", {test.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -501,6 +589,11 @@ struct unary_expression : public expression {
|
||||
}
|
||||
std::string type() const override { return "UnaryExpression"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"argument", {argument.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct slice_expression : public expression {
|
||||
@@ -518,6 +611,13 @@ struct slice_expression : public expression {
|
||||
[[noreturn]] value execute_impl(context &) override {
|
||||
throw std::runtime_error("must be handled by MemberExpression");
|
||||
}
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"start_expr", {start_expr.get()}},
|
||||
{"stop_expr", {stop_expr.get()}},
|
||||
{"step_expr", {step_expr.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct keyword_argument_expression : public expression {
|
||||
@@ -531,6 +631,12 @@ struct keyword_argument_expression : public expression {
|
||||
}
|
||||
std::string type() const override { return "KeywordArgumentExpression"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"key", {key.get()}},
|
||||
{"val", {val.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct spread_expression : public expression {
|
||||
@@ -539,6 +645,11 @@ struct spread_expression : public expression {
|
||||
chk_type<expression>(this->argument);
|
||||
}
|
||||
std::string type() const override { return "SpreadExpression"; }
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"argument", {argument.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct call_statement : public statement {
|
||||
@@ -553,6 +664,13 @@ struct call_statement : public statement {
|
||||
}
|
||||
std::string type() const override { return "CallStatement"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"call", {call.get()}},
|
||||
{"caller_args", stmts_to_ptr(caller_args)},
|
||||
{"body", stmts_to_ptr(body)}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct ternary_expression : public expression {
|
||||
@@ -575,6 +693,13 @@ struct ternary_expression : public expression {
|
||||
return false_expr->execute(ctx);
|
||||
}
|
||||
}
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"condition", {condition.get()}},
|
||||
{"true_expr", {true_expr.get()}},
|
||||
{"false_expr", {false_expr.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct raised_exception : public std::exception {
|
||||
@@ -648,6 +773,8 @@ struct runtime {
|
||||
}
|
||||
return parts;
|
||||
}
|
||||
|
||||
static std::string debug_dump_program(const program & prog, const std::string & src);
|
||||
};
|
||||
|
||||
} // namespace jinja
|
||||
|
||||
@@ -1108,6 +1108,50 @@ const func_builtins & value_array_t::get_builtins() const {
|
||||
std::reverse(arr.begin(), arr.end());
|
||||
return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(arr)) : mk_val<value_array>(std::move(arr));
|
||||
}},
|
||||
{"min", [](const func_args & args) -> value {
|
||||
args.ensure_count(1, 4);
|
||||
args.ensure_vals<value_array>();
|
||||
value val_case = args.get_kwarg_or_pos("case_sensitive", 1);
|
||||
value attribute = args.get_kwarg_or_pos("attribute", 2);
|
||||
if (!attribute->is_undefined()) {
|
||||
throw not_implemented_exception("min: attribute not implemented");
|
||||
}
|
||||
// FIXME: min is currently always case sensitive
|
||||
(void) val_case;
|
||||
const auto & arr = args.get_pos(0)->as_array();
|
||||
if (arr.empty()) {
|
||||
return mk_val<value_undefined>();
|
||||
}
|
||||
value result = arr[0];
|
||||
for (size_t i = 1; i < arr.size(); ++i) {
|
||||
if (value_compare(arr[i], result, value_compare_op::lt)) {
|
||||
result = arr[i];
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}},
|
||||
{"max", [](const func_args & args) -> value {
|
||||
args.ensure_count(1, 4);
|
||||
args.ensure_vals<value_array>();
|
||||
value val_case = args.get_kwarg_or_pos("case_sensitive", 1);
|
||||
value attribute = args.get_kwarg_or_pos("attribute", 2);
|
||||
if (!attribute->is_undefined()) {
|
||||
throw not_implemented_exception("max: attribute not implemented");
|
||||
}
|
||||
// FIXME: max is currently always case sensitive
|
||||
(void) val_case;
|
||||
const auto & arr = args.get_pos(0)->as_array();
|
||||
if (arr.empty()) {
|
||||
return mk_val<value_undefined>();
|
||||
}
|
||||
value result = arr[0];
|
||||
for (size_t i = 1; i < arr.size(); ++i) {
|
||||
if (value_compare(arr[i], result, value_compare_op::gt)) {
|
||||
result = arr[i];
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}},
|
||||
{"unique", array_unique_not_implemented},
|
||||
};
|
||||
return builtins;
|
||||
|
||||
+302
-1
@@ -33,6 +33,7 @@ const std::map<std::string, common_speculative_type> common_speculative_type_fro
|
||||
{"draft-simple", COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE},
|
||||
{"draft-eagle3", COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3},
|
||||
{"draft-mtp", COMMON_SPECULATIVE_TYPE_DRAFT_MTP},
|
||||
{"draft-dflash", COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH},
|
||||
{"ngram-simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
|
||||
{"ngram-map-k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
|
||||
{"ngram-map-k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
|
||||
@@ -898,6 +899,296 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
}
|
||||
};
|
||||
|
||||
// DFlash: block-diffusion drafting with a draft-side KV cache injection
|
||||
struct common_speculative_impl_draft_dflash : public common_speculative_impl {
|
||||
common_params_speculative_draft params;
|
||||
|
||||
llama_batch batch; // noise tokens
|
||||
llama_batch batch_inject; // target features for KV cache injection
|
||||
|
||||
std::vector<common_sampler_ptr> smpls;
|
||||
|
||||
int32_t n_embd_dec = 0; // draft hidden size
|
||||
int32_t n_embd_enc = 0; // target_layer_ids_n * target_hidden_size
|
||||
int32_t n_embd_tgt = 0; // target model hidden size
|
||||
|
||||
int32_t block_size = 0;
|
||||
llama_token mask_token_id = 0;
|
||||
|
||||
const int32_t * target_layer_ids = nullptr; // model_dft's extract layer indices
|
||||
uint32_t target_layer_ids_n = 0;
|
||||
|
||||
// scratch buffer for concatenated target features [n_tokens, n_embd_enc]
|
||||
std::vector<float> features_buf;
|
||||
|
||||
common_speculative_impl_draft_dflash(const common_params_speculative & params, uint32_t n_seq)
|
||||
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH, n_seq)
|
||||
, params(params.draft)
|
||||
{
|
||||
auto * ctx_tgt = this->params.ctx_tgt;
|
||||
auto * ctx_dft = this->params.ctx_dft;
|
||||
GGML_ASSERT(ctx_tgt && ctx_dft && "DFlash requires ctx_tgt and ctx_dft to be set");
|
||||
|
||||
const llama_model * model_dft = llama_get_model(ctx_dft);
|
||||
const llama_model * model_tgt = llama_get_model(ctx_tgt);
|
||||
|
||||
target_layer_ids = llama_model_target_layer_ids (model_dft);
|
||||
target_layer_ids_n = llama_model_target_layer_ids_n(model_dft);
|
||||
GGML_ASSERT(target_layer_ids_n > 0 && "DFlash model has no target_layer_ids");
|
||||
|
||||
n_embd_tgt = llama_model_n_embd(model_tgt);
|
||||
n_embd_dec = llama_model_n_embd(model_dft);
|
||||
n_embd_enc = (int32_t) target_layer_ids_n * n_embd_tgt;
|
||||
|
||||
// read the trained block size from the dflash.block_size metadata key
|
||||
block_size = 16;
|
||||
{
|
||||
char buf[32] = {};
|
||||
if (llama_model_meta_val_str(model_dft, "dflash.block_size", buf, sizeof(buf)) >= 0) {
|
||||
block_size = std::atoi(buf);
|
||||
}
|
||||
}
|
||||
mask_token_id = llama_vocab_mask(llama_model_get_vocab(model_dft));
|
||||
|
||||
LOG_INF("%s: adding speculative implementation 'draft-dflash'\n", __func__);
|
||||
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min);
|
||||
LOG_INF("%s: - block_size=%d, mask_token_id=%d, n_extract=%u\n", __func__, block_size, mask_token_id, target_layer_ids_n);
|
||||
|
||||
// DFlash input is [id_last, <mask> * (block_size-1)], so it can draft at most block_size-1 tokens per step
|
||||
if (this->params.n_max > block_size - 1) {
|
||||
LOG_WRN("%s: requested draft size %d exceeds the trained DFlash block size %d -- clamping to %d draft tokens per step\n",
|
||||
__func__, this->params.n_max, block_size - 1, block_size - 1);
|
||||
this->params.n_max = block_size - 1;
|
||||
}
|
||||
|
||||
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, n_seq);
|
||||
batch_inject = llama_batch_init(llama_n_batch(ctx_dft), n_embd_dec, n_seq);
|
||||
|
||||
smpls.resize(n_seq);
|
||||
for (auto & s : smpls) {
|
||||
common_params_sampling sparams;
|
||||
sparams.no_perf = false;
|
||||
sparams.top_k = 1;
|
||||
sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
|
||||
s.reset(common_sampler_init(model_dft, sparams));
|
||||
}
|
||||
|
||||
// turn on extraction of the target layers' input embeddings
|
||||
for (uint32_t k = 0; k < target_layer_ids_n; ++k) {
|
||||
llama_set_embeddings_layer_inp(ctx_tgt, (uint32_t) target_layer_ids[k], true);
|
||||
}
|
||||
|
||||
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
|
||||
llama_set_causal_attn(ctx_dft, false); // DFlash needs non-causal attention
|
||||
}
|
||||
|
||||
~common_speculative_impl_draft_dflash() override {
|
||||
llama_batch_free(batch);
|
||||
llama_batch_free(batch_inject);
|
||||
}
|
||||
|
||||
void begin(llama_seq_id seq_id, const llama_tokens & prompt) override {
|
||||
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int32_t N = (int32_t) prompt.size();
|
||||
if (N <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(params.ctx_dft), seq_id);
|
||||
if (pos_max < N - 1) {
|
||||
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - process() did not run on every prefill ubatch. "
|
||||
"Drafts may degrade.\n",
|
||||
__func__, (int) pos_max, N - 1);
|
||||
}
|
||||
}
|
||||
|
||||
bool process(const llama_batch & batch_in) override {
|
||||
if (batch_in.n_tokens <= 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (batch_in.token == nullptr || batch_in.embd != nullptr) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const int32_t n_tokens = batch_in.n_tokens;
|
||||
|
||||
// per-seq inclusive batch range (assumes each seq's tokens are contiguous in the batch)
|
||||
std::vector<int32_t> i_batch_beg(n_seq, -1);
|
||||
std::vector<int32_t> i_batch_end(n_seq, -1);
|
||||
for (int32_t k = 0; k < n_tokens; ++k) {
|
||||
GGML_ASSERT(batch_in.n_seq_id[k] == 1);
|
||||
const llama_seq_id seq_id = batch_in.seq_id[k][0];
|
||||
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) {
|
||||
continue;
|
||||
}
|
||||
i_batch_end[seq_id] = k;
|
||||
if (i_batch_beg[seq_id] < 0) {
|
||||
i_batch_beg[seq_id] = k;
|
||||
}
|
||||
}
|
||||
|
||||
auto * ctx_tgt = this->params.ctx_tgt;
|
||||
auto * ctx_dft = this->params.ctx_dft;
|
||||
|
||||
const int32_t n_ubatch = (int32_t) llama_n_ubatch(ctx_dft);
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
if (i_batch_beg[seq_id] < 0) {
|
||||
continue;
|
||||
}
|
||||
const int32_t n_rows = i_batch_end[seq_id] - i_batch_beg[seq_id] + 1;
|
||||
|
||||
for (int32_t offset = 0; offset < n_rows; offset += n_ubatch) {
|
||||
const int32_t n_chunk = std::min(n_ubatch, n_rows - offset);
|
||||
|
||||
// gather this chunk's target features, interleaved by extract layer
|
||||
features_buf.resize((size_t) n_chunk * n_embd_enc);
|
||||
for (uint32_t k = 0; k < target_layer_ids_n; ++k) {
|
||||
const float * layer = llama_get_embeddings_layer_inp(ctx_tgt, (uint32_t) target_layer_ids[k]);
|
||||
if (!layer) {
|
||||
GGML_ABORT("DFlash: target layer %d input not extracted.", target_layer_ids[k]);
|
||||
}
|
||||
for (int32_t i = 0; i < n_chunk; ++i) {
|
||||
float * dst = features_buf.data() + (size_t) i * n_embd_enc + k * (size_t) n_embd_tgt;
|
||||
const float * src = layer + (size_t) (i_batch_beg[seq_id] + offset + i) * n_embd_tgt;
|
||||
std::memcpy(dst, src, (size_t) n_embd_tgt * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
// fuse extracted features through DFlash encoder
|
||||
llama_batch enc_batch = {
|
||||
/*.n_tokens =*/ n_chunk,
|
||||
/*.token =*/ nullptr,
|
||||
/*.embd =*/ features_buf.data(),
|
||||
/*.pos =*/ nullptr,
|
||||
/*.n_seq_id =*/ nullptr,
|
||||
/*.seq_id =*/ nullptr,
|
||||
/*.logits =*/ nullptr,
|
||||
};
|
||||
|
||||
int32_t rc = llama_encode(ctx_dft, enc_batch);
|
||||
if (rc != 0) {
|
||||
LOG_ERR("%s: llama_encode(ctx_dft) failed rc=%d (n_tokens=%d, offset=%d)\n",
|
||||
__func__, rc, (int) n_chunk, (int) offset);
|
||||
return false;
|
||||
}
|
||||
|
||||
const float * inp_g = llama_get_embeddings_nextn(ctx_dft);
|
||||
GGML_ASSERT(inp_g && "DFlash encoder produced no output.");
|
||||
|
||||
// inject the DFlash decoder K/V cache at the tokens' target positions
|
||||
batch_inject.n_tokens = n_chunk;
|
||||
std::memcpy(batch_inject.embd, inp_g, (size_t) n_chunk * n_embd_dec * sizeof(float));
|
||||
|
||||
for (int32_t i = 0; i < n_chunk; ++i) {
|
||||
batch_inject.pos[i] = batch_in.pos[i_batch_beg[seq_id] + offset + i];
|
||||
batch_inject.n_seq_id[i] = 1;
|
||||
batch_inject.seq_id[i][0] = seq_id;
|
||||
batch_inject.logits[i] = false;
|
||||
}
|
||||
rc = llama_decode(ctx_dft, batch_inject);
|
||||
if (rc != 0) {
|
||||
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (n_tokens=%d, offset=%d)\n",
|
||||
__func__, rc, (int) n_chunk, (int) offset);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void draft(common_speculative_draft_params_vec & dparams) override {
|
||||
auto & ctx_dft = params.ctx_dft;
|
||||
|
||||
common_batch_clear(batch);
|
||||
|
||||
// build one batch holding every drafting sequence's noise block into a single decode)
|
||||
// record where each block starts and its size
|
||||
std::vector<int32_t> i_block_beg(n_seq, -1);
|
||||
std::vector<int32_t> n_block (n_seq, 0);
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
auto & dp = dparams[seq_id];
|
||||
if (!dp.drafting) {
|
||||
continue;
|
||||
}
|
||||
|
||||
common_sampler_reset(smpls[seq_id].get());
|
||||
|
||||
const int32_t n = (int32_t) dp.n_past;
|
||||
|
||||
int32_t n_draft = params.n_max;
|
||||
if (dp.n_max > 0) {
|
||||
n_draft = std::min(n_draft, dp.n_max);
|
||||
}
|
||||
|
||||
const int32_t n_block_tokens = n_draft + 1; // id_last + n_draft * <mask>
|
||||
i_block_beg[seq_id] = batch.n_tokens;
|
||||
n_block [seq_id] = n_block_tokens;
|
||||
for (int32_t i = 0; i < n_block_tokens; ++i) {
|
||||
common_batch_add(batch, i == 0 ? dp.id_last : mask_token_id, n + i, { seq_id }, true);
|
||||
}
|
||||
}
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// decode all sequence's noise block in a single batch
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
|
||||
return;
|
||||
}
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
if (i_block_beg[seq_id] < 0) {
|
||||
continue;
|
||||
}
|
||||
auto & dp = dparams[seq_id];
|
||||
|
||||
const int32_t beg = i_block_beg[seq_id];
|
||||
const int32_t n_block_tokens = n_block[seq_id];
|
||||
|
||||
auto * smpl = smpls[seq_id].get();
|
||||
|
||||
auto & result = *dp.result;
|
||||
|
||||
// greedily read the predicted block at this sequence's noise positions 1..n_block_tokens-1
|
||||
for (int32_t i = 1; i < n_block_tokens; ++i) {
|
||||
common_sampler_sample(smpl, ctx_dft, beg + i, true);
|
||||
|
||||
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
||||
|
||||
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
|
||||
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
seq_id, k, i - 1, cur_p->data[k].id, cur_p->data[k].p,
|
||||
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
|
||||
}
|
||||
|
||||
const llama_token id = cur_p->data[0].id;
|
||||
|
||||
common_sampler_accept(smpl, id, true);
|
||||
|
||||
result.push_back(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/, bool /*is_other*/) override {
|
||||
// noop
|
||||
}
|
||||
|
||||
bool need_embd() const override {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
common_params_speculative_draft params; // reuses the draft-model params slot (ctx_tgt/ctx_dft)
|
||||
|
||||
@@ -1841,6 +2132,7 @@ std::string common_speculative_type_to_str(common_speculative_type type) {
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE: return "draft-simple";
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3: return "draft-eagle3";
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT_MTP: return "draft-mtp";
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH: return "draft-dflash";
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram-simple";
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram-map-k";
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram-map-k4v";
|
||||
@@ -1893,6 +2185,7 @@ int32_t common_speculative_n_max(const common_params_speculative * spec) {
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE:
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3:
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT_MTP:
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH:
|
||||
n_max = std::max(n_max, std::max(0, spec->draft.n_max));
|
||||
break;
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE:
|
||||
@@ -1930,6 +2223,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
||||
bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE));
|
||||
bool has_draft_eagle3 = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3)) && params.draft.ctx_dft != nullptr;
|
||||
bool has_draft_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
|
||||
bool has_draft_dflash = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH)) && params.draft.ctx_dft != nullptr;
|
||||
|
||||
|
||||
|
||||
@@ -1940,7 +2234,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
||||
bool has_ngram_mod = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MOD));
|
||||
|
||||
// when adding a new type - update here the logic above
|
||||
static_assert(COMMON_SPECULATIVE_TYPE_COUNT == 9);
|
||||
static_assert(COMMON_SPECULATIVE_TYPE_COUNT == 10);
|
||||
|
||||
// this list here defines the priority of the speculators
|
||||
// the one with highest priority are listed first
|
||||
@@ -1970,6 +2264,9 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
||||
if (has_draft_mtp) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, params));
|
||||
}
|
||||
if (has_draft_dflash) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH, params));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<common_speculative_impl>> impls = {};
|
||||
@@ -1990,6 +2287,10 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
||||
impls.push_back(std::make_unique<common_speculative_impl_draft_mtp>(config.params, n_seq));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH: {
|
||||
impls.push_back(std::make_unique<common_speculative_impl_draft_dflash>(config.params, n_seq));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
|
||||
common_ngram_map ngram_map = get_common_ngram_map(config.type, config.params.ngram_simple);
|
||||
|
||||
|
||||
@@ -50,6 +50,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
||||
"DeepseekV2ForCausalLM": "deepseek",
|
||||
"DeepseekV3ForCausalLM": "deepseek",
|
||||
"DeepseekV32ForCausalLM": "deepseek",
|
||||
"DFlashDraftModel": "qwen",
|
||||
"DistilBertForMaskedLM": "bert",
|
||||
"DistilBertForSequenceClassification": "bert",
|
||||
"DistilBertModel": "bert",
|
||||
|
||||
+3
-3
@@ -73,7 +73,7 @@ class LlamaModel(TextModel):
|
||||
target_num_layers = target_config["num_hidden_layers"]
|
||||
target_layers = [2, target_num_layers // 2, target_num_layers - 3]
|
||||
logger.info(f"EAGLE-3: target_layers = {target_layers} (target model has {target_num_layers} layers)")
|
||||
self.gguf_writer.add_array(f"{self.gguf_writer.arch}.target_layers", target_layers)
|
||||
self.gguf_writer.add_target_layers(target_layers)
|
||||
|
||||
# target_hidden_size: prefer eagle3 config, fallback to target config
|
||||
if eagle3_raw_config.get("target_hidden_size") is not None:
|
||||
@@ -83,12 +83,12 @@ class LlamaModel(TextModel):
|
||||
target_hidden_size = target_config["hidden_size"]
|
||||
src = "target model config"
|
||||
logger.info(f"EAGLE-3: target_hidden_size = {target_hidden_size} (from {src})")
|
||||
self.gguf_writer.add_uint32(f"{self.gguf_writer.arch}.target_hidden_size", target_hidden_size)
|
||||
self.gguf_writer.add_target_hidden_size(target_hidden_size)
|
||||
|
||||
# norm_before_residual (RedHat-style eagle3 specific)
|
||||
norm_before_residual = eagle3_raw_config.get("norm_before_residual", False)
|
||||
logger.info(f"EAGLE-3: norm_before_residual = {norm_before_residual}")
|
||||
self.gguf_writer.add_bool(f"{self.gguf_writer.arch}.norm_before_residual", norm_before_residual)
|
||||
self.gguf_writer.add_norm_before_residual(norm_before_residual)
|
||||
|
||||
def set_vocab(self):
|
||||
# eagle3: use tokenizer from target model if provided
|
||||
|
||||
@@ -625,3 +625,51 @@ class Qwen3_5TextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReor
|
||||
@ModelBase.register("Qwen3_5MoeForConditionalGeneration", "Qwen3_5MoeForCausalLM")
|
||||
class Qwen3_5MoeTextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReorderBase):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN35MOE
|
||||
|
||||
|
||||
@ModelBase.register("DFlashDraftModel")
|
||||
class DFlashModel(Qwen3Model):
|
||||
model_arch = gguf.MODEL_ARCH.DFLASH
|
||||
|
||||
def set_vocab(self):
|
||||
if self.target_model_dir is None:
|
||||
raise ValueError(
|
||||
"DFlash draft model requires --target-model-dir to be specified. "
|
||||
"Please provide the path to the target model directory containing the tokenizer."
|
||||
)
|
||||
logger.info(f"DFlash: Using tokenizer from target model: {self.target_model_dir}")
|
||||
original_dir = self.dir_model
|
||||
self.dir_model = self.target_model_dir
|
||||
super().set_vocab()
|
||||
self.dir_model = original_dir
|
||||
|
||||
mask_token_id = self.hparams.get("dflash_config", {}).get("mask_token_id")
|
||||
if mask_token_id is not None:
|
||||
self.gguf_writer.add_mask_token_id(mask_token_id)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
block_size = self.hparams.get("block_size", 16)
|
||||
self.gguf_writer.add_block_size(block_size)
|
||||
dflash_config = self.hparams.get("dflash_config", {})
|
||||
|
||||
target_layer_ids = dflash_config.get("target_layer_ids", [])
|
||||
if target_layer_ids:
|
||||
extract_layer_ids = [i + 1 for i in target_layer_ids]
|
||||
self.gguf_writer.add_target_layers(extract_layer_ids)
|
||||
|
||||
use_sliding_window = self.hparams.get("use_sliding_window", False)
|
||||
sliding_window = self.hparams.get("sliding_window")
|
||||
layer_types = self.hparams.get("layer_types")
|
||||
if use_sliding_window and sliding_window and layer_types:
|
||||
is_swa = [lt == "sliding_attention" for lt in layer_types]
|
||||
self.gguf_writer.add_sliding_window(sliding_window)
|
||||
self.gguf_writer.add_sliding_window_pattern(is_swa)
|
||||
|
||||
@classmethod
|
||||
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
|
||||
name, gen = item
|
||||
if not name.startswith("model."):
|
||||
name = "model." + name
|
||||
return super().filter_tensors((name, gen))
|
||||
|
||||
+28
-1
@@ -52,6 +52,32 @@ Supported EAGLE-3 draft models include:
|
||||
|
||||
For the full and up-to-date list of supported models, see #18039.
|
||||
|
||||
### DFlash (`draft-dflash`)
|
||||
|
||||
DFlash produces an entire block of draft tokens in a single forward pass (block diffusion) and
|
||||
injects the target model's hidden states into the draft model's attention, instead of drafting one
|
||||
token at a time. This keeps the draft model small while making drafting GPU-friendly. Unlike EAGLE-3
|
||||
(a single-layer autoregressive draft), the DFlash draft uses several transformer layers but emits a
|
||||
whole block per draft step.
|
||||
|
||||
The draft is a small block-diffusion model trained for a specific target (for example
|
||||
`z-lab/Qwen3-4B-DFlash` for `Qwen/Qwen3-4B`). Convert it with `--target-model-dir` so it inherits the
|
||||
target's tokenizer and token embeddings:
|
||||
|
||||
```bash
|
||||
python convert_hf_to_gguf.py z-lab/Qwen3-4B-DFlash \
|
||||
--target-model-dir Qwen/Qwen3-4B --outtype bf16 --outfile Qwen3-4B-DFlash.gguf
|
||||
|
||||
llama-server -m Qwen3-4B.gguf -md Qwen3-4B-DFlash.gguf \
|
||||
--spec-type draft-dflash --spec-draft-n-max 15 -fa on --jinja
|
||||
```
|
||||
|
||||
`--spec-draft-n-max` is clamped to the draft model's trained block size.
|
||||
|
||||
See:
|
||||
|
||||
- #22105
|
||||
|
||||
### n-gram Cache (`ngram-cache`)
|
||||
|
||||
An n-gram is a sequence of n tokens. The n-gram cache implementation maintains statistics about short n-gram sequences.
|
||||
@@ -147,7 +173,7 @@ If a draft model is combined with a draftless decoding the draftless decoding ha
|
||||
### General Speculative Parameters
|
||||
|
||||
```
|
||||
--spec-type [none|draft-simple|draft-eagle3|draft-mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]
|
||||
--spec-type [none|draft-simple|draft-eagle3|draft-dflash|draft-mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]
|
||||
comma-separated list of types of speculative decoding to use
|
||||
(default: none)
|
||||
(env: LLAMA_ARG_SPEC_TYPE)
|
||||
@@ -287,6 +313,7 @@ Specifies a comma-separated list of speculative decoding types to use.
|
||||
| `none` | No speculative decoding (default) |
|
||||
| `draft-simple` | Use a simple draft model for speculation |
|
||||
| `draft-eagle3` | Use an EAGLE-3 draft model that reads the target's hidden states |
|
||||
| `draft-dflash` | Use a DFlash block-diffusion draft model that emits a block per step |
|
||||
| `draft-mtp` | Use Multi Token Prediction (MTP) heads from the main model |
|
||||
| `ngram-cache` | Use n-gram cache lookup |
|
||||
| `ngram-simple` | Use simple n-gram pattern matching |
|
||||
|
||||
@@ -156,6 +156,7 @@ class Keys:
|
||||
DENSE_FEAT_OUT_SIZE = "{arch}.{dense}_feat_out"
|
||||
TARGET_LAYERS = "{arch}.target_layers"
|
||||
TARGET_HIDDEN_SIZE = "{arch}.target_hidden_size"
|
||||
BLOCK_SIZE = "{arch}.block_size"
|
||||
NORM_BEFORE_RESIDUAL = "{arch}.norm_before_residual"
|
||||
|
||||
class Attention:
|
||||
@@ -517,6 +518,7 @@ class MODEL_ARCH(IntEnum):
|
||||
PANGU_EMBED = auto()
|
||||
MISTRAL3 = auto()
|
||||
EAGLE3 = auto()
|
||||
DFLASH = auto()
|
||||
MISTRAL4 = auto()
|
||||
PADDLEOCR = auto()
|
||||
MIMO2 = auto()
|
||||
@@ -1074,6 +1076,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
|
||||
MODEL_ARCH.MISTRAL3: "mistral3",
|
||||
MODEL_ARCH.EAGLE3: "eagle3",
|
||||
MODEL_ARCH.DFLASH: "dflash",
|
||||
MODEL_ARCH.MISTRAL4: "mistral4",
|
||||
MODEL_ARCH.PADDLEOCR: "paddleocr",
|
||||
MODEL_ARCH.MIMO2: "mimo2",
|
||||
@@ -4086,6 +4089,22 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FC,
|
||||
MODEL_TENSOR.D2T,
|
||||
],
|
||||
MODEL_ARCH.DFLASH: [
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FC,
|
||||
MODEL_TENSOR.ENC_OUTPUT_NORM,
|
||||
],
|
||||
MODEL_ARCH.MISTRAL4: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
||||
@@ -940,6 +940,18 @@ class GGUFWriter:
|
||||
def add_sliding_window(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value)
|
||||
|
||||
def add_block_size(self, value: int) -> None:
|
||||
self.add_uint32(Keys.LLM.BLOCK_SIZE.format(arch=self.arch), value)
|
||||
|
||||
def add_target_layers(self, value: Sequence[int]) -> None:
|
||||
self.add_array(Keys.LLM.TARGET_LAYERS.format(arch=self.arch), value)
|
||||
|
||||
def add_target_hidden_size(self, value: int) -> None:
|
||||
self.add_uint32(Keys.LLM.TARGET_HIDDEN_SIZE.format(arch=self.arch), value)
|
||||
|
||||
def add_norm_before_residual(self, value: bool) -> None:
|
||||
self.add_bool(Keys.LLM.NORM_BEFORE_RESIDUAL.format(arch=self.arch), value)
|
||||
|
||||
def add_attention_scale(self, value: float) -> None:
|
||||
self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value)
|
||||
|
||||
|
||||
@@ -1283,6 +1283,11 @@ class TensorNameMap:
|
||||
MODEL_TENSOR.ENC_OUTPUT_NORM: (
|
||||
"encoder.final_layer_norm", # t5
|
||||
"layer_norm", # neobert
|
||||
"model.hidden_norm", # dflash
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FC: (
|
||||
"model.fc", # dflash
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CLS: (
|
||||
|
||||
@@ -0,0 +1,179 @@
|
||||
{{- bos_token }}{%- if tools %}
|
||||
{%- set tool_definitions %}
|
||||
{{- "# Tools\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
||||
{%- for tool in tools %}
|
||||
{{- "\n" }}
|
||||
{{- tool | tojson(ensure_ascii=False) }}
|
||||
{%- endfor %}
|
||||
{{- '\n</tools>\n\nTool usage guidelines:\n- You may call zero or more functions. If no function calls are needed, just answer normally and do not include any <function ... </function>.\n- When calling a function, return an XML object within <function ... </function> using:\n<function name="function-name"><param name="param-name">param-value</param></function>\n- param-value may be multi-line. If it contains <, & or newline characters, wrap it in a CDATA block: <param name="param-name"><![CDATA[...multi-line value...]]></param>' }}
|
||||
{%- endset %}
|
||||
|
||||
{{- '<|im_start|>system\n' }}
|
||||
{%- if messages[0].role == 'system' %}
|
||||
{%- if '<tool_def_sep>' in messages[0].content %}
|
||||
{{- messages[0].content.replace('<tool_def_sep>', tool_definitions) }}
|
||||
{%- else %}
|
||||
{{- messages[0].content + '\n\n' + tool_definitions }}
|
||||
{%- endif %}
|
||||
{%- else %}
|
||||
{{- tool_definitions.lstrip() }}
|
||||
{%- endif %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- else %}
|
||||
{%- if messages[0].role == 'system' %}
|
||||
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
||||
{%- for message in messages[::-1] %}
|
||||
{%- set index = (messages|length - 1) - loop.index0 %}
|
||||
{%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
|
||||
{%- set ns.multi_step_tool = false %}
|
||||
{%- set ns.last_query_index = index %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- for message in messages %}
|
||||
{%- if message.content is string %}
|
||||
{%- set content = message.content %}
|
||||
{%- else %}
|
||||
{%- set content = '' %}
|
||||
{%- endif %}
|
||||
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
|
||||
{%- elif message.role == "assistant" %}
|
||||
{%- set reasoning_content = '' %}
|
||||
{%- if message.reasoning_content is string %}
|
||||
{%- set reasoning_content = message.reasoning_content %}
|
||||
{%- else %}
|
||||
{%- if '</think>' in content %}
|
||||
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
||||
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
|
||||
{%- if message.tool_calls %}
|
||||
{%- set content_parts = content.split('<tool_sep>') %}
|
||||
{%- set processed_content = content_parts[0] %}
|
||||
{%- set tool_calls_count = message.tool_calls|length %}
|
||||
{%- set tool_sep_count = content_parts|length - 1 %}
|
||||
{%- set min_count = [tool_calls_count, tool_sep_count]|min %}
|
||||
|
||||
{%- for i in range(1, content_parts|length) %}
|
||||
{%- set tool_index = i - 1 %}
|
||||
{%- if tool_index < tool_calls_count %}
|
||||
{%- set tool_call = message.tool_calls[tool_index] %}
|
||||
{%- if tool_call.function %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{%- set single_tool_xml %}
|
||||
{{- '<function name="' ~ tool_call.name ~ '">' }}
|
||||
{%- if tool_call.arguments %}
|
||||
{%- set args_dict = tool_call.arguments %}
|
||||
{%- for param_name, param_value in args_dict.items() %}
|
||||
{{- '<param name="' ~ param_name ~ '">' }}
|
||||
{%- if param_value is string and ('<' in param_value or '&' in param_value or '\n' in param_value) %}
|
||||
{{- '<![CDATA[' + param_value + ']]>' }}
|
||||
{%- else %}
|
||||
{{- param_value }}
|
||||
{%- endif %}
|
||||
{{- '</param>' }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- '</function>' }}
|
||||
{%- endset %}
|
||||
{%- set processed_content = processed_content + single_tool_xml + content_parts[i] %}
|
||||
{%- else %}
|
||||
{%- set processed_content = processed_content + content_parts[i] %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
|
||||
{%- if tool_calls_count > tool_sep_count %}
|
||||
{%- for remaining_index in range(tool_sep_count, tool_calls_count) %}
|
||||
{%- set tool_call = message.tool_calls[remaining_index] %}
|
||||
{%- if tool_call.function %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{%- set remaining_tool_xml %}
|
||||
{{- '<function name="' ~ tool_call.name ~ '">' }}
|
||||
{%- if tool_call.arguments %}
|
||||
{%- set args_dict = tool_call.arguments %}
|
||||
{%- for param_name, param_value in args_dict.items() %}
|
||||
{{- '<param name="' ~ param_name ~ '">' }}
|
||||
{%- if param_value is string and ('<' in param_value or '&' in param_value or '\n' in param_value) %}
|
||||
{{- '<![CDATA[' + param_value + ']]>' }}
|
||||
{%- else %}
|
||||
{{- param_value }}
|
||||
{%- endif %}
|
||||
{{- '</param>' }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- '</function>' }}
|
||||
{%- endset %}
|
||||
{%- set processed_content = processed_content + remaining_tool_xml %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
|
||||
{%- set content = processed_content %}
|
||||
{%- endif %}
|
||||
|
||||
{%- if loop.index0 > ns.last_query_index %}
|
||||
{%- if reasoning_content %}
|
||||
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
|
||||
{%- else %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + content }}
|
||||
{%- endif %}
|
||||
{%- else %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + content }}
|
||||
{%- endif %}
|
||||
|
||||
{%- if message.tool_calls and not has_tool_sep %}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{%- if (loop.first and content) or (not loop.first) %}
|
||||
{{- '\n' }}
|
||||
{%- endif %}
|
||||
{%- if tool_call.function %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{{- '<function name="' ~ tool_call.name ~ '">' }}
|
||||
{%- if tool_call.arguments %}
|
||||
{%- set args_dict = tool_call.arguments %}
|
||||
{%- for param_name, param_value in args_dict.items() %}
|
||||
{{- '<param name="' ~ param_name ~ '">' }}
|
||||
{%- if param_value is string and ('<' in param_value or '&' in param_value or '\n' in param_value) %}
|
||||
{{- '<![CDATA[' + param_value + ']]>' }}
|
||||
{%- else %}
|
||||
{{- param_value }}
|
||||
{%- endif %}
|
||||
{{- '</param>' }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- '</function>' }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- elif message.role == "tool" %}
|
||||
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
||||
{{- '<|im_start|>user' }}
|
||||
{%- endif %}
|
||||
{{- '\n<tool_response>\n' }}
|
||||
{%- if message.content is string %}
|
||||
{{- content }}
|
||||
{%- else %}
|
||||
{{- message.content | tojson(ensure_ascii=False) }}
|
||||
{%- endif %}
|
||||
{{- '\n</tool_response>' }}
|
||||
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant\n' }}
|
||||
{%- if enable_thinking is defined %}
|
||||
{%- if enable_thinking is false %}
|
||||
{{- '<think>\n\n</think>\n\n' }}
|
||||
{%- elif enable_thinking is true %}
|
||||
{{- '<think>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
@@ -129,6 +129,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
|
||||
{ LLM_ARCH_MISTRAL3, "mistral3" },
|
||||
{ LLM_ARCH_EAGLE3, "eagle3" },
|
||||
{ LLM_ARCH_DFLASH, "dflash" },
|
||||
{ LLM_ARCH_MISTRAL4, "mistral4" },
|
||||
{ LLM_ARCH_PADDLEOCR, "paddleocr" },
|
||||
{ LLM_ARCH_MIMO2, "mimo2" },
|
||||
|
||||
@@ -143,6 +143,7 @@ enum llm_arch {
|
||||
LLM_ARCH_TALKIE,
|
||||
LLM_ARCH_MELLUM,
|
||||
LLM_ARCH_EAGLE3,
|
||||
LLM_ARCH_DFLASH,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
|
||||
@@ -100,10 +100,10 @@ llama_context::llama_context(
|
||||
cparams.ctx_other = params.ctx_other;
|
||||
}
|
||||
|
||||
if (model.arch == LLM_ARCH_EAGLE3) {
|
||||
if (model.arch == LLM_ARCH_EAGLE3 || model.arch == LLM_ARCH_DFLASH) {
|
||||
if (model.tok_embd == nullptr || model.output == nullptr) {
|
||||
if (params.ctx_other == nullptr) {
|
||||
throw std::runtime_error("EAGLE3 requires ctx_other to be set (this warning is normal during memory fitting)");
|
||||
throw std::runtime_error(model.arch_name() + " requires ctx_other to be set (this warning is normal during memory fitting)");
|
||||
}
|
||||
cparams.ctx_other = params.ctx_other;
|
||||
}
|
||||
|
||||
+6
-1
@@ -486,7 +486,11 @@ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
|
||||
mctx->set_input_k_idxs(self_k_idxs, ubatch);
|
||||
mctx->set_input_v_idxs(self_v_idxs, ubatch);
|
||||
|
||||
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
// the mask is left unallocated when the graph only stores K/V without attending
|
||||
// (e.g. DFlash's KV-injection pass)
|
||||
if (self_kq_mask && self_kq_mask->buffer) {
|
||||
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
if (self_k_rot) {
|
||||
mctx->set_input_k_rot(self_k_rot);
|
||||
@@ -904,6 +908,7 @@ void llm_graph_result::reset() {
|
||||
t_logits = nullptr;
|
||||
t_embd = nullptr;
|
||||
t_embd_pooled = nullptr;
|
||||
t_h_nextn = nullptr;
|
||||
|
||||
t_layer_inp.resize(LLAMA_MAX_LAYERS);
|
||||
std::fill(t_layer_inp.begin(), t_layer_inp.end(), nullptr);
|
||||
|
||||
+5
-1
@@ -291,6 +291,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params
|
||||
return new llama_model_mistral3(params);
|
||||
case LLM_ARCH_EAGLE3:
|
||||
return new llama_model_eagle3(params);
|
||||
case LLM_ARCH_DFLASH:
|
||||
return new llama_model_dflash(params);
|
||||
case LLM_ARCH_MIMO2:
|
||||
return new llama_model_mimo2(params);
|
||||
case LLM_ARCH_KIMI_LINEAR:
|
||||
@@ -2494,6 +2496,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
case LLM_ARCH_STEP35:
|
||||
case LLM_ARCH_TALKIE:
|
||||
case LLM_ARCH_MELLUM:
|
||||
case LLM_ARCH_DFLASH:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
@@ -2617,7 +2620,8 @@ bool llama_model_has_encoder(const llama_model * model) {
|
||||
switch (model->arch) {
|
||||
case LLM_ARCH_T5:
|
||||
case LLM_ARCH_T5ENCODER:
|
||||
case LLM_ARCH_EAGLE3: return true;
|
||||
case LLM_ARCH_EAGLE3:
|
||||
case LLM_ARCH_DFLASH: return true;
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,276 @@
|
||||
#include "models.h"
|
||||
|
||||
#include "llama-kv-cache.h"
|
||||
#include "llama-kv-cache-iswa.h"
|
||||
|
||||
void llama_model_dflash::load_arch_hparams(llama_model_loader & ml) {
|
||||
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
if (!ml.get_arr(LLM_KV_TARGET_LAYERS, target_layer_ids, false)) {
|
||||
throw std::runtime_error("DFlash model requires 'target_layers' in GGUF metadata");
|
||||
}
|
||||
|
||||
hparams.n_embd_inp_enc_impl = (uint32_t) target_layer_ids.size() * hparams.n_embd;
|
||||
|
||||
LLAMA_LOG_INFO("%s: DFlash extract_layers = [", __func__);
|
||||
for (size_t i = 0; i < target_layer_ids.size(); ++i) {
|
||||
LLAMA_LOG_INFO("%d%s", target_layer_ids[i], i + 1 < target_layer_ids.size() ? ", " : "");
|
||||
}
|
||||
LLAMA_LOG_INFO("]\n");
|
||||
|
||||
// optional interleaved sliding-window attention with per-layer pattern array.
|
||||
// DFlash has a single rope, so the SWA rope == main rope.
|
||||
if (ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false) && hparams.n_swa > 0) {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer());
|
||||
hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train;
|
||||
hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
|
||||
}
|
||||
|
||||
type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
|
||||
void llama_model_dflash::load_arch_tensors(llama_model_loader &) {
|
||||
LLAMA_LOAD_LOCALS;
|
||||
|
||||
const int64_t n_embd_inp = hparams.n_embd_inp_enc();
|
||||
|
||||
fc = create_tensor(tn(LLM_TENSOR_FC, "weight"), { n_embd_inp, n_embd }, 0);
|
||||
output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), { n_embd }, 0); // encoder hidden_norm (after fc)
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); // decoder final norm
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
|
||||
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
|
||||
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<llm_graph_context> llama_model_dflash::build_arch_graph(const llm_graph_params & params) const {
|
||||
switch (params.gtype) {
|
||||
case LLM_GRAPH_TYPE_ENCODER:
|
||||
return std::make_unique<graph<true>>(*this, params);
|
||||
case LLM_GRAPH_TYPE_DEFAULT:
|
||||
case LLM_GRAPH_TYPE_DECODER:
|
||||
return std::make_unique<graph<false>>(*this, params);
|
||||
default:
|
||||
GGML_ABORT("invalid graph type");
|
||||
};
|
||||
}
|
||||
|
||||
template <>
|
||||
ggml_tensor * llama_model_dflash::graph<true>::build_inp_embd_enc() const {
|
||||
auto inp_target = std::make_unique<llm_graph_input_embd>(hparams.n_embd_inp_enc());
|
||||
|
||||
inp_target->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd_inp_enc(), n_tokens);
|
||||
ggml_set_input(inp_target->embd);
|
||||
|
||||
ggml_tensor * cur = inp_target->embd;
|
||||
cb(cur, "inp_embd", -1);
|
||||
|
||||
res->add_input(std::move(inp_target));
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
// DFlash Encoder: processes target model features through feature fusion layer
|
||||
template <>
|
||||
llama_model_dflash::graph<true>::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
ggml_tensor * cur = build_inp_embd_enc();
|
||||
|
||||
cur = build_lora_mm(model.fc, cur);
|
||||
cb(cur, "fc_out", -1);
|
||||
|
||||
cur = build_norm(cur, model.output_norm_enc, NULL, LLM_NORM_RMS, -1);
|
||||
cb(cur, "enc_norm_out", -1);
|
||||
|
||||
ggml_set_output(cur);
|
||||
res->t_h_nextn = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
// DFlash decoder, dual-mode by batch type:
|
||||
// * embd batch -> fused target features: project + inject K/V into the cache.
|
||||
// * token batch -> noise-block diffusion: attend over [committed, MASK...] to generate draft tokens
|
||||
template <>
|
||||
llama_model_dflash::graph<false>::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// optional iSWA: pick the matching attention input
|
||||
const bool use_iswa = hparams.swa_type != LLAMA_SWA_TYPE_NONE;
|
||||
|
||||
llm_graph_input_attn_kv * inp_attn = nullptr;
|
||||
llm_graph_input_attn_kv_iswa * inp_attn_iswa = nullptr;
|
||||
if (use_iswa) {
|
||||
inp_attn_iswa = build_attn_inp_kv_iswa();
|
||||
} else {
|
||||
inp_attn = build_attn_inp_kv();
|
||||
}
|
||||
|
||||
const float kq_scale = 1.0f/sqrtf(float(n_embd_head));
|
||||
|
||||
// KV cache injection
|
||||
if (ubatch.embd) {
|
||||
auto inp = std::make_unique<llm_graph_input_embd>(n_embd);
|
||||
|
||||
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
|
||||
ggml_set_input(inp->embd);
|
||||
|
||||
ggml_tensor * inp_g = inp->embd;
|
||||
cb(inp_g, "inp_g_embeddings", -1);
|
||||
|
||||
res->add_input(std::move(inp));
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const auto & layer = model.layers[il];
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(layer.wk, inp_g);
|
||||
ggml_tensor * Vcur = build_lora_mm(layer.wv, inp_g);
|
||||
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Kcur = build_norm(Kcur, layer.attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Kcur, "Kcur_injected", il);
|
||||
cb(Vcur, "Vcur_injected", il);
|
||||
|
||||
if (use_iswa) {
|
||||
// route each layer's K/V to its sub-cache: SWA layers -> sliding cache, full -> dense
|
||||
const bool is_swa = hparams.is_swa(il);
|
||||
const auto * kv = is_swa ? inp_attn_iswa->mctx->get_swa() : inp_attn_iswa->mctx->get_base();
|
||||
ggml_tensor * k_idxs = is_swa ? inp_attn_iswa->get_k_idxs_swa() : inp_attn_iswa->get_k_idxs();
|
||||
ggml_tensor * v_idxs = is_swa ? inp_attn_iswa->get_v_idxs_swa() : inp_attn_iswa->get_v_idxs();
|
||||
ggml_build_forward_expand(gf, kv->cpy_k(ctx0, Kcur, k_idxs, il));
|
||||
ggml_build_forward_expand(gf, kv->cpy_v(ctx0, Vcur, v_idxs, il));
|
||||
} else {
|
||||
ggml_build_forward_expand(gf, inp_attn->mctx->cpy_k(ctx0, Kcur, inp_attn->get_k_idxs(), il));
|
||||
ggml_build_forward_expand(gf, inp_attn->mctx->cpy_v(ctx0, Vcur, inp_attn->get_v_idxs(), il));
|
||||
}
|
||||
}
|
||||
|
||||
res->t_embd = inp_g;
|
||||
|
||||
ggml_build_forward_expand(gf, inp_g);
|
||||
return;
|
||||
}
|
||||
|
||||
// tok_embd from the target model (shared via ctx_other)
|
||||
auto * tok_embd = model.tok_embd;
|
||||
if (tok_embd == nullptr) {
|
||||
GGML_ASSERT(cparams.ctx_other != nullptr);
|
||||
const auto * model_other = llama_get_model(cparams.ctx_other);
|
||||
|
||||
GGML_ASSERT(model_other->tok_embd != nullptr && "DFlash decoder requires the target model's token embeddings");
|
||||
tok_embd = model_other->tok_embd;
|
||||
}
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_embd>(n_embd);
|
||||
|
||||
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||
ggml_set_input(inp->tokens);
|
||||
|
||||
ggml_tensor * inpL = ggml_get_rows(ctx0, tok_embd, inp->tokens);
|
||||
cb(inpL, "inp_noise_embd", -1);
|
||||
|
||||
res->add_input(std::move(inp));
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const auto & layer = model.layers[il];
|
||||
|
||||
ggml_tensor * noise_norm = build_norm(inpL, layer.attn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(noise_norm, "noise_norm", il);
|
||||
|
||||
ggml_tensor * Qcur = build_lora_mm(layer.wq, noise_norm);
|
||||
ggml_tensor * Kcur = build_lora_mm(layer.wk, noise_norm);
|
||||
ggml_tensor * Vcur = build_lora_mm(layer.wv, noise_norm);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = build_norm(Qcur, layer.attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||
Kcur = build_norm(Kcur, layer.attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
// cache-aware, non-causal attention
|
||||
ggml_tensor * cur = use_iswa
|
||||
? build_attn(inp_attn_iswa, layer.wo, NULL, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il)
|
||||
: build_attn(inp_attn, layer.wo, NULL, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
cur = build_norm(ffn_inp, layer.ffn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = build_ffn(cur,
|
||||
layer.ffn_up, NULL, NULL,
|
||||
layer.ffn_gate, NULL, NULL,
|
||||
layer.ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
ggml_tensor * cur = build_norm(inpL, model.output_norm, NULL, LLM_NORM_RMS, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head from the target model (shared via ctx_other)
|
||||
auto * output = model.output;
|
||||
if (output == nullptr) {
|
||||
GGML_ASSERT(cparams.ctx_other != nullptr);
|
||||
const auto * model_other = llama_get_model(cparams.ctx_other);
|
||||
GGML_ASSERT(model_other->output != nullptr && "DFlash decoder requires the target model's output projection");
|
||||
output = model_other->output;
|
||||
}
|
||||
|
||||
cur = build_lora_mm(output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
@@ -1122,6 +1122,22 @@ struct llama_model_eagle3 : public llama_model_base {
|
||||
};
|
||||
|
||||
|
||||
struct llama_model_dflash : public llama_model_base {
|
||||
llama_model_dflash(const struct llama_model_params & params) : llama_model_base(params) {}
|
||||
void load_arch_hparams(llama_model_loader & ml) override;
|
||||
void load_arch_tensors(llama_model_loader & ml) override;
|
||||
|
||||
template <bool is_enc>
|
||||
struct graph : public llm_graph_context {
|
||||
graph(const llama_model & model, const llm_graph_params & params);
|
||||
|
||||
ggml_tensor * build_inp_embd_enc() const;
|
||||
};
|
||||
|
||||
std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override;
|
||||
};
|
||||
|
||||
|
||||
struct llama_model_mistral4 : public llama_model_deepseek2 {
|
||||
llama_model_mistral4(const struct llama_model_params & params) : llama_model_deepseek2(params) {}
|
||||
// reuse load_arch_hparams and load_arch_tensors from llama_model_deepseek2
|
||||
|
||||
@@ -25,7 +25,7 @@ using json = nlohmann::ordered_json;
|
||||
static int main_automated_tests(void);
|
||||
|
||||
static void run_multiple(const std::string& dir_path, bool stop_on_first_failure, const json& input, bool use_common = false);
|
||||
static void run_single(const std::string& contents, json input, bool use_common = false, const std::string & output_path = "");
|
||||
static void run_single(const std::string& contents, json input, bool use_common = false, bool dump_prog = false, const std::string & output_path = "");
|
||||
|
||||
static std::string HELP = R"(
|
||||
Usage: test-chat-template [OPTIONS] PATH_TO_TEMPLATE
|
||||
@@ -35,6 +35,7 @@ Options:
|
||||
--json <path> Path to the JSON input file.
|
||||
--stop-on-first-fail Stop testing on the first failure (default: false).
|
||||
--no-common Use direct Jinja engine instead of common chat templates (default: use common).
|
||||
--dump-prog Dump the parsed program for debugging (only for single template runs).
|
||||
--output <path> Path to output results (only for single template runs).
|
||||
If PATH_TO_TEMPLATE is a file, runs that single template.
|
||||
If PATH_TO_TEMPLATE is a directory, runs all .jinja files in that directory.
|
||||
@@ -118,6 +119,7 @@ int main(int argc, char ** argv) {
|
||||
std::string & json_to_use = DEFAULT_JSON;
|
||||
bool stop_on_first_fail = false;
|
||||
bool use_common = true;
|
||||
bool dump_prog = false;
|
||||
|
||||
for (size_t i = 1; i < args.size(); i++) {
|
||||
if (args[i] == "--help" || args[i] == "-h") {
|
||||
@@ -136,6 +138,8 @@ int main(int argc, char ** argv) {
|
||||
i++;
|
||||
} else if (args[i] == "--no-common") {
|
||||
use_common = false;
|
||||
} else if (args[i] == "--dump-prog") {
|
||||
dump_prog = true;
|
||||
} else if (tmpl_path.empty()) {
|
||||
tmpl_path = args[i];
|
||||
} else {
|
||||
@@ -172,7 +176,7 @@ int main(int argc, char ** argv) {
|
||||
std::string contents = std::string(
|
||||
std::istreambuf_iterator<char>(infile),
|
||||
std::istreambuf_iterator<char>());
|
||||
run_single(contents, input_json, use_common, output_path);
|
||||
run_single(contents, input_json, use_common, dump_prog, output_path);
|
||||
} else {
|
||||
std::cerr << "Error: PATH_TO_TEMPLATE is not a valid file or directory: " << tmpl_path << "\n";
|
||||
return 1;
|
||||
@@ -276,11 +280,21 @@ static jinja::value_string format_using_direct_engine(
|
||||
}
|
||||
|
||||
|
||||
void run_single(const std::string& contents, json input, bool use_common, const std::string & output_path) {
|
||||
void run_single(const std::string& contents, json input, bool use_common, bool dump_prog, const std::string & output_path) {
|
||||
jinja::enable_debug(true);
|
||||
|
||||
jinja::value_string output_parts;
|
||||
|
||||
if (dump_prog) {
|
||||
jinja::lexer lexer;
|
||||
auto lexer_res = lexer.tokenize(contents);
|
||||
jinja::program ast = jinja::parse_from_tokens(lexer_res);
|
||||
std::string prog_dump = jinja::runtime::debug_dump_program(ast, contents);
|
||||
std::cout << "\n=== DUMPED PROGRAM ===\n";
|
||||
std::cout << prog_dump << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
if (use_common) {
|
||||
std::string bos_token = "<s>";
|
||||
std::string eos_token = "</s>";
|
||||
|
||||
@@ -5593,6 +5593,77 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.expect_content("Hello, world!\nWhat's up?")
|
||||
.run();
|
||||
}
|
||||
|
||||
// MiniCPM5 - XML tool calls with <function name="..."><param name="...">...</param></function>
|
||||
{
|
||||
auto tst = peg_tester("models/templates/openbmb-MiniCPM5-1B.jinja", detailed_debug);
|
||||
|
||||
tst.test("Hello, world!\nWhat's up?")
|
||||
.enable_thinking(false)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.expect(message_assist)
|
||||
.run();
|
||||
|
||||
tst.test(R"(<function name="python"><param name="code">print('Hello, World!')</param></function>)")
|
||||
.enable_thinking(false)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.tools({ python_tool })
|
||||
.expect_tool_calls({ { "python", R"#({"code": "print('Hello, World!')"})#", {} } })
|
||||
.run();
|
||||
|
||||
tst.test(R"(<function name="empty_args"></function>)")
|
||||
.enable_thinking(false)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.tools({ empty_args_tool })
|
||||
.expect(simple_assist_msg("", "", "empty_args", "{}"))
|
||||
.run();
|
||||
|
||||
tst.test(R"(<function name="python"><param name="code">print('x')</param></function>)")
|
||||
.enable_thinking(false)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.parallel_tool_calls(true)
|
||||
.tools({ python_tool })
|
||||
.expect_tool_calls({ { "python", R"#({"code": "print('x')"})#", {} } })
|
||||
.run();
|
||||
|
||||
// CDATA lets a string value carry characters that would otherwise close the tag.
|
||||
tst.test(R"(<function name="html"><param name="markup"><![CDATA[<a href="/x">hi</a> </param>]]></param></function>)")
|
||||
.enable_thinking(false)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.tools({ html_tool })
|
||||
.expect_tool_calls({ { "html", R"#({"markup": "<a href=\"/x\">hi</a> </param>"})#", {} } })
|
||||
.run();
|
||||
|
||||
tst.test(R"(I'm thinking</think><function name="python"><param name="code">print('hey')</param></function>)")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.tools({ python_tool })
|
||||
.expect_reasoning("I'm thinking")
|
||||
.expect_tool_calls({ { "python", R"#({"code": "print('hey')"})#", {} } })
|
||||
.run();
|
||||
|
||||
tst.test(R"(<function name="python"><param name="code">print('x')</param></function>
|
||||
<function name="python"><param name="code">print('y')</param></function>)")
|
||||
.enable_thinking(false)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.parallel_tool_calls(true)
|
||||
.tools({ python_tool })
|
||||
.expect_tool_calls({
|
||||
{ "python", R"#({"code": "print('x')"})#", {} },
|
||||
{ "python", R"#({"code": "print('y')"})#", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
tst.test(" thinking</think>Hello, world!\nWhat's up?")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.messages({ message_user, message_assist_prefill_reasoning })
|
||||
.add_generation_prompt(false)
|
||||
.continue_final_message(COMMON_CHAT_CONTINUATION_REASONING)
|
||||
.expect_reasoning("I'm thinking")
|
||||
.expect_content("Hello, world!\nWhat's up?")
|
||||
.run();
|
||||
}
|
||||
}
|
||||
|
||||
static void test_template_generation_prompt() {
|
||||
@@ -5740,6 +5811,13 @@ static void test_template_generation_prompt() {
|
||||
check(tmpls, continuation_content(), "<|Assistant|><think>I'm thinking</think>Hello, ");
|
||||
check(tmpls, continuation_reasoning(), "<|Assistant|><think>I'm");
|
||||
}
|
||||
|
||||
{
|
||||
auto tmpls = read_templates("models/templates/openbmb-MiniCPM5-1B.jinja");
|
||||
check(tmpls, basic(), "<|im_start|>assistant\n<think>\n");
|
||||
check(tmpls, continuation_content(), "<|im_start|>assistant\n<think>\nI'm thinking\n</think>\n\nHello, ");
|
||||
check(tmpls, continuation_reasoning(), "<|im_start|>assistant\n<think>\nI'm");
|
||||
}
|
||||
}
|
||||
|
||||
// Test the developer role to system workaround with a simple mock template
|
||||
|
||||
@@ -1584,6 +1584,36 @@ static void test_array_methods(testing & t) {
|
||||
"6"
|
||||
);
|
||||
|
||||
test_template(t, "array|min",
|
||||
"{{ [tool_calls_count, tool_sep_count]|min }}",
|
||||
{{"tool_calls_count", 2}, {"tool_sep_count", 1}},
|
||||
"1"
|
||||
);
|
||||
|
||||
test_template(t, "array|max",
|
||||
"{{ [tool_calls_count, tool_sep_count]|max }}",
|
||||
{{"tool_calls_count", 2}, {"tool_sep_count", 1}},
|
||||
"2"
|
||||
);
|
||||
|
||||
test_template(t, "array|min attribute",
|
||||
"{{ items|min(attribute='x') }}",
|
||||
{{"items", json::array({
|
||||
json({{"x", 2}}),
|
||||
json({{"x", 1}}),
|
||||
})}},
|
||||
"{'x': 1}"
|
||||
);
|
||||
|
||||
test_template(t, "array|max attribute",
|
||||
"{{ items|max(attribute='x') }}",
|
||||
{{"items", json::array({
|
||||
json({{"x", 2}}),
|
||||
json({{"x", 1}}),
|
||||
})}},
|
||||
"{'x': 2}"
|
||||
);
|
||||
|
||||
// not used by any chat templates
|
||||
// test_template(t, "array.insert()",
|
||||
// "{% set _ = arr.insert(1, 'x') %}{{ arr|join(',') }}",
|
||||
|
||||
@@ -451,7 +451,7 @@ static int save_models(const llm_arch target_arch, const size_t seed, const ggml
|
||||
if (arch == LLM_ARCH_GEMMA4 || arch == LLM_ARCH_GEMMA4_ASSISTANT) {
|
||||
continue; // FIXME: ISWA KV cache initialization needs more fixture params
|
||||
}
|
||||
if (arch == LLM_ARCH_EAGLE3) {
|
||||
if (arch == LLM_ARCH_EAGLE3 || arch == LLM_ARCH_DFLASH) {
|
||||
continue;
|
||||
}
|
||||
for (bool moe : {false, true}) {
|
||||
@@ -557,7 +557,7 @@ static int test_backends(const llm_arch target_arch, const size_t seed, const gg
|
||||
if (arch == LLM_ARCH_GEMMA4 || arch == LLM_ARCH_GEMMA4_ASSISTANT) {
|
||||
continue; // FIXME: ISWA KV cache initialization needs more fixture params
|
||||
}
|
||||
if (arch == LLM_ARCH_EAGLE3) {
|
||||
if (arch == LLM_ARCH_EAGLE3 || arch == LLM_ARCH_DFLASH) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
@@ -1538,6 +1538,19 @@ private:
|
||||
/* media_path */ params_base.media_path,
|
||||
/* force_pure_content */ params_base.force_pure_content_parser
|
||||
};
|
||||
|
||||
{
|
||||
auto caps = common_chat_templates_get_caps(chat_params.tmpls.get());
|
||||
auto it = params_base.default_template_kwargs.find("preserve_reasoning");
|
||||
bool supported = caps.at("supports_preserve_reasoning");
|
||||
bool enabled = it != params_base.default_template_kwargs.end();
|
||||
if (supported && !enabled) {
|
||||
SRV_INF("%s", "chat template supports preserving reasoning, consider enabling it via --reasoning-preserve\n");
|
||||
}
|
||||
if (!supported && enabled) {
|
||||
SRV_WRN("%s", "chat template does NOT support preserving reasoning, --reasoning-preserve has no effect\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
@@ -2450,6 +2463,8 @@ private:
|
||||
|
||||
server_slot * slot = get_slot_by_cmpl_id(task.params.control_cmpl_id);
|
||||
if (slot == nullptr) {
|
||||
SRV_WRN("control %s on unknown completion id=%s, no live slot\n",
|
||||
task.params.control_action.c_str(), task.params.control_cmpl_id.c_str());
|
||||
res->success = false;
|
||||
res->message = "no active completion for this id";
|
||||
queue_results.send(std::move(res));
|
||||
|
||||
@@ -1983,7 +1983,10 @@ void server_models_routes::init_routes() {
|
||||
cli.set_read_timeout(0, STREAM_LOOKUP_TIMEOUT_MS * 1000);
|
||||
cli.set_write_timeout(0, STREAM_LOOKUP_TIMEOUT_MS * 1000);
|
||||
auto resp = cli.Delete(child_path.c_str());
|
||||
(void) resp; // best effort, 404 and network errors are equivalent to no op
|
||||
(void) resp; // the child logs its own miss when the session is unknown there
|
||||
} else {
|
||||
SRV_WRN("router stop for unknown conv_id=%s, no owning child in the conv map\n",
|
||||
conv_id.c_str());
|
||||
}
|
||||
// drop the tracking entry, the session is being torn down
|
||||
models.conv_models.forget(conv_id);
|
||||
|
||||
@@ -218,6 +218,13 @@ void stream_session_manager::evict_and_cancel(const std::string & conversation_i
|
||||
std::unique_lock<std::shared_mutex> lock(map_mu);
|
||||
auto it = sessions.find(conversation_id);
|
||||
if (it == sessions.end()) {
|
||||
std::string live;
|
||||
for (const auto & kv : sessions) {
|
||||
if (!live.empty()) live += ", ";
|
||||
live += kv.first;
|
||||
}
|
||||
SRV_WRN("stop on unknown stream session, conv_id=%s matched nothing, %zu live: [%s]\n",
|
||||
conversation_id.c_str(), sessions.size(), live.c_str());
|
||||
return;
|
||||
}
|
||||
s = it->second;
|
||||
|
||||
+1
-1
@@ -33,7 +33,7 @@
|
||||
|
||||
{#if !readonly && onRemove}
|
||||
<div
|
||||
class="absolute top-10 right-2 flex items-center justify-center opacity-0 transition-opacity group-focus-within:opacity-100 group-hover:opacity-100"
|
||||
class="absolute top-10 right-2 flex items-center justify-center opacity-0 transition-opacity group-hover:opacity-100"
|
||||
>
|
||||
<ActionIcon icon={X} tooltip="Remove" stopPropagationOnClick onclick={() => onRemove?.()} />
|
||||
</div>
|
||||
|
||||
+1
-1
@@ -56,7 +56,7 @@
|
||||
<div class="relative flex h-6 items-center justify-between">
|
||||
<div class="right-0 flex items-center gap-2 opacity-100 transition-opacity">
|
||||
<div
|
||||
class="pointer-events-auto inset-0 flex items-center gap-1 opacity-0 transition-all duration-150 group-focus-within:opacity-100 group-hover:opacity-100"
|
||||
class="pointer-events-auto inset-0 flex items-center gap-1 opacity-0 transition-all duration-150 group-hover:opacity-100"
|
||||
>
|
||||
<ActionIcon icon={Edit} tooltip="Edit" onclick={editCtx.handleEdit} />
|
||||
<ActionIcon icon={Trash2} tooltip="Delete" onclick={onDelete} />
|
||||
|
||||
+81
-56
@@ -39,6 +39,7 @@
|
||||
depth = 0
|
||||
}: Props = $props();
|
||||
|
||||
let renderActionsDropdown = $state(false);
|
||||
let dropdownOpen = $state(false);
|
||||
|
||||
let isLoading = $derived(getAllLoadingChats().includes(conversation.id));
|
||||
@@ -70,10 +71,26 @@
|
||||
}
|
||||
}
|
||||
|
||||
function handleMouseLeave() {
|
||||
if (!dropdownOpen) {
|
||||
renderActionsDropdown = false;
|
||||
}
|
||||
}
|
||||
|
||||
function handleMouseOver() {
|
||||
renderActionsDropdown = true;
|
||||
}
|
||||
|
||||
function handleSelect() {
|
||||
onSelect?.(conversation.id);
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (!dropdownOpen) {
|
||||
renderActionsDropdown = false;
|
||||
}
|
||||
});
|
||||
|
||||
onMount(() => {
|
||||
document.addEventListener('edit-active-conversation', handleGlobalEditEvent as EventListener);
|
||||
|
||||
@@ -86,19 +103,23 @@
|
||||
});
|
||||
</script>
|
||||
|
||||
<div
|
||||
class="conversation-item group relative flex min-h-9 w-full items-center justify-between space-x-3 rounded-lg py-1.5 transition-colors hover:bg-foreground/10 {isActive
|
||||
<!-- svelte-ignore a11y_mouse_events_have_key_events -->
|
||||
<button
|
||||
class="group flex min-h-9 w-full cursor-pointer items-center justify-between space-x-3 rounded-lg py-1.5 text-left transition-colors hover:bg-foreground/10 {isActive
|
||||
? 'bg-foreground/5 text-accent-foreground'
|
||||
: ''} px-3"
|
||||
onclick={handleSelect}
|
||||
onmouseover={handleMouseOver}
|
||||
onmouseleave={handleMouseLeave}
|
||||
onfocusin={handleMouseOver}
|
||||
onfocusout={(e) => {
|
||||
if (!e.currentTarget.contains(e.relatedTarget as Node | null)) {
|
||||
handleMouseLeave();
|
||||
}
|
||||
}}
|
||||
>
|
||||
<button
|
||||
class="absolute inset-0 z-0 cursor-pointer rounded-lg focus:outline-none focus-visible:ring-2 focus-visible:ring-ring"
|
||||
onclick={handleSelect}
|
||||
aria-label={conversation.name}
|
||||
>
|
||||
</button>
|
||||
<div
|
||||
class="pointer-events-none relative z-10 flex min-w-0 flex-1 items-center gap-2"
|
||||
class="flex min-w-0 flex-1 items-center gap-2"
|
||||
style:padding-left="{depth * FORK_TREE_DEPTH_PADDING}px"
|
||||
>
|
||||
{#if depth > 0}
|
||||
@@ -109,7 +130,7 @@
|
||||
<a
|
||||
{...props}
|
||||
href={RouterService.chat(conversation.forkedFromConversationId)}
|
||||
class="pointer-events-auto flex shrink-0 items-center text-muted-foreground transition-colors hover:text-foreground"
|
||||
class="flex shrink-0 items-center text-muted-foreground transition-colors hover:text-foreground"
|
||||
>
|
||||
<GitBranch class="h-3.5 w-3.5" />
|
||||
</a>
|
||||
@@ -125,15 +146,18 @@
|
||||
{#if isLoading}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<button
|
||||
class="stop-button pointer-events-auto flex h-4 w-4 shrink-0 cursor-pointer items-center justify-center rounded text-muted-foreground transition-colors hover:text-foreground"
|
||||
<div
|
||||
class="stop-button flex h-4 w-4 shrink-0 cursor-pointer items-center justify-center rounded text-muted-foreground transition-colors hover:text-foreground"
|
||||
onclick={handleStop}
|
||||
onkeydown={(e) => e.key === 'Enter' && handleStop(e)}
|
||||
role="button"
|
||||
tabindex="0"
|
||||
aria-label="Stop generation"
|
||||
>
|
||||
<Loader2 class="loading-icon h-3.5 w-3.5 animate-spin" />
|
||||
|
||||
<Square class="stop-icon hidden h-3 w-3 fill-current text-destructive" />
|
||||
</button>
|
||||
</div>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content>
|
||||
@@ -145,50 +169,52 @@
|
||||
<TruncatedText text={conversation.name} class="text-sm font-medium" showTooltip={false} />
|
||||
</div>
|
||||
|
||||
<div class="actions pointer-events-auto relative z-20 flex items-center">
|
||||
<DropdownMenuActions
|
||||
triggerIcon={MoreHorizontal}
|
||||
triggerTooltip="More actions"
|
||||
bind:open={dropdownOpen}
|
||||
actions={[
|
||||
{
|
||||
icon: conversation.pinned ? PinOff : Pin,
|
||||
label: conversation.pinned ? 'Unpin' : 'Pin',
|
||||
onclick: (e: Event) => {
|
||||
e.stopPropagation();
|
||||
handleTogglePin();
|
||||
}
|
||||
},
|
||||
{
|
||||
icon: Pencil,
|
||||
label: 'Edit',
|
||||
onclick: handleEdit,
|
||||
shortcut: ['shift', 'cmd', 'e']
|
||||
},
|
||||
{
|
||||
icon: Download,
|
||||
label: 'Export',
|
||||
onclick: (e: Event) => {
|
||||
e.stopPropagation();
|
||||
conversationsStore.downloadConversation(conversation.id);
|
||||
{#if renderActionsDropdown}
|
||||
<div class="actions flex items-center">
|
||||
<DropdownMenuActions
|
||||
triggerIcon={MoreHorizontal}
|
||||
triggerTooltip="More actions"
|
||||
bind:open={dropdownOpen}
|
||||
actions={[
|
||||
{
|
||||
icon: conversation.pinned ? PinOff : Pin,
|
||||
label: conversation.pinned ? 'Unpin' : 'Pin',
|
||||
onclick: (e: Event) => {
|
||||
e.stopPropagation();
|
||||
handleTogglePin();
|
||||
}
|
||||
},
|
||||
shortcut: ['shift', 'cmd', 's']
|
||||
},
|
||||
{
|
||||
icon: Trash2,
|
||||
label: 'Delete',
|
||||
onclick: handleDelete,
|
||||
variant: 'destructive',
|
||||
shortcut: ['shift', 'cmd', 'd'],
|
||||
separator: true
|
||||
}
|
||||
]}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
{
|
||||
icon: Pencil,
|
||||
label: 'Edit',
|
||||
onclick: handleEdit,
|
||||
shortcut: ['shift', 'cmd', 'e']
|
||||
},
|
||||
{
|
||||
icon: Download,
|
||||
label: 'Export',
|
||||
onclick: (e: Event) => {
|
||||
e.stopPropagation();
|
||||
conversationsStore.downloadConversation(conversation.id);
|
||||
},
|
||||
shortcut: ['shift', 'cmd', 's']
|
||||
},
|
||||
{
|
||||
icon: Trash2,
|
||||
label: 'Delete',
|
||||
onclick: handleDelete,
|
||||
variant: 'destructive',
|
||||
shortcut: ['shift', 'cmd', 'd'],
|
||||
separator: true
|
||||
}
|
||||
]}
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
</button>
|
||||
|
||||
<style>
|
||||
.conversation-item {
|
||||
button {
|
||||
:global([data-slot='dropdown-menu-trigger']:not([data-state='open'])) {
|
||||
opacity: 0;
|
||||
}
|
||||
@@ -213,8 +239,7 @@
|
||||
}
|
||||
}
|
||||
|
||||
&:is(:hover) .stop-button,
|
||||
&:focus-within .stop-button {
|
||||
&:is(:hover) .stop-button {
|
||||
:global(.stop-icon) {
|
||||
display: block;
|
||||
}
|
||||
|
||||
@@ -154,7 +154,13 @@ class ChatStore {
|
||||
});
|
||||
if (convId === conversationsStore.activeConversation?.id) this.currentResponse = response;
|
||||
}
|
||||
private clearChatStreaming(convId: string): void {
|
||||
private clearChatStreaming(convId: string, messageId?: string): void {
|
||||
// session aware: a stale generation must not wipe a newer one's streaming state on the
|
||||
// same conversation, that would drop the frozen stop identity and stop the wrong session
|
||||
if (messageId !== undefined) {
|
||||
const cur = this.chatStreamingStates.get(convId);
|
||||
if (cur && cur.messageId !== messageId) return;
|
||||
}
|
||||
this.chatStreamingStates.delete(convId);
|
||||
if (convId === conversationsStore.activeConversation?.id) this.currentResponse = '';
|
||||
}
|
||||
@@ -1055,11 +1061,14 @@ class ChatStore {
|
||||
modelOverride?: string | null,
|
||||
firstUserMessageContent?: string
|
||||
): Promise<void> {
|
||||
let effectiveModel = modelOverride;
|
||||
// the ::model suffix in the stream identity is only for router mode, where it routes to the
|
||||
// owning child. in single-model mode the identity stays the bare conv id so that attach, stop
|
||||
// and reattach all agree, regardless of fresh send vs regenerate passing a resolved model
|
||||
let effectiveModel: string | null | undefined = undefined;
|
||||
|
||||
if (isRouterMode() && !effectiveModel) {
|
||||
if (isRouterMode()) {
|
||||
const conversationModel = this.getConversationModel(allMessages);
|
||||
effectiveModel = selectedModelName() || conversationModel;
|
||||
effectiveModel = modelOverride || selectedModelName() || conversationModel;
|
||||
}
|
||||
|
||||
if (isRouterMode() && effectiveModel) {
|
||||
@@ -1074,6 +1083,9 @@ class ChatStore {
|
||||
let resolvedModel: string | null = null;
|
||||
let modelPersisted = false;
|
||||
const convId = assistantMessage.convId;
|
||||
// freeze the POST identity from t0 so a stop cancels with the exact session key,
|
||||
// never a stale or empty model resolved later
|
||||
this.setChatStreaming(convId, streamedContent, currentMessageId, effectiveModel);
|
||||
|
||||
const recordModel = (modelName: string | null | undefined, persistImmediately = true): void => {
|
||||
if (!modelName) return;
|
||||
@@ -1103,7 +1115,7 @@ class ChatStore {
|
||||
};
|
||||
|
||||
const updateStreamingUI = () => {
|
||||
this.setChatStreaming(convId, streamedContent, currentMessageId);
|
||||
this.setChatStreaming(convId, streamedContent, currentMessageId, effectiveModel);
|
||||
const idx = conversationsStore.findMessageIndex(currentMessageId);
|
||||
conversationsStore.updateMessageAtIndex(idx, { content: streamedContent });
|
||||
};
|
||||
@@ -1111,7 +1123,7 @@ class ChatStore {
|
||||
const cleanupStreamingState = () => {
|
||||
this.setStreamingActive(false);
|
||||
this.setChatLoading(convId, false);
|
||||
this.clearChatStreaming(convId);
|
||||
this.clearChatStreaming(convId, currentMessageId);
|
||||
this.setProcessingState(convId, null);
|
||||
};
|
||||
|
||||
@@ -1128,7 +1140,7 @@ class ChatStore {
|
||||
onReasoningChunk: (chunk: string) => {
|
||||
streamedReasoningContent += chunk;
|
||||
// mark streaming state so a stop mid-thinking can persist the partial reasoning
|
||||
this.setChatStreaming(convId, streamedContent, currentMessageId);
|
||||
this.setChatStreaming(convId, streamedContent, currentMessageId, effectiveModel);
|
||||
const idx = conversationsStore.findMessageIndex(currentMessageId);
|
||||
conversationsStore.updateMessageAtIndex(idx, {
|
||||
reasoningContent: streamedReasoningContent
|
||||
@@ -1405,7 +1417,7 @@ class ChatStore {
|
||||
// detached drain keeps producing tokens until eos or max_tokens. use the frozen identity
|
||||
// captured when the session started, not the live dropdown
|
||||
const streamStateForStop = this.chatStreamingStates.get(convId);
|
||||
const modelForStop = streamStateForStop?.model ?? selectedModelName();
|
||||
const modelForStop = streamStateForStop?.model;
|
||||
void ChatService.cancelServerStream(convId, modelForStop);
|
||||
this.abortRequest(convId);
|
||||
this.setChatLoading(convId, false);
|
||||
@@ -1846,6 +1858,14 @@ class ChatStore {
|
||||
updateStreamingContent(originalContent + appendedContent);
|
||||
this.setChatReasoning(msg.convId, false);
|
||||
},
|
||||
onCompletionId: (id: string) => {
|
||||
if (!id) return;
|
||||
// refresh the message id so a later skip targets the live slot after a continue
|
||||
conversationsStore.updateMessageAtIndex(conversationsStore.findMessageIndex(msg.id), {
|
||||
completionId: id
|
||||
});
|
||||
DatabaseService.updateMessage(msg.id, { completionId: id }).catch(() => {});
|
||||
},
|
||||
onReasoningChunk: (chunk: string) => {
|
||||
appendedReasoning += chunk;
|
||||
hasReceivedContent = true;
|
||||
|
||||
Reference in New Issue
Block a user