Compare commits

..

9 Commits

Author SHA1 Message Date
Pascal 7cb8576e7c ui: fix stop and reasoning skip in single-model mode (#25084) 2026-06-28 21:06:43 +02:00
Ruixiang Wang fa72bc6826 dflash: refactor draft model conversion (#25110)
* dflash: refactor draft model conversion

* apply fix for eagle3 convert
2026-06-28 20:31:48 +02:00
Aldehir Rojas c818263f2a chat : implement minicpm5 parser (#24889)
* Add minicpm5 tool call parser

* Refactor MiniCPM5 PEG parser per review feedback

* Fix jinja min/max API to match Jinja2

* modify by review

* MiniCPM5: use autoparser for XML tool calls and fix grammar preserved-token triggers

* MiniCPM5: fix streaming tool-arg placeholder and remove alt XML markers

* skip min/max attribute tests in -py mode

* test-jinja: use real expected output for min/max attribute tests

* MiniCPM5: revert shared mapper and history fallbacks per review

Drop streaming tool-arg placeholder workarounds from the generic PEG
mapper and restore strict tool-call argument JSON parsing so MiniCPM5
support stays limited to autoparser/diff-analyzer changes.

* chat : refactor minicpm5 back to dedicated parser

* cont : simplify grammar

* cont : refactor

* cont : fixes

* cont : rename template to openbmb-MiniCPM5-1B.jinja

* cont : add message delimiters

* cont : fix tests

---------

Co-authored-by: zhangtao <zhangtao2@modelbest.cn>
Co-authored-by: 张涛 <>
2026-06-28 16:53:32 +02:00
Xuan-Son Nguyen f68a788b0b jinja: add --dump-prog for debugging (#25086)
* jinja: add --dump-prog for debugging

* Update common/jinja/runtime.cpp

Co-authored-by: Sigbjørn Skjæret <1629204+CISC@users.noreply.github.com>

---------

Co-authored-by: Sigbjørn Skjæret <1629204+CISC@users.noreply.github.com>
2026-06-28 15:50:31 +02:00
Ruixiang Wang d1b34251bc spec : add DFlash support (#22105)
* spec: add DFlash v2 support

* dflash: support sliding window attention per layer_types

* docs: add dflash section

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2026-06-28 16:01:34 +03:00
Adrien Gallouët c1a1c8ee94 common : allow --offline in llama download (#25091)
Expose the existing --offline flag to `llama download` so a script can
run it to check whether a model is already cached and ready to be served
without touching the network.

Also fix a latent use-after-free in the URL-task on_done callback:
first_path is block-scoped and was captured by reference, but invoked
after the block ends.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-06-28 12:34:11 +02:00
Georgi Gerganov 27c8bb4f63 logs : reduce v2 (#25078)
* server : reduce logs

* cont : common

* cont : spec

* cont : CMN_ -> COM_
2026-06-28 08:52:15 +03:00
Hongqiang Wang ebd048fc5e opencl: flash attention improvement (#25069)
* opencl: rework FA kernel for f16 and f32

* opencl: flash-attention prefill prepass kernels

- flash_attn_kv_pad_f16    pads the tail KV tile to a BLOCK_N multiple
- flash_attn_mask_pad_f16  pads the matching mask tile
- flash_attn_blk_f16       classifies each KV tile per query block as
                           fully masked / mixed / fully unmasked, so
                           the main kernel can skip fully-masked tiles
                           and the mask lookup for fully-unmasked ones

* opencl: FA kernels for q4_0 and q8_0

* opencl: `set_rows` for f32 to q8_0/q4_0

* opencl: dequant kernels for q4_0 and q8_0

* opencl: add FA tile tuning table with override

* opencl: wire host side for FA

* opencl: q4_0 MoE tensors are also SOA'ed

* opencl: cosmetic fix

* opencl: refactor, also clarify some code paths in comments

* opencl: fix inifity for `-cl-finite-math-only`

---------

Co-authored-by: Li He <lih@qti.qualcomm.com>
2026-06-27 15:36:06 -07:00
Gaurav Garg 0ed235ea2c [CUDA] Added a cudaMemcpy2DAsync fast path to ggml_cuda_cpy (#25057)
* [CUDA] Added a cudaMemcpy2DAsync fast path to ggml_cuda_cpy

Add a CUDA ggml_cpy fast path for same-type, same-shape strided copies that are just 2D pitched block copies.
When tensors are not fully contiguous but each row is contiguous, it now uses cudaMemcpy2DAsync instead of the slow element-wise scalar copy kernel.

This fixes the GDN recurrent snapshot update with -np 4, where rollback slots are separated by cache stride gaps.

* Add new tests that execute the new optimized strided copy path

* Return unsupported for strided copy in OpenVINO, as new tests are failing
2026-06-27 17:46:21 +05:30
78 changed files with 18276 additions and 11973 deletions
+2 -2
View File
@@ -467,7 +467,7 @@ void common_models_handler_apply(common_models_handler & handler, common_params
// the first part is what gets loaded, so point params.model.path at it
if (!url_tasks.empty()) {
std::string first_path = url_tasks.front().local_path;
url_tasks.front().on_done = [&]() { params.model.path = first_path; };
url_tasks.front().on_done = [&, first_path]() { params.model.path = first_path; };
}
for (auto & task : url_tasks) {
tasks.push_back(std::move(task));
@@ -3471,7 +3471,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) {
params.offline = true;
}
).set_env("LLAMA_ARG_OFFLINE"));
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_OFFLINE"));
add_opt(common_arg(
{"-lv", "--verbosity", "--log-verbosity"}, "N",
string_format("Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:\n"
+151
View File
@@ -2376,6 +2376,149 @@ static void func_args_not_string(json & messages) {
}
// MiniCPM5 format:
// - Reasoning: <think>{reasoning}</think> (optional)
// - Tool calls: <function name="foo"><param name="bar">value</param></function>
static common_chat_params common_chat_params_init_minicpm5(const common_chat_template & tmpl,
const autoparser::generation_params & inputs) {
common_chat_params data;
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;
data.preserved_tokens = {
"<function",
"<param",
"</function>",
"</param>",
"<think>",
"</think>",
};
data.thinking_start_tag = "<think>";
data.thinking_end_tag = "</think>";
data.message_delimiters = {
{ COMMON_CHAT_ROLE_ASSISTANT, "<|im_start|>assistant" },
{ COMMON_CHAT_ROLE_TOOL, "<|im_start|>user\n<tool_response>" },
{ COMMON_CHAT_ROLE_USER, "<|im_start|>user" },
{ COMMON_CHAT_ROLE_SYSTEM, "<|im_start|>system" },
};
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto has_response_format = inputs.json_schema.is_object() && !inputs.json_schema.empty();
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE);
if (inputs.has_continuation()) {
const auto & msg = inputs.continue_msg;
data.generation_prompt = "<|im_start|>assistant\n<think>\n" + msg.reasoning_content;
if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) {
data.generation_prompt += "\n</think>\n\n" + msg.render_content();
}
data.prompt += data.generation_prompt;
}
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
auto generation_prompt = p.literal("<|im_start|>assistant\n");
auto reasoning = p.eps();
if (extract_reasoning) {
reasoning = ("<think>" << p.reasoning(p.until("</think>")) << "</think>") + p.space();
}
// Response format parser
if (has_response_format) {
return generation_prompt + reasoning + p.content(p.schema(p.json(), "response-format", inputs.json_schema));
}
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
// CDATA lets a value carry characters that would otherwise close the tag (e.g.
// </param>); capture the inner text only, excluding the CDATA markers.
auto string_value = p.choice({
p.literal("<![CDATA[") + p.ac(p.tool_arg_string_value(p.until("]]>")) + p.literal("]]>"), "]]>") + p.tool_arg_close(p.literal("</param>")),
p.negate(p.literal("<![CDATA[")) + p.ac(p.tool_arg_string_value(p.until("</param>")) + p.tool_arg_close(p.literal("</param>")), "</param>")
});
auto tool_choice = p.choice();
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
const std::string name = function.at("name");
auto params = function.contains("parameters") ? function.at("parameters") : json::object();
auto args = p.eps();
if (params.contains("properties") && params.at("properties").is_object() && !params.at("properties").empty()) {
auto schema_info = common_schema_info();
schema_info.resolve_refs(params);
auto arg_choice = p.choice();
for (const auto & [prop_name, prop_schema] : params.at("properties").items()) {
auto value_parser = p.eps();
if (schema_info.resolves_to_string(prop_schema)) {
value_parser = string_value;
} else {
value_parser = p.tool_arg_json_value(
p.schema(p.json(), "tool-" + name + "-arg-" + prop_name + "-schema", prop_schema, false)
) + p.tool_arg_close(p.literal("</param>"));
}
auto arg_rule = p.tool_arg(
p.tool_arg_open(p.literal("<param name=\"") + p.tool_arg_name(p.literal(prop_name)) + p.literal("\">")) +
value_parser
);
arg_choice |= arg_rule;
}
args = p.zero_or_more(arg_choice + p.space());
}
auto tool_parser = p.tool(
p.tool_open(p.literal("<function name=\"") + p.tool_name(p.literal(name)) + p.literal("\">"))
<< p.tool_args(args)
<< p.tool_close(p.literal("</function>")));
tool_choice |= p.rule("tool-" + name, tool_parser);
});
auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
auto tool_calls = p.trigger_rule("tool-call", p.repeat(tool_choice + p.space(), 1, max_calls));
auto content = p.content(p.until("<function"));
return generation_prompt + reasoning + content + tool_calls + p.end();
}
return generation_prompt + reasoning + p.content(p.rest()) + p.end();
});
data.parser = parser.save();
if (include_grammar) {
data.grammar_lazy = !(has_response_format || (has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED));
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
auto schema = function.contains("parameters") ? function.at("parameters") : json::object();
builder.resolve_refs(schema);
});
if (has_response_format) {
auto schema = inputs.json_schema;
builder.resolve_refs(schema);
}
parser.build_grammar(builder, data.grammar_lazy);
});
data.grammar_triggers = {
{ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function" },
};
}
return data;
}
static json common_chat_extra_context() {
json ctx = json::object();
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
@@ -2468,6 +2611,14 @@ std::optional<common_chat_params> common_chat_try_specialized_template(
return common_chat_params_init_gemma4(tmpl, params);
}
// MiniCPM5 - XML tool calls with <function name="..."><param name="...">...</param></function>
if (src.find("Tool usage guidelines:") != std::string::npos &&
src.find("<function name=\"") != std::string::npos &&
src.find("<param name=\"") != std::string::npos) {
LOG_DBG("Using specialized template: MiniCPM5\n");
return common_chat_params_init_minicpm5(tmpl, params);
}
return std::nullopt;
}
+47 -47
View File
@@ -225,7 +225,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
}
if (!SetPriorityClass(GetCurrentProcess(), p)) {
LOG_WRN("failed to set process priority class %d : (%d)\n", prio, (int) GetLastError());
COM_WRN("failed to set process priority class %d : (%d)\n", prio, (int) GetLastError());
return false;
}
@@ -251,7 +251,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
}
if (setpriority(PRIO_PROCESS, 0, p) != 0) {
LOG_WRN("failed to set process priority %d : %s (%d)\n", prio, strerror(errno), errno);
COM_WRN("failed to set process priority %d : %s (%d)\n", prio, strerror(errno), errno);
return false;
}
return true;
@@ -284,14 +284,14 @@ void postprocess_cpu_params(common_cpu_params & cpuparams, const common_cpu_para
if (n_set && n_set < cpuparams.n_threads) {
// Not enough set bits, may experience performance issues.
LOG_WRN("Not enough set bits in CPU mask (%d) to satisfy requested thread count: %d\n", n_set, cpuparams.n_threads);
COM_WRN("Not enough set bits in CPU mask (%d) to satisfy requested thread count: %d\n", n_set, cpuparams.n_threads);
}
}
bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THREADS]) {
size_t dash_loc = range.find('-');
if (dash_loc == std::string::npos) {
LOG_ERR("Format of CPU range is invalid! Expected [<start>]-[<end>].\n");
COM_ERR("%s", "Format of CPU range is invalid! Expected [<start>]-[<end>].\n");
return false;
}
@@ -303,7 +303,7 @@ bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THRE
} else {
start_i = std::stoull(range.substr(0, dash_loc));
if (start_i >= GGML_MAX_N_THREADS) {
LOG_ERR("Start index out of bounds!\n");
COM_ERR("%s", "Start index out of bounds!\n");
return false;
}
}
@@ -313,7 +313,7 @@ bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THRE
} else {
end_i = std::stoull(range.substr(dash_loc + 1));
if (end_i >= GGML_MAX_N_THREADS) {
LOG_ERR("End index out of bounds!\n");
COM_ERR("%s", "End index out of bounds!\n");
return false;
}
}
@@ -333,7 +333,7 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
}
size_t num_digits = mask.length() - start_i;
if (num_digits > 128) num_digits = 128;
num_digits = std::min<size_t>(num_digits, 128);
size_t end_i = num_digits + start_i;
@@ -348,7 +348,7 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
} else if (c >= 'A' && c <= 'F') {
id -= 'A' - 10;
} else {
LOG_ERR("Invalid hex character '%c' at position %d\n", c, int32_t(i));
COM_ERR("Invalid hex character '%c' at position %d\n", c, int32_t(i));
return false;
}
@@ -379,21 +379,21 @@ void common_params_print_info(const common_params & params, bool print_devices)
#else
const char * build_type = " (debug)";
#endif
LOG_TRC("%s: build %d (%s) with %s for %s%s\n", __func__, llama_build_number(), llama_commit(), llama_compiler(), llama_build_target(), build_type);
COM_TRC("%s: build %d (%s) with %s for %s%s\n", __func__, llama_build_number(), llama_commit(), llama_compiler(), llama_build_target(), build_type);
LOG_INF("log_info: verbosity = %d (adjust with the `-lv N` CLI arg)\n", common_log_get_verbosity_thold());
COM_INF("%s: verbosity = %d (adjust with the `-lv N` CLI arg)\n", __func__, common_log_get_verbosity_thold());
// device enumeration creates a primary context on CUDA backends, skip it when the caller does not own any device
if (print_devices) {
LOG_INF("device_info:\n");
COM_TRC("%s", "device_info:\n");
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
auto * dev = ggml_backend_dev_get(i);
size_t free, total;
ggml_backend_dev_memory(dev, &free, &total);
LOG_INF(" - %-8s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
COM_TRC(" - %-8s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
}
}
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
COM_TRC("%s\n", common_params_get_system_info(params).c_str());
}
std::string common_params_get_system_info(const common_params & params) {
@@ -660,7 +660,7 @@ void string_process_escapes(std::string & input) {
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides) {
const char * sep = strchr(data, '=');
if (sep == nullptr || sep - data >= 128) {
LOG_ERR("%s: malformed KV override '%s'\n", __func__, data);
COM_ERR("%s: malformed KV override '%s'\n", __func__, data);
return false;
}
llama_model_kv_override kvo;
@@ -683,20 +683,20 @@ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_over
} else if (std::strcmp(sep, "false") == 0) {
kvo.val_bool = false;
} else {
LOG_ERR("%s: invalid boolean value for KV override '%s'\n", __func__, data);
COM_ERR("%s: invalid boolean value for KV override '%s'\n", __func__, data);
return false;
}
} else if (strncmp(sep, "str:", 4) == 0) {
sep += 4;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR;
if (strlen(sep) > 127) {
LOG_ERR("%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data);
COM_ERR("%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data);
return false;
}
strncpy(kvo.val_str, sep, 127);
kvo.val_str[127] = '\0';
} else {
LOG_ERR("%s: invalid type for KV override '%s'\n", __func__, data);
COM_ERR("%s: invalid type for KV override '%s'\n", __func__, data);
return false;
}
overrides.emplace_back(std::move(kvo));
@@ -1199,8 +1199,8 @@ common_init_result::common_init_result(common_params & params, bool model_only)
auto cparams = common_context_params_to_llama(params);
if (params.fit_params) {
LOG_INF("%s: fitting params to device memory ...\n", __func__);
LOG_INF("%s: (for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on)\n", __func__);
COM_TRC("%s", "fitting params to device memory ...\n");
COM_TRC("%s", "(for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on)\n");
common_fit_params(params.model.path.c_str(), &mparams, &cparams,
params.tensor_split,
params.tensor_buft_overrides.data(),
@@ -1227,7 +1227,7 @@ common_init_result::common_init_result(common_params & params, bool model_only)
llama_adapter_lora_ptr lora;
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
if (lora == nullptr) {
LOG_ERR("%s: failed to load lora adapter '%s'\n", __func__, la.path.c_str());
COM_ERR("failed to load lora adapter '%s'\n", la.path.c_str());
pimpl->model.reset(model);
return;
}
@@ -1246,14 +1246,14 @@ common_init_result::common_init_result(common_params & params, bool model_only)
common_init_sampler_from_model(model, params.sampling);
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
COM_WRN("%s", "vocab does not have an EOS token, ignoring --ignore-eos\n");
params.sampling.ignore_eos = false;
}
// initialize once
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
if (llama_vocab_is_eog(vocab, i)) {
LOG_TRC("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(vocab, i).c_str(), -INFINITY);
COM_TRC("added %s logit bias = %f\n", common_token_to_piece(vocab, i).c_str(), -INFINITY);
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
}
}
@@ -1291,7 +1291,7 @@ common_init_result::common_init_result(common_params & params, bool model_only)
llama_context * lctx = llama_init_from_model(model, cparams);
if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
COM_ERR("failed to create context with model '%s'\n", params.model.path.c_str());
return;
}
@@ -1328,7 +1328,7 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
llama_model * model = res->model();
if (model == NULL) {
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
COM_ERR("failed to load model '%s'\n", params.model.path.c_str());
return res;
}
@@ -1338,14 +1338,14 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
llama_context * lctx = res->context();
if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
COM_ERR("failed to create context with model '%s'\n", params.model.path.c_str());
return res;
}
const llama_vocab * vocab = llama_model_get_vocab(model);
if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
COM_WRN("%s", "KV cache shifting is not supported for this context, disabling KV cache shifting\n");
params.ctx_shift = false;
}
@@ -1374,7 +1374,7 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
bool ok = true;
if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__);
COM_WRN("%s", "vocab does not have a BOS token, reranking will not work\n");
ok = false;
}
@@ -1383,10 +1383,10 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
bool has_rerank_prompt = llama_model_chat_template(model, "rerank") != NULL;
if (!has_eos && !has_sep && !has_rerank_prompt) {
LOG_WRN("%s: warning: vocab does not have an EOS token, SEP token, or rerank prompt. Reranking will not work\n", __func__);
COM_WRN("%s", "vocab does not have an EOS token, SEP token, or rerank prompt. Reranking will not work\n");
ok = false;
} else if (!has_eos) {
LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
COM_WRN("%s", "vocab does not have an EOS token, using SEP token as fallback\n");
}
if (!ok) {
@@ -1399,7 +1399,7 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
}
if (params.warmup) {
LOG_INF("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
COM_TRC("%s", "warming up the model with an empty run - please wait ... (--no-warmup to disable)\n");
std::vector<llama_token> tmp;
llama_token bos = llama_vocab_bos(vocab);
@@ -1473,20 +1473,20 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {
int ret = llama_decode(ctx, llama_batch_get_one(tmp.data(), tmp.size()));
if (ret != 0) {
LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
COM_ERR("llama_decode() failed: %d\n", ret);
res = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
goto done;
}
if (llama_n_rs_seq(ctx) > 0) {
LOG_INF("%s: the context supports bounded partial sequence removal\n", __func__);
COM_TRC("%s", "the context supports bounded partial sequence removal\n");
res = COMMON_CONTEXT_SEQ_RM_TYPE_RS;
goto done;
}
// try to remove the last tokens
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
LOG_TRC("%s: the context does not support partial sequence removal\n", __func__);
COM_TRC("%s", "the context does not support partial sequence removal\n");
res = COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
goto done;
}
@@ -1803,13 +1803,13 @@ static common_control_vector_data common_control_vector_load_one(const common_co
};
struct gguf_context * ctx_gguf = gguf_init_from_file(load_info.fname.c_str(), meta_gguf_params);
if (!ctx_gguf) {
LOG_ERR("%s: failed to load control vector file from %s\n", __func__, load_info.fname.c_str());
COM_ERR("failed to load control vector file from %s\n", load_info.fname.c_str());
return result;
}
int32_t n_tensors = gguf_get_n_tensors(ctx_gguf);
if (n_tensors == 0) {
LOG_WRN("%s: no direction tensors found in %s\n", __func__, load_info.fname.c_str());
COM_WRN("no direction tensors found in %s\n", load_info.fname.c_str());
}
for (int i = 0; i < n_tensors; i++) {
@@ -1827,23 +1827,23 @@ static common_control_vector_data common_control_vector_load_one(const common_co
}
}
if (layer_idx < 0) {
LOG_ERR("%s: invalid/unparsable direction tensor layer index in %s\n", __func__, load_info.fname.c_str());
COM_ERR("invalid/unparsable direction tensor layer index in %s\n", load_info.fname.c_str());
result.n_embd = -1;
break;
} else if (layer_idx == 0) {
LOG_ERR("%s: invalid (zero) direction tensor layer index in %s\n", __func__, load_info.fname.c_str());
COM_ERR("invalid (zero) direction tensor layer index in %s\n", load_info.fname.c_str());
result.n_embd = -1;
break;
}
struct ggml_tensor * tensor = ggml_get_tensor(ctx, name.c_str());
if (tensor->type != GGML_TYPE_F32) {
LOG_ERR("%s: invalid (non-F32) direction tensor type in %s\n", __func__, load_info.fname.c_str());
COM_ERR("invalid (non-F32) direction tensor type in %s\n", load_info.fname.c_str());
result.n_embd = -1;
break;
}
if (ggml_n_dims(tensor) != 1) {
LOG_ERR("%s: invalid (non-1D) direction tensor shape in %s\n", __func__, load_info.fname.c_str());
COM_ERR("invalid (non-1D) direction tensor shape in %s\n", load_info.fname.c_str());
result.n_embd = -1;
break;
}
@@ -1851,7 +1851,7 @@ static common_control_vector_data common_control_vector_load_one(const common_co
if (result.n_embd == -1) {
result.n_embd = ggml_nelements(tensor);
} else if (ggml_nelements(tensor) != result.n_embd) {
LOG_ERR("%s: direction tensor in %s does not match previous dimensions\n", __func__, load_info.fname.c_str());
COM_ERR("direction tensor in %s does not match previous dimensions\n", load_info.fname.c_str());
result.n_embd = -1;
break;
}
@@ -1868,7 +1868,7 @@ static common_control_vector_data common_control_vector_load_one(const common_co
}
if (result.n_embd == -1) {
LOG_WRN("%s: skipping %s due to invalid direction tensors\n", __func__, load_info.fname.c_str());
COM_WRN("skipping %s due to invalid direction tensors\n", load_info.fname.c_str());
result.data.clear();
}
@@ -1889,7 +1889,7 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
break;
}
if (result.n_embd != -1 && result.n_embd != cur.n_embd) {
LOG_ERR("%s: control vectors in %s does not match previous dimensions\n", __func__, info.fname.c_str());
COM_ERR("control vectors in %s does not match previous dimensions\n", info.fname.c_str());
result.n_embd = -1;
break;
}
@@ -1905,7 +1905,7 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
}
if (result.n_embd == -1) {
LOG_ERR("%s: no valid control vector files passed\n", __func__);
COM_ERR("%s", "no valid control vector files passed\n");
result.data.clear();
}
@@ -2016,13 +2016,13 @@ bool common_prompt_batch_decode(
// memory, so we can't just remove the last token from the memory and replay the last token which
// is the reason for this logic.
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(all_tokens.data() + offset), n_tokens_before_last))) {
LOG_ERR("%s : failed to eval\n", __func__);
COM_ERR("%s", "failed to eval\n");
return false;
}
n_past += n_tokens_before_last;
llama_state_save_file(ctx, state_path.data(), all_tokens.data(), all_tokens.size());
LOG_INF("saved session before last token to %s, n_new = %zu\n", state_path.data(), all_tokens.size());
COM_INF("saved session before last token to %s, n_new = %zu\n", state_path.data(), all_tokens.size());
llama_token last_token = all_tokens.back();
llama_batch batch = llama_batch_get_one(&last_token, 1);
@@ -2030,13 +2030,13 @@ bool common_prompt_batch_decode(
batch.pos = &pos;
if (llama_decode(ctx, batch)) {
LOG_ERR("%s : failed to eval last token\n", __func__);
COM_ERR("%s", "failed to eval last token\n");
return false;
}
n_past++;
} else {
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(all_tokens.data() + offset), n_new))) {
LOG_ERR("%s : failed to eval\n", __func__);
COM_ERR("%s", "failed to eval\n");
return false;
}
n_past += n_new;
+9 -1
View File
@@ -25,6 +25,13 @@
#define DIRECTORY_SEPARATOR '/'
#endif // _WIN32
#define COM_DBG(fmt, ...) LOG_DBG("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define COM_TRC(fmt, ...) LOG_TRC("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define COM_INF(fmt, ...) LOG_INF("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define COM_WRN(fmt, ...) LOG_WRN("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define COM_ERR(fmt, ...) LOG_ERR("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define COM_CNT(fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
#define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0)
#define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0)
@@ -162,6 +169,7 @@ enum common_speculative_type {
COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, // standalone draft model speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, // Eagle3 speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT_MTP, // Multi-token prediction
COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH, // DFlash speculative decoding
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding based on n-grams
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
@@ -377,7 +385,7 @@ struct common_params_speculative {
uint32_t need_n_rs_seq() const {
bool needs_rs_seq = std::any_of(types.begin(), types.end(), [&](auto t) {
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP || t == COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3;
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP || t == COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3 || t == COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH;
});
return needs_rs_seq ? draft.n_max : 0u;
+1 -1
View File
@@ -233,7 +233,7 @@ static void common_params_fit_impl(
sum_projected_used = dmds_full.back().mb.total();
sum_free = dmds_full.back().total;
sum_projected_free = sum_free - sum_projected_used;
LOG_INF("%s: projected to use %" PRId64 " MiB of host memory vs. %" PRId64 " MiB of total host memory\n",
LOG_TRC("%s: projected to use %" PRId64 " MiB of host memory vs. %" PRId64 " MiB of total host memory\n",
__func__, sum_projected_used/MiB, sum_free/MiB);
if (sum_projected_free >= margins[0]) {
LOG_TRC("%s: will leave %" PRId64 " >= %" PRId64 " MiB of system memory, no changes needed\n",
+46
View File
@@ -954,4 +954,50 @@ value keyword_argument_expression::execute_impl(context & ctx) {
return mk_val<value_kwarg>(k, v);
}
std::string runtime::debug_dump_program(const program & prog, const std::string & src) {
std::ostringstream oss;
size_t lvl = 0;
context ctx;
ctx.src.reset(new std::string(src));
auto indent = [](size_t lvl) -> std::string {
return std::string(lvl * 2, ' ');
};
ctx.visitor = [&](bool is_leaf, statement * node, std::vector<visitor_pair> children) {
oss << indent(lvl) << node->type() << ":\n";
lvl++;
if (is_leaf) {
const auto & pos = node->pos;
oss << indent(lvl) << "(leaf) at " << get_line_col(src, pos) << " in source:\n";
std::string snippet = peak_source(src, pos);
string_replace_all(snippet, "\n", "\n" + indent(lvl));
oss << indent(lvl) << snippet << "\n";
} else {
for (auto & [label, children_vec] : children) {
oss << indent(lvl) << label << ":\n";
lvl++;
if (children_vec.empty()) {
oss << indent(lvl) << "<empty>\n\n";
} else {
for (auto * child : children_vec) {
if (!child) {
continue;
}
child->visit(ctx);
}
}
lvl--;
}
}
lvl--;
};
for (const auto & stmt : prog.body) {
stmt->visit(ctx);
}
return oss.str();
}
} // namespace jinja
+127
View File
@@ -47,12 +47,19 @@ const T * cast_stmt(const statement_ptr & ptr) {
// not thread-safe
void enable_debug(bool enable);
// for visiting AST nodes
// function signature: void(bool is_leaf, statement * node, pair of <label, children>)
using visitor_pair = std::pair<std::string, std::vector<statement *>>;
using visitor_fn = std::function<void(bool, statement *, std::vector<visitor_pair>)>;
struct context {
std::shared_ptr<std::string> src; // for debugging; use shared_ptr to avoid copying on scope creation
std::time_t current_time; // for functions that need current time
bool is_get_stats = false; // whether to collect stats
visitor_fn visitor;
// src is optional, used for error reporting
context(std::string src = "") : src(std::make_shared<std::string>(std::move(src))) {
env = mk_val<value_object>();
@@ -99,6 +106,15 @@ private:
value_object env;
};
// utils for visiting AST nodes
static std::vector<statement *> stmts_to_ptr(const statements & stmts) {
std::vector<statement *> children;
for (const auto & stmt : stmts) {
children.push_back(stmt.get());
}
return children;
}
/**
* Base class for all nodes in the AST.
*/
@@ -106,6 +122,7 @@ struct statement {
size_t pos; // position in source, for debugging
virtual ~statement() = default;
virtual std::string type() const { return "Statement"; }
virtual void visit(context & ctx) { ctx.visitor(true, this, {}); }
// execute_impl must be overridden by derived classes
virtual value execute_impl(context &) { throw_exec_error(); }
@@ -166,6 +183,13 @@ struct if_statement : public statement {
std::string type() const override { return "If"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"test", {test.get()}},
{"body", stmts_to_ptr(body)},
{"alternate", stmts_to_ptr(alternate)}
});
}
};
struct identifier;
@@ -190,6 +214,14 @@ struct for_statement : public statement {
std::string type() const override { return "For"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"loopvar", {loopvar.get()}},
{"iterable", {iterable.get()}},
{"body", stmts_to_ptr(body)},
{"default_block", stmts_to_ptr(default_block)}
});
}
};
struct break_statement : public statement {
@@ -241,6 +273,13 @@ struct set_statement : public statement {
std::string type() const override { return "Set"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"assignee", {assignee.get()}},
{"value", {val.get()}},
{"body", stmts_to_ptr(body)}
});
}
};
struct macro_statement : public statement {
@@ -256,6 +295,13 @@ struct macro_statement : public statement {
std::string type() const override { return "Macro"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"name", {name.get()}},
{"args", stmts_to_ptr(args)},
{"body", stmts_to_ptr(body)}
});
}
};
struct comment_statement : public statement {
@@ -289,6 +335,12 @@ struct member_expression : public expression {
}
std::string type() const override { return "MemberExpression"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"object", {object.get()}},
{"property", {property.get()}}
});
}
};
struct call_expression : public expression {
@@ -302,6 +354,12 @@ struct call_expression : public expression {
}
std::string type() const override { return "CallExpression"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"callee", {callee.get()}},
{"args", stmts_to_ptr(args)}
});
}
};
/**
@@ -405,6 +463,12 @@ struct binary_expression : public expression {
}
std::string type() const override { return "BinaryExpression"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"left", {left.get()}},
{"right", {right.get()}}
});
}
};
/**
@@ -431,6 +495,12 @@ struct filter_expression : public expression {
std::string type() const override { return "FilterExpression"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"operand", {operand.get()}},
{"filter", {filter.get()}}
});
}
};
struct filter_statement : public statement {
@@ -443,6 +513,12 @@ struct filter_statement : public statement {
}
std::string type() const override { return "FilterStatement"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"filter", {filter.get()}},
{"body", stmts_to_ptr(body)}
});
}
};
/**
@@ -468,6 +544,12 @@ struct select_expression : public expression {
}
return lhs->execute_impl(ctx);
}
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"lhs", {lhs.get()}},
{"test", {test.get()}}
});
}
};
/**
@@ -486,6 +568,12 @@ struct test_expression : public expression {
}
std::string type() const override { return "TestExpression"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"operand", {operand.get()}},
{"test", {test.get()}}
});
}
};
/**
@@ -501,6 +589,11 @@ struct unary_expression : public expression {
}
std::string type() const override { return "UnaryExpression"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"argument", {argument.get()}}
});
}
};
struct slice_expression : public expression {
@@ -518,6 +611,13 @@ struct slice_expression : public expression {
[[noreturn]] value execute_impl(context &) override {
throw std::runtime_error("must be handled by MemberExpression");
}
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"start_expr", {start_expr.get()}},
{"stop_expr", {stop_expr.get()}},
{"step_expr", {step_expr.get()}}
});
}
};
struct keyword_argument_expression : public expression {
@@ -531,6 +631,12 @@ struct keyword_argument_expression : public expression {
}
std::string type() const override { return "KeywordArgumentExpression"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"key", {key.get()}},
{"val", {val.get()}}
});
}
};
struct spread_expression : public expression {
@@ -539,6 +645,11 @@ struct spread_expression : public expression {
chk_type<expression>(this->argument);
}
std::string type() const override { return "SpreadExpression"; }
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"argument", {argument.get()}}
});
}
};
struct call_statement : public statement {
@@ -553,6 +664,13 @@ struct call_statement : public statement {
}
std::string type() const override { return "CallStatement"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"call", {call.get()}},
{"caller_args", stmts_to_ptr(caller_args)},
{"body", stmts_to_ptr(body)}
});
}
};
struct ternary_expression : public expression {
@@ -575,6 +693,13 @@ struct ternary_expression : public expression {
return false_expr->execute(ctx);
}
}
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"condition", {condition.get()}},
{"true_expr", {true_expr.get()}},
{"false_expr", {false_expr.get()}}
});
}
};
struct raised_exception : public std::exception {
@@ -648,6 +773,8 @@ struct runtime {
}
return parts;
}
static std::string debug_dump_program(const program & prog, const std::string & src);
};
} // namespace jinja
+44
View File
@@ -1108,6 +1108,50 @@ const func_builtins & value_array_t::get_builtins() const {
std::reverse(arr.begin(), arr.end());
return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(arr)) : mk_val<value_array>(std::move(arr));
}},
{"min", [](const func_args & args) -> value {
args.ensure_count(1, 4);
args.ensure_vals<value_array>();
value val_case = args.get_kwarg_or_pos("case_sensitive", 1);
value attribute = args.get_kwarg_or_pos("attribute", 2);
if (!attribute->is_undefined()) {
throw not_implemented_exception("min: attribute not implemented");
}
// FIXME: min is currently always case sensitive
(void) val_case;
const auto & arr = args.get_pos(0)->as_array();
if (arr.empty()) {
return mk_val<value_undefined>();
}
value result = arr[0];
for (size_t i = 1; i < arr.size(); ++i) {
if (value_compare(arr[i], result, value_compare_op::lt)) {
result = arr[i];
}
}
return result;
}},
{"max", [](const func_args & args) -> value {
args.ensure_count(1, 4);
args.ensure_vals<value_array>();
value val_case = args.get_kwarg_or_pos("case_sensitive", 1);
value attribute = args.get_kwarg_or_pos("attribute", 2);
if (!attribute->is_undefined()) {
throw not_implemented_exception("max: attribute not implemented");
}
// FIXME: max is currently always case sensitive
(void) val_case;
const auto & arr = args.get_pos(0)->as_array();
if (arr.empty()) {
return mk_val<value_undefined>();
}
value result = arr[0];
for (size_t i = 1; i < arr.size(); ++i) {
if (value_compare(arr[i], result, value_compare_op::gt)) {
result = arr[i];
}
}
return result;
}},
{"unique", array_unique_not_implemented},
};
return builtins;
+10 -10
View File
@@ -65,12 +65,12 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
if (ctx->start_matcher.advance(token)) {
ctx->state = REASONING_BUDGET_COUNTING;
ctx->remaining = ctx->budget;
LOG_INF("reasoning-budget: activated, budget=%d tokens\n", ctx->budget);
COM_TRC("activated, budget=%d tokens\n", ctx->budget);
if (ctx->remaining <= 0) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
COM_TRC("%s", "budget=0, forcing immediately\n");
}
}
break;
@@ -80,7 +80,7 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
{
if (ctx->end_matcher.advance(token)) {
ctx->state = REASONING_BUDGET_DONE;
LOG_INF("reasoning-budget: deactivated (natural end)\n");
COM_TRC("%s", "deactivated (natural end)\n");
break;
}
@@ -95,7 +95,7 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n");
COM_TRC("%s", "UTF-8 complete, now forcing end sequence\n");
}
} else if (ctx->state == REASONING_BUDGET_COUNTING) {
ctx->remaining--;
@@ -104,11 +104,11 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: budget exhausted, forcing end sequence\n");
COM_TRC("%s", "budget exhausted, forcing end sequence\n");
} else {
ctx->state = REASONING_BUDGET_WAITING_UTF8;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: budget exhausted, waiting for UTF-8 completion\n");
COM_TRC("%s", "budget exhausted, waiting for UTF-8 completion\n");
}
}
}
@@ -118,7 +118,7 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
ctx->force_pos++;
if (ctx->force_pos >= ctx->forced_tokens.size()) {
ctx->state = REASONING_BUDGET_DONE;
LOG_INF("reasoning-budget: forced sequence complete, done\n");
COM_TRC("%s", "forced sequence complete, done\n");
}
break;
case REASONING_BUDGET_DONE:
@@ -128,12 +128,12 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
ctx->state = REASONING_BUDGET_COUNTING;
ctx->remaining = ctx->budget;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: re-activated on new start tag, budget=%d tokens\n", ctx->budget);
COM_TRC("re-activated on new start tag, budget=%d tokens\n", ctx->budget);
if (ctx->remaining <= 0) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
COM_TRC("%s", "budget=0, forcing immediately\n");
}
}
break;
@@ -264,7 +264,7 @@ bool common_reasoning_budget_force(struct llama_sampler * smpl) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: forced into forcing state (manual transition)\n");
COM_TRC("%s", "forced into forcing state (manual transition)\n");
return true;
}
+371 -65
View File
@@ -18,6 +18,13 @@
#include <map>
#include <cinttypes>
#define SPC_DBG(fmt, ...) LOG_DBG("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SPC_TRC(fmt, ...) LOG_TRC("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SPC_INF(fmt, ...) LOG_INF("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SPC_WRN(fmt, ...) LOG_WRN("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SPC_ERR(fmt, ...) LOG_ERR("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SPC_CNT(fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
@@ -26,6 +33,7 @@ const std::map<std::string, common_speculative_type> common_speculative_type_fro
{"draft-simple", COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE},
{"draft-eagle3", COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3},
{"draft-mtp", COMMON_SPECULATIVE_TYPE_DRAFT_MTP},
{"draft-dflash", COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH},
{"ngram-simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
{"ngram-map-k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
{"ngram-map-k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
@@ -60,21 +68,20 @@ static bool common_speculative_are_compatible(
const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
const auto vocab_type_tgt = llama_vocab_type(vocab_tgt);
LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
SPC_DBG("vocab_type tgt: %d\n", vocab_type_tgt);
const auto vocab_type_dft = llama_vocab_type(vocab_dft);
LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
SPC_DBG("vocab_type dft: %d\n", vocab_type_dft);
if (vocab_type_tgt != vocab_type_dft) {
LOG_WRN("%s: draft model vocab type must match target model to use speculation but "
"vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt);
SPC_WRN("draft model vocab type must match target model to use speculation but "
"vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
return false;
}
if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
(llama_vocab_get_add_bos(vocab_tgt) && llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft))) {
LOG_WRN("%s: draft model bos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
__func__,
SPC_WRN("draft model bos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
llama_vocab_get_add_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_dft),
llama_vocab_bos(vocab_tgt), llama_vocab_bos(vocab_dft));
return false;
@@ -82,8 +89,7 @@ static bool common_speculative_are_compatible(
if (llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
(llama_vocab_get_add_eos(vocab_tgt) && llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft))) {
LOG_WRN("%s: draft model eos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
__func__,
SPC_WRN("draft model eos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
llama_vocab_get_add_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_dft),
llama_vocab_eos(vocab_tgt), llama_vocab_eos(vocab_dft));
return false;
@@ -97,8 +103,8 @@ static bool common_speculative_are_compatible(
: n_vocab_dft - n_vocab_tgt;
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
LOG_DBG("%s: draft model vocab must closely match target model to use speculation but ", __func__);
LOG_DBG("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
SPC_DBG("draft model vocab must closely match target model to use speculation but "
"target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
return false;
}
@@ -108,8 +114,8 @@ static bool common_speculative_are_compatible(
const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
LOG_DBG("%s: draft model vocab must match target model to use speculation but ", __func__);
LOG_DBG("token %d content differs - target '%s', draft '%s'\n", i,
SPC_DBG("draft model vocab must match target model to use speculation but "
"token %d content differs - target '%s', draft '%s'\n", i,
common_token_to_piece(vocab_tgt, i).c_str(),
common_token_to_piece(vocab_dft, i).c_str());
return false;
@@ -186,9 +192,9 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
auto * ctx_dft = this->params.ctx_dft;
auto * ctx_tgt = this->params.ctx_tgt;
LOG_INF("%s: adding speculative implementation 'draft-simple'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%f\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min);
LOG_INF("%s: - gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n", __func__,
SPC_TRC("%s", "adding speculative implementation 'draft-simple'\n");
SPC_TRC("- n_max=%d, n_min=%d, p_min=%f\n", this->params.n_max, this->params.n_min, this->params.p_min);
SPC_TRC("- gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n",
this->params.n_gpu_layers,
ggml_type_name(this->params.cache_type_k),
ggml_type_name(this->params.cache_type_v),
@@ -228,16 +234,16 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
}
const bool vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft));
LOG_DBG("%s: vocab_cmpt = %d\n", __func__, vocab_cmpt);
SPC_DBG("vocab_cmpt = %d\n", vocab_cmpt);
if (!vocab_cmpt) {
LOG_ERR("%s: the target and draft vocabs are not compatible\n", __func__);
SPC_ERR("%s", "the target and draft vocabs are not compatible\n");
throw std::runtime_error("draft model vocab type must match target model to use speculation");
}
if (n_seq != llama_n_seq_max(ctx_dft)) {
LOG_ERR("%s: n_seq mismatch: %d != %d\n", __func__, n_seq, llama_n_seq_max(ctx_dft));
SPC_ERR("n_seq mismatch: %d != %d\n", n_seq, llama_n_seq_max(ctx_dft));
throw std::runtime_error("the draft model number of sequences is incompatible with the speculative n_seq");
}
@@ -257,7 +263,7 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
const int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_ERR("%s: failed to decode draft batch, ret = %d\n", __func__, ret);
SPC_ERR("failed to decode draft batch, ret = %d\n", ret);
return false;
}
@@ -290,7 +296,7 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
SPC_ERR("llama_decode returned %d\n", ret);
return;
}
@@ -314,7 +320,7 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
const auto * cur_p = common_sampler_get_candidates(smpl, true);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
SPC_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p,
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
}
@@ -354,7 +360,7 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
// evaluate the drafted tokens on the draft model
ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
SPC_ERR("llama_decode[%d] returned %d\n", i, ret);
break;
}
@@ -449,8 +455,8 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, n_seq)
, params(params.draft)
{
LOG_INF("%s: adding speculative implementation 'draft-eagle3'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%f, backend_sampling=%d\n", __func__, params.draft.n_max, params.draft.n_min, params.draft.p_min, (int) params.draft.backend_sampling);
SPC_TRC("%s", "adding speculative implementation 'draft-eagle3'\n");
SPC_TRC("- n_max=%d, n_min=%d, p_min=%f, backend_sampling=%d\n", params.draft.n_max, params.draft.n_min, params.draft.p_min, (int) params.draft.backend_sampling);
auto * ctx_tgt = this->params.ctx_tgt;
auto * ctx_dft = this->params.ctx_dft;
@@ -493,7 +499,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
llama_sampler_chain_add(chain, llama_sampler_init_top_k(10));
if (!llama_set_sampler(ctx_dft, seq_id, chain)) {
LOG_WRN("%s: backend offload failed for seq_id=%d; using CPU sampler\n", __func__, (int) seq_id);
SPC_WRN("backend offload failed for seq_id=%d; using CPU sampler\n", (int) seq_id);
llama_sampler_free(chain);
chain = nullptr;
}
@@ -548,9 +554,9 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
auto * ctx_dft = this->params.ctx_dft;
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
if (pos_max < N - 2) {
LOG_WRN("%s: ctx_dft pos_max=%d < N-2=%d — process() did not run on every prefill ubatch. "
SPC_WRN("ctx_dft pos_max=%d < N-2=%d — process() did not run on every prefill ubatch. "
"Drafts may degrade.\n",
__func__, (int) pos_max, N - 2);
(int) pos_max, N - 2);
}
}
@@ -621,8 +627,8 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
};
const int32_t rc = llama_encode(ctx_dft, enc_batch);
if (rc != 0) {
LOG_ERR("%s: llama_encode(ctx_dft) failed rc=%d (n_tokens=%d, offset=%d)\n",
__func__, rc, (int) n_chunk, (int) i);
SPC_ERR("llama_encode(ctx_dft) failed rc=%d (n_tokens=%d, offset=%d)\n",
rc, (int) n_chunk, (int) i);
return false;
}
@@ -692,8 +698,8 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
if (batch.n_tokens > 0) {
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (n_tokens=%d, ubatch_pos[0]=%d)\n",
__func__, rc, (int) batch.n_tokens, (int) batch_in.pos[0]);
SPC_ERR("llama_decode(ctx_dft) failed rc=%d (n_tokens=%d, ubatch_pos[0]=%d)\n",
rc, (int) batch.n_tokens, (int) batch_in.pos[0]);
return false;
}
}
@@ -744,7 +750,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
SPC_ERR("llama_decode returned %d\n", ret);
return;
}
@@ -770,7 +776,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
const auto * cur_p = common_sampler_get_candidates(smpl, true);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
SPC_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p,
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
}
@@ -809,7 +815,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
SPC_ERR("llama_decode[%d] returned %d\n", i, ret);
break;
}
@@ -893,6 +899,296 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
}
};
// DFlash: block-diffusion drafting with a draft-side KV cache injection
struct common_speculative_impl_draft_dflash : public common_speculative_impl {
common_params_speculative_draft params;
llama_batch batch; // noise tokens
llama_batch batch_inject; // target features for KV cache injection
std::vector<common_sampler_ptr> smpls;
int32_t n_embd_dec = 0; // draft hidden size
int32_t n_embd_enc = 0; // target_layer_ids_n * target_hidden_size
int32_t n_embd_tgt = 0; // target model hidden size
int32_t block_size = 0;
llama_token mask_token_id = 0;
const int32_t * target_layer_ids = nullptr; // model_dft's extract layer indices
uint32_t target_layer_ids_n = 0;
// scratch buffer for concatenated target features [n_tokens, n_embd_enc]
std::vector<float> features_buf;
common_speculative_impl_draft_dflash(const common_params_speculative & params, uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH, n_seq)
, params(params.draft)
{
auto * ctx_tgt = this->params.ctx_tgt;
auto * ctx_dft = this->params.ctx_dft;
GGML_ASSERT(ctx_tgt && ctx_dft && "DFlash requires ctx_tgt and ctx_dft to be set");
const llama_model * model_dft = llama_get_model(ctx_dft);
const llama_model * model_tgt = llama_get_model(ctx_tgt);
target_layer_ids = llama_model_target_layer_ids (model_dft);
target_layer_ids_n = llama_model_target_layer_ids_n(model_dft);
GGML_ASSERT(target_layer_ids_n > 0 && "DFlash model has no target_layer_ids");
n_embd_tgt = llama_model_n_embd(model_tgt);
n_embd_dec = llama_model_n_embd(model_dft);
n_embd_enc = (int32_t) target_layer_ids_n * n_embd_tgt;
// read the trained block size from the dflash.block_size metadata key
block_size = 16;
{
char buf[32] = {};
if (llama_model_meta_val_str(model_dft, "dflash.block_size", buf, sizeof(buf)) >= 0) {
block_size = std::atoi(buf);
}
}
mask_token_id = llama_vocab_mask(llama_model_get_vocab(model_dft));
LOG_INF("%s: adding speculative implementation 'draft-dflash'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min);
LOG_INF("%s: - block_size=%d, mask_token_id=%d, n_extract=%u\n", __func__, block_size, mask_token_id, target_layer_ids_n);
// DFlash input is [id_last, <mask> * (block_size-1)], so it can draft at most block_size-1 tokens per step
if (this->params.n_max > block_size - 1) {
LOG_WRN("%s: requested draft size %d exceeds the trained DFlash block size %d -- clamping to %d draft tokens per step\n",
__func__, this->params.n_max, block_size - 1, block_size - 1);
this->params.n_max = block_size - 1;
}
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, n_seq);
batch_inject = llama_batch_init(llama_n_batch(ctx_dft), n_embd_dec, n_seq);
smpls.resize(n_seq);
for (auto & s : smpls) {
common_params_sampling sparams;
sparams.no_perf = false;
sparams.top_k = 1;
sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
s.reset(common_sampler_init(model_dft, sparams));
}
// turn on extraction of the target layers' input embeddings
for (uint32_t k = 0; k < target_layer_ids_n; ++k) {
llama_set_embeddings_layer_inp(ctx_tgt, (uint32_t) target_layer_ids[k], true);
}
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
llama_set_causal_attn(ctx_dft, false); // DFlash needs non-causal attention
}
~common_speculative_impl_draft_dflash() override {
llama_batch_free(batch);
llama_batch_free(batch_inject);
}
void begin(llama_seq_id seq_id, const llama_tokens & prompt) override {
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) {
return;
}
const int32_t N = (int32_t) prompt.size();
if (N <= 0) {
return;
}
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(params.ctx_dft), seq_id);
if (pos_max < N - 1) {
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - process() did not run on every prefill ubatch. "
"Drafts may degrade.\n",
__func__, (int) pos_max, N - 1);
}
}
bool process(const llama_batch & batch_in) override {
if (batch_in.n_tokens <= 0) {
return true;
}
if (batch_in.token == nullptr || batch_in.embd != nullptr) {
return true;
}
const int32_t n_tokens = batch_in.n_tokens;
// per-seq inclusive batch range (assumes each seq's tokens are contiguous in the batch)
std::vector<int32_t> i_batch_beg(n_seq, -1);
std::vector<int32_t> i_batch_end(n_seq, -1);
for (int32_t k = 0; k < n_tokens; ++k) {
GGML_ASSERT(batch_in.n_seq_id[k] == 1);
const llama_seq_id seq_id = batch_in.seq_id[k][0];
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) {
continue;
}
i_batch_end[seq_id] = k;
if (i_batch_beg[seq_id] < 0) {
i_batch_beg[seq_id] = k;
}
}
auto * ctx_tgt = this->params.ctx_tgt;
auto * ctx_dft = this->params.ctx_dft;
const int32_t n_ubatch = (int32_t) llama_n_ubatch(ctx_dft);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_batch_beg[seq_id] < 0) {
continue;
}
const int32_t n_rows = i_batch_end[seq_id] - i_batch_beg[seq_id] + 1;
for (int32_t offset = 0; offset < n_rows; offset += n_ubatch) {
const int32_t n_chunk = std::min(n_ubatch, n_rows - offset);
// gather this chunk's target features, interleaved by extract layer
features_buf.resize((size_t) n_chunk * n_embd_enc);
for (uint32_t k = 0; k < target_layer_ids_n; ++k) {
const float * layer = llama_get_embeddings_layer_inp(ctx_tgt, (uint32_t) target_layer_ids[k]);
if (!layer) {
GGML_ABORT("DFlash: target layer %d input not extracted.", target_layer_ids[k]);
}
for (int32_t i = 0; i < n_chunk; ++i) {
float * dst = features_buf.data() + (size_t) i * n_embd_enc + k * (size_t) n_embd_tgt;
const float * src = layer + (size_t) (i_batch_beg[seq_id] + offset + i) * n_embd_tgt;
std::memcpy(dst, src, (size_t) n_embd_tgt * sizeof(float));
}
}
// fuse extracted features through DFlash encoder
llama_batch enc_batch = {
/*.n_tokens =*/ n_chunk,
/*.token =*/ nullptr,
/*.embd =*/ features_buf.data(),
/*.pos =*/ nullptr,
/*.n_seq_id =*/ nullptr,
/*.seq_id =*/ nullptr,
/*.logits =*/ nullptr,
};
int32_t rc = llama_encode(ctx_dft, enc_batch);
if (rc != 0) {
LOG_ERR("%s: llama_encode(ctx_dft) failed rc=%d (n_tokens=%d, offset=%d)\n",
__func__, rc, (int) n_chunk, (int) offset);
return false;
}
const float * inp_g = llama_get_embeddings_nextn(ctx_dft);
GGML_ASSERT(inp_g && "DFlash encoder produced no output.");
// inject the DFlash decoder K/V cache at the tokens' target positions
batch_inject.n_tokens = n_chunk;
std::memcpy(batch_inject.embd, inp_g, (size_t) n_chunk * n_embd_dec * sizeof(float));
for (int32_t i = 0; i < n_chunk; ++i) {
batch_inject.pos[i] = batch_in.pos[i_batch_beg[seq_id] + offset + i];
batch_inject.n_seq_id[i] = 1;
batch_inject.seq_id[i][0] = seq_id;
batch_inject.logits[i] = false;
}
rc = llama_decode(ctx_dft, batch_inject);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (n_tokens=%d, offset=%d)\n",
__func__, rc, (int) n_chunk, (int) offset);
return false;
}
}
}
return true;
}
void draft(common_speculative_draft_params_vec & dparams) override {
auto & ctx_dft = params.ctx_dft;
common_batch_clear(batch);
// build one batch holding every drafting sequence's noise block into a single decode)
// record where each block starts and its size
std::vector<int32_t> i_block_beg(n_seq, -1);
std::vector<int32_t> n_block (n_seq, 0);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
auto & dp = dparams[seq_id];
if (!dp.drafting) {
continue;
}
common_sampler_reset(smpls[seq_id].get());
const int32_t n = (int32_t) dp.n_past;
int32_t n_draft = params.n_max;
if (dp.n_max > 0) {
n_draft = std::min(n_draft, dp.n_max);
}
const int32_t n_block_tokens = n_draft + 1; // id_last + n_draft * <mask>
i_block_beg[seq_id] = batch.n_tokens;
n_block [seq_id] = n_block_tokens;
for (int32_t i = 0; i < n_block_tokens; ++i) {
common_batch_add(batch, i == 0 ? dp.id_last : mask_token_id, n + i, { seq_id }, true);
}
}
if (batch.n_tokens == 0) {
return;
}
// decode all sequence's noise block in a single batch
int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
return;
}
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_block_beg[seq_id] < 0) {
continue;
}
auto & dp = dparams[seq_id];
const int32_t beg = i_block_beg[seq_id];
const int32_t n_block_tokens = n_block[seq_id];
auto * smpl = smpls[seq_id].get();
auto & result = *dp.result;
// greedily read the predicted block at this sequence's noise positions 1..n_block_tokens-1
for (int32_t i = 1; i < n_block_tokens; ++i) {
common_sampler_sample(smpl, ctx_dft, beg + i, true);
const auto * cur_p = common_sampler_get_candidates(smpl, true);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
seq_id, k, i - 1, cur_p->data[k].id, cur_p->data[k].p,
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
}
const llama_token id = cur_p->data[0].id;
common_sampler_accept(smpl, id, true);
result.push_back(id);
}
}
}
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/, bool /*is_other*/) override {
// noop
}
bool need_embd() const override {
return false;
}
};
struct common_speculative_impl_draft_mtp : public common_speculative_impl {
common_params_speculative_draft params; // reuses the draft-model params slot (ctx_tgt/ctx_dft)
@@ -942,9 +1238,9 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
"MTP input row width must match the target h_nextn width");
n_mtp_layers = std::max(1, (int) llama_model_n_layer_nextn(llama_get_model(ctx_dft)));
LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
LOG_INF("%s: - gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n", __func__,
SPC_TRC("%s", "adding speculative implementation 'draft-mtp'\n");
SPC_TRC("- n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
SPC_TRC("- gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n",
this->params.n_gpu_layers,
ggml_type_name(this->params.cache_type_k),
ggml_type_name(this->params.cache_type_v),
@@ -975,7 +1271,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
llama_sampler_chain_add(chain, llama_sampler_init_top_k(10));
if (!llama_set_sampler(ctx_dft, seq_id, chain)) {
LOG_WRN("%s: backend offload failed for seq_id=%d; using CPU sampler\n", __func__, (int) seq_id);
SPC_WRN("backend offload failed for seq_id=%d; using CPU sampler\n", (int) seq_id);
llama_sampler_free(chain);
chain = nullptr;
}
@@ -1038,11 +1334,11 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
if (pos_max < N - 1 && !is_mem_shared) {
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - "
SPC_WRN("ctx_dft pos_max=%d < N-1=%d - "
"process() hook may not have run on every prefill ubatch "
"(need_embd / logits=1 on every prompt position?). "
"Drafts may degrade.\n",
__func__, (int) pos_max, N - 1);
(int) pos_max, N - 1);
}
}
@@ -1128,8 +1424,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) head=%d failed rc=%d (pos=%d)\n",
__func__, head, (int) rc, (int) batch_in.pos[0]);
SPC_ERR("llama_decode(ctx_dft) head=%d failed rc=%d (pos=%d)\n",
head, (int) rc, (int) batch_in.pos[0]);
ok = false;
break;
}
@@ -1217,7 +1513,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
SPC_ERR("llama_decode[%d] returned %d\n", i, ret);
break;
}
@@ -1239,7 +1535,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
const auto * cur_p = common_sampler_get_candidates(smpl, true);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
SPC_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p,
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
}
@@ -1353,8 +1649,8 @@ struct common_speculative_impl_ngram_simple : public common_speculative_impl {
, params(params.ngram_simple)
, config(config)
{
LOG_INF("%s: adding speculative implementation 'ngram-simple'\n", __func__);
LOG_INF("%s: - size_n=%d, size_m=%d, min_hits=%d\n", __func__,
SPC_TRC("%s", "adding speculative implementation 'ngram-simple'\n");
SPC_TRC("- size_n=%d, size_m=%d, min_hits=%d\n",
this->params.size_n, this->params.size_m, this->params.min_hits);
}
@@ -1403,8 +1699,8 @@ struct common_speculative_impl_ngram_map_k : public common_speculative_impl {
this->config.push_back(config);
}
LOG_INF("%s: adding speculative implementation '%s'\n", __func__, common_speculative_type_to_str(this->type).c_str());
LOG_INF("%s: - size_key=%d, size_value=%d, key_only=%d, min_hits=%d\n", __func__,
SPC_TRC("adding speculative implementation '%s'\n", common_speculative_type_to_str(this->type).c_str());
SPC_TRC("- size_key=%d, size_value=%d, key_only=%d, min_hits=%d\n",
config.size_key, config.size_value, config.key_only, config.min_hits);
}
@@ -1478,15 +1774,15 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
, verbose(std::getenv("LLAMA_TRACE") != nullptr) {
static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t));
LOG_INF("%s: adding speculative implementation 'ngram-mod'\n", __func__);
LOG_INF("%s: - n_match=%d, n_max=%d, n_min=%d\n", __func__,
SPC_TRC("%s", "adding speculative implementation 'ngram-mod'\n");
SPC_TRC("- n_match=%d, n_max=%d, n_min=%d\n",
this->params.n_match, this->params.n_max, this->params.n_min);
LOG_INF("%s: - mod size=%zu (%.3f MB)\n", __func__,
SPC_TRC("- mod size=%zu (%.3f MB)\n",
mod.size(), (float)(mod.size_bytes())/1024/1024);
if (this->params.n_match < 16) {
LOG_WRN("%s: ngram_mod n_match=%d is too small - poor quality is possible, "
"see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, this->params.n_match);
SPC_WRN("ngram_mod n_match=%d is too small - poor quality is possible, "
"see: https://github.com/ggml-org/llama.cpp/pull/19164\n", this->params.n_match);
}
sinfos.resize(n_seq);
@@ -1510,11 +1806,11 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
sinfo.i_last = prompt.size() - n;
const double f = (double)mod.get_used() / (double)mod.size();
LOG_INF("%s: ngram_mod occupancy = %zu/%zu (%.2f)\n", __func__, mod.get_used(), mod.size(), f);
SPC_TRC("ngram_mod occupancy = %zu/%zu (%.2f)\n", mod.get_used(), mod.size(), f);
constexpr double f_thold = 0.25;
if (f > f_thold) {
LOG_WRN("%s: ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", __func__, f, f_thold);
SPC_WRN("ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", f, f_thold);
mod.reset();
}
@@ -1608,7 +1904,7 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
sinfo.n_low++;
if (sinfo.n_low >= 5) {
if (verbose) {
LOG_WRN("%s: low acceptance streak (%d) - resetting ngram_mod\n", __func__, sinfo.n_low);
SPC_TRC("low acceptance streak (%d) - resetting ngram_mod\n", sinfo.n_low);
}
mod.reset();
@@ -1658,8 +1954,8 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl {
, save_dynamic(save_dynamic)
, save_static(save_static)
{
LOG_INF("%s: adding speculative implementation 'ngram-cache'\n", __func__);
LOG_INF("%s: - n_draft=%d, cache_static=%s, cache_dynamic=%s\n", __func__,
SPC_TRC("%s", "adding speculative implementation 'ngram-cache'\n");
SPC_TRC("- n_draft=%d, cache_static=%s, cache_dynamic=%s\n",
n_draft,
path_static.empty() ? "none" : path_static.c_str(),
path_dynamic.empty() ? "none" : path_dynamic.c_str());
@@ -1674,7 +1970,7 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl {
sinfo.ngram_cache_static = ngram_cache_static;
}
} catch (...) {
LOG_ERR("failed to open static lookup cache: %s", path_static.c_str());
SPC_ERR("failed to open static lookup cache: %s", path_static.c_str());
GGML_ABORT("Couldn't read static lookup cache");
}
}
@@ -1687,7 +1983,7 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl {
sinfo.ngram_cache_dynamic = ngram_cache_dynamic;
}
} catch (...) {
LOG_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str());
SPC_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str());
GGML_ABORT("Couldn't read dynamic lookup cache");
}
}
@@ -1836,6 +2132,7 @@ std::string common_speculative_type_to_str(common_speculative_type type) {
case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE: return "draft-simple";
case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3: return "draft-eagle3";
case COMMON_SPECULATIVE_TYPE_DRAFT_MTP: return "draft-mtp";
case COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH: return "draft-dflash";
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram-simple";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram-map-k";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram-map-k4v";
@@ -1888,6 +2185,7 @@ int32_t common_speculative_n_max(const common_params_speculative * spec) {
case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE:
case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3:
case COMMON_SPECULATIVE_TYPE_DRAFT_MTP:
case COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH:
n_max = std::max(n_max, std::max(0, spec->draft.n_max));
break;
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE:
@@ -1925,6 +2223,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE));
bool has_draft_eagle3 = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3)) && params.draft.ctx_dft != nullptr;
bool has_draft_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
bool has_draft_dflash = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH)) && params.draft.ctx_dft != nullptr;
@@ -1935,7 +2234,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
bool has_ngram_mod = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MOD));
// when adding a new type - update here the logic above
static_assert(COMMON_SPECULATIVE_TYPE_COUNT == 9);
static_assert(COMMON_SPECULATIVE_TYPE_COUNT == 10);
// this list here defines the priority of the speculators
// the one with highest priority are listed first
@@ -1965,6 +2264,9 @@ common_speculative * common_speculative_init(common_params_speculative & params,
if (has_draft_mtp) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, params));
}
if (has_draft_dflash) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH, params));
}
}
std::vector<std::unique_ptr<common_speculative_impl>> impls = {};
@@ -1985,6 +2287,10 @@ common_speculative * common_speculative_init(common_params_speculative & params,
impls.push_back(std::make_unique<common_speculative_impl_draft_mtp>(config.params, n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH: {
impls.push_back(std::make_unique<common_speculative_impl_draft_dflash>(config.params, n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
common_ngram_map ngram_map = get_common_ngram_map(config.type, config.params.ngram_simple);
@@ -2034,7 +2340,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
}
if (impls.empty()) {
LOG_WRN("%s: no implementations specified for speculative decoding\n", __func__);
SPC_TRC("%s", "no implementations specified for speculative decoding\n");
return nullptr;
}
@@ -2161,13 +2467,13 @@ void common_speculative_draft(common_speculative * spec) {
if (dp.n_max > 0) {
if (!result.empty() && (int) result.size() > dp.n_max) {
LOG_DBG("%s: truncating draft to %d tokens\n", __func__, dp.n_max);
SPC_DBG("truncating draft to %d tokens\n", dp.n_max);
result.resize(dp.n_max);
}
}
if (!result.empty()) {
LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__,
SPC_DBG("called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n",
common_speculative_type_to_str(impl.get()->type).c_str(), dp.prompt->size(),
impl.get()->n_call_draft, result.size());
@@ -2291,7 +2597,7 @@ void common_speculative_print_stats(const common_speculative * spec) {
str_stats = ", #mean acc len = " + oss.str() + ", #acc rate/pos = (" + tmp.str() + ")";
}
LOG_INF("statistics %16s: #calls(b,g,a) = %4zu %6zu %6zu, #gen drafts = %6zu, #acc drafts = %5zu, #gen tokens = %6zu, #acc tokens = %5zu%s%s\n",
SPC_TRC("statistics %16s: #calls(b,g,a) = %4zu %6zu %6zu, #gen drafts = %6zu, #acc drafts = %5zu, #gen tokens = %6zu, #acc tokens = %5zu%s%s\n",
common_speculative_type_to_str(impl->type).c_str(),
impl->n_call_begin, impl->n_call_draft, impl->n_call_accept,
impl->n_gen_drafts,
+1
View File
@@ -50,6 +50,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
"DeepseekV2ForCausalLM": "deepseek",
"DeepseekV3ForCausalLM": "deepseek",
"DeepseekV32ForCausalLM": "deepseek",
"DFlashDraftModel": "qwen",
"DistilBertForMaskedLM": "bert",
"DistilBertForSequenceClassification": "bert",
"DistilBertModel": "bert",
+3 -3
View File
@@ -73,7 +73,7 @@ class LlamaModel(TextModel):
target_num_layers = target_config["num_hidden_layers"]
target_layers = [2, target_num_layers // 2, target_num_layers - 3]
logger.info(f"EAGLE-3: target_layers = {target_layers} (target model has {target_num_layers} layers)")
self.gguf_writer.add_array(f"{self.gguf_writer.arch}.target_layers", target_layers)
self.gguf_writer.add_target_layers(target_layers)
# target_hidden_size: prefer eagle3 config, fallback to target config
if eagle3_raw_config.get("target_hidden_size") is not None:
@@ -83,12 +83,12 @@ class LlamaModel(TextModel):
target_hidden_size = target_config["hidden_size"]
src = "target model config"
logger.info(f"EAGLE-3: target_hidden_size = {target_hidden_size} (from {src})")
self.gguf_writer.add_uint32(f"{self.gguf_writer.arch}.target_hidden_size", target_hidden_size)
self.gguf_writer.add_target_hidden_size(target_hidden_size)
# norm_before_residual (RedHat-style eagle3 specific)
norm_before_residual = eagle3_raw_config.get("norm_before_residual", False)
logger.info(f"EAGLE-3: norm_before_residual = {norm_before_residual}")
self.gguf_writer.add_bool(f"{self.gguf_writer.arch}.norm_before_residual", norm_before_residual)
self.gguf_writer.add_norm_before_residual(norm_before_residual)
def set_vocab(self):
# eagle3: use tokenizer from target model if provided
+48
View File
@@ -625,3 +625,51 @@ class Qwen3_5TextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReor
@ModelBase.register("Qwen3_5MoeForConditionalGeneration", "Qwen3_5MoeForCausalLM")
class Qwen3_5MoeTextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReorderBase):
model_arch = gguf.MODEL_ARCH.QWEN35MOE
@ModelBase.register("DFlashDraftModel")
class DFlashModel(Qwen3Model):
model_arch = gguf.MODEL_ARCH.DFLASH
def set_vocab(self):
if self.target_model_dir is None:
raise ValueError(
"DFlash draft model requires --target-model-dir to be specified. "
"Please provide the path to the target model directory containing the tokenizer."
)
logger.info(f"DFlash: Using tokenizer from target model: {self.target_model_dir}")
original_dir = self.dir_model
self.dir_model = self.target_model_dir
super().set_vocab()
self.dir_model = original_dir
mask_token_id = self.hparams.get("dflash_config", {}).get("mask_token_id")
if mask_token_id is not None:
self.gguf_writer.add_mask_token_id(mask_token_id)
def set_gguf_parameters(self):
super().set_gguf_parameters()
block_size = self.hparams.get("block_size", 16)
self.gguf_writer.add_block_size(block_size)
dflash_config = self.hparams.get("dflash_config", {})
target_layer_ids = dflash_config.get("target_layer_ids", [])
if target_layer_ids:
extract_layer_ids = [i + 1 for i in target_layer_ids]
self.gguf_writer.add_target_layers(extract_layer_ids)
use_sliding_window = self.hparams.get("use_sliding_window", False)
sliding_window = self.hparams.get("sliding_window")
layer_types = self.hparams.get("layer_types")
if use_sliding_window and sliding_window and layer_types:
is_swa = [lt == "sliding_attention" for lt in layer_types]
self.gguf_writer.add_sliding_window(sliding_window)
self.gguf_writer.add_sliding_window_pattern(is_swa)
@classmethod
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
name, gen = item
if not name.startswith("model."):
name = "model." + name
return super().filter_tensors((name, gen))
+28 -1
View File
@@ -52,6 +52,32 @@ Supported EAGLE-3 draft models include:
For the full and up-to-date list of supported models, see #18039.
### DFlash (`draft-dflash`)
DFlash produces an entire block of draft tokens in a single forward pass (block diffusion) and
injects the target model's hidden states into the draft model's attention, instead of drafting one
token at a time. This keeps the draft model small while making drafting GPU-friendly. Unlike EAGLE-3
(a single-layer autoregressive draft), the DFlash draft uses several transformer layers but emits a
whole block per draft step.
The draft is a small block-diffusion model trained for a specific target (for example
`z-lab/Qwen3-4B-DFlash` for `Qwen/Qwen3-4B`). Convert it with `--target-model-dir` so it inherits the
target's tokenizer and token embeddings:
```bash
python convert_hf_to_gguf.py z-lab/Qwen3-4B-DFlash \
--target-model-dir Qwen/Qwen3-4B --outtype bf16 --outfile Qwen3-4B-DFlash.gguf
llama-server -m Qwen3-4B.gguf -md Qwen3-4B-DFlash.gguf \
--spec-type draft-dflash --spec-draft-n-max 15 -fa on --jinja
```
`--spec-draft-n-max` is clamped to the draft model's trained block size.
See:
- #22105
### n-gram Cache (`ngram-cache`)
An n-gram is a sequence of n tokens. The n-gram cache implementation maintains statistics about short n-gram sequences.
@@ -147,7 +173,7 @@ If a draft model is combined with a draftless decoding the draftless decoding ha
### General Speculative Parameters
```
--spec-type [none|draft-simple|draft-eagle3|draft-mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]
--spec-type [none|draft-simple|draft-eagle3|draft-dflash|draft-mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]
comma-separated list of types of speculative decoding to use
(default: none)
(env: LLAMA_ARG_SPEC_TYPE)
@@ -287,6 +313,7 @@ Specifies a comma-separated list of speculative decoding types to use.
| `none` | No speculative decoding (default) |
| `draft-simple` | Use a simple draft model for speculation |
| `draft-eagle3` | Use an EAGLE-3 draft model that reads the target's hidden states |
| `draft-dflash` | Use a DFlash block-diffusion draft model that emits a block per step |
| `draft-mtp` | Use Multi Token Prediction (MTP) heads from the main model |
| `ngram-cache` | Use n-gram cache lookup |
| `ngram-simple` | Use simple n-gram pattern matching |
+3
View File
@@ -340,6 +340,9 @@ set(GGML_PUBLIC_HEADERS
include/gguf.h)
set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
#if (GGML_METAL)
# set_target_properties(ggml PROPERTIES RESOURCE "${CMAKE_CURRENT_SOURCE_DIR}/src/ggml-metal.metal")
#endif()
install(TARGETS ggml LIBRARY PUBLIC_HEADER)
install(TARGETS ggml-base LIBRARY)
+45
View File
@@ -386,6 +386,46 @@ static void ggml_cpy_f32_iq4_nl_cuda(
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
// check if a same-type copy reduces to a 2D strided copy (height rows of width
// contiguous bytes), so it can use cudaMemcpy2DAsync instead of the scalar kernel
static bool ggml_cuda_cpy_as_memcpy_2d(const ggml_tensor * src0, const ggml_tensor * src1,
size_t & width, size_t & height, size_t & spitch, size_t & dpitch) {
// require matching shape: a reshaped copy maps elements by flat order, which the
// prefix walk below does not handle
if (src0->type != src1->type || !ggml_are_same_shape(src0, src1)) {
return false;
}
// grow the contiguous prefix block shared by both tensors
size_t block_nb = ggml_element_size(src0);
int d = 0;
for (; d < GGML_MAX_DIMS; ++d) {
if (src0->nb[d] != block_nb || src1->nb[d] != block_nb) {
break;
}
block_nb *= src0->ne[d];
}
// d == 0: nothing contiguous; d == GGML_MAX_DIMS: fully contiguous (handled by memcpy)
if (d == 0 || d == GGML_MAX_DIMS) {
return false;
}
// dim d carries the rows; everything above it must be a single element
for (int i = d + 1; i < GGML_MAX_DIMS; ++i) {
if (src0->ne[i] != 1) {
return false;
}
}
width = block_nb;
height = src0->ne[d];
spitch = src0->nb[d];
dpitch = src1->nb[d];
return spitch >= width && dpitch >= width;
}
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1));
@@ -421,6 +461,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) &&
src0->ne[3] == 1 && nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0);
size_t mc_width = 0, mc_height = 0, mc_spitch = 0, mc_dpitch = 0;
if (src0->type == src1->type && contiguous_srcs) {
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
@@ -431,6 +473,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
{
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
}
} else if (ggml_cuda_cpy_as_memcpy_2d(src0, src1, mc_width, mc_height, mc_spitch, mc_dpitch)) {
CUDA_CHECK(cudaMemcpy2DAsync(src1_ddc, mc_dpitch, src0_ddc, mc_spitch,
mc_width, mc_height, cudaMemcpyDeviceToDevice, main_stream));
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
if (can_be_transposed) {
ggml_cpy_scalar_cuda<float, float, true>
+51 -119
View File
@@ -24,119 +24,62 @@ if (GGML_METAL_NDEBUG)
endif()
set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h")
set(METALLIB_KERNELS_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/kernels/common.h")
set(METALLIB_KERNELS_DEQUANTIZE "${CMAKE_CURRENT_SOURCE_DIR}/kernels/dequantize.h")
set(METALLIB_KERNELS_QUANTIZE "${CMAKE_CURRENT_SOURCE_DIR}/kernels/quantize.h")
set(METALLIB_KERNEL_SOURCES
kernels/fa.metal
kernels/mul_mv.metal
kernels/mul_mm.metal
kernels/quantize.metal
kernels/softmax.metal
kernels/norm.metal
kernels/unary.metal
kernels/binbcast.metal
kernels/reduce.metal
kernels/tri.metal
kernels/ssm.metal
kernels/wkv.metal
kernels/gated_delta_net.metal
kernels/solve_tri.metal
kernels/rope.metal
kernels/conv.metal
kernels/upscale.metal
kernels/argsort.metal
kernels/pool.metal
kernels/misc.metal
)
if (GGML_METAL_EMBED_LIBRARY)
enable_language(ASM)
add_compile_definitions(GGML_METAL_EMBED_LIBRARY)
set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h")
set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal")
set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h")
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/autogenerated")
set(METALLIB_EMBED_ASM_FILES "")
foreach(src ${METALLIB_KERNEL_SOURCES})
get_filename_component(kind ${src} NAME_WE)
# symbol names must be valid C identifiers ('-' is not allowed)
string(REPLACE "-" "_" kind_sym ${kind})
# merge ggml-common.h and ggml-metal.metal into a single file
set(METALLIB_EMBED_ASM "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.s")
set(METALLIB_SOURCE_EMBED "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal")
set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp")
set(SRC "${CMAKE_CURRENT_SOURCE_DIR}/kernels/${kind}.metal")
set(EMBED "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed-${kind}.metal")
set(ASM "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed-${kind}.s")
add_custom_command(
OUTPUT "${METALLIB_EMBED_ASM}"
COMMAND echo "Embedding Metal library"
COMMAND sed -e "/__embed_ggml-common.h__/r ${METALLIB_COMMON}" -e "/__embed_ggml-common.h__/d" < "${METALLIB_SOURCE}" > "${METALLIB_SOURCE_EMBED_TMP}"
COMMAND sed -e "/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}" -e "/\#include \"ggml-metal-impl.h\"/d" < "${METALLIB_SOURCE_EMBED_TMP}" > "${METALLIB_SOURCE_EMBED}"
COMMAND echo ".section __DATA,__ggml_metallib" > "${METALLIB_EMBED_ASM}"
COMMAND echo ".globl _ggml_metallib_start" >> "${METALLIB_EMBED_ASM}"
COMMAND echo "_ggml_metallib_start:" >> "${METALLIB_EMBED_ASM}"
COMMAND echo .incbin "\"${METALLIB_SOURCE_EMBED}\"" >> "${METALLIB_EMBED_ASM}"
COMMAND echo ".globl _ggml_metallib_end" >> "${METALLIB_EMBED_ASM}"
COMMAND echo "_ggml_metallib_end:" >> "${METALLIB_EMBED_ASM}"
DEPENDS ../ggml-common.h ggml-metal.metal ggml-metal-impl.h
COMMENT "Generate assembly for embedded Metal library"
VERBATIM
)
# only prepend headers that this source actually includes
set(HEADERS_FOR_SRC ${METALLIB_KERNELS_COMMON})
file(STRINGS ${SRC} _has_dequantize REGEX "#include \"dequantize\\.h\"")
file(STRINGS ${SRC} _has_quantize REGEX "#include \"quantize\\.h\"")
if(_has_dequantize)
list(APPEND HEADERS_FOR_SRC ${METALLIB_KERNELS_DEQUANTIZE})
endif()
if(_has_quantize)
list(APPEND HEADERS_FOR_SRC ${METALLIB_KERNELS_QUANTIZE})
endif()
add_custom_command(
OUTPUT "${ASM}"
# Step 1: concatenate shared headers + this kernel source
COMMAND cat ${HEADERS_FOR_SRC} ${SRC} > "${EMBED}.tmp1"
# Step 2: remove internal #include and #pragma once
COMMAND sed -e "/\#include \"common.h\"/d" -e "/\#include \"dequantize.h\"/d" -e "/\#include \"quantize.h\"/d" -e "/\#pragma once/d" < "${EMBED}.tmp1" > "${EMBED}.tmp2"
# Step 3: inline ggml-common.h (replacing __embed_ggml-common.h__ sentinel)
COMMAND sed -e "/__embed_ggml-common.h__/r ${METALLIB_COMMON}" -e "/__embed_ggml-common.h__/d" < "${EMBED}.tmp2" > "${EMBED}.tmp3"
# Step 4: inline ggml-metal-impl.h
COMMAND sed -e "/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}" -e "/\#include \"ggml-metal-impl.h\"/d" < "${EMBED}.tmp3" > "${EMBED}"
# Step 5: emit an asm chunk with kind-specific start/end symbols
# note: '-' is illegal in C symbols, so we use kind_sym; the macOS
# section name is limited to 16 chars so we keep it shared
# across kinds (__ggml_metallib) and only vary the global symbols.
COMMAND echo ".section __DATA,__ggml_metallib" > "${ASM}"
COMMAND echo ".globl _ggml_metallib_${kind_sym}_start" >> "${ASM}"
COMMAND echo "_ggml_metallib_${kind_sym}_start:" >> "${ASM}"
COMMAND echo .incbin "\"${EMBED}\"" >> "${ASM}"
COMMAND echo ".globl _ggml_metallib_${kind_sym}_end" >> "${ASM}"
COMMAND echo "_ggml_metallib_${kind_sym}_end:" >> "${ASM}"
DEPENDS ../ggml-common.h ggml-metal-impl.h
kernels/common.h kernels/dequantize.h kernels/quantize.h
kernels/${kind}.metal
COMMENT "Generate embedded Metal library for ${kind}"
VERBATIM
)
list(APPEND METALLIB_EMBED_ASM_FILES "${ASM}")
endforeach()
target_sources(ggml-metal PRIVATE ${METALLIB_EMBED_ASM_FILES})
target_sources(ggml-metal PRIVATE "${METALLIB_EMBED_ASM}")
else()
# copy header files to bin directory
# copy metal files to bin directory
configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY)
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY)
file(MAKE_DIRECTORY "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels")
configure_file(kernels/common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels/common.h COPYONLY)
configure_file(kernels/dequantize.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels/dequantize.h COPYONLY)
configure_file(kernels/quantize.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels/quantize.h COPYONLY)
foreach(src ${METALLIB_KERNEL_SOURCES})
configure_file(${src} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${src} COPYONLY)
endforeach()
if (GGML_METAL_SHADER_DEBUG)
# note: disabling fast math is needed in order to pass tests/test-backend-ops
# custom command to do the following:
# xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air
# xcrun -sdk macosx metallib ggml-metal.air -o default.metallib
#
# note: this is the only way I found to disable fast-math in Metal. it's ugly, but at least it works
# disabling fast math is needed in order to pass tests/test-backend-ops
# note: adding -fno-inline fixes the tests when using MTL_SHADER_VALIDATION=1
# note: unfortunately, we have to call it default.metallib instead of ggml.metallib
# ref: https://github.com/ggml-org/whisper.cpp/issues/1720
# note: adding -g causes segmentation fault during compile
#set(XC_FLAGS -fno-fast-math -fno-inline -g)
set(XC_FLAGS -fno-fast-math -fno-inline)
else()
set(XC_FLAGS -O3)
endif()
# Append macOS metal versioning flags
if (GGML_METAL_MACOSX_VERSION_MIN)
message(STATUS "Adding -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN} flag to metal compilation")
list (APPEND XC_FLAGS -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN})
@@ -147,46 +90,35 @@ else()
list (APPEND XC_FLAGS -std=${GGML_METAL_STD})
endif()
# Compile each kernel source to .air, then link into default.metallib
set(AIR_FILES "")
foreach(src ${METALLIB_KERNEL_SOURCES})
get_filename_component(name ${src} NAME_WE)
set(AIR "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${name}.air")
list(APPEND AIR_FILES ${AIR})
add_custom_command(
OUTPUT ${AIR}
COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -I ${CMAKE_RUNTIME_OUTPUT_DIRECTORY} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${src} -o ${AIR}
DEPENDS ${src} kernels/common.h kernels/dequantize.h kernels/quantize.h ${METALLIB_COMMON} ggml-metal-impl.h
COMMENT "Compiling ${src}"
VERBATIM
)
endforeach()
add_custom_command(
OUTPUT ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
COMMAND xcrun -sdk macosx metallib ${AIR_FILES} -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal -o - |
xcrun -sdk macosx metallib - -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h
COMMAND rm -rf ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels
DEPENDS ${AIR_FILES}
COMMENT "Linking Metal kernels into default.metallib"
)
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal
DEPENDS ggml-metal.metal ${METALLIB_COMMON}
COMMENT "Compiling Metal kernels"
)
# FIXME: only add to the ggml-metal target?
add_custom_target(
ggml-metal-lib ALL
DEPENDS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
)
)
endif() # GGML_METAL_EMBED_LIBRARY
if (NOT GGML_METAL_EMBED_LIBRARY)
install(
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/kernels/
DESTINATION ${CMAKE_INSTALL_BINDIR}/kernels
FILES_MATCHING PATTERN "*.metal" PATTERN "*.h"
)
FILES src/ggml-metal/ggml-metal.metal
PERMISSIONS
OWNER_READ
OWNER_WRITE
GROUP_READ
WORLD_READ
DESTINATION ${CMAKE_INSTALL_BINDIR})
install(
FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
DESTINATION ${CMAKE_INSTALL_BINDIR}
)
install(
FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
DESTINATION ${CMAKE_INSTALL_BINDIR}
)
endif()
+127 -422
View File
@@ -94,63 +94,8 @@ int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_wi
return pipeline.pipeline->obj.maxTotalThreadsPerThreadgroup;
}
//
// MTLLibrary collection (one library per op-source, compiled separately)
//
// Single source of truth for the per-kind metal libraries. The order here
// defines the enum values and every per-kind table below, so adding a library
// is a one-line change here (plus adding its source to CMakeLists.txt).
// X(suffix, name): name is both the kernels/<name>.metal basename and the
// ggml_metallib_<name>_{start,end} embed-symbol stem.
#define GGML_METAL_LIBS \
X(FA, fa) \
X(MUL_MV, mul_mv) \
X(MUL_MM, mul_mm) \
X(QUANTIZE, quantize) \
X(SOFTMAX, softmax) \
X(NORM, norm) \
X(UNARY, unary) \
X(BINBCAST, binbcast) \
X(REDUCE, reduce) \
X(TRI, tri) \
X(SSM, ssm) \
X(WKV, wkv) \
X(GATED_DELTA_NET, gated_delta_net)\
X(SOLVE_TRI, solve_tri) \
X(ROPE, rope) \
X(CONV, conv) \
X(UPSCALE, upscale) \
X(ARGSORT, argsort) \
X(POOL, pool) \
X(MISC, misc)
enum ggml_metal_lib_kind {
#define X(e, s) GGML_METAL_LIB_##e,
GGML_METAL_LIBS
#undef X
GGML_METAL_LIB_COUNT,
};
static const char * const k_lib_names[GGML_METAL_LIB_COUNT] = {
#define X(e, s) [GGML_METAL_LIB_##e] = #s,
GGML_METAL_LIBS
#undef X
};
struct ggml_metal_library {
// Per-kind compiled libraries. When single_library is true, the whole library
// (e.g. a pre-compiled default.metallib or a from-source build) lives at
// objs[0] and the remaining slots are nil.
id<MTLLibrary> objs[GGML_METAL_LIB_COUNT];
bool single_library; // true: combined library at objs[0]; false: per-kind libs in objs[*]
// Routing table: kernel function name -> objs[] index, populated from each
// compiled library's -[MTLLibrary functionNames]. The actual compiled
// libraries are the single source of truth for which library owns a kernel,
// so adding kernels later requires no manual routing maintenance.
// nil in single_library mode (everything resolves to objs[0]).
NSMutableDictionary<NSString *, NSNumber *> * fn_to_lib;
id<MTLLibrary> obj;
ggml_metal_device_t dev;
ggml_metal_pipelines_t pipelines; // cache of compiled pipelines
@@ -158,376 +103,160 @@ struct ggml_metal_library {
NSLock * lock;
};
// Build the fn_to_lib routing table by querying each compiled library's public
// function names. Call once after all per-kind libraries have been compiled.
static void ggml_metal_library_build_index(ggml_metal_library_t lib) {
@autoreleasepool {
NSMutableDictionary<NSString *, NSNumber *> * index = [[NSMutableDictionary alloc] init];
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
for (NSString * fname in [lib->objs[kind] functionNames]) {
index[fname] = @(kind);
}
}
lib->fn_to_lib = index;
}
}
ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
id<MTLLibrary> library = nil;
id<MTLDevice> device = ggml_metal_device_get_obj(dev);
// Parse a `#include "name"` line. Returns the quoted name in *include_name on
// success. Whitespace-tolerant; ignores `#include <...>` (system headers).
static bool ggml_metal_library_parse_quoted_include(NSString * line, NSString ** include_name) {
NSScanner * scanner = [NSScanner scannerWithString:line];
scanner.charactersToBeSkipped = [NSCharacterSet whitespaceCharacterSet];
// load library
//
// - first check if the library is embedded
// - then check if the library is in the bundle
// - if not found, load the source and compile it
// - if that fails, return NULL
//
// TODO: move to a function
{
const int64_t t_start = ggml_time_us();
if (![scanner scanString:@"#" intoString:NULL] ||
![scanner scanString:@"include" intoString:NULL] ||
![scanner scanString:@"\"" intoString:NULL]) {
return false;
}
NSError * error = nil;
NSString * src = nil;
NSString * name = nil;
if (![scanner scanUpToString:@"\"" intoString:&name]) {
return false;
}
#if GGML_METAL_EMBED_LIBRARY
GGML_LOG_INFO("%s: using embedded metal library\n", __func__);
if (include_name) {
*include_name = name;
}
return true;
}
extern const char ggml_metallib_start[];
extern const char ggml_metallib_end[];
// Recursively inline `#include "name"` directives. System includes (<...>),
// `#if/#else/#endif`, and other preprocessor lines are passed through to the
// Metal compiler unchanged. `#pragma once` is dropped since `seen` already
// guards against double-inclusion.
static bool ggml_metal_library_flatten_file(NSMutableString * dst, NSString * path,
NSArray<NSString *> * search_paths,
NSMutableSet<NSString *> * seen, NSError ** error) {
NSString * key = [path stringByStandardizingPath];
if ([seen containsObject:key]) {
return true;
}
[seen addObject:key];
src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
#else
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:error];
if (!src) {
return false;
}
#ifdef SWIFT_PACKAGE
NSBundle * bundle = SWIFTPM_MODULE_BUNDLE;
#else
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
#endif
NSFileManager * fm = [NSFileManager defaultManager];
for (NSString * line in [src componentsSeparatedByString:@"\n"]) {
NSString * trimmed = [line stringByTrimmingCharactersInSet:[NSCharacterSet whitespaceCharacterSet]];
if ([trimmed isEqualToString:@"#pragma once"]) {
continue;
}
NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
if (path_lib == nil) {
// Try to find the resource in the directory where the current binary located.
NSString * bin_cur = [[NSProcessInfo processInfo] arguments][0];
NSString * bin_dir = [bin_cur stringByDeletingLastPathComponent];
NSString * include_name = nil;
if (ggml_metal_library_parse_quoted_include(line, &include_name)) {
NSString * resolved = nil;
for (NSString * dir in search_paths) {
NSString * candidate = [dir stringByAppendingPathComponent:include_name];
if ([fm isReadableFileAtPath:candidate]) {
resolved = candidate;
break;
NSString * path_lib_default = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
if ([[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
GGML_LOG_INFO("%s: found '%s'\n", __func__, [path_lib_default UTF8String]);
NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:path_lib_default error:&error];
if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
// Optionally, if this is a symlink, try to resolve it.
path_lib_default = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:path_lib_default error:&error];
if (path_lib_default && [path_lib_default length] > 0 && ![[path_lib_default substringToIndex:1] isEqualToString:@"/"]) {
// It is a relative path, adding the binary directory as directory prefix.
path_lib_default = [NSString pathWithComponents:@[bin_dir, path_lib_default]];
}
if (!path_lib_default || ![[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
// Link to the resource could not be resolved.
path_lib_default = nil;
} else {
GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [path_lib_default UTF8String]);
}
}
} else {
// The resource couldn't be found in the binary's directory.
path_lib_default = nil;
}
if (!resolved) {
if (error) {
NSString * msg = [NSString stringWithFormat:@"could not resolve include \"%@\" from '%@'", include_name, path];
*error = [NSError errorWithDomain:@"ggml-metal-source-flatten" code:1
userInfo:@{NSLocalizedDescriptionKey: msg}];
}
return false;
}
if (!ggml_metal_library_flatten_file(dst, resolved, search_paths, seen, error)) {
return false;
}
continue;
path_lib = path_lib_default;
}
[dst appendString:line];
[dst appendString:@"\n"];
}
if (path_lib != nil) {
// pre-compiled library found
NSURL * libURL = [NSURL fileURLWithPath:path_lib];
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
return true;
}
library = [device newLibraryWithURL:libURL error:&error];
if (error) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
return nil;
}
} else {
GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
static NSString * ggml_metal_library_flatten_source(NSString * path_source, NSError ** error) {
// Search paths cover both runtime layout (build/bin/kernels + build/bin)
// and source-tree layout (ggml/src/ggml-metal/kernels + ggml/src/ggml-metal + ggml/src).
NSString * path_kernels = [path_source stringByDeletingLastPathComponent];
NSString * path_base = [path_kernels stringByDeletingLastPathComponent];
NSArray<NSString *> * search_paths = @[
path_kernels,
path_base,
[path_base stringByDeletingLastPathComponent],
];
NSString * path_source;
NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
NSMutableString * src = [[NSMutableString alloc] init];
NSMutableSet<NSString *> * seen = [NSMutableSet set];
GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil");
if (!ggml_metal_library_flatten_file(src, path_source, search_paths, seen, error)) {
[src release];
return nil;
}
return src;
}
// Compile all per-kind libraries in parallel. `source_for_kind` returns the MSL
// source for a kind (the helper takes ownership and releases it), or nil with
// *err set on failure. On success the objs[] slots are populated and the routing
// index is built; on any failure every error is logged and false is returned
// (the caller is responsible for freeing `res`).
static bool ggml_metal_library_compile_all(
ggml_metal_library_t res,
id<MTLDevice> device,
NSDictionary * prep,
NSString * (^source_for_kind)(int kind, NSError ** err),
const char * origin) {
const int64_t t_start = ggml_time_us();
int64_t * t_per_lib = calloc(GGML_METAL_LIB_COUNT, sizeof(int64_t));
NSError ** err_per_lib = calloc(GGML_METAL_LIB_COUNT, sizeof(NSError *));
__block atomic_bool any_failure = false;
dispatch_group_t group = dispatch_group_create();
dispatch_queue_t queue = dispatch_get_global_queue(QOS_CLASS_USER_INITIATED, 0);
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
dispatch_group_async(group, queue, ^{
const int64_t t0 = ggml_time_us();
NSError * error = nil;
NSString * src = source_for_kind(kind, &error);
if (!src) {
err_per_lib[kind] = [error retain];
atomic_store(&any_failure, true);
return;
if (path_resource) {
path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"];
} else {
path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
}
id<MTLLibrary> lib = nil;
if (path_source == nil) {
GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
path_source = @"ggml-metal.metal";
}
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]);
src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error];
if (error) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
return nil;
}
}
#endif
if (!library) {
@autoreleasepool {
// dictionary of preprocessor macros
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
if (ggml_metal_device_get_props(dev)->has_bfloat) {
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_BF16"];
}
if (ggml_metal_device_get_props(dev)->has_tensor) {
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_TENSOR"];
}
#if GGML_METAL_EMBED_LIBRARY
[prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"];
#endif
MTLCompileOptions * options = [MTLCompileOptions new];
options.preprocessorMacros = prep;
lib = [device newLibraryWithSource:src options:options error:&error];
//[options setFastMathEnabled:false];
[options release];
// retain the error before the autorelease pool drains it
if (!lib) {
err_per_lib[kind] = [error retain];
library = [device newLibraryWithSource:src options:options error:&error];
if (error) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
return nil;
}
}
[src release];
t_per_lib[kind] = ggml_time_us() - t0;
if (!lib) {
atomic_store(&any_failure, true);
return;
}
res->objs[kind] = lib;
});
}
dispatch_group_wait(group, DISPATCH_TIME_FOREVER);
dispatch_release(group);
const bool ok = !atomic_load(&any_failure);
if (ok) {
const int64_t t_total = ggml_time_us() - t_start;
int64_t t_max = 0;
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
GGML_LOG_DEBUG("%s: compiled '%s' library in %.3f sec\n",
__func__, k_lib_names[kind], t_per_lib[kind] / 1e6);
if (t_per_lib[kind] > t_max) t_max = t_per_lib[kind];
}
GGML_LOG_INFO("%s: loaded %d libraries from %s in %.3f sec (max single = %.3f sec)\n",
__func__, GGML_METAL_LIB_COUNT, origin, t_total / 1e6, t_max / 1e6);
ggml_metal_library_build_index(res);
} else {
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
if (err_per_lib[kind]) {
GGML_LOG_ERROR("%s: failed to build '%s' library: %s\n", __func__,
k_lib_names[kind], [[err_per_lib[kind] description] UTF8String]);
[err_per_lib[kind] release];
#if !__has_feature(objc_arc)
[options release];
#endif
}
}
#if GGML_METAL_EMBED_LIBRARY
[src release];
#endif // GGML_METAL_EMBED_LIBRARY
GGML_LOG_INFO("%s: loaded in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6);
}
free(err_per_lib);
free(t_per_lib);
return ok;
}
ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
id<MTLDevice> device = ggml_metal_device_get_obj(dev);
ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
res->obj = library;
res->dev = dev;
res->pipelines = ggml_metal_pipelines_init();
res->lock = [NSLock new];
// shared MTLCompileOptions preprocessor macros (matches the build-time defines)
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
if (ggml_metal_device_get_props(dev)->has_bfloat) {
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_BF16"];
}
if (ggml_metal_device_get_props(dev)->has_tensor) {
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_TENSOR"];
}
#if GGML_METAL_EMBED_LIBRARY
[prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"];
#endif
#if GGML_METAL_EMBED_LIBRARY
GGML_LOG_INFO("%s: using embedded metal library\n", __func__);
// start/end symbols emitted by CMake (see CMakeLists.txt), one pair per kind
#define X(e, s) extern const char ggml_metallib_##s##_start[]; extern const char ggml_metallib_##s##_end[];
GGML_METAL_LIBS
#undef X
static const char * const lib_start[GGML_METAL_LIB_COUNT] = {
#define X(e, s) [GGML_METAL_LIB_##e] = ggml_metallib_##s##_start,
GGML_METAL_LIBS
#undef X
};
static const char * const lib_end[GGML_METAL_LIB_COUNT] = {
#define X(e, s) [GGML_METAL_LIB_##e] = ggml_metallib_##s##_end,
GGML_METAL_LIBS
#undef X
};
const bool ok = ggml_metal_library_compile_all(res, device, prep,
^NSString * (int kind, NSError ** err) {
(void) err;
return [[NSString alloc] initWithBytes:lib_start[kind]
length:(lib_end[kind] - lib_start[kind])
encoding:NSUTF8StringEncoding];
}, "embedded data");
if (!ok) {
ggml_metal_library_free(res);
return NULL;
}
return res;
#else
#ifdef SWIFT_PACKAGE
NSBundle * bundle = SWIFTPM_MODULE_BUNDLE;
#else
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
#endif
const int64_t t_start = ggml_time_us();
NSError * error = nil;
NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
if (path_lib == nil) {
// Try to find the resource in the directory where the current binary located.
NSString * bin_cur = [[NSProcessInfo processInfo] arguments][0];
NSString * bin_dir = [bin_cur stringByDeletingLastPathComponent];
NSString * path_lib_default = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
if ([[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
GGML_LOG_INFO("%s: found '%s'\n", __func__, [path_lib_default UTF8String]);
NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:path_lib_default error:&error];
if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
// Optionally, if this is a symlink, try to resolve it.
path_lib_default = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:path_lib_default error:&error];
if (path_lib_default && [path_lib_default length] > 0 && ![[path_lib_default substringToIndex:1] isEqualToString:@"/"]) {
// It is a relative path, adding the binary directory as directory prefix.
path_lib_default = [NSString pathWithComponents:@[bin_dir, path_lib_default]];
}
if (!path_lib_default || ![[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
// Link to the resource could not be resolved.
path_lib_default = nil;
} else {
GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [path_lib_default UTF8String]);
}
}
} else {
// The resource couldn't be found in the binary's directory.
path_lib_default = nil;
}
path_lib = path_lib_default;
}
if (path_lib != nil) {
// pre-compiled library found: a single combined default.metallib
NSURL * libURL = [NSURL fileURLWithPath:path_lib];
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
res->objs[0] = [device newLibraryWithURL:libURL error:&error];
res->single_library = true;
if (!res->objs[0]) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
ggml_metal_library_free(res);
return NULL;
}
GGML_LOG_INFO("%s: loaded in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6);
return res;
}
// no pre-compiled metallib: fall back to compiling each kernel source separately
GGML_LOG_INFO("%s: default.metallib not found, loading kernel sources\n", __func__);
NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
if (path_resource) {
GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, [path_resource UTF8String]);
}
// resolve each kind's source path up front (file lookup/logging stays on the calling thread)
NSString ** path_per_kind = calloc(GGML_METAL_LIB_COUNT, sizeof(NSString *));
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
NSString * rel = [NSString stringWithFormat:@"kernels/%s.metal", k_lib_names[kind]];
NSString * path_source = nil;
if (path_resource) {
path_source = [path_resource stringByAppendingPathComponent:rel];
} else {
NSString * stem = [NSString stringWithFormat:@"kernels/%s", k_lib_names[kind]];
path_source = [bundle pathForResource:stem ofType:@"metal"];
}
if (path_source == nil || ![[NSFileManager defaultManager] isReadableFileAtPath:path_source]) {
GGML_LOG_WARN("%s: could not locate %s in bundle, falling back to cwd\n", __func__, [rel UTF8String]);
path_source = rel;
}
GGML_LOG_DEBUG("%s: loading '%s'\n", __func__, [path_source UTF8String]);
path_per_kind[kind] = [path_source retain];
}
const bool ok = ggml_metal_library_compile_all(res, device, prep,
^NSString * (int kind, NSError ** err) {
return ggml_metal_library_flatten_source(path_per_kind[kind], err);
}, "source");
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
[path_per_kind[kind] release];
}
free(path_per_kind);
if (!ok) {
ggml_metal_library_free(res);
return NULL;
}
return res;
#endif
}
ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev, const char * source, bool verbose) {
@@ -589,11 +318,10 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev
return NULL;
}
res->objs[0] = library;
res->single_library = true;
res->dev = dev;
res->pipelines = ggml_metal_pipelines_init();
res->lock = [NSLock new];
res->obj = library;
res->dev = dev;
res->pipelines = ggml_metal_pipelines_init();
res->lock = [NSLock new];
return res;
}
@@ -603,14 +331,8 @@ void ggml_metal_library_free(ggml_metal_library_t lib) {
return;
}
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
if (lib->objs[kind]) {
[lib->objs[kind] release];
}
}
if (lib->fn_to_lib) {
[lib->fn_to_lib release];
if (lib->obj) {
[lib->obj release];
}
ggml_metal_pipelines_free(lib->pipelines);
@@ -671,28 +393,11 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_
GGML_LOG_DEBUG("%s: compiling pipeline: base = '%s', name = '%s'\n", __func__, base, name);
// route to the library that actually defines this kernel; fn_to_lib is
// built from -[MTLLibrary functionNames] so it's always in sync
int lib_idx = 0;
if (!lib->single_library) {
NSNumber * idx = lib->fn_to_lib[base_func];
if (!idx) {
[lib->lock unlock];
GGML_LOG_ERROR("%s: kernel not found in any metal library: base = '%s', name = '%s'\n", __func__, base, name);
return res;
}
lib_idx = [idx intValue];
}
id<MTLLibrary> mtl_lib = lib->objs[lib_idx];
id<MTLFunction> mtl_function;
if (!cv) {
mtl_function = [mtl_lib newFunctionWithName:base_func];
mtl_function = [lib->obj newFunctionWithName:base_func];
} else {
mtl_function = [mtl_lib newFunctionWithName:base_func constantValues:cv->obj error:&error];
mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error];
}
if (!mtl_function) {
[lib->lock unlock];
File diff suppressed because it is too large Load Diff
-232
View File
@@ -1,232 +0,0 @@
#include "common.h"
// bitonic sort implementation following the CUDA kernels as reference
typedef void (argsort_t)(
constant ggml_metal_kargs_argsort & args,
device const char * src0,
device int32_t * dst,
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]);
template<ggml_sort_order order>
kernel void kernel_argsort_f32_i32(
constant ggml_metal_kargs_argsort & args,
device const char * src0,
device int32_t * dst,
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
// bitonic sort
const int col = tpitg[0];
const int ib = tgpig[0] / args.ne01;
const int i00 = ib*ntg.x;
const int i01 = tgpig[0] % args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];
device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
// initialize indices
shmem_i32[col] = i00 + col;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int k = 2; k <= ntg.x; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (shmem_i32[col] >= args.ne00 ||
(shmem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]] :
src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]]))
) {
SWAP(shmem_i32[col], shmem_i32[ixj]);
}
} else {
if (shmem_i32[ixj] >= args.ne00 ||
(shmem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]] :
src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]]))
) {
SWAP(shmem_i32[col], shmem_i32[ixj]);
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
const int64_t i0 = ib*args.top_k;
// copy the result to dst without the padding
if (i0 + col < args.ne0 && col < args.top_k) {
dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03;
dst[col] = shmem_i32[col];
}
}
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
typedef void (argsort_merge_t)(
constant ggml_metal_kargs_argsort_merge & args,
device const char * src0,
device const int32_t * tmp,
device int32_t * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]);
template<ggml_sort_order order>
kernel void kernel_argsort_merge_f32_i32(
constant ggml_metal_kargs_argsort_merge & args,
device const char * src0,
device const int32_t * tmp,
device int32_t * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int im = tgpig[0] / args.ne01;
const int i01 = tgpig[0] % args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];
const int start = im * (2 * args.len);
const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
const int total = len0 + len1;
device const int32_t * tmp0 = tmp + start
+ i01*args.ne0
+ i02*args.ne0*args.ne01
+ i03*args.ne0*args.ne01*args.ne02;
device const int32_t * tmp1 = tmp0 + args.len;
dst += start
+ i01*args.top_k
+ i02*args.top_k*args.ne01
+ i03*args.top_k*args.ne01*args.ne02;
device const float * src0_row = (device const float *)(src0
+ args.nb01*i01
+ args.nb02*i02
+ args.nb03*i03);
if (total == 0) {
return;
}
const int chunk = (total + ntg.x - 1) / ntg.x;
const int k0 = tpitg.x * chunk;
const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
if (k0 >= args.top_k) {
return;
}
if (k0 >= total) {
return;
}
int low = k0 > len1 ? k0 - len1 : 0;
int high = MIN(k0, len0);
// binary-search partition (i, j) such that i + j = k
while (low < high) {
const int mid = (low + high) >> 1;
const int32_t idx0 = tmp0[mid];
const int32_t idx1 = tmp1[k0 - mid - 1];
const float val0 = src0_row[idx0];
const float val1 = src0_row[idx1];
bool take_left;
if (order == GGML_SORT_ORDER_ASC) {
take_left = (val0 <= val1);
} else {
take_left = (val0 >= val1);
}
if (take_left) {
low = mid + 1;
} else {
high = mid;
}
}
int i = low;
int j = k0 - i;
// keep the merge fronts into registers
int32_t idx0 = 0;
float val0 = 0.0f;
if (i < len0) {
idx0 = tmp0[i];
val0 = src0_row[idx0];
}
int32_t idx1 = 0;
float val1 = 0.0f;
if (j < len1) {
idx1 = tmp1[j];
val1 = src0_row[idx1];
}
for (int k = k0; k < k1; ++k) {
int32_t out_idx;
if (i >= len0) {
while (k < k1) {
dst[k++] = tmp1[j++];
}
break;
} else if (j >= len1) {
while (k < k1) {
dst[k++] = tmp0[i++];
}
break;
} else {
bool take_left;
if (order == GGML_SORT_ORDER_ASC) {
take_left = (val0 <= val1);
} else {
take_left = (val0 >= val1);
}
if (take_left) {
out_idx = idx0;
++i;
if (i < len0) {
idx0 = tmp0[i];
val0 = src0_row[idx0];
}
} else {
out_idx = idx1;
++j;
if (j < len1) {
idx1 = tmp1[j];
val1 = src0_row[idx1];
}
}
}
dst[k] = out_idx;
}
}
template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
-226
View File
@@ -1,226 +0,0 @@
#include "common.h"
// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
constant short FC_bin_f [[function_constant(FC_BIN + 1)]];
constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]];
constant bool FC_bin_cb [[function_constant(FC_BIN + 3)]];
template <typename T0, typename T1, typename T>
kernel void kernel_bin_fuse_impl(
constant ggml_metal_kargs_bin & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
#define FC_OP FC_bin_op
#define FC_F FC_bin_f
#define FC_RB FC_bin_rb
#define FC_CB FC_bin_cb
if (FC_RB) {
// row broadcast
const uint i0 = tgpig.y*args.ne00 + tgpig.x;
const uint i1 = FC_CB ? tgpig.x%args.ne10 : tgpig.x;
device const T0 * src0_row = (device const T0 *) (src0);
device T * dst_row = (device T *) (dst);
if (FC_F == 1) {
device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]);
if (FC_OP == 0) {
dst_row[i0] = src0_row[i0] + src1_row[i1];
}
if (FC_OP == 1) {
dst_row[i0] = src0_row[i0] - src1_row[i1];
}
if (FC_OP == 2) {
dst_row[i0] = src0_row[i0] * src1_row[i1];
}
if (FC_OP == 3) {
dst_row[i0] = src0_row[i0] / src1_row[i1];
}
} else {
T0 res = src0_row[i0];
if (FC_OP == 0) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res += ((device const T1 *) (src1 + args.o1[j]))[i1];
}
}
if (FC_OP == 1) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res -= ((device const T1 *) (src1 + args.o1[j]))[i1];
}
}
if (FC_OP == 2) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res *= ((device const T1 *) (src1 + args.o1[j]))[i1];
}
}
if (FC_OP == 3) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res /= ((device const T1 *) (src1 + args.o1[j]))[i1];
}
}
dst_row[i0] = res;
}
} else {
const int i03 = tgpig.z;
const int i02 = tgpig.y;
const int i01 = tgpig.x;
if (i01 >= args.ne01) {
return;
}
const int i13 = i03%args.ne13;
const int i12 = i02%args.ne12;
const int i11 = i01%args.ne11;
device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
device T * dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
if (FC_F == 1) {
device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = FC_CB ? i0%args.ne10 : i0;
if (FC_OP == 0) {
dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10];
}
if (FC_OP == 1) {
dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10];
}
if (FC_OP == 2) {
dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10];
}
if (FC_OP == 3) {
dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10];
}
}
} else {
device const T1 * src1_ptr[8];
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
}
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = FC_CB ? i0%args.ne10 : i0;
T res = src0_ptr[i0];
if (FC_OP == 0) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res += src1_ptr[j][i10];
}
}
if (FC_OP == 1) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res -= src1_ptr[j][i10];
}
}
if (FC_OP == 2) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res *= src1_ptr[j][i10];
}
}
if (FC_OP == 3) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res /= src1_ptr[j][i10];
}
}
dst_ptr[i0] = res;
}
}
}
#undef FC_OP
#undef FC_F
#undef FC_RB
#undef FC_CB
}
typedef decltype(kernel_bin_fuse_impl<float, float, float>) kernel_bin_fuse_t;
template [[host_name("kernel_bin_fuse_f32_f32_f32")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float, float, float>;
template [[host_name("kernel_bin_fuse_f32_f32_f32_4")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float4, float4, float4>;
kernel void kernel_add_id(
constant ggml_metal_kargs_add_id & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i1 = tgpig.x;
const int i2 = tgpig.y;
const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
const size_t nb1 = args.ne0 * sizeof(float);
const size_t nb2 = args.ne1 * nb1;
device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02);
device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
dst_row[i0] = src0_row[i0] + src1_row[i0];
}
}
template<typename T>
kernel void kernel_repeat(
constant ggml_metal_kargs_repeat & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = tgpig.x;
const int i03 = i3%args.ne03;
const int i02 = i2%args.ne02;
const int i01 = i1%args.ne01;
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i00 = i0%args.ne00;
*((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
}
}
typedef decltype(kernel_repeat<float>) kernel_repeat_t;
template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_repeat_bf16")]] kernel kernel_repeat_t kernel_repeat<bfloat>;
#endif
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
-126
View File
@@ -1,126 +0,0 @@
#pragma once
#include "ggml-metal-impl.h"
#include <metal_stdlib>
#ifdef GGML_METAL_HAS_TENSOR
#include <metal_tensor>
#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>
#endif
using namespace metal;
#define MAX(x, y) ((x) > (y) ? (x) : (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
#define PAD2(x, n) (((x) + (n) - 1) & ~((n) - 1))
#define FOR_UNROLL(x) _Pragma("clang loop unroll(full)") for (x)
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
//
// cmd:
// .../usr/bin/metal -dM -E -c ggml/src/ggml-metal/kernels/<src>.metal
// .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal/kernels/<src>.metal
//
#if __METAL_VERSION__ < 310 && defined(GGML_METAL_HAS_BF16)
#undef GGML_METAL_HAS_BF16
#endif
#if defined(GGML_METAL_HAS_BF16)
typedef matrix<bfloat, 4, 4> bfloat4x4;
typedef matrix<bfloat, 2, 4> bfloat2x4;
#endif
constexpr constant static float kvalues_iq4nl_f[16] = {
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
};
constexpr constant static float kvalues_mxfp4_f[16] = {
0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
};
static inline int best_index_int8(int n, constant float * val, float x) {
if (x <= val[0]) return 0;
if (x >= val[n-1]) return n-1;
int ml = 0, mu = n-1;
while (mu-ml > 1) {
int mav = (ml+mu)/2;
if (x < val[mav]) mu = mav; else ml = mav;
}
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
}
static inline float e8m0_to_fp32(uint8_t x) {
uint32_t bits;
if (x == 0) {
bits = 0x00400000;
} else {
bits = (uint32_t) x << 23;
}
return as_type<float>(bits);
}
static inline float dot(float x, float y) {
return x*y;
}
static inline float sum(float x) {
return x;
}
static inline float sum(float4 x) {
return x[0] + x[1] + x[2] + x[3];
}
enum ggml_sort_order {
GGML_SORT_ORDER_ASC,
GGML_SORT_ORDER_DESC,
};
constant float GELU_COEF_A = 0.044715f;
constant float GELU_QUICK_COEF = -1.702f;
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
// ref: https://www.johndcook.com/blog/python_erf/
constant float p_erf = 0.3275911f;
constant float a1_erf = 0.254829592f;
constant float a2_erf = -0.284496736f;
constant float a3_erf = 1.421413741f;
constant float a4_erf = -1.453152027f;
constant float a5_erf = 1.061405429f;
template<typename T>
inline T erf_approx(T x) {
T sign_x = sign(x);
x = fabs(x);
T t = 1.0f / (1.0f + p_erf * x);
T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
return sign_x * y;
}
template<typename T> T elu_approx(T x);
template<> inline float elu_approx<float>(float x) {
return (x > 0.f) ? x : (exp(x) - 1);
}
template<> inline float4 elu_approx<float4>(float4 x) {
float4 res;
res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
return res;
}
-485
View File
@@ -1,485 +0,0 @@
#include "common.h"
typedef void (im2col_t)(
constant ggml_metal_kargs_im2col & args,
device const float * x,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template <typename T>
kernel void kernel_im2col(
constant ggml_metal_kargs_im2col & args,
device const float * x,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
// const int64_t IC = tgpg[0];
const int64_t OH = tgpg[1];
const int64_t OW = tgpg[2];
const int64_t KH = ntg[1];
const int64_t KW = ntg[2];
int64_t in = tpitg[0];
const int64_t ikh = tpitg[1];
const int64_t ikw = tpitg[2];
const int64_t iic = tgpig[0];
const int64_t ioh = tgpig[1];
const int64_t iow = tgpig[2];
const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;
int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
device T * pdst = (device T *) (dst);
if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
while (in < args.N) {
pdst[offset_dst] = 0.0f;
offset_dst += ntg[0]*args.CHW*OH*OW;
in += ntg[0];
}
} else {
int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
while (in < args.N) {
pdst[offset_dst] = x[offset_src];
offset_dst += ntg[0]*args.CHW*OH*OW;
offset_src += ntg[0]*args.ofs0;
in += ntg[0];
}
}
}
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
// TODO: optimize
typedef void (im2col_ext_t)(
constant ggml_metal_kargs_im2col & args,
device const float * x,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template <typename T>
kernel void kernel_im2col_ext(
constant ggml_metal_kargs_im2col & args,
device const float * x,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
const int64_t KHW = (int64_t)args.KHW;
const int64_t d = tgpig[0] / args.CHW;
const int64_t chw = tgpig[0] % args.CHW;
const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
const int64_t HW = tgpig[0] % KHW;
const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
if (tpitg_0 >= args.N) {
return;
}
const int64_t tpitg_1 = HW / args.KW;
const int64_t tpitg_2 = HW % args.KW;
const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
const int64_t offset_dst =
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
(tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
device T * pdst = (device T *) (dst);
if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
pdst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
}
}
template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
template <typename TK>
kernel void kernel_conv_2d(
constant ggml_metal_kargs_conv_2d & args,
device const char * weights,
device const char * src,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const uint threads_per_tg = ntg.x * ntg.y * ntg.z;
const uint tg_index = (tgpig.z * tgpg.y + tgpig.y) * tgpg.x + tgpig.x;
const uint local_thread = tpitg.z * (ntg.x * ntg.y) + tpitg.y * ntg.x + tpitg.x;
const uint thread_index = tg_index * threads_per_tg + local_thread;
const uint64_t total_threads = (uint64_t) threads_per_tg * tgpg.x * tgpg.y * tgpg.z;
const uint64_t total_outputs = (uint64_t) args.N * args.OC * args.OH * args.OW;
for (uint64_t index = thread_index; index < total_outputs; index += total_threads) {
uint64_t tmp = index;
const int32_t ow = tmp % args.OW; tmp /= args.OW;
const int32_t oh = tmp % args.OH; tmp /= args.OH;
const int32_t oc = tmp % args.OC; tmp /= args.OC;
const int32_t n = tmp;
float acc = 0.0f;
const int32_t base_x = ow*args.s0 - args.p0;
const int32_t base_y = oh*args.s1 - args.p1;
int32_t ky_start = 0;
if (base_y < 0) {
ky_start = (-base_y + args.d1 - 1)/args.d1;
}
int32_t ky_end = args.KH;
const int32_t y_max = args.IH - 1 - base_y;
if (y_max < 0) {
ky_end = ky_start;
} else if (base_y + (args.KH - 1)*args.d1 >= args.IH) {
ky_end = min(ky_end, y_max/args.d1 + 1);
}
int32_t kx_start = 0;
if (base_x < 0) {
kx_start = (-base_x + args.d0 - 1)/args.d0;
}
int32_t kx_end = args.KW;
const int32_t x_max = args.IW - 1 - base_x;
if (x_max < 0) {
kx_end = kx_start;
} else if (base_x + (args.KW - 1)*args.d0 >= args.IW) {
kx_end = min(kx_end, x_max/args.d0 + 1);
}
if (ky_start < ky_end && kx_start < kx_end) {
const uint64_t src_base_n = (uint64_t) n * args.nb13;
const uint64_t w_base_oc = (uint64_t) oc * args.nb03;
for (int32_t ic = 0; ic < args.IC; ++ic) {
const uint64_t src_base_nc = src_base_n + (uint64_t) ic * args.nb12;
const uint64_t w_base_ocic = w_base_oc + (uint64_t) ic * args.nb02;
for (int32_t ky = ky_start; ky < ky_end; ++ky) {
const int32_t iy = base_y + ky*args.d1;
const uint64_t src_base_row = src_base_nc + (uint64_t) iy * args.nb11;
const uint64_t w_base_row = w_base_ocic + (uint64_t) ky * args.nb01;
for (int32_t kx = kx_start; kx < kx_end; ++kx) {
const int32_t ix = base_x + kx*args.d0;
const uint64_t src_offs = src_base_row + (uint64_t) ix * args.nb10;
const uint64_t w_offs = w_base_row + (uint64_t) kx * args.nb00;
const float x = *(device const float *)(src + src_offs);
const float w = (float) (*(device const TK *)(weights + w_offs));
acc += x * w;
}
}
}
}
const uint64_t dst_offs =
(uint64_t) n * args.nb3 +
(uint64_t) oc * args.nb2 +
(uint64_t) oh * args.nb1 +
(uint64_t) ow * args.nb0;
*(device float *)(dst + dst_offs) = acc;
}
}
template [[host_name("kernel_conv_2d_f32_f32")]]
kernel void kernel_conv_2d<float>(
constant ggml_metal_kargs_conv_2d & args,
device const char * weights,
device const char * src,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template [[host_name("kernel_conv_2d_f16_f32")]]
kernel void kernel_conv_2d<half>(
constant ggml_metal_kargs_conv_2d & args,
device const char * weights,
device const char * src,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
typedef void (conv_transpose_1d_t)(
constant ggml_metal_kargs_conv_transpose_1d & args,
device const float * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
template <typename T>
kernel void kernel_conv_transpose_1d(
constant ggml_metal_kargs_conv_transpose_1d & args,
device const T * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]) {
// For output position j on the time axis, only input positions
// i such that i*s0 <= j < i*s0 + K
// contribute -- i.e. i in [ceil((j - K + 1)/s0), floor(j/s0)]
// intersected with [0, IL-1]. That's at most ceil(K/s0) values
// (typically 2 for stride==K/2 transposed convs).
const int32_t j = tgpig[0];
const int32_t s0 = args.s0;
const int32_t K = args.K;
const int32_t IL = args.IL;
int32_t i_min;
{
int32_t a = j - K + 1;
i_min = a <= 0 ? 0 : (a + s0 - 1) / s0; // ceil(a/s0) for a>0
}
int32_t i_max = j / s0;
if (i_max > IL - 1) i_max = IL - 1;
float v = 0.0f;
if (i_min <= i_max) {
for (int64_t c = 0; c < args.IC; c++) {
const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
const int32_t input_offset = c * IL;
for (int32_t i = i_min; i <= i_max; i++) {
v += float(src0[kernel_offset + j - i * s0]) * src1[input_offset + i];
}
}
}
device float * dst_ptr = (device float *) (dst + tgpig[0] * args.nb0 + tgpig[1] * args.nb1);
dst_ptr[0] = v;
}
template [[host_name("kernel_conv_transpose_1d_f32_f32")]]
kernel void kernel_conv_transpose_1d<float>(
constant ggml_metal_kargs_conv_transpose_1d & args,
device const float * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
template [[host_name("kernel_conv_transpose_1d_f16_f32")]]
kernel void kernel_conv_transpose_1d<half>(
constant ggml_metal_kargs_conv_transpose_1d & args,
device const half * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
typedef void (conv_transpose_2d_t)(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const float * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
template <typename T>
kernel void kernel_conv_transpose_2d(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const T * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t out_x = tgpig[0];
const int64_t out_y = tgpig[1];
const int64_t out_c = tgpig[2];
const int64_t kw = tpitg[0];
const int64_t kh = tpitg[1];
float v = 0.0f;
for (int64_t in_c = 0; in_c < args.IC; in_c++) {
int64_t in_y = out_y - kh;
if (in_y < 0 || in_y % args.s0) continue;
in_y /= args.s0;
if (in_y >= args.IH) continue;
int64_t in_x = out_x - kw;
if (in_x < 0 || in_x % args.s0) continue;
in_x /= args.s0;
if (in_x >= args.IW) continue;
const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
v += (float)src0[kernel_idx] * src1[input_idx];
}
const uint tid = tpitg.y * ntg.x + tpitg.x;
shared_sum[tid] = v;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid == 0) {
float total = 0.0f;
const uint num_threads = ntg.x * ntg.y;
for (uint i = 0; i < num_threads; i++) {
total += shared_sum[i];
}
device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
dst_ptr[0] = total;
}
}
template [[host_name("kernel_conv_transpose_2d_f32_f32")]]
kernel void kernel_conv_transpose_2d<float>(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const float * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template [[host_name("kernel_conv_transpose_2d_f16_f32")]]
kernel void kernel_conv_transpose_2d<half>(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const half * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template <typename T>
kernel void kernel_conv_3d(
constant ggml_metal_kargs_conv_3d & args,
device const char * src0, // Weights [IC * OC, KD, KH, KW]
device const char * src1, // Inputs [IC * N, ID, IH, IW]
device char * dst, // Outputs [OC * N, OD, OH, OW]
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]) {
// 1. Un-flatten the spatial dimension from Grid X
int64_t spatial_idx = tgpig.x * 32 + tpitg.x;
if (spatial_idx >= args.OW * args.OH * args.OD) {
return; // Thread falls outside the spatial volume
}
int64_t od = spatial_idx / (args.OW * args.OH);
int64_t oh = (spatial_idx / args.OW) % args.OH;
int64_t ow = spatial_idx % args.OW;
// 2. Map Y to Channels, Z to Batch
int64_t oc = tgpig.y;
int64_t batch_idx = tgpig.z;
// 3. Calculate anchor coordinates in the Input volume
int64_t i_w_base = ow * args.s0 - args.p0;
int64_t i_h_base = oh * args.s1 - args.p1;
int64_t i_d_base = od * args.s2 - args.p2;
float sum = 0.0f;
// 4. Gather Loop (Iterate over Input Channels -> Depth -> Height -> Width)
for (int64_t ic = 0; ic < args.IC; ++ic) {
// ggml packs batch and channel together in the 4th dimension
int64_t src_cn_idx = batch_idx * args.IC + ic;
int64_t w_cn_idx = oc * args.IC + ic;
for (int64_t kz = 0; kz < args.KD; ++kz) {
int64_t id = i_d_base + kz * args.d2;
if (id < 0 || id >= args.ID) continue; // Boundary check (Padding)
for (int64_t ky = 0; ky < args.KH; ++ky) {
int64_t ih = i_h_base + ky * args.d1;
if (ih < 0 || ih >= args.IH) continue;
for (int64_t kx = 0; kx < args.KW; ++kx) {
int64_t iw = i_w_base + kx * args.d0;
if (iw < 0 || iw >= args.IW) continue;
// Convert multi-dimensional coordinates to flat byte offsets
int64_t w_idx = kx*args.nb00 + ky*args.nb01 + kz*args.nb02 + w_cn_idx*args.nb03;
int64_t i_idx = iw*args.nb10 + ih*args.nb11 + id*args.nb12 + src_cn_idx*args.nb13;
// Dereference memory and cast weights to f32 if they were f16
float w_val = (float)*(device const T*)((device const char*)src0 + w_idx);
float i_val = *(device const float*)((device const char*)src1 + i_idx);
sum += w_val * i_val;
}
}
}
}
// 5. Write the accumulated value out to RAM
int64_t dst_cn_idx = batch_idx * args.OC + oc;
int64_t d_idx = ow*args.nb0 + oh*args.nb1 + od*args.nb2 + dst_cn_idx*args.nb3;
*(device float*)(dst + d_idx) = sum;
}
// Explicit instantiations so the JIT compiler can find them by name
template [[host_name("kernel_conv_3d_f32_f32")]]
kernel void kernel_conv_3d<float>(
constant ggml_metal_kargs_conv_3d & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]);
// Explicit instantiation for f16 weights
template [[host_name("kernel_conv_3d_f16_f32")]]
kernel void kernel_conv_3d<half>(
constant ggml_metal_kargs_conv_3d & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]);
-686
View File
@@ -1,686 +0,0 @@
#pragma once
#include "common.h"
#define GGML_COMMON_DECL_METAL
#define GGML_COMMON_IMPL_METAL
#if defined(GGML_METAL_EMBED_LIBRARY)
__embed_ggml-common.h__
#else
#include "ggml-common.h"
#endif
#define QK_NL 16 // shared by mul_mm and get_rows_q instantiations
// NOTE: this is not dequantizing - we are simply fitting the template
template <typename type4x4>
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
reg = (type4x4)(*src);
}
template <typename type4>
void dequantize_f32_t4(device const float4 * src, short il, thread type4 & reg) {
reg = (type4)(*src);
}
template <typename type4x4>
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
reg = (type4x4)(*src);
}
template <typename type4>
void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
reg = (type4)(*(src));
}
#if defined(GGML_METAL_HAS_BF16)
template <typename type4x4>
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
reg = (type4x4)(*src);
}
template <typename type4>
void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg) {
reg = (type4)(*(src));
}
#endif
template <typename type4x4>
void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & reg) {
device const uint8_t * qs = xb->qs;
const float d = xb->d;
const float neg_d = -d;
const int byte_offset = il * 2; // il*16 bits = il*2 bytes
const uint8_t b0 = qs[byte_offset];
const uint8_t b1 = qs[byte_offset + 1];
float4x4 reg_f;
reg_f[0][0] = select(neg_d, d, bool(b0 & 0x01));
reg_f[0][1] = select(neg_d, d, bool(b0 & 0x02));
reg_f[0][2] = select(neg_d, d, bool(b0 & 0x04));
reg_f[0][3] = select(neg_d, d, bool(b0 & 0x08));
reg_f[1][0] = select(neg_d, d, bool(b0 & 0x10));
reg_f[1][1] = select(neg_d, d, bool(b0 & 0x20));
reg_f[1][2] = select(neg_d, d, bool(b0 & 0x40));
reg_f[1][3] = select(neg_d, d, bool(b0 & 0x80));
reg_f[2][0] = select(neg_d, d, bool(b1 & 0x01));
reg_f[2][1] = select(neg_d, d, bool(b1 & 0x02));
reg_f[2][2] = select(neg_d, d, bool(b1 & 0x04));
reg_f[2][3] = select(neg_d, d, bool(b1 & 0x08));
reg_f[3][0] = select(neg_d, d, bool(b1 & 0x10));
reg_f[3][1] = select(neg_d, d, bool(b1 & 0x20));
reg_f[3][2] = select(neg_d, d, bool(b1 & 0x40));
reg_f[3][3] = select(neg_d, d, bool(b1 & 0x80));
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) {
const float d = xb->d;
const float neg_d = -d;
const int base = il * 4;
const uint8_t byte = xb->qs[base / 8];
const int s = base % 8;
float4 reg_f;
reg_f[0] = select(neg_d, d, bool((byte >> (s )) & 1));
reg_f[1] = select(neg_d, d, bool((byte >> (s + 1)) & 1));
reg_f[2] = select(neg_d, d, bool((byte >> (s + 2)) & 1));
reg_f[3] = select(neg_d, d, bool((byte >> (s + 3)) & 1));
reg = (type4) reg_f;
}
template <typename type4x4>
void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
const float d1 = il ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
const float md = -8.h * xb->d;
const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8;
float4x4 reg_f;
for (int i = 0; i < 8; i++) {
reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md;
reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md;
}
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
const float md = -8.h * xb->d;
const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8;
for (int i = 0; i < 2; i++) {
reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + md;
reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + md;
}
}
template <typename type4x4>
void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
const float d1 = il ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
const float m = xb->m;
const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8;
float4x4 reg_f;
for (int i = 0; i < 8; i++) {
reg_f[i/2][2*(i%2) + 0] = ((qs[i] & mask0) * d1) + m;
reg_f[i/2][2*(i%2) + 1] = ((qs[i] & mask1) * d2) + m;
}
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
const float m = xb->m;
const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8;
for (int i = 0; i < 2; i++) {
reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + m;
reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + m;
}
}
template <typename type4x4>
void dequantize_q5_0(device const block_q5_0 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
const float d = xb->d;
const float md = -16.h * xb->d;
const ushort mask = il ? 0x00F0 : 0x000F;
const uint32_t qh = *((device const uint32_t *)xb->qh);
const int x_mv = il ? 4 : 0;
const int gh_mv = il ? 12 : 0;
const int gh_bk = il ? 0 : 4;
float4x4 reg_f;
for (int i = 0; i < 8; i++) {
// extract the 5-th bits for x0 and x1
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
// combine the 4-bits from qs with the 5th bit
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
reg_f[i/2][2*(i%2) + 0] = d * x0 + md;
reg_f[i/2][2*(i%2) + 1] = d * x1 + md;
}
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
const float d = xb->d;
const float md = -16.h * xb->d;
const ushort mask = (il/4) ? 0x00F0 : 0x000F;
const uint32_t qh = *((device const uint32_t *)xb->qh);
const int x_mv = (il/4) ? 4 : 0;
const int gh_mv = (il/4) ? 12 : 0;
const int gh_bk = (il/4) ? 0 : 4;
for (int ii = 0; ii < 2; ii++) {
int i = 2*(il%4) + ii;
// extract the 5-th bits for x0 and x1
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
// combine the 4-bits from qs with the 5th bit
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
reg[2*ii + 0] = d * x0 + md;
reg[2*ii + 1] = d * x1 + md;
}
}
template <typename type4x4>
void dequantize_q5_1(device const block_q5_1 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
const float d = xb->d;
const float m = xb->m;
const ushort mask = il ? 0x00F0 : 0x000F;
const uint32_t qh = *((device const uint32_t *)xb->qh);
const int x_mv = il ? 4 : 0;
const int gh_mv = il ? 12 : 0;
const int gh_bk = il ? 0 : 4;
float4x4 reg_f;
for (int i = 0; i < 8; i++) {
// extract the 5-th bits for x0 and x1
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
// combine the 4-bits from qs with the 5th bit
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
reg_f[i/2][2*(i%2) + 0] = d * x0 + m;
reg_f[i/2][2*(i%2) + 1] = d * x1 + m;
}
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
const float d = xb->d;
const float m = xb->m;
const ushort mask = (il/4) ? 0x00F0 : 0x000F;
const uint32_t qh = *((device const uint32_t *)xb->qh);
const int x_mv = (il/4) ? 4 : 0;
const int gh_mv = (il/4) ? 12 : 0;
const int gh_bk = (il/4) ? 0 : 4;
for (int ii = 0; ii < 2; ii++) {
int i = 2*(il%4) + ii;
// extract the 5-th bits for x0 and x1
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
// combine the 4-bits from qs with the 5th bit
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
reg[2*ii + 0] = d * x0 + m;
reg[2*ii + 1] = d * x1 + m;
}
}
template <typename type4x4>
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
device const int8_t * qs = ((device const int8_t *)xb->qs);
const float d = xb->d;
float4x4 reg_f;
for (int i = 0; i < 16; i++) {
reg_f[i/4][i%4] = (qs[i + 16*il] * d);
}
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) {
device const int8_t * qs = ((device const int8_t *)xb->qs);
const float d = xb->d;
for (int i = 0; i < 4; i++) {
reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
}
}
template <typename type4x4>
void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {
device const uint8_t * q2 = (device const uint8_t *)xb->qs;
const float d = e8m0_to_fp32(xb->e);
const uint8_t shr = il >= 1 ? 4 : 0;
for (int i = 0; i < 4; ++i) {
reg[i][0] = d * kvalues_mxfp4_f[(q2[4*i + 0] >> shr) & 0x0F];
reg[i][1] = d * kvalues_mxfp4_f[(q2[4*i + 1] >> shr) & 0x0F];
reg[i][2] = d * kvalues_mxfp4_f[(q2[4*i + 2] >> shr) & 0x0F];
reg[i][3] = d * kvalues_mxfp4_f[(q2[4*i + 3] >> shr) & 0x0F];
}
}
template <typename type4>
void dequantize_mxfp4_t4(device const block_mxfp4 * xb, short il, thread type4 & reg) {
device const uint8_t * q2 = (device const uint8_t *)xb->qs;
const float d = e8m0_to_fp32(xb->e);
const short il4 = il%4;
const uint8_t shr = il >= 4 ? 4 : 0;
reg[0] = d * kvalues_mxfp4_f[(q2[4*il4 + 0] >> shr) & 0x0F];
reg[1] = d * kvalues_mxfp4_f[(q2[4*il4 + 1] >> shr) & 0x0F];
reg[2] = d * kvalues_mxfp4_f[(q2[4*il4 + 2] >> shr) & 0x0F];
reg[3] = d * kvalues_mxfp4_f[(q2[4*il4 + 3] >> shr) & 0x0F];
}
template <typename type4x4>
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
const float d = xb->d;
const float min = xb->dmin;
device const uint8_t * q = (device const uint8_t *)xb->qs;
float dl, ml;
uint8_t sc = xb->scales[il];
q = q + 32*(il/8) + 16*(il&1);
il = (il/2)%4;
half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
}
}
template <typename type4x4>
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
const half d_all = xb->d;
device const uint8_t * q = (device const uint8_t *)xb->qs;
device const uint8_t * h = (device const uint8_t *)xb->hmask;
device const int8_t * scales = (device const int8_t *)xb->scales;
q = q + 32 * (il/8) + 16 * (il&1);
h = h + 16 * (il&1);
uint8_t m = 1 << (il/2);
uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
((il/4)>0 ? 12 : 3);
uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
: (scale_2&kmask2) | ((scale_1&kmask1) << 4);
float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
const float ml = 4.f * dl;
il = (il/2) & 3;
const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
dl *= coef;
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
}
}
static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
: uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
}
template <typename type4x4>
void dequantize_q4_K(device const block_q4_K * xb, short il, thread type4x4 & reg) {
device const uchar * q = xb->qs;
short is = (il/4) * 2;
q = q + (il/4) * 32 + 16 * (il&1);
il = il & 3;
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
const float d = il < 2 ? xb->d : xb->d / 16.h;
const float min = xb->dmin;
const float dl = d * sc[0];
const float ml = min * sc[1];
const ushort mask = il < 2 ? 0x0F : 0xF0;
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
}
}
template <typename type4x4>
void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
device const uint8_t * q = xb->qs;
device const uint8_t * qh = xb->qh;
short is = (il/4) * 2;
q = q + 32 * (il/4) + 16 * (il&1);
qh = qh + 16 * (il&1);
uint8_t ul = 1 << (il/2);
il = il & 3;
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
const float d = il < 2 ? xb->d : xb->d / 16.f;
const float min = xb->dmin;
const float dl = d * sc[0];
const float ml = min * sc[1];
const ushort mask = il<2 ? 0x0F : 0xF0;
const float qh_val = il<2 ? 16.f : 256.f;
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
}
}
template <typename type4x4>
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
const half d_all = xb->d;
device const uint16_t * ql = (device const uint16_t *)xb->ql;
device const uint16_t * qh = (device const uint16_t *)xb->qh;
device const int8_t * scales = (device const int8_t *)xb->scales;
ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1);
qh = qh + 16*(il/8) + 8*(il&1);
float sc = scales[(il%2) + 2 * ((il/2))];
il = (il/2) & 3;
const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303);
const uint32_t kmask2 = il>1 ? 0xF0F0F0F0 : 0x0F0F0F0F;
const float ml = d_all * sc * 32.f;
const float dl0 = d_all * sc;
const float dl1 = dl0 / 256.f;
const float dl2 = dl0 / (256.f * 256.f);
const float dl3 = dl0 / (256.f * 256.f * 256.f);
const uint8_t shr_h = il>2 ? 2 : 0;
const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4);
const uint8_t shr_l = il>1 ? 4 : 0;
for (int i = 0; i < 4; ++i) {
const uint32_t low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2;
const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1;
const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l);
reg[i][0] = dl0 * ((half)(q & 0xFF)) - ml;
reg[i][1] = dl1 * ((float)(q & 0xFF00)) - ml;
reg[i][2] = dl2 * ((float)(q & 0xFF0000)) - ml;
reg[i][3] = dl3 * ((float)(q & 0xFF000000)) - ml;
}
}
template <typename type4x4>
void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const float d = xb->d;
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
// each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
device const uint16_t * q2 = xb->qs + 4*ib32;
const uint32_t aux32_g = q2[0] | (q2[1] << 16);
const uint32_t aux32_s = q2[2] | (q2[3] << 16);
thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
for (int i = 0; i < 8; ++i) {
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
}
grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
for (int i = 0; i < 8; ++i) {
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
}
}
template <typename type4x4>
void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const float d = xb->d;
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
device const uint16_t * q2 = xb->qs + 4*ib32;
const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
for (int i = 0; i < 8; ++i) {
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
}
grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
signs = ksigns_iq2xs[q2[2*il+1] >> 9];
for (int i = 0; i < 8; ++i) {
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
}
}
template <typename type4x4>
void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const float d = xb->d;
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
device const uint8_t * q3 = xb->qs + 8*ib32;
device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
const uint32_t aux32 = gas[0] | (gas[1] << 16);
const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
for (int i = 0; i < 4; ++i) {
reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
}
grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
for (int i = 0; i < 4; ++i) {
reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
}
}
template <typename type4x4>
void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const float d = xb->d;
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
device const uint8_t * qs = xb->qs + 8*ib32;
device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
const uint8_t qh = xb->qh[ib32] >> 4*il;
const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
for (int i = 0; i < 4; ++i) {
reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
}
grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
for (int i = 0; i < 4; ++i) {
reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
}
}
template <typename type4x4>
void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const float d = xb->d;
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
device const uint8_t * signs = qs + QK_K/8;
const uint8_t qh = xb->qh[ib32] >> 4*il;
const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
for (int i = 0; i < 8; ++i) {
reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
}
}
template <typename type4x4>
void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const int ib32 = il/2;
il = il%2;
const float d = xb->d;
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
device const uint16_t * qh = xb->qh;
const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1);
const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA);
const uint16_t h = qh[ib32] >> 6*il;
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700)));
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700)));
for (int i = 0; i < 4; ++i) {
reg[0][i] = dl * (grid1[i] & 0xf) + ml;
reg[1][i] = dl * (grid1[i] >> 4) + ml;
reg[2][i] = dl * (grid2[i] & 0xf) + ml;
reg[3][i] = dl * (grid2[i] >> 4) + ml;
}
}
template <typename type4x4>
void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const int ib32 = il/2;
il = il%2;
device const uint16_t * sc = (device const uint16_t *)xb->scales;
iq1m_scale_t scale;
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
const float d = scale.f16;
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
device const uint8_t * qh = xb->qh + 2*ib32 + il;
const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
for (int i = 0; i < 4; ++i) {
reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
reg[1][i] = dl * (grid1[i] >> 4) + ml1;
reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
reg[3][i] = dl * (grid2[i] >> 4) + ml2;
}
}
template <typename type4x4>
void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
device const uint16_t * q4 = (device const uint16_t *)xb->qs;
const float d = xb->d;
uint32_t aux32;
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
for (int i = 0; i < 4; ++i) {
aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
}
}
template <typename type4>
void dequantize_iq4_nl_t4(device const block_iq4_nl * xb, short il, thread type4 & reg) {
device const uint16_t * q4 = (device const uint16_t *)xb->qs;
const float d = xb->d;
uint32_t aux32;
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
aux32 = ((q4[2*(il%4)] | (q4[2*(il%4)+1] << 16)) >> 4*(il/4)) & 0x0f0f0f0f;
reg[0] = d * kvalues_iq4nl_f[q8[0]];
reg[1] = d * kvalues_iq4nl_f[q8[1]];
reg[2] = d * kvalues_iq4nl_f[q8[2]];
reg[3] = d * kvalues_iq4nl_f[q8[3]];
}
template <typename type4x4>
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
const float d = (float)xb->d * (ls - 32);
uint32_t aux32;
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
for (int i = 0; i < 4; ++i) {
aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
}
}
File diff suppressed because it is too large Load Diff
@@ -1,250 +0,0 @@
#include "common.h"
constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]];
constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]];
constant short FC_gated_delta_net_K [[function_constant(FC_GATED_DELTA_NET + 2)]];
#if 1
template<short NSG>
kernel void kernel_gated_delta_net_impl(
constant ggml_metal_kargs_gated_delta_net & args,
device const char * q,
device const char * k,
device const char * v,
device const char * g,
device const char * b,
device const char * s,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
#define S_v FC_gated_delta_net_ne20
#define G FC_gated_delta_net_ne30
#define K FC_gated_delta_net_K
const uint tx = tpitg.x;
const uint ty = tpitg.y;
const uint i23 = tgpig.z; // B (n_seqs)
const uint i21 = tgpig.y; // H (head)
const uint i20 = tgpig.x*NSG + ty; // row within S_v
const uint i01 = i21 % args.ne01;
const uint i11 = i21 % args.ne11;
const float scale = 1.0f / sqrt((float)S_v);
// input state layout [S_v, S_v, H, n_seqs] (s0 only): per-seq stride is H*D.
// state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous
const uint state_in_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
device const float * s_ptr = (device const float *) (s) + state_in_base;
float ls[NSG];
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] = s_ptr[is];
}
device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
// When n_tokens < K, only slots 0..n_tokens-1 are written; older slots are caller-owned.
// output state base offset: after attention scores
const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23;
// output state per-slot size: S_v * S_v * H * n_seqs
const uint state_size_per_snap = S_v * S_v * args.ne21 * args.ne23;
// per-(seq,head) offset within a slot
const uint state_out_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
for (short t = 0; t < args.ne22; t++) {
float s_k = 0.0f;
if (G == 1) {
const float g_exp = exp(g_ptr[0]);
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] *= g_exp;
s_k += ls[j]*k_ptr[is];
}
} else {
// KDA
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] *= exp(g_ptr[is]);
s_k += ls[j]*k_ptr[is];
}
}
s_k = simd_sum(s_k);
const float d = (v_ptr[i20] - s_k)*b_ptr[0];
float y = 0.0f;
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] += k_ptr[is]*d;
y += ls[j]*q_ptr[is];
}
y = simd_sum(y);
if (tx == 0) {
dst_attn[t*args.ne21*S_v] = y*scale;
}
q_ptr += args.ns02;
k_ptr += args.ns12;
v_ptr += args.ns22;
b_ptr += args.ne21;
g_ptr += args.ne21*G;
if (K > 1) {
const int target_slot = (int)args.ne22 - 1 - (int)t;
if (target_slot >= 0 && target_slot < (int)K) {
device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base;
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
dst_state[is] = ls[j];
}
}
}
}
if (K == 1) {
device float * dst_state = (device float *) (dst) + attn_size + state_out_base;
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
dst_state[is] = ls[j];
}
}
#undef S_v
#undef G
#undef K
}
typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t;
template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<1>;
template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<2>;
template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<4>;
#else
// a simplified version of the above
// no performance improvement, so keep the above version for now
template<typename T, short NSG>
kernel void kernel_gated_delta_net_impl(
constant ggml_metal_kargs_gated_delta_net & args,
device const char * q,
device const char * k,
device const char * v,
device const char * g,
device const char * b,
device const char * s,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
#define S_v FC_gated_delta_net_ne20
#define G FC_gated_delta_net_ne30
const uint tx = tpitg.x;
const uint ty = tpitg.y;
const uint i23 = tgpig.z; // B
const uint i21 = tgpig.y; // H
const uint i20 = tgpig.x*NSG + ty;
const uint i01 = i21 % args.ne01;
const uint i11 = i21 % args.ne11;
const float scale = 1.0f / sqrt((float)S_v);
device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
float lsf[NSG];
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
lsf[j] = s_ptr[is*S_v];
}
thread T * ls = (thread T *) (lsf);
device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
for (short t = 0; t < args.ne22; t++) {
device const T * qt_ptr = (device const T *) (q_ptr);
device const T * kt_ptr = (device const T *) (k_ptr);
device const T * gt_ptr = (device const T *) (g_ptr);
if (G == 1) {
*ls *= exp(g_ptr[0]);
} else {
// KDA
*ls *= exp(gt_ptr[tx]);
}
const float s_k = simd_sum(dot(*ls, kt_ptr[tx]));
const float d = (v_ptr[i20] - s_k)*b_ptr[0];
*ls += kt_ptr[tx]*d;
const float y = simd_sum(dot(*ls, qt_ptr[tx]));
if (tx == 0) {
*dst_attn = y*scale;
}
q_ptr += args.ns02;
k_ptr += args.ns12;
v_ptr += args.ns22;
b_ptr += args.ne21;
g_ptr += args.ne21*G;
dst_attn += args.ne21*S_v;
}
device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
device T * dstt_state = (device T *) (dst_state);
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
dst_state[is*S_v] = lsf[j];
}
#undef S_v
#undef G
}
typedef decltype(kernel_gated_delta_net_impl<float4, 4>) kernel_gated_delta_net_t;
template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float, 1>;
template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float2, 2>;
template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float4, 4>;
#endif
-347
View File
@@ -1,347 +0,0 @@
#include "common.h"
kernel void kernel_argmax_f32(
constant ggml_metal_kargs_argmax & args,
device const char * src0,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const float * x_row = (device const float *) ((device const char *) src0 + tgpig * args.nb01);
float lmax = -INFINITY;
int32_t larg = -1;
for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
if (x_row[i00] > lmax) {
lmax = x_row[i00];
larg = i00;
}
}
// find the argmax value in the block
float max_val = simd_max(lmax);
int32_t arg_val = simd_max(select(-1, larg, lmax == max_val));
device int32_t * dst_i32 = (device int32_t *) dst;
threadgroup float * shared_maxval = (threadgroup float *) shmem;
threadgroup int32_t * shared_argmax = (threadgroup int32_t *) shmem + N_SIMDWIDTH;
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
shared_maxval[tiisg] = -INFINITY;
shared_argmax[tiisg] = -1;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shared_maxval[sgitg] = max_val;
shared_argmax[sgitg] = arg_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_val = shared_maxval[tiisg];
arg_val = shared_argmax[tiisg];
float max_val_reduced = simd_max(max_val);
int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced));
dst_i32[tgpig] = arg_val_reduced;
return;
}
dst_i32[tgpig] = arg_val;
}
kernel void kernel_diag_f32(
constant ggml_metal_kargs_diag & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]]) {
constexpr short NW = N_SIMDWIDTH;
const int32_t i3 = tgpig.z;
const int32_t i2 = tgpig.y;
const int32_t i1 = tgpig.x;
device const float * src0_ptr = (device const float *)(src0 + i2*args.nb02 + i3*args.nb03);
device float * dst_ptr = (device float *)(dst + i1*args.nb01 + i2*args.nb2 + i3*args.nb3);
for (int i0 = tiitg; i0 < args.ne0; i0 += NW) {
dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f;
}
}
kernel void kernel_roll_f32(
constant ggml_metal_kargs_roll & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
device const float * src0_ptr = (device const float *) src0;
device float * dst_ptr = (device float *) dst;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
// apply shifts and wrap around
int64_t i00 = i0 - args.s0;
int64_t i01 = i1 - args.s1;
int64_t i02 = i2 - args.s2;
int64_t i03 = i3 - args.s3;
if (i00 < 0) { i00 += args.ne00; } else if (i00 >= args.ne00) { i00 -= args.ne00; }
if (i01 < 0) { i01 += args.ne01; } else if (i01 >= args.ne01) { i01 -= args.ne01; }
if (i02 < 0) { i02 += args.ne02; } else if (i02 >= args.ne02) { i02 -= args.ne02; }
if (i03 < 0) { i03 += args.ne03; } else if (i03 >= args.ne03) { i03 -= args.ne03; }
int64_t src_idx = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00 + i00;
int64_t dst_idx = i3 *args.ne2 *args.ne1 *args.ne0 + i2 *args.ne1 *args.ne0 + i1 *args.ne0 + i0;
dst_ptr[dst_idx] = src0_ptr[src_idx];
}
}
template <typename T>
kernel void kernel_pad_impl(
constant ggml_metal_kargs_pad & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int32_t i3 = tgpig.z;
const int32_t i2 = tgpig.y;
const int32_t k0 = tgpig.x/args.ne1;
const int32_t i1 = tgpig.x - k0*args.ne1;
const int32_t i03 = i3;
const int32_t i02 = i2;
const int32_t i01 = i1;
device const T * src0_ptr = (device const T *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T * dst_ptr = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
for (int32_t l0 = 0; l0 < 1024; l0 += ntg.x) {
const int32_t i0 = k0*1024 + tpitg.x + l0;
if (i0 >= args.ne0) {
break;
}
if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
dst_ptr[i0] = src0_ptr[i0];
} else {
dst_ptr[i0] = 0.0f;
}
}
}
typedef decltype(kernel_pad_impl<float>) kernel_pad_t;
template [[host_name("kernel_pad_f32")]] kernel kernel_pad_t kernel_pad_impl<float>;
template [[host_name("kernel_pad_f32_4")]] kernel kernel_pad_t kernel_pad_impl<float4>;
// TODO: this is slow - optimize
kernel void kernel_pad_reflect_1d_f32(
constant ggml_metal_kargs_pad_reflect_1d & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t i03 = i3;
const int64_t i02 = i2;
const int64_t i01 = i1;
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
if (i0 < args.p0) {
dst_ptr[i0] = src0_ptr[args.p0 - i0];
} else if (i0 < args.ne0 - args.p1) {
dst_ptr[i0] = src0_ptr[i0 - args.p0];
} else {
dst_ptr[i0] = src0_ptr[(args.ne0 - args.p1 - args.p0) - (args.p1 + 1 - (args.ne0 - i0)) - 1];
}
}
}
}
kernel void kernel_arange_f32(
constant ggml_metal_kargs_arange & args,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
device float * dst_ptr = (device float *) dst;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
dst_ptr[i0] = args.start + args.step * i0;
}
}
kernel void kernel_timestep_embedding_f32(
constant ggml_metal_kargs_timestep_embedding & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
int i = tgpig.x;
device float * embed_data = (device float *)(dst + i*args.nb1);
int half_ = args.dim / 2;
for (int j = tpitg.x; j < half_; j += ntg.x) {
float timestep = ((device float *)src0)[i];
float freq = (float)exp(-log((float)args.max_period) * j / half_);
float arg = timestep * freq;
embed_data[j ] = cos(arg);
embed_data[j + half_] = sin(arg);
}
if (args.dim % 2 != 0 && tpitg.x == 0) {
embed_data[2 * half_] = 0.f;
}
}
kernel void kernel_opt_step_adamw_f32(
constant ggml_metal_kargs_opt_step_adamw & args,
device float * x,
device const float * g,
device float * g_m,
device float * g_v,
device const float * pars,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
const float alpha = pars[0];
const float beta1 = pars[1];
const float beta2 = pars[2];
const float eps = pars[3];
const float wd = pars[4];
const float beta1h = pars[5];
const float beta2h = pars[6];
const float gi = g[gid];
const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1);
const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2);
g_m[gid] = gmi;
g_v[gid] = gvi;
const float mh = gmi * beta1h;
const float vh = sqrt(gvi * beta2h) + eps;
x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
}
kernel void kernel_opt_step_sgd_f32(
constant ggml_metal_kargs_opt_step_sgd & args,
device float * x,
device const float * g,
device const float * pars,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
}
template<typename T>
kernel void kernel_memset(
constant ggml_metal_kargs_memset & args,
device T * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = args.val;
}
typedef decltype(kernel_memset<int64_t>) kernel_memset_t;
template [[host_name("kernel_memset_i64")]] kernel kernel_memset_t kernel_memset<int64_t>;
constant short FC_count_equal_nsg [[function_constant(FC_COUNT_EQUAL + 0)]];
template<typename T>
kernel void kernel_count_equal(
constant ggml_metal_kargs_count_equal & args,
device const char * src0,
device const char * src1,
device atomic_int * dst,
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const short NSG = FC_count_equal_nsg;
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = tgpig.x;
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return;
}
int sum = 0;
device const char * base0 = src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03;
device const char * base1 = src1 + i1*args.nb11 + i2*args.nb12 + i3*args.nb13;
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
const T v0 = *(device const T *)(base0 + i0*args.nb00);
const T v1 = *(device const T *)(base1 + i0*args.nb10);
sum += (v0 == v1);
}
sum = simd_sum(sum);
if (tiisg == 0) {
shmem_i32[sgitg] = sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgitg == 0) {
float v = 0.0f;
if (tpitg.x < NSG) {
v = shmem_i32[tpitg.x];
}
float total = simd_sum(v);
if (tpitg.x == 0) {
atomic_fetch_add_explicit(dst, (int32_t) total, memory_order_relaxed);
}
}
}
typedef decltype(kernel_count_equal<int32_t>) kernel_count_equal_t;
template [[host_name("kernel_count_equal_i32")]] kernel kernel_count_equal_t kernel_count_equal<int32_t>;
-838
View File
@@ -1,838 +0,0 @@
#include "common.h"
#include "dequantize.h"
constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
constant short FC_mul_mm_ne12 [[function_constant(FC_MUL_MM + 2)]];
constant short FC_mul_mm_ne13 [[function_constant(FC_MUL_MM + 3)]];
constant short FC_mul_mm_r2 [[function_constant(FC_MUL_MM + 4)]];
constant short FC_mul_mm_r3 [[function_constant(FC_MUL_MM + 5)]];
// each block_q contains 16*nl weights
#ifdef GGML_METAL_HAS_TENSOR
template<
typename SA, typename SA_4x4, typename SA_8x8,
typename SB, typename SB_2x4, typename SB_8x8,
typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread SA_4x4 &),
typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm(
constant ggml_metal_kargs_mul_mm & args,
device const char * srcA,
device const char * srcB,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig [[threadgroup_position_in_grid]],
ushort tiitg [[thread_index_in_threadgroup]],
ushort sgitg [[simdgroup_index_in_threadgroup]]) {
(void) sgitg;
// Matrix dimensions: A(M,K) x B(K,N) -> C(M,N)
const int K = args.ne00;
const int M = args.ne0;
const int N = args.ne1;
// Batch dimension handling
const int im = tgpig.z;
const int i12 = im % FC_mul_mm_ne12;
const int i13 = im / FC_mul_mm_ne12;
// Batch offsets for srcA and srcB
const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03;
// Tile dimensions
constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X;
constexpr int NRA = SZ_SIMDGROUP * N_MM_BLOCK_Y * N_MM_SIMD_GROUP_Y;
// Tile offsets in output matrix
const int ra = tgpig.y * NRA;
const int rb = tgpig.x * NRB;
// Threadgroup memory for dequantized A tile only
threadgroup SA * sa = (threadgroup SA *)(shmem);
// Work-item count for A loading
constexpr int A_WORK_ITEMS = NRA * N_MM_NK;
constexpr int NUM_THREADS = N_SIMDWIDTH * N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y;
// tA wraps threadgroup memory
auto tA = tensor(sa, dextents<int32_t, 2>(N_MM_NK_TOTAL, NRA));
// tB wraps device memory directly
device T1 * ptrB = (device T1 *)(srcB + args.nb12*i12 + args.nb13*i13);
const int strideB = args.nb11 / sizeof(T1);
auto tB = tensor(ptrB, dextents<int32_t, 2>(K, N), array<int, 2>({1, strideB}));
// Configure matmul operation
mpp::tensor_ops::matmul2d<
mpp::tensor_ops::matmul2d_descriptor(
NRB, NRA, N_MM_NK_TOTAL, false, true, true,
mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
execution_simdgroups<N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y>> mm;
auto cT = mm.get_destination_cooperative_tensor<decltype(tB), decltype(tA), float>();
// Accumulate partial results over K dimension
for (int loop_k = 0; loop_k < K; loop_k += N_MM_NK_TOTAL) {
// === PHASE 1: Dequantization of A into threadgroup memory ===
for (int work = tiitg; work < A_WORK_ITEMS; work += NUM_THREADS) {
const int row = work / N_MM_NK;
const int k_chunk = work % N_MM_NK;
const int k_pos = loop_k + k_chunk * 16;
const short k_base = k_chunk * 16;
// Bounds check: skip device read if row is out of matrix bounds
if (ra + row < M) {
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
// Element-wise reads when K is not aligned (nb01 not aligned for half4x4/float4x4).
// MSL spec Table 2.5: half4x4 requires 8-byte alignment. When K is odd,
// nb01 = K*2 is not 8-byte aligned, so odd-row pointers are misaligned.
// Mirrors the legacy kernel's existing guard.
device const T0 * row_ptr = (device const T0 *)(srcA + args.nb01 * (ra + row) + offset0);
FOR_UNROLL (short i = 0; i < 16; i++) {
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? (SA) row_ptr[k_pos + i] : (SA)0;
}
} else {
const int block_idx = k_pos / (16 * nl);
const short il = (k_pos / 16) % nl;
device const block_q * row_ptr = (device const block_q *)(srcA + args.nb01 * (ra + row) + offset0);
SA_4x4 temp_a;
dequantize_func(row_ptr + block_idx, il, temp_a);
FOR_UNROLL (short i = 0; i < 16; i++) {
// Zero-pad A for K positions beyond valid range (handles partial K iterations)
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? temp_a[i/4][i%4] : (SA)0;
}
}
} else {
// Zero-pad rows beyond matrix bounds
FOR_UNROLL (short i = 0; i < 16; i++) {
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (SA)0;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// === PHASE 2: Tensor matmul ===
auto mA = tA.slice(0, 0);
auto mB = tB.slice(loop_k, rb);
mm.run(mB, mA, cT);
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// Store result tile to output matrix (with batch offset)
// cT.store handles bounds checking via tD's extents (M, N)
device float * dstBatch = (device float *)dst + im * N * M;
auto tD = tensor(dstBatch, dextents<int32_t, 2>(M, N), array<int, 2>({1, M}));
cT.store(tD.slice(ra, rb));
}
#else
template<
typename S0, typename S0_4x4, typename S0_8x8,
typename S1, typename S1_2x4, typename S1_8x8,
typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &),
typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm(
constant ggml_metal_kargs_mul_mm & args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
constexpr int NR0 = 64;
constexpr int NR1 = 32;
constexpr int NK = 32;
constexpr int NL0 = NK/16;
constexpr int NL1 = NK/8;
const int im = tgpig.z;
const int r0 = tgpig.y*NR0;
const int r1 = tgpig.x*NR1;
// if this block is of 64x32 shape or smaller
const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
const short nr1 = (args.ne1 - r1 < NR1) ? (args.ne1 - r1) : NR1;
// a thread shouldn't load data outside of the matrix
const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
const short il0 = (tiitg % NL0);
short il = il0;
const int i12 = im % FC_mul_mm_ne12;
const int i13 = im / FC_mul_mm_ne12;
const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03;
const short offset1 = il0/nl;
device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
const short iy = 8*(tiitg % NL1);
device const T1 * y = (device const T1 *)(src1
+ args.nb13*i13
+ args.nb12*i12
+ args.nb11*(r1 + lr1)
+ args.nb10*iy);
S0_8x8 ma[4];
S1_8x8 mb[2];
simdgroup_float8x8 mc[8];
for (short i = 0; i < 8; i++){
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
}
for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
// load data and store to threadgroup memory
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// no need for dequantization
for (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
//const short lx = i%8;
//const short ly = (tiitg/NL0)%8;
const short lx = (tiitg/NL0)%8;
const short ly = i%8;
const short ib = 8*sx + sy;
*(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
}
} else {
S0_4x4 temp_a;
dequantize_func(x, il, temp_a);
threadgroup_barrier(mem_flags::mem_threadgroup);
FOR_UNROLL (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
//const short lx = i%8;
//const short ly = (tiitg/NL0)%8;
const short lx = (tiitg/NL0)%8;
const short ly = i%8;
const short ib = 8*sx + sy;
// NOTE: this is massively slower.. WTF?
//sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
*(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
}
}
if (FC_mul_mm_bc_inp) {
for (short i = 0; i < 8; ++i) {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
const short lx = i;
const short ly = (tiitg/NL1)%8;
//const short lx = (tiitg/NL1)%8;
//const short ly = i;
const short ib = 4*sx + sy;
*(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
}
} else {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
//const short dx = sx;
//const short dy = sy;
const short ly = (tiitg/NL1)%8;
const short ib = 4*sx + sy;
*(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
}
il = (il + 2 < nl) ? il + 2 : il % 2;
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
y += NK;
threadgroup_barrier(mem_flags::mem_threadgroup);
// load matrices from threadgroup memory and conduct outer products
threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 4; i++) {
simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
}
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 2; i++) {
simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
}
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 8; i++){
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
}
lsma += 8*64;
lsmb += 4*64;
}
}
if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) {
// if no bounds checks on the output are needed, we can directly write to device memory
device float * C = (device float *) dst +
(r0 + 32*(sgitg & 1)) + \
(r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false);
}
} else {
// block is smaller than 64x32, we should avoid writing data outside of the matrix
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgitg == 0) {
for (int j = tiitg; j < nr1; j += NR1) {
device float * D = (device float *) dst + r0 + (r1 + j)*args.ne0 + im*args.ne1*args.ne0;
device float4 * D4 = (device float4 *) D;
threadgroup float * C = temp_str + (j*NR0);
threadgroup float4 * C4 = (threadgroup float4 *) C;
int i = 0;
for (; i < nr0/4; i++) {
*(D4 + i) = *(C4 + i);
}
i *= 4;
for (; i < nr0; i++) {
*(D + i) = *(C + i);
}
}
}
}
}
#endif // GGML_METAL_HAS_TENSOR
template<short ne20> // n_expert_used
kernel void kernel_mul_mm_id_map0(
constant ggml_metal_kargs_mul_mm_id_map0 & args,
device const char * src2,
device char * htpe,
device char * hids,
threadgroup char * shmem [[threadgroup(0)]],
ushort tpitg[[thread_position_in_threadgroup]],
ushort ntg[[threads_per_threadgroup]]) {
const short ide = tpitg; // expert id
uint32_t n_all = 0;
device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21;
for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens
if (i21 + tpitg < args.ne21) {
device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21);
threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20;
#pragma unroll(ne20)
for (short i20 = 0; i20 < ne20; i20++) {
sids[i20] = src2_i32[i20];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (short t = 0; t < ntg; t++) {
if (i21 + t >= args.ne21) {
break;
}
threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20;
short sel = 0;
#pragma unroll(ne20)
for (short i20 = 0; i20 < ne20; i20++) {
sel += (sids[i20] == ide)*(i20 + 1);
}
ids_i32[n_all] = (i21 + t)*ne20 + sel - 1;
n_all += sel > 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
device uint32_t * tpe_u32 = (device uint32_t *) (htpe);
tpe_u32[ide] = n_all;
}
typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
template [[host_name("kernel_mul_mm_id_map0_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
template [[host_name("kernel_mul_mm_id_map0_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
template [[host_name("kernel_mul_mm_id_map0_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
template [[host_name("kernel_mul_mm_id_map0_ne20_5" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<5>;
template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm_id(
constant ggml_metal_kargs_mul_mm_id & args,
device const char * src0,
device const char * src1,
device const char * htpe,
device const char * hids,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
#ifdef GGML_METAL_HAS_TENSOR
threadgroup float * sc = (threadgroup float *)(shmem);
#endif
constexpr int NR0 = 64;
constexpr int NR1 = 32;
constexpr int NK = 32;
constexpr int NL0 = NK/16;
constexpr int NL1 = NK/8;
const int im = tgpig.z; // expert
const int r0 = tgpig.y*NR0;
const int r1 = tgpig.x*NR1;
device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
device const int32_t * ids_i32 = (device const int32_t *) (hids);
const int32_t neh1 = tpe_u32[im];
if (r1 >= neh1) {
return;
}
// if this block is of 64x32 shape or smaller
const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
const short nr1 = ( neh1 - r1 < NR1) ? ( neh1 - r1) : NR1;
// a thread shouldn't load data outside of the matrix
const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
const short il0 = (tiitg % NL0);
short il = il0;
const int id = ids_i32[im*args.ne21 + r1 + lr1];
const short i11 = (id % args.ne20) % args.ne11;
const short i12 = (id / args.ne20);
const short i13 = 0;
const uint64_t offset0 = im*args.nb02 + i13*args.nb03;
const short offset1 = il0/nl;
device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
const short iy = 8*(tiitg % NL1);
device const T1 * y = (device const T1 *)(src1
+ args.nb13*i13
+ args.nb12*i12
+ args.nb11*i11
+ args.nb10*iy);
#ifndef GGML_METAL_HAS_TENSOR
S0_8x8 ma[4];
S1_8x8 mb[2];
simdgroup_float8x8 mc[8];
for (short i = 0; i < 8; i++){
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
}
#else
auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
mpp::tensor_ops::matmul2d<
mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
execution_simdgroups<4>> mm;
auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
#endif
for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
#ifndef GGML_METAL_HAS_TENSOR
// load data and store to threadgroup memory
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// no need for dequantization
for (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
//const short lx = i%8;
//const short ly = (tiitg/NL0)%8;
const short lx = (tiitg/NL0)%8;
const short ly = i%8;
const short ib = 8*sx + sy;
*(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? (S0) *((device T0 *) x + i) : (S0) 0;
}
} else {
S0_4x4 temp_a;
dequantize_func(x, il, temp_a);
threadgroup_barrier(mem_flags::mem_threadgroup);
FOR_UNROLL (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
//const short lx = i%8;
//const short ly = (tiitg/NL0)%8;
const short lx = (tiitg/NL0)%8;
const short ly = i%8;
const short ib = 8*sx + sy;
// NOTE: this is massively slower.. WTF?
//sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
*(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
}
}
if (FC_mul_mm_bc_inp) {
for (short i = 0; i < 8; ++i) {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
const short lx = i;
const short ly = (tiitg/NL1)%8;
//const short lx = (tiitg/NL1)%8;
//const short ly = i;
const short ib = 4*sx + sy;
*(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
}
} else {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
//const short dx = sx;
//const short dy = sy;
const short ly = (tiitg/NL1)%8;
const short ib = 4*sx + sy;
*(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
}
#else
// load data and store to threadgroup memory
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// no need for dequantization
for (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
const short lx = i%8;
const short ly = (tiitg/NL0)%8;
//const short lx = (tiitg/NL0)%8;
//const short ly = i%8;
*(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
}
} else {
S0_4x4 temp_a;
dequantize_func(x, il, temp_a);
threadgroup_barrier(mem_flags::mem_threadgroup);
FOR_UNROLL (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
const short lx = i%8;
const short ly = (tiitg/NL0)%8;
//const short lx = (tiitg/NL0)%8;
//const short ly = i%8;
*(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];
}
}
if (FC_mul_mm_bc_inp) {
for (short i = 0; i < 8; ++i) {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
const short lx = i;
const short ly = (tiitg/NL1)%8;
//const short lx = (tiitg/NL1)%8;
//const short ly = i;
*(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
}
} else {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
//const short lx = i;
const short ly = (tiitg/NL1)%8;
//const short lx = (tiitg/NL1)%8;
//const short ly = i;
*(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
}
#endif
il = (il + 2 < nl) ? il + 2 : il % 2;
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
y += NK;
threadgroup_barrier(mem_flags::mem_threadgroup);
#ifndef GGML_METAL_HAS_TENSOR
// load matrices from threadgroup memory and conduct outer products
threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 4; i++) {
simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
}
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 2; i++) {
simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
}
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 8; i++){
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
}
lsma += 8*64;
lsmb += 4*64;
}
#else
auto sA = tA.slice(0, 0);
auto sB = tB.slice(0, 0);
mm.run(sB, sA, cT);
#endif
}
// block is smaller than 64x32, we should avoid writing data outside of the matrix
threadgroup_barrier(mem_flags::mem_threadgroup);
#ifdef GGML_METAL_HAS_TENSOR
auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
cT.store(tC);
#else
threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
}
#endif
threadgroup_barrier(mem_flags::mem_threadgroup);
for (short j = sgitg; j < nr1; j += 4) {
const int id = ids_i32[im*args.ne21 + r1 + j];
const short ide = id % args.ne20;
const short idt = id / args.ne20;
device float * D = (device float *) dst + r0 + ide*args.ne0 + idt*args.ne1*args.ne0;
device float4 * D4 = (device float4 *) D;
threadgroup float * C = (threadgroup float *) shmem + j*NR0;
threadgroup float4 * C4 = (threadgroup float4 *) C;
int i = tiisg;
for (; i < nr0/4; i += 32) {
*(D4 + i) = *(C4 + i);
}
i = (4*(nr0/4)) + tiisg;
for (; i < nr0; i += 32) {
*(D + i) = *(C + i);
}
}
}
//
// matrix-matrix multiplication
//
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_t;
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, float, float2x4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
#endif
template [[host_name("kernel_mul_mm_q1_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q1_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q5_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q8_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_mxfp4_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q2_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q3_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q4_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q5_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q6_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq2_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq2_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq3_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq3_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq2_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq1_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq1_m_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq4_nl_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq4_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, half, half2x4>;
//
// indirect matrix-matrix multiplication
//
typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_id;
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, float, float2x4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
#endif
template [[host_name("kernel_mul_mm_id_q1_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_mxfp4_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q1_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, half, half2x4>;
File diff suppressed because it is too large Load Diff
-308
View File
@@ -1,308 +0,0 @@
#include "common.h"
// F == 1 : norm (no fuse)
// F == 2 : norm + mul
// F == 3 : norm + mul + add
template <typename T, short F>
kernel void kernel_norm_fuse_impl(
constant ggml_metal_kargs_norm & args,
device const char * src0,
device const char * src1_0,
device const char * src1_1,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
const int i01 = tgpig.x;
const int i02 = tgpig.y;
const int i03 = tgpig.z;
device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
T sumft(0.0f);
float sumf = 0.0f;
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
sumft += x[i00];
}
sumf = dot(sumft, T(1.0f));
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
const float mean = sumf/args.ne00;
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
sumf = 0.0f;
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
y[i00] = x[i00] - mean;
sumf += dot(y[i00], y[i00]);
}
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
const float variance = sumf/args.ne00;
const float scale = 1.0f/sqrt(variance + args.eps);
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
if (F == 1) {
y[i00] = (y[i00]*scale);
}
if (F == 2) {
y[i00] = (y[i00]*scale)*f0[i00];
}
if (F == 3) {
y[i00] = (y[i00]*scale)*f0[i00] + f1[i00];
}
}
}
typedef decltype(kernel_norm_fuse_impl<float4, 1>) kernel_norm_fuse_t;
template [[host_name("kernel_norm_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 1>;
template [[host_name("kernel_norm_mul_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 2>;
template [[host_name("kernel_norm_mul_add_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 3>;
template [[host_name("kernel_norm_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 1>;
template [[host_name("kernel_norm_mul_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 2>;
template [[host_name("kernel_norm_mul_add_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 3>;
// F == 1 : rms_norm (no fuse)
// F == 2 : rms_norm + mul
// F == 3 : rms_norm + mul + add
template <typename T, short F>
kernel void kernel_rms_norm_fuse_impl(
constant ggml_metal_kargs_norm & args,
device const char * src0,
device const char * src1_0,
device const char * src1_1,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
const int i01 = tgpig.x;
const int i02 = tgpig.y;
const int i03 = tgpig.z;
device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
float sumf = 0.0f;
// parallel sum
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
sumf += dot(x[i00], x[i00]);
}
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
const float mean = sumf/args.ne00;
const float scale = 1.0f/sqrt(mean + args.eps);
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
if (F == 1) {
y[i00] = (x[i00]*scale);
}
if (F == 2) {
y[i00] = (x[i00]*scale)*f0[i00];
}
if (F == 3) {
y[i00] = (x[i00]*scale)*f0[i00] + f1[i00];
}
}
}
typedef decltype(kernel_rms_norm_fuse_impl<float4, 1>) kernel_rms_norm_fuse_t;
template [[host_name("kernel_rms_norm_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 1>;
template [[host_name("kernel_rms_norm_mul_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 2>;
template [[host_name("kernel_rms_norm_mul_add_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 3>;
template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 1>;
template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
template <typename T0, typename T>
kernel void kernel_l2_norm_impl(
constant ggml_metal_kargs_l2_norm & args,
device const char * src0,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig.z;
const int i02 = tgpig.y;
const int i01 = tgpig.x;
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
float sumf = 0.0f;
// parallel sum
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
sumf += dot(x[i00], x[i00]);
}
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
const float scale = 1.0f/max(sqrt(sumf), args.eps);
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
y[i00] = x[i00] * scale;
}
}
typedef decltype(kernel_l2_norm_impl<float, float>) kernel_l2_norm_t;
template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float, float>;
template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float4, float4>;
kernel void kernel_group_norm_f32(
constant ggml_metal_kargs_group_norm & args,
device const float * src0,
device float * dst,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
const int64_t ne = args.ne00*args.ne01*args.ne02;
const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.ngrp - 1) / args.ngrp);
int start = tgpig * gs;
int end = start + gs;
start += tpitg;
if (end >= ne) {
end = ne;
}
float tmp = 0.0f; // partial sum for thread in warp
for (int j = start; j < end; j += ntg) {
tmp += src0[j];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
tmp = simd_sum(tmp);
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = tmp;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
tmp = buf[tiisg];
tmp = simd_sum(tmp);
}
const float mean = tmp / gs;
tmp = 0.0f;
for (int j = start; j < end; j += ntg) {
float xi = src0[j] - mean;
dst[j] = xi;
tmp += xi * xi;
}
tmp = simd_sum(tmp);
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = tmp;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
tmp = buf[tiisg];
tmp = simd_sum(tmp);
}
const float variance = tmp / gs;
const float scale = 1.0f/sqrt(variance + args.eps);
for (int j = start; j < end; j += ntg) {
dst[j] *= scale;
}
}
-148
View File
@@ -1,148 +0,0 @@
#include "common.h"
kernel void kernel_pool_2d_max_f32(
constant ggml_metal_kargs_pool_2d & args,
device const float * src0,
device float * dst,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
const int idx = gid;
const int I_HW = args.IH * args.IW;
const int O_HW = args.OH * args.OW;
const int nc = idx / O_HW;
const int cur_oh = idx % O_HW / args.OW;
const int cur_ow = idx % O_HW % args.OW;
device const float * i_ptr = src0 + nc * I_HW;
device float * o_ptr = dst + nc * O_HW;
const int start_h = cur_oh * args.s1 - args.p1;
const int bh = MAX(0, start_h);
const int eh = MIN(args.IH, start_h + args.k1);
const int start_w = cur_ow * args.s0 - args.p0;
const int bw = MAX(0, start_w);
const int ew = MIN(args.IW, start_w + args.k0);
float res = -INFINITY;
for (int i = bh; i < eh; i += 1) {
for (int j = bw; j < ew; j += 1) {
res = MAX(res, i_ptr[i * args.IW + j]);
}
}
o_ptr[cur_oh * args.OW + cur_ow] = res;
}
kernel void kernel_pool_2d_avg_f32(
constant ggml_metal_kargs_pool_2d & args,
device const float * src0,
device float * dst,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
const int idx = gid;
const int I_HW = args.IH * args.IW;
const int O_HW = args.OH * args.OW;
const int nc = idx / O_HW;
const int cur_oh = idx % O_HW / args.OW;
const int cur_ow = idx % O_HW % args.OW;
device const float * i_ptr = src0 + nc * I_HW;
device float * o_ptr = dst + nc * O_HW;
const int start_h = cur_oh * args.s1 - args.p1;
const int bh = MAX(0, start_h);
const int eh = MIN(args.IH, start_h + args.k1);
const int start_w = cur_ow * args.s0 - args.p0;
const int bw = MAX(0, start_w);
const int ew = MIN(args.IW, start_w + args.k0);
// const float scale = 1. / ((eh - bh) * (ew - bw));
const float scale = 1. / (args.k0 * args.k1);
float res = 0;
for (int i = bh; i < eh; i += 1) {
for (int j = bw; j < ew; j += 1) {
float cur = i_ptr[i * args.IW + j];
res += cur * scale;
}
}
o_ptr[cur_oh * args.OW + cur_ow] = res;
}
kernel void kernel_pool_1d_max_f32(
constant ggml_metal_kargs_pool_1d & args,
device const float * src,
device float * dst,
uint gid [[thread_position_in_grid]]
) {
if (gid >= args.np) {
return;
}
const int ow = (int)gid % args.OW;
const int row = (int)gid / args.OW;
const int base = ow * args.s0 - args.p0;
float acc = -INFINITY;
const int src_off = row * args.IW;
const int dst_off = row * args.OW;
for (int ki = 0; ki < args.k0; ++ki) {
int j = base + ki;
if (j < 0 || j >= args.IW){
continue;
}
float v = src[src_off + j];
acc = max(acc, v);
}
dst[dst_off + ow] = acc;
}
kernel void kernel_pool_1d_avg_f32(
constant ggml_metal_kargs_pool_1d & args,
device const float * src,
device float * dst,
uint gid [[thread_position_in_grid]]
) {
if (gid >= args.np) {
return;
}
const int ow = (int)gid % args.OW;
const int row = (int)gid / args.OW;
const int base = ow * args.s0 - args.p0;
float acc = 0.0f;
int cnt = 0;
const int src_off = row * args.IW;
const int dst_off = row * args.OW;
for (int ki = 0; ki < args.k0; ++ki) {
const int j = base + ki;
if (j < 0 || j >= args.IW) {
continue;
}
acc += src[src_off + j];
cnt += 1;
}
dst[dst_off + ow] = (cnt > 0) ? (acc / (float)cnt) : 0.0f;
}
-213
View File
@@ -1,213 +0,0 @@
#pragma once
#include "common.h"
void quantize_q1_0(device const float * src, device block_q1_0 & dst) {
float sum_abs = 0.0f;
for (int j = 0; j < QK1_0; j++) {
sum_abs += fabs(src[j]);
}
dst.d = sum_abs / QK1_0;
for (int j = 0; j < QK1_0 / 8; j++) {
dst.qs[j] = 0;
}
for (int j = 0; j < QK1_0; j++) {
if (src[j] >= 0.0f) {
dst.qs[j / 8] |= (1 << (j % 8));
}
}
}
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
float max = 0.0f;
for (int j = 0; j < QK4_0; j++) {
const float v = src[j];
if (amax < fabs(v)) {
amax = fabs(v);
max = v;
}
}
const float d = max / -8;
const float id = d ? 1.0f/d : 0.0f;
dst.d = d;
for (int j = 0; j < QK4_0/2; ++j) {
const float x0 = src[0 + j]*id;
const float x1 = src[QK4_0/2 + j]*id;
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
dst.qs[j] = xi0;
dst.qs[j] |= xi1 << 4;
}
}
void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
#pragma METAL fp math_mode(safe)
float min = FLT_MAX;
float max = -FLT_MAX;
for (int j = 0; j < QK4_1; j++) {
const float v = src[j];
if (min > v) min = v;
if (max < v) max = v;
}
const float d = (max - min) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;
dst.d = d;
dst.m = min;
for (int j = 0; j < QK4_1/2; ++j) {
const float x0 = (src[0 + j] - min)*id;
const float x1 = (src[QK4_1/2 + j] - min)*id;
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
dst.qs[j] = xi0;
dst.qs[j] |= xi1 << 4;
}
}
void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
float max = 0.0f;
for (int j = 0; j < QK5_0; j++) {
const float v = src[j];
if (amax < fabs(v)) {
amax = fabs(v);
max = v;
}
}
const float d = max / -16;
const float id = d ? 1.0f/d : 0.0f;
dst.d = d;
uint32_t qh = 0;
for (int j = 0; j < QK5_0/2; ++j) {
const float x0 = src[0 + j]*id;
const float x1 = src[QK5_0/2 + j]*id;
const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
}
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
for (int j = 0; j < 4; ++j) {
dst.qh[j] = qh8[j];
}
}
void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
#pragma METAL fp math_mode(safe)
float max = src[0];
float min = src[0];
for (int j = 1; j < QK5_1; j++) {
const float v = src[j];
min = v < min ? v : min;
max = v > max ? v : max;
}
const float d = (max - min) / 31;
const float id = d ? 1.0f/d : 0.0f;
dst.d = d;
dst.m = min;
uint32_t qh = 0;
for (int j = 0; j < QK5_1/2; ++j) {
const float x0 = (src[0 + j] - min)*id;
const float x1 = (src[QK5_1/2 + j] - min)*id;
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
}
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
for (int j = 0; j < 4; ++j) {
dst.qh[j] = qh8[j];
}
}
void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
for (int j = 0; j < QK8_0; j++) {
const float v = src[j];
amax = MAX(amax, fabs(v));
}
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
dst.d = d;
for (int j = 0; j < QK8_0; ++j) {
const float x0 = src[j]*id;
dst.qs[j] = round(x0);
}
}
void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
float max = 0.0f;
for (int j = 0; j < QK4_NL; j++) {
const float v = src[j];
if (amax < fabs(v)) {
amax = fabs(v);
max = v;
}
}
const float d = max / kvalues_iq4nl_f[0];
const float id = d ? 1.0f/d : 0.0f;
float sumqx = 0, sumq2 = 0;
for (int j = 0; j < QK4_NL/2; ++j) {
const float x0 = src[0 + j]*id;
const float x1 = src[QK4_NL/2 + j]*id;
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
dst.qs[j] = xi0 | (xi1 << 4);
const float v0 = kvalues_iq4nl_f[xi0];
const float v1 = kvalues_iq4nl_f[xi1];
const float w0 = src[0 + j]*src[0 + j];
const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
sumq2 += w0*v0*v0 + w1*v1*v1;
}
dst.d = sumq2 > 0 ? sumqx/sumq2 : d;
}
-389
View File
@@ -1,389 +0,0 @@
#include "common.h"
#include "dequantize.h"
#include "quantize.h"
template<typename T0, typename T1>
kernel void kernel_cpy_t_t(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int32_t i03 = tgpig[2];
const int32_t i02 = tgpig[1];
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
if (i01 >= args.ne01) {
return;
}
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.ne00;) {
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
dst_data[i00] = (T1) src[0];
break;
}
}
typedef decltype(kernel_cpy_t_t<float, float>) kernel_cpy_t;
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, float>;
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, half>;
template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, int32_t>;
template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, float>;
template [[host_name("kernel_cpy_i32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, int32_t>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, bfloat>;
#endif
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<half, float>;
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<half, half>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, float>;
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, bfloat>;
#endif
template<short QK,
typename block_q,
void (*quantize_func)(device const float *, device block_q &)>
kernel void kernel_cpy_f32_q(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int32_t i03 = tgpig[2];
const int32_t i02 = tgpig[1];
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
if (i01 >= args.ne01) {
return;
}
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int32_t i3 = n / (args.ne2*args.ne1*args.ne0);
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00);
quantize_func(src, dst_data[i00]);
break;
}
}
typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t;
template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>;
template [[host_name("kernel_cpy_f32_q1_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK1_0, block_q1_0, quantize_q1_0>;
template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>;
template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>;
template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>;
template [[host_name("kernel_cpy_f32_q5_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_1, block_q5_1, quantize_q5_1>;
template [[host_name("kernel_cpy_f32_iq4_nl")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_NL, block_iq4_nl, quantize_iq4_nl>;
template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
kernel void kernel_cpy_q_f32(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int32_t i03 = tgpig[2];
const int32_t i02 = tgpig[1];
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
if (i01 >= args.ne01) {
return;
}
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
T4x4 temp;
dequantize_func(src_data + i00/nl, i00%nl, temp);
dst_data[i00] = temp;
break;
}
}
typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
template [[host_name("kernel_cpy_q1_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q1_0, 8, dequantize_q1_0>;
template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_cpy_q1_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q1_0, 8, dequantize_q1_0>;
template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q8_0, 2, dequantize_q8_0>;
template<typename T>
kernel void kernel_concat(
constant ggml_metal_kargs_concat & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = ntg.y == 1 ? tgpig.x : tgpig.x*ntg.y + tpitg.y;
if (i1 >= args.ne1) {
return;
}
int o[4] = {0, 0, 0, 0};
o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03));
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
device const T * x;
if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
x = (device const T *)(src0 + (i3 )*args.nb03 + (i2 )*args.nb02 + (i1 )*args.nb01 + (i0 )*args.nb00);
} else {
x = (device const T *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10);
}
device T * y = (device T *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
*y = *x;
}
}
typedef decltype(kernel_concat<float>) kernel_concat_t;
template [[host_name("kernel_concat_f32")]] kernel kernel_concat_t kernel_concat<float>;
template [[host_name("kernel_concat_f16")]] kernel kernel_concat_t kernel_concat<half>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_concat_bf16")]] kernel kernel_concat_t kernel_concat<bfloat>;
#endif
template [[host_name("kernel_concat_i8")]] kernel kernel_concat_t kernel_concat<char>;
template [[host_name("kernel_concat_i16")]] kernel kernel_concat_t kernel_concat<short>;
template [[host_name("kernel_concat_i32")]] kernel kernel_concat_t kernel_concat<int>;
template [[host_name("kernel_concat_i64")]] kernel kernel_concat_t kernel_concat<long>;
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
kernel void kernel_get_rows_q(
constant ggml_metal_kargs_get_rows & args,
device const void * src0,
device const void * src1,
device void * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 ntg [[threads_per_threadgroup]]) {
const int32_t iw0 = tgpig.x/args.ne10;
const int32_t i10 = tgpig.x%args.ne10;
const int32_t i11 = tgpig.y;
const int32_t i12 = tgpig.z;
const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
const int32_t i02 = i11;
const int32_t i03 = i12;
auto psrc = (device const block_q *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
auto pdst = (device float4x4 *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
float4x4 temp;
dequantize_func(psrc + ind/nl, ind%nl, temp);
pdst[ind] = temp;
break;
}
}
template<typename T0, typename T>
kernel void kernel_get_rows_f(
constant ggml_metal_kargs_get_rows & args,
device const void * src0,
device const void * src1,
device void * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 ntg [[threads_per_threadgroup]]) {
const int32_t iw0 = tgpig.x/args.ne10;
const int32_t i10 = tgpig.x%args.ne10;
const int32_t i11 = tgpig.y;
const int32_t i12 = tgpig.z;
const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
const int32_t i02 = i11;
const int32_t i03 = i12;
auto psrc = (const device T0 *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
auto pdst = ( device T *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
pdst[ind] = psrc[ind];
break;
}
}
template<typename TI, typename block_q, void (*quantize_func)(device const float *, device block_q &)>
kernel void kernel_set_rows_q32(
constant ggml_metal_kargs_set_rows & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) {
const int32_t i03 = tgpig.z;
const int32_t i02 = tgpig.y;
const int32_t i12 = i03%args.ne12;
const int32_t i11 = i02%args.ne11;
const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
if (i01 >= args.ne01) {
return;
}
const int32_t i10 = i01;
const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
quantize_func(src_row + 32*ind, dst_row[ind]);
}
}
template<typename T, typename TI>
kernel void kernel_set_rows_f(
constant ggml_metal_kargs_set_rows & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) {
const int32_t i03 = tgpig.z;
const int32_t i02 = tgpig.y;
const int32_t i12 = i03%args.ne12;
const int32_t i11 = i02%args.ne11;
const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
if (i01 >= args.ne01) {
return;
}
const int32_t i10 = i01;
const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
dst_row[ind] = (T) src_row[ind];
}
}
//
// get rows
//
typedef decltype(kernel_get_rows_f<float, float>) get_rows_f_t;
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float, float>;
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half, float>;
template [[host_name("kernel_get_rows_i32")]] kernel get_rows_f_t kernel_get_rows_f<int32_t, int32_t>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat, float>;
#endif
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
template [[host_name("kernel_get_rows_q1_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q1_0, 8, dequantize_q1_0>;
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_get_rows_mxfp4")]] kernel get_rows_q_t kernel_get_rows_q<block_mxfp4, 2, dequantize_mxfp4>;
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_s, QK_NL, dequantize_iq3_s>;
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_s, QK_NL, dequantize_iq2_s>;
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_m, QK_NL, dequantize_iq1_m>;
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
//
// set rows
//
typedef decltype(kernel_set_rows_f<float, int64_t>) set_rows_f_t;
template [[host_name("kernel_set_rows_f32_i64")]] kernel set_rows_f_t kernel_set_rows_f<float, int64_t>;
template [[host_name("kernel_set_rows_f32_i32")]] kernel set_rows_f_t kernel_set_rows_f<float, int32_t>;
template [[host_name("kernel_set_rows_f16_i64")]] kernel set_rows_f_t kernel_set_rows_f<half, int64_t>;
template [[host_name("kernel_set_rows_f16_i32")]] kernel set_rows_f_t kernel_set_rows_f<half, int32_t>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_set_rows_bf16_i64")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int64_t>;
template [[host_name("kernel_set_rows_bf16_i32")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int32_t>;
#endif
typedef decltype(kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>) set_rows_q32_t;
template [[host_name("kernel_set_rows_q8_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>;
template [[host_name("kernel_set_rows_q8_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q8_0, quantize_q8_0>;
template [[host_name("kernel_set_rows_q4_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_0, quantize_q4_0>;
template [[host_name("kernel_set_rows_q4_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_0, quantize_q4_0>;
template [[host_name("kernel_set_rows_q4_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_1, quantize_q4_1>;
template [[host_name("kernel_set_rows_q4_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_1, quantize_q4_1>;
template [[host_name("kernel_set_rows_q5_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_0, quantize_q5_0>;
template [[host_name("kernel_set_rows_q5_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_0, quantize_q5_0>;
template [[host_name("kernel_set_rows_q5_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_1, quantize_q5_1>;
template [[host_name("kernel_set_rows_q5_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_1, quantize_q5_1>;
template [[host_name("kernel_set_rows_iq4_nl_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_iq4_nl, quantize_iq4_nl>;
template [[host_name("kernel_set_rows_iq4_nl_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_iq4_nl, quantize_iq4_nl>;
-228
View File
@@ -1,228 +0,0 @@
#include "common.h"
kernel void kernel_op_sum_f32(
constant ggml_metal_kargs_sum & args,
device const float * src0,
device float * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
if (args.np == 0) {
return;
}
// TODO: become function constant
const uint nsg = (ntg.x + 31) / 32;
float sumf = 0;
for (uint64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
sumf += src0[i0];
}
sumf = simd_sum(sumf);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float total = 0;
if (sgitg == 0) {
float v = 0;
if (tpitg.x < nsg) {
v = shmem_f32[tpitg.x];
}
total = simd_sum(v);
if (tpitg.x == 0) {
dst[0] = total;
}
}
}
constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]];
template <typename T0, typename T>
kernel void kernel_sum_rows_impl(
constant ggml_metal_kargs_sum_rows & args,
device const char * src0,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
#define FC_OP FC_sum_rows_op
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = tgpig.x;
threadgroup T0 * shmem_t = (threadgroup T0 *) shmem;
if (sgitg == 0) {
shmem_t[tiisg] = 0.0f;
}
device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device T * dst_row = (device T *) (dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
T0 sumf = T0(0.0f);
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
sumf += src_row[i0];
}
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_t[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_t[tiisg];
sumf = simd_sum(sumf);
if (tpitg.x == 0) {
if (FC_OP == OP_SUM_ROWS_NUM_MEAN) {
if (is_same<float4, T0>::value) {
dst_row[0] = sum(sumf) / (4*args.ne00);
} else {
dst_row[0] = sum(sumf) / args.ne00;
}
} else {
dst_row[0] = sum(sumf);
}
}
#undef FC_OP
}
typedef decltype(kernel_sum_rows_impl<float, float>) kernel_sum_rows_t;
template [[host_name("kernel_sum_rows_f32_f32")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float, float>;
template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float4, float>;
template<typename T>
kernel void kernel_cumsum_blk(
constant ggml_metal_kargs_cumsum_blk & args,
device const char * src0,
device char * tmp,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int ib = tgpig[0]/args.ne01;
const int i00 = ib*ntg.x;
const int i01 = tgpig[0]%args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];
device const float * src0_row = (device const float *) (src0 +
args.nb01*i01 +
args.nb02*i02 +
args.nb03*i03);
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
float v = 0.0f;
if (i00 + tpitg.x < args.ne00) {
v = src0_row[i00 + tpitg.x];
}
float s = simd_prefix_inclusive_sum(v);
if (tiisg == N_SIMDWIDTH - 1) {
shmem_f32[sgitg] = s;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgitg == 0) {
shmem_f32[tiisg] = simd_prefix_exclusive_sum(shmem_f32[tiisg]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
s += shmem_f32[sgitg];
device float * dst_row = (device float *) dst +
args.ne00*i01 +
args.ne00*args.ne01*i02 +
args.ne00*args.ne01*args.ne02*i03;
if (i00 + tpitg.x < args.ne00) {
dst_row[i00 + tpitg.x] = s;
}
if (args.outb && tpitg.x == ntg.x - 1) {
device float * tmp_row = (device float *) tmp +
args.net0*i01 +
args.net0*args.net1*i02 +
args.net0*args.net1*args.net2*i03;
tmp_row[ib] = s;
}
}
typedef decltype(kernel_cumsum_blk<float>) kernel_cumsum_blk_t;
template [[host_name("kernel_cumsum_blk_f32")]] kernel kernel_cumsum_blk_t kernel_cumsum_blk<float>;
template<typename T>
kernel void kernel_cumsum_add(
constant ggml_metal_kargs_cumsum_add & args,
device const char * tmp,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int ib = tgpig[0]/args.ne01;
if (ib == 0) {
return;
}
const int i00 = ib*ntg.x;
const int i01 = tgpig[0]%args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];
device const float * tmp_row = (device const float *) (tmp +
args.nbt1*i01 +
args.nbt2*i02 +
args.nbt3*i03);
device float * dst_row = (device float *) dst +
args.ne00*i01 +
args.ne00*args.ne01*i02 +
args.ne00*args.ne01*args.ne02*i03;
if (i00 + tpitg.x < args.ne00) {
dst_row[i00 + tpitg.x] += tmp_row[ib - 1];
}
}
typedef decltype(kernel_cumsum_add<float>) kernel_cumsum_add_t;
template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add<float>;
-318
View File
@@ -1,318 +0,0 @@
#include "common.h"
constant bool FC_rope_is_imrope [[function_constant(FC_ROPE + 0)]];
constant bool FC_rope_is_back [[function_constant(FC_ROPE + 1)]];
static float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
}
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
static void rope_yarn(
float theta_extrap, float freq_scale, float corr_dims[2], int i0, float ext_factor, float mscale,
thread float * cos_theta, thread float * sin_theta) {
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
float theta = theta_interp;
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
}
*cos_theta = cos(theta) * mscale;
*sin_theta = sin(theta) * mscale;
if (FC_rope_is_back) {
*sin_theta *= -1.0f;
}
}
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
}
static void rope_yarn_corr_dims(
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
) {
// start and end correction dims
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
}
template<typename T>
kernel void kernel_rope_norm(
constant ggml_metal_kargs_rope & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
const int i3 = tgpig[2];
const int i2 = tgpig[1];
const int i1 = tgpig[0];
float corr_dims[2];
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
device const int32_t * pos = (device const int32_t *) src1;
const float theta_base = (float) pos[i2];
const float inv_ndims = -1.f/args.n_dims;
float cos_theta;
float sin_theta;
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
if (i0 < args.n_dims) {
const int ic = i0/2;
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
const float x0 = src[0];
const float x1 = src[1];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[1] = x0*sin_theta + x1*cos_theta;
} else {
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
template<typename T>
kernel void kernel_rope_neox(
constant ggml_metal_kargs_rope & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
const int i3 = tgpig[2];
const int i2 = tgpig[1];
const int i1 = tgpig[0];
float corr_dims[2];
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
device const int32_t * pos = (device const int32_t *) src1;
const float theta_base = (float) pos[i2];
const float inv_ndims = -1.f/args.n_dims;
float cos_theta;
float sin_theta;
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
if (i0 < args.n_dims) {
const int ic = i0/2;
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
const float x0 = src[0];
const float x1 = src[args.n_dims/2];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
template<typename T>
kernel void kernel_rope_multi(
constant ggml_metal_kargs_rope & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
const int i3 = tgpig[2];
const int i2 = tgpig[1];
const int i1 = tgpig[0];
float corr_dims[2];
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
device const int32_t * pos = (device const int32_t *) src1;
const float inv_ndims = -1.f/args.n_dims;
float cos_theta;
float sin_theta;
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
if (i0 < args.n_dims) {
const int ic = i0/2;
// mrope theta calculations
// note: the rest is the same as kernel_rope_neox
const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3;
const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1
const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
const int sector = ic % sect_dims;
float theta_base;
if (FC_rope_is_imrope) {
if (sector % 3 == 1 && sector < 3 * args.sect_1) { // h
theta_base = (float) pos[i2 + args.ne02 * 1];
} else if (sector % 3 == 2 && sector < 3 * args.sect_2) { // w
theta_base = (float) pos[i2 + args.ne02 * 2];
} else if (sector % 3 == 0 && sector < 3 * args.sect_0) { // t
theta_base = (float) pos[i2 + args.ne02 * 0];
} else { // e
theta_base = (float) pos[i2 + args.ne02 * 3];
}
} else {
if (sector < args.sect_0) {
theta_base = (float) pos[i2];
} else if (sector < sec_w01) {
theta_base = (float) pos[i2 + args.ne02 * 1];
} else if (sector < sec_w012) {
theta_base = (float) pos[i2 + args.ne02 * 2];
} else {
theta_base = (float) pos[i2 + args.ne02 * 3];
}
}
// end of mrope
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
const float x0 = src[0];
const float x1 = src[args.n_dims/2];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
template<typename T>
kernel void kernel_rope_vision(
constant ggml_metal_kargs_rope & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
const int i3 = tgpig[2];
const int i2 = tgpig[1];
const int i1 = tgpig[0];
float corr_dims[2];
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
device const int32_t * pos = (device const int32_t *) src1;
const float inv_ndims = -1.f/args.n_dims;
float cos_theta;
float sin_theta;
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
if (i0 < 2*args.n_dims) { // different from kernel_rope_multi
const int ic = i0/2;
// mrope theta calculations (only support 2 dimensions)
const int sect_dims = args.sect_0 + args.sect_1;
const int sector = ic % sect_dims;
float p;
float theta_base;
if (sector < args.sect_1) {
p = (float) sector;
theta_base = (float) pos[i2];
} else {
p = (float) sector - args.sect_0;
theta_base = (float) pos[i2 + args.ne02];
}
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
// end of mrope
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
const float x0 = src[0];
const float x1 = src[args.n_dims]; // different from kernel_rope_multi
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi
} else {
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
-223
View File
@@ -1,223 +0,0 @@
#include "common.h"
template<typename T>
kernel void kernel_soft_max(
constant ggml_metal_kargs_soft_max & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
threadgroup float * buf [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint3 tptg[[threads_per_threadgroup]]) {
const int32_t i03 = tgpig.z;
const int32_t i02 = tgpig.y;
const int32_t i01 = tgpig.x;
const int32_t i13 = i03%args.ne13;
const int32_t i12 = i02%args.ne12;
const int32_t i11 = i01;
device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
device const float * psrc2 = src2 != src0 ? (device const float *) (src2) : nullptr;
device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
float slope = 1.0f;
// ALiBi
if (args.max_bias > 0.0f) {
const int32_t h = i02;
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
slope = pow(base, exp);
}
// parallel max
float lmax = psrc2 ? psrc2[i02] : -INFINITY;
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
}
// find the max value in the block
float max_val = simd_max(lmax);
if (tptg.x > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = -INFINITY;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = max_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_val = buf[tiisg];
max_val = simd_max(max_val);
}
// parallel sum
float lsum = 0.0f;
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
lsum += exp_psrc0;
pdst[i00] = exp_psrc0;
}
// This barrier fixes a failing test
// ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
threadgroup_barrier(mem_flags::mem_none);
float sum = simd_sum(lsum);
if (tptg.x > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sum = buf[tiisg];
sum = simd_sum(sum);
}
if (psrc2) {
sum += exp(psrc2[i02] - max_val);
}
const float inv_sum = 1.0f/sum;
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
pdst[i00] *= inv_sum;
}
}
template<typename T>
kernel void kernel_soft_max_4(
constant ggml_metal_kargs_soft_max & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
threadgroup float * buf [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint3 tptg[[threads_per_threadgroup]]) {
const int32_t i03 = tgpig.z;
const int32_t i02 = tgpig.y;
const int32_t i01 = tgpig.x;
const int32_t i13 = i03%args.ne13;
const int32_t i12 = i02%args.ne12;
const int32_t i11 = i01;
device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
device const float * psrc2 = src2 != src0 ? (device const float * ) (src2) : nullptr;
device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
float slope = 1.0f;
if (args.max_bias > 0.0f) {
const int32_t h = i02;
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
slope = pow(base, exp);
}
// parallel max
float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
}
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
float max_val = simd_max(lmax);
if (tptg.x > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = -INFINITY;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = max_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_val = buf[tiisg];
max_val = simd_max(max_val);
}
// parallel sum
float4 lsum4 = 0.0f;
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
// This barrier fixes a failing test
// ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
threadgroup_barrier(mem_flags::mem_none);
float sum = simd_sum(lsum);
if (tptg.x > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sum = buf[tiisg];
sum = simd_sum(sum);
}
if (psrc2) {
sum += exp(psrc2[i02] - max_val);
}
const float inv_sum = 1.0f/sum;
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
pdst4[i00] *= inv_sum;
}
}
typedef decltype(kernel_soft_max<float>) kernel_soft_max_t;
typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max<half>;
template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max<float>;
template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
@@ -1,75 +0,0 @@
#include "common.h"
constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]];
constant short FC_solve_tri_n [[function_constant(FC_SOLVE_TRI + 1)]];
constant short FC_solve_tri_k [[function_constant(FC_SOLVE_TRI + 2)]];
kernel void kernel_solve_tri_f32(
constant ggml_metal_kargs_solve_tri & args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
ushort3 tgpig[[threadgroup_position_in_grid]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
constexpr short NW = N_SIMDWIDTH;
const short NSG = FC_solve_tri_nsg;
const short N = FC_solve_tri_n;
const short K = FC_solve_tri_k;
const short NP = PAD2(N, NW);
const int32_t i03 = tgpig.z;
const int32_t i02 = tgpig.y;
const int32_t i01 = tgpig.x*NSG + sgitg;
threadgroup float * sh0 = (threadgroup float *) shmem;
device const float * src0_ptr = (device const float *)(src0 + i02 * args.nb02 + i03 * args.nb03) + sgitg*N;
device const float * src1_ptr = (device const float *)(src1 + i02 * args.nb12 + i03 * args.nb13) + i01;
device float * dst_ptr = (device float *)(dst + i02 * args.nb2 + i03 * args.nb3) + i01;
for (short rr = 0; rr < N; rr += NSG) {
threadgroup_barrier(mem_flags::mem_threadgroup);
{
threadgroup float * sh0_cur = sh0 + sgitg*NP;
for (short t = 0; t*NW < N; ++t) {
const short idx = t*NW + tiisg;
sh0_cur[idx] = src0_ptr[idx];
}
src0_ptr += NSG*N;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (i01 >= args.ne10) {
continue;
}
for (short ir = 0; ir < NSG && rr + ir < N; ++ir) {
const short r = rr + ir;
threadgroup float * sh0_cur = sh0 + ir*NP;
float sum = 0.0f;
for (short t = 0; t*NW < r; ++t) {
const short idx = t*NW + tiisg;
sum += sh0_cur[idx] * dst_ptr[idx*K] * (idx < r);
}
sum = simd_sum(sum);
if (tiisg == 0) {
const float diag = sh0_cur[r];
dst_ptr[r*K] = (src1_ptr[r*K] - sum) / diag;
}
}
}
}
-279
View File
@@ -1,279 +0,0 @@
#include "common.h"
// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
kernel void kernel_ssm_conv_f32_f32(
constant ggml_metal_kargs_ssm_conv & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t ir = tgpig.x;
const int64_t i2 = tgpig.y;
const int64_t i3 = tgpig.z;
const int64_t nc = args.ne10;
//const int64_t ncs = args.ne00;
//const int64_t nr = args.ne01;
//const int64_t n_t = args.ne1;
//const int64_t n_s = args.ne2;
device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc; ++i0) {
sumf += s[i0] * c[i0];
}
x[0] = sumf;
}
kernel void kernel_ssm_conv_f32_f32_4(
constant ggml_metal_kargs_ssm_conv & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t ir = tgpig.x;
const int64_t i2 = tgpig.y;
const int64_t i3 = tgpig.z;
const int64_t nc = args.ne10;
//const int64_t ncs = args.ne00;
//const int64_t nr = args.ne01;
//const int64_t n_t = args.ne1;
//const int64_t n_s = args.ne2;
device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc/4; ++i0) {
sumf += dot(s[i0], c[i0]);
}
x[0] = sumf;
}
constant short FC_ssm_conv_bs [[function_constant(FC_SSM_CONV + 0)]];
// Batched version: each threadgroup processes multiple tokens for better efficiency
// Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens
kernel void kernel_ssm_conv_f32_f32_batched(
constant ggml_metal_kargs_ssm_conv & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
// tgpig.x = row index (ir)
// tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
// tgpig.z = sequence index (i3)
// tpitg.x = thread within batch (0..BATCH_SIZE-1)
const short BATCH_SIZE = FC_ssm_conv_bs;
const int64_t ir = tgpig.x;
const int64_t i2_base = tgpig.y * BATCH_SIZE;
const int64_t i3 = tgpig.z;
const int64_t i2_off = tpitg.x;
const int64_t i2 = i2_base + i2_off;
const int64_t nc = args.ne10; // conv kernel size (typically 4)
const int64_t n_t = args.ne1; // number of tokens
// Bounds check for partial batches at the end
if (i2 >= n_t) {
return;
}
// Load conv weights (shared across all tokens for this row)
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
// Load source for this specific token
device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
// Output location for this token
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc; ++i0) {
sumf += s[i0] * c[i0];
}
x[0] = sumf;
}
kernel void kernel_ssm_conv_f32_f32_batched_4(
constant ggml_metal_kargs_ssm_conv & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
// tgpig.x = row index (ir)
// tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
// tgpig.z = sequence index (i3)
// tpitg.x = thread within batch (0..BATCH_SIZE-1)
const short BATCH_SIZE = FC_ssm_conv_bs;
const int64_t ir = tgpig.x;
const int64_t i2_base = tgpig.y * BATCH_SIZE;
const int64_t i3 = tgpig.z;
const int64_t i2_off = tpitg.x;
const int64_t i2 = i2_base + i2_off;
const int64_t nc = args.ne10; // conv kernel size (typically 4)
const int64_t n_t = args.ne1; // number of tokens
// Bounds check for partial batches at the end
if (i2 >= n_t) {
return;
}
// Load conv weights (shared across all tokens for this row)
device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
// Load source for this specific token
device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
// Output location for this token
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc/4; ++i0) {
sumf += dot(s[i0], c[i0]);
}
x[0] = sumf;
}
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
// Optimized version: reduces redundant memory loads by having one thread load shared values
kernel void kernel_ssm_scan_f32(
constant ggml_metal_kargs_ssm_scan & args,
device const void * src0,
device const void * src1,
device const void * src2,
device const void * src3,
device const void * src4,
device const void * src5,
device const void * src6,
device float * dst,
threadgroup float * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgptg[[simdgroups_per_threadgroup]],
uint3 tgpg[[threadgroups_per_grid]]) {
constexpr short NW = N_SIMDWIDTH;
// Shared memory layout:
// [0..sgptg*NW-1]: partial sums for reduction (existing)
// [sgptg*NW..sgptg*NW+sgptg-1]: pre-computed x_dt values for each token in batch
// [sgptg*NW+sgptg..sgptg*NW+2*sgptg-1]: pre-computed dA values for each token in batch
threadgroup float * shared_sums = shared;
threadgroup float * shared_x_dt = shared + sgptg * NW;
threadgroup float * shared_dA = shared + sgptg * NW + sgptg;
shared_sums[tpitg.x] = 0.0f;
const int32_t i0 = tpitg.x;
const int32_t i1 = tgpig.x;
const int32_t ir = tgpig.y; // current head
const int32_t i3 = tgpig.z; // current seq
const int32_t nc = args.d_state;
const int32_t nr = args.d_inner;
const int32_t nh = args.n_head;
const int32_t ng = args.n_group;
const int32_t n_t = args.n_seq_tokens;
const int32_t s_off = args.s_off;
device const int32_t * ids = (device const int32_t *) src6;
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
const int32_t i = i0 + i1*nc;
const int32_t g = ir / (nh / ng); // repeat_interleave
float s0 = s0_buff[i];
float s = 0.0f;
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh}
const float A0 = A[i0%args.ne30];
device const float * x = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13); // {dim, nh, nt, ns}
device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22); // {nh, nt, ns}
device const float * B = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43); // {d_state, ng, nt, ns}
device const float * C = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53); // {d_state, ng, nt, ns}
device float * y = dst + (i1 + ir*(nr) + i3*(n_t*nh*nr)); // {dim, nh, nt, ns}
for (int i2 = 0; i2 < n_t; i2 += sgptg) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Pre-compute x_dt and dA for this batch of tokens
// Only first sgptg threads do the loads and expensive math
if (i0 < sgptg && i2 + i0 < n_t) {
// ns12 and ns21 are element strides (nb12/nb10, nb21/nb20)
device const float * x_t = x + i0 * args.ns12;
device const float * dt_t = dt + i0 * args.ns21;
const float dt0 = dt_t[0];
const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
shared_x_dt[i0] = x_t[0] * dtsp;
shared_dA[i0] = dtsp; // Store dtsp, compute exp(dtsp * A0) per-thread since A0 varies
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
const float x_dt = shared_x_dt[t];
const float dA = exp(shared_dA[t] * A0);
s = (s0 * dA) + (B[i0] * x_dt);
const float sumf = simd_sum(s * C[i0]);
if (tiisg == 0) {
shared_sums[t*NW + sgitg] = sumf;
}
// recurse
s0 = s;
B += args.ns42;
C += args.ns52;
}
// Advance pointers for next batch
x += sgptg * args.ns12;
dt += sgptg * args.ns21;
threadgroup_barrier(mem_flags::mem_threadgroup);
const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]);
if (tiisg == 0 && i2 + sgitg < n_t) {
y[sgitg*nh*nr] = sumf;
}
y += sgptg*nh*nr;
}
s_buff[i] = s;
}
-69
View File
@@ -1,69 +0,0 @@
#include "common.h"
template<uint32_t ttype>
bool _ggml_vec_tri_cmp(const int i, const int r);
template<>
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER */ 3>(const int i, const int r) {
return i < r;
}
template<>
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER_DIAG */ 2>(const int i, const int r) {
return i <= r;
}
template<>
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER */ 1>(const int i, const int r) {
return i > r;
}
template<>
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER_DIAG */ 0>(const int i, const int r) {
return i >= r;
}
template<typename T, int ttype>
kernel void kernel_tri(
constant ggml_metal_kargs_tri & args,
device const char * src0,
device const char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = tgpig.x;
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return;
}
device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
// Each thread is a single element of the row if ne00 < max threads per
// threadgroup, so this will loop once for each index that this thread is
// responsible for
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
// Use the comparison as a mask for branchless
dst_row[i0] = static_cast<T>(_ggml_vec_tri_cmp<ttype>(i0, i1)) * src_row[i0];
}
}
typedef decltype(kernel_tri<float, 0>) kernel_tri_t;
template [[host_name("kernel_tri_f32_0")]] kernel kernel_tri_t kernel_tri<float, 0>;
template [[host_name("kernel_tri_f32_1")]] kernel kernel_tri_t kernel_tri<float, 1>;
template [[host_name("kernel_tri_f32_2")]] kernel kernel_tri_t kernel_tri<float, 2>;
template [[host_name("kernel_tri_f32_3")]] kernel kernel_tri_t kernel_tri<float, 3>;
template [[host_name("kernel_tri_f16_0")]] kernel kernel_tri_t kernel_tri<half, 0>;
template [[host_name("kernel_tri_f16_1")]] kernel kernel_tri_t kernel_tri<half, 1>;
template [[host_name("kernel_tri_f16_2")]] kernel kernel_tri_t kernel_tri<half, 2>;
template [[host_name("kernel_tri_f16_3")]] kernel kernel_tri_t kernel_tri<half, 3>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_tri_bf16_0")]] kernel kernel_tri_t kernel_tri<bfloat, 0>;
template [[host_name("kernel_tri_bf16_1")]] kernel kernel_tri_t kernel_tri<bfloat, 1>;
template [[host_name("kernel_tri_bf16_2")]] kernel kernel_tri_t kernel_tri<bfloat, 2>;
template [[host_name("kernel_tri_bf16_3")]] kernel kernel_tri_t kernel_tri<bfloat, 3>;
#endif
-360
View File
@@ -1,360 +0,0 @@
#include "common.h"
constant short FC_unary_op [[function_constant(FC_UNARY + 0)]];
constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]];
template <typename T0, typename T, typename TC>
kernel void kernel_unary_impl(
constant ggml_metal_kargs_unary & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
#define FC_OP FC_unary_op
#define FC_CNT FC_unary_cnt
device const T0 * src0_ptr;
device T * dst_ptr;
int i0;
if (FC_CNT) {
i0 = tgpig.x;
src0_ptr = (device const T0 *) (src0);
dst_ptr = (device T *) (dst);
} else {
const int i03 = tgpig.z;
const int i02 = tgpig.y;
const int k0 = tgpig.x/args.ne01;
const int i01 = tgpig.x - k0*args.ne01;
i0 = k0*ntg.x + tpitg.x;
src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 );
}
{
//threadgroup_barrier(mem_flags::mem_none);
if (!FC_CNT) {
if (i0 >= args.ne0) {
return;
}
}
const TC x = (TC) src0_ptr[i0];
if (FC_OP == OP_UNARY_NUM_SCALE) {
dst_ptr[i0] = (T) (args.scale * x + args.bias);
}
if (FC_OP == OP_UNARY_NUM_FILL) {
dst_ptr[i0] = (T) args.val;
}
if (FC_OP == OP_UNARY_NUM_CLAMP) {
dst_ptr[i0] = (T) clamp(x, args.min, args.max);
}
if (FC_OP == OP_UNARY_NUM_SQR) {
dst_ptr[i0] = (T) (x * x);
}
if (FC_OP == OP_UNARY_NUM_SQRT) {
dst_ptr[i0] = (T) sqrt(x);
}
if (FC_OP == OP_UNARY_NUM_SIN) {
dst_ptr[i0] = (T) sin(x);
}
if (FC_OP == OP_UNARY_NUM_COS) {
dst_ptr[i0] = (T) cos(x);
}
if (FC_OP == OP_UNARY_NUM_LOG) {
dst_ptr[i0] = (T) log(x);
}
if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) {
dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope));
}
if (FC_OP == OP_UNARY_NUM_TANH) {
dst_ptr[i0] = (T) precise::tanh(x);
}
if (FC_OP == OP_UNARY_NUM_RELU) {
dst_ptr[i0] = (T) fmax(0, x);
}
if (FC_OP == OP_UNARY_NUM_SIGMOID) {
dst_ptr[i0] = (T) (1 / (1 + exp(-x)));
}
if (FC_OP == OP_UNARY_NUM_GELU) {
dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x))));
}
if (FC_OP == OP_UNARY_NUM_GELU_ERF) {
dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x)));
}
if (FC_OP == OP_UNARY_NUM_GELU_QUICK) {
dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x))));
}
if (FC_OP == OP_UNARY_NUM_SILU) {
dst_ptr[i0] = (T) (x / (1 + exp(-x)));
}
if (FC_OP == OP_UNARY_NUM_ELU) {
dst_ptr[i0] = (T) elu_approx(x);
}
if (FC_OP == OP_UNARY_NUM_NEG) {
dst_ptr[i0] = (T) -x;
}
if (FC_OP == OP_UNARY_NUM_ABS) {
dst_ptr[i0] = (T) fabs(x);
}
if (FC_OP == OP_UNARY_NUM_SGN) {
dst_ptr[i0] = T(x > 0) - T(x < 0);
}
if (FC_OP == OP_UNARY_NUM_STEP) {
dst_ptr[i0] = T(x > 0);
}
if (FC_OP == OP_UNARY_NUM_HARDSWISH) {
dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5)));
}
if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) {
dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5));
}
if (FC_OP == OP_UNARY_NUM_EXP) {
dst_ptr[i0] = (T) exp(x);
}
if (FC_OP == OP_UNARY_NUM_SOFTPLUS) {
dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20);
}
if (FC_OP == OP_UNARY_NUM_EXPM1) {
// TODO: precise implementation
dst_ptr[i0] = (T) (exp(x) - 1);
}
if (FC_OP == OP_UNARY_NUM_FLOOR) {
dst_ptr[i0] = (T) floor(x);
}
if (FC_OP == OP_UNARY_NUM_CEIL) {
dst_ptr[i0] = (T) ceil(x);
}
if (FC_OP == OP_UNARY_NUM_ROUND) {
dst_ptr[i0] = (T) round(x);
}
if (FC_OP == OP_UNARY_NUM_TRUNC) {
dst_ptr[i0] = (T) trunc(x);
}
if (FC_OP == OP_UNARY_NUM_XIELU) {
const TC xi = x;
const TC gate = TC(xi > TC(0.0f));
const TC clamped = fmin(xi, TC(args.val));
const TC y_pos = TC(args.scale) * xi * xi + TC(args.bias) * xi;
const TC y_neg = (exp(clamped) - TC(1.0f) - xi) * TC(args.slope) + TC(args.bias) * xi;
dst_ptr[i0] = (T) (gate * y_pos + (TC(1.0f) - gate) * y_neg);
}
}
#undef FC_OP
#undef FC_CNT
}
typedef decltype(kernel_unary_impl<float, float, float>) kernel_unary_t;
template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl<float, float, float>;
template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl<float4, float4, float4>;
template [[host_name("kernel_unary_f16_f16")]] kernel kernel_unary_t kernel_unary_impl<half, half, float>;
template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl<half4, half4, float4>;
template<typename T>
kernel void kernel_reglu(
constant ggml_metal_kargs_glu & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
dst_row[i0] = (T)(x0*x1*(x0 > 0.0f));
}
}
typedef decltype(kernel_reglu<float>) kernel_reglu_t;
template [[host_name("kernel_reglu_f32")]] kernel kernel_reglu_t kernel_reglu<float>;
template [[host_name("kernel_reglu_f16")]] kernel kernel_reglu_t kernel_reglu<half>;
template<typename T>
kernel void kernel_geglu(
constant ggml_metal_kargs_glu & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
dst_row[i0] = (T)(gelu*x1);
}
}
typedef decltype(kernel_geglu<float>) kernel_geglu_t;
template [[host_name("kernel_geglu_f32")]] kernel kernel_geglu_t kernel_geglu<float>;
template [[host_name("kernel_geglu_f16")]] kernel kernel_geglu_t kernel_geglu<half>;
template<typename T>
kernel void kernel_swiglu(
constant ggml_metal_kargs_glu & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
const float silu = x0 / (1.0f + exp(-x0));
dst_row[i0] = (T)(silu*x1);
}
}
typedef decltype(kernel_swiglu<float>) kernel_swiglu_t;
template [[host_name("kernel_swiglu_f32")]] kernel kernel_swiglu_t kernel_swiglu<float>;
template [[host_name("kernel_swiglu_f16")]] kernel kernel_swiglu_t kernel_swiglu<half>;
template<typename T>
kernel void kernel_swiglu_oai(
constant ggml_metal_kargs_glu & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
float x0 = src0_row[i0];
float x1 = src1_row[i0];
x0 = min(x0, args.limit);
x1 = max(min(x1, args.limit), -args.limit);
float out_glu = x0 / (1.0f + exp(-x0 * args.alpha));
out_glu = out_glu * (1.0f + x1);
dst_row[i0] = (T)out_glu;
}
}
typedef decltype(kernel_swiglu_oai<float>) kernel_swiglu_oai_t;
template [[host_name("kernel_swiglu_oai_f32")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<float>;
template [[host_name("kernel_swiglu_oai_f16")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<half>;
template<typename T>
kernel void kernel_geglu_erf(
constant ggml_metal_kargs_glu & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
dst_row[i0] = (T)(gelu_erf*x1);
}
}
typedef decltype(kernel_geglu_erf<float>) kernel_geglu_erf_t;
template [[host_name("kernel_geglu_erf_f32")]] kernel kernel_geglu_erf_t kernel_geglu_erf<float>;
template [[host_name("kernel_geglu_erf_f16")]] kernel kernel_geglu_erf_t kernel_geglu_erf<half>;
template<typename T>
kernel void kernel_geglu_quick(
constant ggml_metal_kargs_glu & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
dst_row[i0] = (T)(gelu_quick*x1);
}
}
typedef decltype(kernel_geglu_quick<float>) kernel_geglu_quick_t;
template [[host_name("kernel_geglu_quick_f32")]] kernel kernel_geglu_quick_t kernel_geglu_quick<float>;
template [[host_name("kernel_geglu_quick_f16")]] kernel kernel_geglu_quick_t kernel_geglu_quick<half>;
-179
View File
@@ -1,179 +0,0 @@
#include "common.h"
constant bool FC_upscale_aa [[function_constant(FC_UPSCALE + 0)]];
kernel void kernel_upscale_nearest_f32(
constant ggml_metal_kargs_upscale & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t i03 = i3/args.sf3;
const int64_t i02 = i2/args.sf2;
const int64_t i01 = i1/args.sf1;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int64_t i00 = i0/args.sf0;
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_ptr[0] = src0_ptr[0];
}
}
static inline float bilinear_tri(float x) {
return MAX(0.0f, 1.0f - fabs(x));
}
kernel void kernel_upscale_bilinear_f32(
constant ggml_metal_kargs_upscale & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t i03 = i3 / args.sf3;
const int64_t i02 = i2 / args.sf2;
const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
const int64_t i01 = MAX(0, MIN(args.ne01 - 1, (int64_t)floor(f01)));
const int64_t i01p = MAX(0, MIN(args.ne01 - 1, i01 + 1));
const float fd1 = MAX(0.0f, MIN(1.0f, f01 - (float)i01));
src0 += i03*args.nb03 + i02*args.nb02;
device float * dst_ptr = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
if (FC_upscale_aa) {
const float support0 = MAX(1.0f, 1.0f / args.sf0);
const float invscale0 = 1.0f / support0;
const float support1 = MAX(1.0f, 1.0f / args.sf1);
const float invscale1 = 1.0f / support1;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
int64_t x_min = MAX((int64_t)0, (int64_t)floor(f00 - support0 + args.poffs));
int64_t x_max = MIN(args.ne00, (int64_t)ceil (f00 + support0 + args.poffs));
int64_t y_min = MAX((int64_t)0, (int64_t)floor(f01 - support1 + args.poffs));
int64_t y_max = MIN(args.ne01, (int64_t)ceil (f01 + support1 + args.poffs));
float sum = 0.0f;
float wsum = 0.0f;
for (int64_t sy = y_min; sy < y_max; ++sy) {
const float wy = MAX(0.0f, 1.0f - fabs((float)sy - f01) * invscale1);
for (int64_t sx = x_min; sx < x_max; ++sx) {
const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0);
const float w = wx * wy;
device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
sum += (*src_ptr) * w;
wsum += w;
}
}
const float v = (wsum > 0.0f) ? (sum / wsum) : 0.0f;
dst_ptr[i0] = v;
}
} else {
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
const int64_t i00 = MAX(0, MIN(args.ne00 - 1, (int64_t)floor(f00)));
const int64_t i00p = MAX(0, MIN(args.ne00 - 1, i00 + 1));
const float fd0 = MAX(0.0f, MIN(1.0f, f00 - (float)i00));
device const float * src00 = (device const float *)(src0 + i01*args.nb01 + i00*args.nb00);
device const float * src10 = (device const float *)(src0 + i01*args.nb01 + i00p*args.nb00);
device const float * src01 = (device const float *)(src0 + i01p*args.nb01 + i00*args.nb00);
device const float * src11 = (device const float *)(src0 + i01p*args.nb01 + i00p*args.nb00);
const float v =
(*src00) * (1.0f - fd0) * (1.0f - fd1) +
(*src10) * fd0 * (1.0f - fd1) +
(*src01) * (1.0f - fd0) * fd1 +
(*src11) * fd0 * fd1;
dst_ptr[i0] = v;
}
}
}
static inline float bicubic_weight1(float x) {
const float a = -0.75f;
return ((a + 2) * x - (a + 3)) * x * x + 1;
}
static inline float bicubic_weight2(float x) {
const float a = -0.75f;
return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a;
}
kernel void kernel_upscale_bicubic_f32(
constant ggml_metal_kargs_upscale & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t i03 = i3 / args.sf3;
const int64_t i02 = i2 / args.sf2;
const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
const int64_t i01 = (int64_t)floor(f01);
const float fd1 = f01 - (float)i01;
const float w_y0 = bicubic_weight2(fd1 + 1.0f);
const float w_y1 = bicubic_weight1(fd1);
const float w_y2 = bicubic_weight1(1.0f - fd1);
const float w_y3 = bicubic_weight2(2.0f - fd1);
const device const char * src_slice = src0 + i03 * args.nb03 + i02 * args.nb02;
device float * dst_ptr = (device float *)(dst + i3 * args.nb3 + i2 * args.nb2 + i1 * args.nb1);
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
const int64_t i00 = (int64_t)floor(f00);
const float fd0 = f00 - (float)i00;
const float w_x0 = bicubic_weight2(fd0 + 1.0f);
const float w_x1 = bicubic_weight1(fd0);
const float w_x2 = bicubic_weight1(1.0f - fd0);
const float w_x3 = bicubic_weight2(2.0f - fd0);
float sum = 0.0f;
for (int dy = -1; dy <= 2; ++dy) {
const int64_t iy = MAX(0, MIN(args.ne01 - 1, i01 + dy));
const float wy = (dy == -1) ? w_y0 : (dy == 0) ? w_y1 : (dy == 1) ? w_y2 : w_y3;
for (int dx = -1; dx <= 2; ++dx) {
const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx));
const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3;
device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
sum += (*src_ptr) * wx * wy;
}
}
dst_ptr[i0] = sum;
}
}
-179
View File
@@ -1,179 +0,0 @@
#include "common.h"
kernel void kernel_rwkv_wkv6_f32(
device const float * k,
device const float * v,
device const float * r,
device const float * tf,
device const float * td,
device const float * state_in,
device float * dst,
constant uint & B,
constant uint & T,
constant uint & C,
constant uint & H,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const uint head_size = 64; // TODO: support head_size = 128
const uint batch_id = tgpig.x / H;
const uint head_id = tgpig.x % H;
const uint tid = tpitg.x;
if (batch_id >= B || head_id >= H) {
return;
}
const uint state_size = C * head_size;
const uint n_seq_tokens = T / B;
threadgroup float _k[head_size];
threadgroup float _r[head_size];
threadgroup float _tf[head_size];
threadgroup float _td[head_size];
float state[head_size];
for (uint i = 0; i < head_size; i++) {
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+ i * head_size + tid];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
_tf[tid] = tf[head_id * head_size + tid];
threadgroup_barrier(mem_flags::mem_threadgroup);
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
for (uint t = start_t; t < end_t; t += C) {
threadgroup_barrier(mem_flags::mem_threadgroup);
_k[tid] = k[t];
_r[tid] = r[t];
_td[tid] = td[t];
threadgroup_barrier(mem_flags::mem_threadgroup);
const float v_val = v[t];
float y = 0.0;
for (uint j = 0; j < head_size; j += 4) {
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
float4 kv = k_vec * v_val;
float4 temp = tf_vec * kv + s_vec;
y += dot(r_vec, temp);
s_vec = s_vec * td_vec + kv;
state[j] = s_vec[0];
state[j+1] = s_vec[1];
state[j+2] = s_vec[2];
state[j+3] = s_vec[3];
}
dst[t] = y;
}
for (uint i = 0; i < head_size; i++) {
dst[T * C + batch_id * state_size + head_id * head_size * head_size
+ i * head_size + tid] = state[i];
}
}
kernel void kernel_rwkv_wkv7_f32(
device const float * r,
device const float * w,
device const float * k,
device const float * v,
device const float * a,
device const float * b,
device const float * state_in,
device float * dst,
constant uint & B,
constant uint & T,
constant uint & C,
constant uint & H,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const uint head_size = 64; // TODO: support head_size = 128
const uint batch_id = tgpig.x / H;
const uint head_id = tgpig.x % H;
const uint tid = tpitg.x;
if (batch_id >= B || head_id >= H) {
return;
}
const uint state_size = C * head_size;
const uint n_seq_tokens = T / B;
threadgroup float _r[head_size];
threadgroup float _w[head_size];
threadgroup float _k[head_size];
threadgroup float _a[head_size];
threadgroup float _b[head_size];
float state[head_size];
for (uint i = 0; i < head_size; i++) {
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+ tid * head_size + i];
}
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
for (uint t = start_t; t < end_t; t += C) {
threadgroup_barrier(mem_flags::mem_threadgroup);
_r[tid] = r[t];
_w[tid] = w[t];
_k[tid] = k[t];
_a[tid] = a[t];
_b[tid] = b[t];
threadgroup_barrier(mem_flags::mem_threadgroup);
const float v_val = v[t];
float y = 0.0, sa = 0.0;
float4 sa_vec(0.0);
for (uint j = 0; j < head_size; j += 4) {
float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
sa_vec += a_vec * s_vec;
}
sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
for (uint j = 0; j < head_size; j += 4) {
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
float4 kv = k_vec * v_val;
s_vec = s_vec * w_vec + kv + sa * b_vec;
y += dot(s_vec, r_vec);
state[j] = s_vec[0];
state[j+1] = s_vec[1];
state[j+2] = s_vec[2];
state[j+3] = s_vec[3];
}
dst[t] = y;
}
for (uint i = 0; i < head_size; i++) {
dst[T * C + batch_id * state_size + head_id * head_size * head_size
+ tid * head_size + i] = state[i];
}
}
+3
View File
@@ -192,7 +192,10 @@ set(GGML_OPENCL_KERNELS
mul_mm_f16_f32_kq_kqv
conv2d
conv2d_f16_f32
flash_attn_pre_f16
flash_attn_f32_f16
flash_attn_f32_q8_0
flash_attn_f32_q4_0
flash_attn_f16
flash_attn_f32
)
+91
View File
@@ -0,0 +1,91 @@
#pragma once
// Flash-attention per-(dk,dv) tile tuning for the Adreno OpenCL backend.
// Isolated from ggml-opencl.cpp so the tuning numbers are easy to find and
// edit; the FA dispatch and kernel-compile logic stay in the main file.
// This header is a file section — it is #included exactly once, at the point
// in ggml-opencl.cpp where the ggml logging macros are already in scope.
// Per-(dk, dv) FA config; shared by dispatch and supports_op.
struct ggml_opencl_fa_dim {
int dk; int dv; int bm; int bn; int n_split; int nkv_split_threshold;
};
// Split variant fires when n_kv >= threshold (threshold=0 -> always split).
// Default tuning covers Adreno 7xx/8xx mobile and X1-series laptop GPUs.
static const ggml_opencl_fa_dim g_fa_dims_adreno_default[] = {
{ 40, 40, 64, 32, 1, 0}, { 64, 64, 64, 32, 2, 64},
{ 80, 80, 64, 32, 2, 64}, { 96, 96, 64, 32, 2, 64},
{112, 112, 64, 32, 2, 64}, {128, 128, 64, 32, 2, 64},
{192, 128, 16, 16, 1, 0},
{192, 192, 16, 16, 1, 0},
{256, 256, 16, 16, 16, 0},
};
struct ggml_opencl_fa_dim_table {
const ggml_opencl_fa_dim * data;
size_t count;
const ggml_opencl_fa_dim * begin() const { return data; }
const ggml_opencl_fa_dim * end() const { return data + count; }
};
// Mutable copy of the active table; GGML_OPENCL_FA_TUNE patches entries here
// at backend init without touching the const source table.
static ggml_opencl_fa_dim g_fa_dims_runtime[
sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0])];
static ggml_opencl_fa_dim_table g_opencl_fa_dims = {
g_fa_dims_adreno_default,
sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0]),
};
// GGML_OPENCL_FA_TUNE=dk:dv:bm:bn:nsplit:thr[,…] — patches matching entries
// in the active table at backend init, before the first FA kernel compiles.
// Unmatched (dk,dv) pairs are warned and ignored.
static void ggml_opencl_fa_apply_env_overrides() {
const char * e = std::getenv("GGML_OPENCL_FA_TUNE");
if (!e || !e[0]) {
return;
}
std::string s = e;
size_t pos = 0;
while (pos < s.size()) {
size_t comma = s.find(',', pos);
std::string entry = s.substr(pos, comma == std::string::npos ? std::string::npos : comma - pos);
int dk, dv, bm, bn, nsplit, thr;
if (std::sscanf(entry.c_str(), "%d:%d:%d:%d:%d:%d", &dk, &dv, &bm, &bn, &nsplit, &thr) == 6) {
bool patched = false;
for (size_t i = 0; i < g_opencl_fa_dims.count; ++i) {
ggml_opencl_fa_dim & d = g_fa_dims_runtime[i];
if (d.dk == dk && d.dv == dv) {
d.bm = bm; d.bn = bn; d.n_split = nsplit; d.nkv_split_threshold = thr;
GGML_LOG_INFO("ggml_opencl: FA tune override DK=%d DV=%d -> bm=%d bn=%d n_split=%d thr=%d\n",
dk, dv, bm, bn, nsplit, thr);
patched = true;
break;
}
}
if (!patched) {
GGML_LOG_WARN("ggml_opencl: FA tune override DK=%d DV=%d ignored (no matching dim)\n", dk, dv);
}
} else {
GGML_LOG_WARN("ggml_opencl: FA tune override entry malformed: '%s'\n", entry.c_str());
}
if (comma == std::string::npos) break;
pos = comma + 1;
}
}
// Copy the default table into the mutable runtime buffer and apply any
// GGML_OPENCL_FA_TUNE overrides. A per-generation table can be added here
// once it has been tuned on hardware.
static void ggml_cl_init_fa_dims_table() {
const size_t count = sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0]);
for (size_t i = 0; i < count; ++i) {
g_fa_dims_runtime[i] = g_fa_dims_adreno_default[i];
}
g_opencl_fa_dims = { g_fa_dims_runtime, count };
ggml_opencl_fa_apply_env_overrides();
}
File diff suppressed because it is too large Load Diff
+152
View File
@@ -1582,6 +1582,158 @@ kernel void kernel_restore_block_q8_0(
}
}
// View-aware AoS q8_0 -> f32 dequant (f32/f32 FA path).
kernel void kernel_dequant_q8_0_f32_view_aos(
global char * src,
ulong src_offset,
ulong src_nb1,
ulong src_nb2,
ulong src_nb3,
int nblk0,
int ne1,
int ne2,
int ne3,
global float * dst
) {
int blk_i0 = get_global_id(0);
int i1 = get_global_id(1);
int batch = get_global_id(2);
if (blk_i0 >= nblk0) return;
if (i1 >= ne1) return;
int i2 = batch % ne2;
int i3 = batch / ne2;
if (i3 >= ne3) return;
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK8_0);
float d = vload_half(0, (global half *)block);
global char * qs = block + 2;
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
global float * out = dst + (dst_row_base + blk_i0) * QK8_0;
for (int i = 0; i < QK8_0; ++i) {
out[i] = d * (float)qs[i];
}
}
// View-aware AoS q8_0 -> f16 dequant. Rows tight, batch strides may be gapped.
kernel void kernel_dequant_q8_0_f16_view_aos(
global char * src,
ulong src_offset,
ulong src_nb1,
ulong src_nb2,
ulong src_nb3,
int nblk0,
int ne1,
int ne2,
int ne3,
global half * dst
) {
int blk_i0 = get_global_id(0);
int i1 = get_global_id(1);
int batch = get_global_id(2);
if (blk_i0 >= nblk0) return;
if (i1 >= ne1) return;
int i2 = batch % ne2;
int i3 = batch / ne2;
if (i3 >= ne3) return;
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK8_0);
float d = vload_half(0, (global half *)block);
global char * qs = block + 2;
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
global half * out = dst + (dst_row_base + blk_i0) * QK8_0;
for (int i = 0; i < QK8_0; ++i) {
out[i] = (half)(d * (float)qs[i]);
}
}
// View-aware AoS q4_0 -> f32 dequant (mirrors the q8_0 view variant).
kernel void kernel_dequant_q4_0_f32_view_aos(
global char * src,
ulong src_offset,
ulong src_nb1,
ulong src_nb2,
ulong src_nb3,
int nblk0,
int ne1,
int ne2,
int ne3,
global float * dst
) {
int blk_i0 = get_global_id(0);
int i1 = get_global_id(1);
int batch = get_global_id(2);
if (blk_i0 >= nblk0) return;
if (i1 >= ne1) return;
int i2 = batch % ne2;
int i3 = batch / ne2;
if (i3 >= ne3) return;
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK4_0/2);
float d = vload_half(0, (global half *)block);
global uchar * qs = (global uchar *)(block + 2);
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
global float * out = dst + (dst_row_base + blk_i0) * QK4_0;
for (int i = 0; i < QK4_0/2; ++i) {
uchar byte = qs[i];
int q0 = (int)(byte & 0x0F) - 8;
int q1 = (int)(byte >> 4) - 8;
out[i] = d * (float)q0;
out[i + QK4_0/2] = d * (float)q1;
}
}
// View-aware AoS q4_0 -> f16 dequant (mirrors the q8_0 view variant).
kernel void kernel_dequant_q4_0_f16_view_aos(
global char * src,
ulong src_offset,
ulong src_nb1,
ulong src_nb2,
ulong src_nb3,
int nblk0,
int ne1,
int ne2,
int ne3,
global half * dst
) {
int blk_i0 = get_global_id(0);
int i1 = get_global_id(1);
int batch = get_global_id(2);
if (blk_i0 >= nblk0) return;
if (i1 >= ne1) return;
int i2 = batch % ne2;
int i3 = batch / ne2;
if (i3 >= ne3) return;
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK4_0/2);
float d = vload_half(0, (global half *)block);
global uchar * qs = (global uchar *)(block + 2);
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
global half * out = dst + (dst_row_base + blk_i0) * QK4_0;
for (int i = 0; i < QK4_0/2; ++i) {
uchar byte = qs[i];
int q0 = (int)(byte & 0x0F) - 8;
int q1 = (int)(byte >> 4) - 8;
out[i] = (half)(d * (float)q0);
out[i + QK4_0/2] = (half)(d * (float)q1);
}
}
kernel void kernel_restore_block_q8_0_trans(
global uchar * src_q,
global half * src_d,
+75 -40
View File
@@ -4,14 +4,26 @@
#define ACC_TYPE4 float4
#define DATA_TYPE half
#define DATA_TYPE4 half4
#define CONVERT_ACC4(x) convert_float4(x)
#define CONVERT_DATA4(x) convert_half4(x)
#define CONVERT_ACC4(x) ((float4)((float)(x).s0, (float)(x).s1, (float)(x).s2, (float)(x).s3))
#define CONVERT_DATA4(x) ((half4)((half)(x).s0, (half)(x).s1, (half)(x).s2, (half)(x).s3))
#define DK_VEC (DK/4)
#define DV_VEC (DV/4)
#define WG_SIZE (BLOCK_M)
#define Q1_WG_SIZE 64
// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs,
// infinite operand can cause undefined behavior and miscompilation for exp.
// Therefore, a large negative value is used instead.
#define FA_M_INIT (-3.0e38f)
// Drop full unroll at DK>=192 Adreno compiler host-memory budget.
#if DK >= 192
#define FA_UNROLL
#else
#define FA_UNROLL _Pragma("unroll")
#endif
inline float get_alibi_slope(
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
) {
@@ -81,18 +93,18 @@ __kernel void flash_attn_f16(
if (my_query_row < n_q) {
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
}
}
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = (ACC_TYPE4)(0.0f);
}
ACC_TYPE m_i = -INFINITY;
ACC_TYPE m_i = FA_M_INIT;
ACC_TYPE l_i = 0.0f;
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
@@ -125,49 +137,72 @@ __kernel void flash_attn_f16(
continue;
}
for (int j = 0; j < BLOCK_N; j += 2) {
for (int j = 0; j < BLOCK_N; j += 4) {
const int k_row0 = k_start + j;
const int k_row1 = k_start + j + 1;
const int k_row2 = k_start + j + 2;
const int k_row3 = k_start + j + 3;
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
#pragma unroll
ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f);
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
const ACC_TYPE4 qk = q_priv[k];
dot_acc0 = mad(qk, CONVERT_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(qk, CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
dot_acc2 = mad(qk, CONVERT_ACC4(l_k[j+2][k]), dot_acc2);
dot_acc3 = mad(qk, CONVERT_ACC4(l_k[j+3][k]), dot_acc3);
}
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
ACC_TYPE s0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE s1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
ACC_TYPE s2 = (dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3) * scale;
ACC_TYPE s3 = (dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3) * scale;
if (is_causal) {
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
const int causal_limit = n_kv - n_q + my_query_row;
if (k_row0 > causal_limit) s0 = FA_M_INIT;
if (k_row1 > causal_limit) s1 = FA_M_INIT;
if (k_row2 > causal_limit) s2 = FA_M_INIT;
if (k_row3 > causal_limit) s3 = FA_M_INIT;
}
if (k_row0 >= n_kv) score0 = -INFINITY;
if (k_row1 >= n_kv) score1 = -INFINITY;
if (k_row0 >= n_kv) s0 = FA_M_INIT;
if (k_row1 >= n_kv) s1 = FA_M_INIT;
if (k_row2 >= n_kv) s2 = FA_M_INIT;
if (k_row3 >= n_kv) s3 = FA_M_INIT;
if (mask_base != NULL) {
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
if (k_row0 < n_kv) s0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) s1 += slope * (ACC_TYPE)mask_ptr[k_row1];
if (k_row2 < n_kv) s2 += slope * (ACC_TYPE)mask_ptr[k_row2];
if (k_row3 < n_kv) s3 += slope * (ACC_TYPE)mask_ptr[k_row3];
}
if (logit_softcap > 0.0f) {
score0 = logit_softcap * tanh(score0 / logit_softcap);
score1 = logit_softcap * tanh(score1 / logit_softcap);
s0 = logit_softcap * tanh(s0 / logit_softcap);
s1 = logit_softcap * tanh(s1 / logit_softcap);
s2 = logit_softcap * tanh(s2 / logit_softcap);
s3 = logit_softcap * tanh(s3 / logit_softcap);
}
const ACC_TYPE m_new = max(m_i, max(score0, score1));
const ACC_TYPE p0 = exp(score0 - m_new);
const ACC_TYPE p1 = exp(score1 - m_new);
const ACC_TYPE scale_prev = exp(m_i - m_new);
const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3)));
const ACC_TYPE scale_prev = native_exp(m_i - m_new);
const ACC_TYPE p0 = native_exp(s0 - m_new);
const ACC_TYPE p1 = native_exp(s1 - m_new);
const ACC_TYPE p2 = native_exp(s2 - m_new);
const ACC_TYPE p3 = native_exp(s3 - m_new);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]);
o_acc[i] = mad(p3, CONVERT_ACC4(l_v[j+3][i]),
mad(p2, CONVERT_ACC4(l_v[j+2][i]),
mad(p1, CONVERT_ACC4(l_v[j+1][i]),
mad(p0, CONVERT_ACC4(l_v[j][i]),
o_acc[i] * scale_prev))));
}
l_i = l_i * scale_prev + p0 + p1;
l_i = l_i * scale_prev + p0 + p1 + p2 + p3;
m_i = m_new;
}
}
@@ -179,7 +214,7 @@ __kernel void flash_attn_f16(
const ACC_TYPE m_final = max(m_i, m_sink);
const ACC_TYPE scale_o = exp(m_i - m_final);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] *= scale_o;
}
@@ -191,12 +226,12 @@ __kernel void flash_attn_f16(
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_i;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv);
}
} else {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = (DATA_TYPE4)(0.0f);
}
@@ -258,7 +293,7 @@ __kernel void flash_attn_f16_q1(
ACC_TYPE4 q_priv[DK_VEC];
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
}
@@ -270,12 +305,12 @@ __kernel void flash_attn_f16_q1(
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
}
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
}
@@ -293,7 +328,7 @@ __kernel void flash_attn_f16_q1(
__local ACC_TYPE local_m[Q1_WG_SIZE];
local_m[tid] = m_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
barrier(CLK_LOCAL_MEM_FENCE);
@@ -301,7 +336,7 @@ __kernel void flash_attn_f16_q1(
const ACC_TYPE m_final = local_m[0];
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
ACC_TYPE l_i = 0.0f;
@@ -311,7 +346,7 @@ __kernel void flash_attn_f16_q1(
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
}
@@ -325,7 +360,7 @@ __kernel void flash_attn_f16_q1(
}
const ACC_TYPE p = exp(score - m_final);
l_i += p;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; i++) {
o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
}
@@ -335,7 +370,7 @@ __kernel void flash_attn_f16_q1(
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
local_l[tid] = l_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_l[tid] += local_l[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -354,7 +389,7 @@ __kernel void flash_attn_f16_q1(
for (int i = 0; i < DV_VEC; i++) {
local_o_comp[tid] = o_acc[i];
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -364,7 +399,7 @@ __kernel void flash_attn_f16_q1(
}
}
} else if (tid == 0) {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f);
}
}
+73 -38
View File
@@ -13,6 +13,18 @@
#define WG_SIZE (BLOCK_M)
#define Q1_WG_SIZE 64
// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs,
// infinite operand can cause undefined behavior and miscompilation for exp.
// Therefore, a large negative value is used instead.
#define FA_M_INIT (-3.0e38f)
// Drop full unroll at DK>=192 Adreno compiler host-memory budget.
#if DK >= 192
#define FA_UNROLL
#else
#define FA_UNROLL _Pragma("unroll")
#endif
inline float get_alibi_slope(
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
) {
@@ -82,18 +94,18 @@ __kernel void flash_attn_f32(
if (my_query_row < n_q) {
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
}
}
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = (ACC_TYPE4)(0.0f);
}
ACC_TYPE m_i = -INFINITY;
ACC_TYPE m_i = FA_M_INIT;
ACC_TYPE l_i = 0.0f;
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
@@ -126,49 +138,72 @@ __kernel void flash_attn_f32(
continue;
}
for (int j = 0; j < BLOCK_N; j += 2) {
for (int j = 0; j < BLOCK_N; j += 4) {
const int k_row0 = k_start + j;
const int k_row1 = k_start + j + 1;
const int k_row2 = k_start + j + 2;
const int k_row3 = k_start + j + 3;
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
#pragma unroll
ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f);
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
const ACC_TYPE4 qk = q_priv[k];
dot_acc0 = mad(qk, CONVERT_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(qk, CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
dot_acc2 = mad(qk, CONVERT_ACC4(l_k[j+2][k]), dot_acc2);
dot_acc3 = mad(qk, CONVERT_ACC4(l_k[j+3][k]), dot_acc3);
}
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
ACC_TYPE s0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE s1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
ACC_TYPE s2 = (dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3) * scale;
ACC_TYPE s3 = (dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3) * scale;
if (is_causal) {
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
const int causal_limit = n_kv - n_q + my_query_row;
if (k_row0 > causal_limit) s0 = FA_M_INIT;
if (k_row1 > causal_limit) s1 = FA_M_INIT;
if (k_row2 > causal_limit) s2 = FA_M_INIT;
if (k_row3 > causal_limit) s3 = FA_M_INIT;
}
if (k_row0 >= n_kv) score0 = -INFINITY;
if (k_row1 >= n_kv) score1 = -INFINITY;
if (k_row0 >= n_kv) s0 = FA_M_INIT;
if (k_row1 >= n_kv) s1 = FA_M_INIT;
if (k_row2 >= n_kv) s2 = FA_M_INIT;
if (k_row3 >= n_kv) s3 = FA_M_INIT;
if (mask_base != NULL) {
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
if (k_row0 < n_kv) s0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) s1 += slope * (ACC_TYPE)mask_ptr[k_row1];
if (k_row2 < n_kv) s2 += slope * (ACC_TYPE)mask_ptr[k_row2];
if (k_row3 < n_kv) s3 += slope * (ACC_TYPE)mask_ptr[k_row3];
}
if (logit_softcap > 0.0f) {
score0 = logit_softcap * tanh(score0 / logit_softcap);
score1 = logit_softcap * tanh(score1 / logit_softcap);
s0 = logit_softcap * tanh(s0 / logit_softcap);
s1 = logit_softcap * tanh(s1 / logit_softcap);
s2 = logit_softcap * tanh(s2 / logit_softcap);
s3 = logit_softcap * tanh(s3 / logit_softcap);
}
const ACC_TYPE m_new = max(m_i, max(score0, score1));
const ACC_TYPE p0 = exp(score0 - m_new);
const ACC_TYPE p1 = exp(score1 - m_new);
const ACC_TYPE scale_prev = exp(m_i - m_new);
const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3)));
const ACC_TYPE scale_prev = native_exp(m_i - m_new);
const ACC_TYPE p0 = native_exp(s0 - m_new);
const ACC_TYPE p1 = native_exp(s1 - m_new);
const ACC_TYPE p2 = native_exp(s2 - m_new);
const ACC_TYPE p3 = native_exp(s3 - m_new);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]);
o_acc[i] = mad(p3, CONVERT_ACC4(l_v[j+3][i]),
mad(p2, CONVERT_ACC4(l_v[j+2][i]),
mad(p1, CONVERT_ACC4(l_v[j+1][i]),
mad(p0, CONVERT_ACC4(l_v[j][i]),
o_acc[i] * scale_prev))));
}
l_i = l_i * scale_prev + p0 + p1;
l_i = l_i * scale_prev + p0 + p1 + p2 + p3;
m_i = m_new;
}
}
@@ -180,7 +215,7 @@ __kernel void flash_attn_f32(
const ACC_TYPE m_final = max(m_i, m_sink);
const ACC_TYPE scale_o = exp(m_i - m_final);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] *= scale_o;
}
@@ -192,12 +227,12 @@ __kernel void flash_attn_f32(
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_i;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv);
}
} else {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = (DATA_TYPE4)(0.0f);
}
@@ -259,7 +294,7 @@ __kernel void flash_attn_f32_q1(
ACC_TYPE4 q_priv[DK_VEC];
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
}
@@ -271,12 +306,12 @@ __kernel void flash_attn_f32_q1(
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
}
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
}
@@ -294,7 +329,7 @@ __kernel void flash_attn_f32_q1(
__local ACC_TYPE local_m[Q1_WG_SIZE];
local_m[tid] = m_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
barrier(CLK_LOCAL_MEM_FENCE);
@@ -302,7 +337,7 @@ __kernel void flash_attn_f32_q1(
const ACC_TYPE m_final = local_m[0];
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
ACC_TYPE l_i = 0.0f;
@@ -312,7 +347,7 @@ __kernel void flash_attn_f32_q1(
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
}
@@ -326,7 +361,7 @@ __kernel void flash_attn_f32_q1(
}
const ACC_TYPE p = exp(score - m_final);
l_i += p;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; i++) {
o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
}
@@ -336,7 +371,7 @@ __kernel void flash_attn_f32_q1(
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
local_l[tid] = l_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_l[tid] += local_l[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -355,7 +390,7 @@ __kernel void flash_attn_f32_q1(
for (int i = 0; i < DV_VEC; i++) {
local_o_comp[tid] = o_acc[i];
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -365,7 +400,7 @@ __kernel void flash_attn_f32_q1(
}
}
} else if (tid == 0) {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f);
}
}
@@ -1,5 +1,13 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_khr_subgroup_shuffle
#pragma OPENCL EXTENSION cl_khr_subgroup_shuffle : enable
#define HAS_SUBGROUP_SHUFFLE 1
#elif defined(cl_qcom_subgroup_shuffle)
#pragma OPENCL EXTENSION cl_qcom_subgroup_shuffle : enable
#define HAS_SUBGROUP_SHUFFLE 1
#endif
#define ACC_TYPE float
#define ACC_TYPE4 float4
#define Q_DATA_TYPE4 float4
@@ -12,9 +20,34 @@
#define DK_VEC (DK/4)
#define DV_VEC (DV/4)
#define WG_SIZE (BLOCK_M)
#define Q1_WG_SIZE 64
// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs,
// infinite operand can cause undefined behavior and miscompilation for exp.
// Therefore, a large negative value is used instead.
#define FA_M_INIT (-3.0e38f)
// Drop full unroll at DK>=192 Adreno compiler host-memory budget.
#if DK >= 192
#define FA_UNROLL
#else
#define FA_UNROLL _Pragma("unroll")
#endif
// N_SPLIT>1 splits DK/DV across threads to cut per-thread register use.
#ifndef N_SPLIT
#define N_SPLIT 1
#endif
#define SPLIT_DK_VEC (DK_VEC / N_SPLIT)
#define SPLIT_DV_VEC (DV_VEC / N_SPLIT)
#if N_SPLIT > 1
#define WG_SIZE (BLOCK_M * N_SPLIT)
#else
#define WG_SIZE (BLOCK_M)
#endif
inline float get_alibi_slope(
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
) {
@@ -54,19 +87,38 @@ __kernel void flash_attn_f32_f16(
const int mask_ne2,
const int mask_ne3,
const global void* sinks_void,
const ulong sinks_offset
const ulong sinks_offset,
const global void * k_pad_void,
const global void * v_pad_void,
const global void * mask_pad_void,
const global char * blk,
const int n_kv_blocks,
const ulong mask_pad_nb1,
const ulong mask_pad_nb2,
const ulong mask_pad_nb3
) {
const int tid = get_local_id(0);
const int block_q_idx = get_group_id(0);
const int head_batch_idx = get_global_id(1);
const int my_query_row = block_q_idx * BLOCK_M + tid;
#if N_SPLIT > 1
const int q_lane = tid / N_SPLIT;
const int split_idx = tid % N_SPLIT;
#else
const int q_lane = tid;
const int split_idx = 0;
#endif
const int my_query_row = block_q_idx * BLOCK_M + q_lane;
const int query_valid = my_query_row < n_q;
const int batch_idx = head_batch_idx / n_head;
const int head_idx = head_batch_idx % n_head;
const int gqa_ratio = n_head / n_head_kv;
const int head_kv_idx = head_idx / gqa_ratio;
const int mask_head_idx = mask_void != NULL ? head_idx % mask_ne2 : 0;
const int mask_batch_idx = mask_void != NULL ? batch_idx % mask_ne3 : 0;
const global char* q_base = (const global char*)q_void + q_offset;
const global char* k_base = (const global char*)k_void + k_offset;
@@ -75,27 +127,41 @@ __kernel void flash_attn_f32_f16(
const global char* mask_base = NULL;
if (mask_void != NULL) {
const int mask_head_idx = head_idx % mask_ne2;
const int mask_batch_idx = batch_idx % mask_ne3;
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
}
const global char* mask_pad_base = NULL;
if (mask_pad_void != NULL) {
mask_pad_base = (const global char*)mask_pad_void + mask_batch_idx * mask_pad_nb3 + mask_head_idx * mask_pad_nb2;
}
const global char* blk_base = NULL;
if (blk != NULL) {
const int n_q_blocks = (n_q + BLOCK_M - 1) / BLOCK_M;
blk_base = blk + (((mask_batch_idx * mask_ne2) + mask_head_idx) * n_q_blocks + block_q_idx) * n_kv_blocks;
}
ACC_TYPE4 q_priv[DK_VEC];
if (my_query_row < n_q) {
ACC_TYPE4 q_priv[SPLIT_DK_VEC];
const int dk_off = split_idx * SPLIT_DK_VEC;
if (query_valid) {
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);
FA_UNROLL
for (int i = 0; i < SPLIT_DK_VEC; ++i) {
q_priv[i] = CONVERT_Q_ACC4(q_ptr[dk_off + i]);
}
} else {
FA_UNROLL
for (int i = 0; i < SPLIT_DK_VEC; ++i) {
q_priv[i] = (ACC_TYPE4)(0.0f);
}
}
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
ACC_TYPE4 o_acc[SPLIT_DV_VEC];
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_acc[i] = (ACC_TYPE4)(0.0f);
}
ACC_TYPE m_i = -INFINITY;
ACC_TYPE m_i = FA_M_INIT;
ACC_TYPE l_i = 0.0f;
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
@@ -103,86 +169,369 @@ __kernel void flash_attn_f32_f16(
__local KV_DATA_TYPE4 l_k[BLOCK_N][DK_VEC];
__local KV_DATA_TYPE4 l_v[BLOCK_N][DV_VEC];
#if N_SPLIT > 1 && !defined(HAS_SUBGROUP_SHUFFLE)
__local ACC_TYPE local_partial[BLOCK_N][WG_SIZE];
__local ACC_TYPE local_p[BLOCK_M][BLOCK_N];
__local ACC_TYPE local_softmax_scale[BLOCK_M];
__local ACC_TYPE local_l_inv[BLOCK_M];
#endif
for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) {
char blk_cur = 1;
if (blk_base != NULL) {
blk_cur = blk_base[k_start / BLOCK_N];
if (blk_cur == 0) continue;
}
const int use_kv_pad = k_pad_void != NULL && k_start + BLOCK_N > n_kv;
const int k_tile_start = use_kv_pad ? 0 : k_start;
const ulong k_tile_nb2 = use_kv_pad ? (ulong) BLOCK_N * k_nb1 : k_nb2;
const ulong k_tile_nb3 = use_kv_pad ? (ulong) n_head_kv * k_tile_nb2 : k_nb3;
const ulong v_tile_nb2 = use_kv_pad ? (ulong) BLOCK_N * v_nb1 : v_nb2;
const ulong v_tile_nb3 = use_kv_pad ? (ulong) n_head_kv * v_tile_nb2 : v_nb3;
const global char* k_tile_base = use_kv_pad ? (const global char*) k_pad_void : k_base;
const global char* v_tile_base = use_kv_pad ? (const global char*) v_pad_void : v_base;
for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) {
const int row = i / DK_VEC;
const int col = i % DK_VEC;
const int k_row_idx = k_start + row;
if (k_row_idx < n_kv) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1;
l_k[row][col] = ((__global KV_DATA_TYPE4*)(k_base + k_row_offset))[col];
const int k_row_idx = k_tile_start + row;
if (use_kv_pad || k_row_idx < n_kv) {
const ulong k_row_offset = batch_idx * k_tile_nb3 + head_kv_idx * k_tile_nb2 + k_row_idx * k_nb1;
l_k[row][col] = ((__global KV_DATA_TYPE4*)(k_tile_base + k_row_offset))[col];
} else {
l_k[row][col] = (KV_DATA_TYPE4)(0.0h);
}
}
for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) {
const int row = i / DV_VEC;
const int col = i % DV_VEC;
const int v_row_idx = k_start + row;
if (v_row_idx < n_kv) {
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1;
l_v[row][col] = ((__global KV_DATA_TYPE4*)(v_base + v_row_offset))[col];
const int v_row_idx = k_tile_start + row;
if (use_kv_pad || v_row_idx < n_kv) {
const ulong v_row_offset = batch_idx * v_tile_nb3 + head_kv_idx * v_tile_nb2 + v_row_idx * v_nb1;
l_v[row][col] = ((__global KV_DATA_TYPE4*)(v_tile_base + v_row_offset))[col];
} else {
l_v[row][col] = (KV_DATA_TYPE4)(0.0h);
}
}
barrier(CLK_LOCAL_MEM_FENCE);
if (my_query_row >= n_q) {
continue;
#if N_SPLIT > 1 && defined(HAS_SUBGROUP_SHUFFLE)
{
const int dv_off = split_idx * SPLIT_DV_VEC;
for (int j = 0; j < BLOCK_N; j += 2) {
const int k_row0 = k_start + j;
const int k_row1 = k_start + j + 1;
ACC_TYPE partial0 = 0.0f;
ACC_TYPE partial1 = 0.0f;
FA_UNROLL
for (int k = 0; k < SPLIT_DK_VEC; k++) {
const ACC_TYPE4 qk = q_priv[k];
ACC_TYPE4 dot0 = qk * CONVERT_KV_ACC4(l_k[j ][dk_off + k]);
ACC_TYPE4 dot1 = qk * CONVERT_KV_ACC4(l_k[j+1][dk_off + k]);
partial0 += dot0.s0 + dot0.s1 + dot0.s2 + dot0.s3;
partial1 += dot1.s0 + dot1.s1 + dot1.s2 + dot1.s3;
}
FA_UNROLL
for (int step = 1; step < N_SPLIT; step <<= 1) {
partial0 += sub_group_shuffle_xor(partial0, step);
partial1 += sub_group_shuffle_xor(partial1, step);
}
ACC_TYPE score0 = partial0 * scale;
ACC_TYPE score1 = partial1 * scale;
if (!query_valid) { score0 = FA_M_INIT; score1 = FA_M_INIT; }
if (is_causal) {
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = FA_M_INIT;
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = FA_M_INIT;
}
if (k_row0 >= n_kv) score0 = FA_M_INIT;
if (k_row1 >= n_kv) score1 = FA_M_INIT;
if (query_valid && mask_base != NULL && blk_cur != 2) {
if (use_kv_pad && mask_pad_base != NULL) {
const global MASK_DATA_TYPE* mask_ptr =
(const global MASK_DATA_TYPE*)(mask_pad_base + my_query_row * mask_pad_nb1);
score0 += slope * (ACC_TYPE)mask_ptr[j];
score1 += slope * (ACC_TYPE)mask_ptr[j + 1];
} else {
const global MASK_DATA_TYPE* mask_ptr =
(const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
}
}
if (logit_softcap > 0.0f) {
score0 = logit_softcap * tanh(score0 / logit_softcap);
score1 = logit_softcap * tanh(score1 / logit_softcap);
}
const ACC_TYPE m_new = max(m_i, max(score0, score1));
// Whole tile masked (m_new == FA_M_INIT): force the exp() args
// far negative so the tile contributes 0, not exp(0)=1.
const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new;
const ACC_TYPE sp = native_exp(m_i - m_exp);
const ACC_TYPE p0 = native_exp(score0 - m_exp);
const ACC_TYPE p1 = native_exp(score1 - m_exp);
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_acc[i] = o_acc[i] * sp
+ p0 * CONVERT_KV_ACC4(l_v[j ][dv_off + i])
+ p1 * CONVERT_KV_ACC4(l_v[j+1][dv_off + i]);
}
l_i = l_i * sp + p0 + p1;
m_i = m_new;
}
}
for (int j = 0; j < BLOCK_N; j += 2) {
const int k_row0 = k_start + j;
const int k_row1 = k_start + j + 1;
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
#pragma unroll
for (int k = 0; k < DK_VEC; k++) {
dot_acc0 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1);
#elif N_SPLIT > 1
// N_SPLIT>1 fallback (no shuffle): 3-phase local-memory reduction.
// Phase 1 partial dots for all BLOCK_N tokens.
for (int j = 0; j < BLOCK_N; ++j) {
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
FA_UNROLL
for (int k = 0; k < SPLIT_DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j][dk_off + k]), dot_acc);
}
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
if (is_causal) {
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
}
if (k_row0 >= n_kv) score0 = -INFINITY;
if (k_row1 >= n_kv) score1 = -INFINITY;
if (mask_base != NULL) {
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
}
if (logit_softcap > 0.0f) {
score0 = logit_softcap * tanh(score0 / logit_softcap);
score1 = logit_softcap * tanh(score1 / logit_softcap);
}
const ACC_TYPE m_new = max(m_i, max(score0, score1));
const ACC_TYPE p0 = exp(score0 - m_new);
const ACC_TYPE p1 = exp(score1 - m_new);
const ACC_TYPE scale_prev = exp(m_i - m_new);
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_KV_ACC4(l_v[j][i]) + p1 * CONVERT_KV_ACC4(l_v[j+1][i]);
}
l_i = l_i * scale_prev + p0 + p1;
m_i = m_new;
local_partial[j][tid] =
dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3;
}
barrier(CLK_LOCAL_MEM_FENCE); // 1 barrier: partial dots visible
// Phase 2 split_idx==0 reduces partial sums and computes block softmax.
if (split_idx == 0) {
if (query_valid) {
ACC_TYPE m_new = m_i;
for (int j = 0; j < BLOCK_N; ++j) {
const int k_row = k_start + j;
ACC_TYPE score = 0.0f;
FA_UNROLL
for (int s = 0; s < N_SPLIT; s++) {
score += local_partial[j][q_lane * N_SPLIT + s];
}
score *= scale;
if (is_causal && k_row > (n_kv - n_q + my_query_row)) score = FA_M_INIT;
if (k_row >= n_kv) score = FA_M_INIT;
if (mask_base != NULL && blk_cur != 2) {
if (use_kv_pad && mask_pad_base != NULL) {
const global MASK_DATA_TYPE* mask_ptr =
(const global MASK_DATA_TYPE*)(mask_pad_base + my_query_row * mask_pad_nb1);
score += slope * (ACC_TYPE)mask_ptr[j];
} else {
const global MASK_DATA_TYPE* mask_ptr =
(const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row < n_kv) score += slope * (ACC_TYPE)mask_ptr[k_row];
}
}
if (logit_softcap > 0.0f) {
score = logit_softcap * tanh(score / logit_softcap);
}
m_new = max(m_new, score);
local_p[q_lane][j] = score;
}
const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new;
const ACC_TYPE sp = native_exp(m_i - m_exp);
ACC_TYPE l_new = l_i * sp;
for (int j = 0; j < BLOCK_N; ++j) {
const ACC_TYPE p = native_exp(local_p[q_lane][j] - m_exp);
local_p[q_lane][j] = p;
l_new += p;
}
local_softmax_scale[q_lane] = sp;
l_i = l_new;
m_i = m_new;
} else {
local_softmax_scale[q_lane] = 1.0f;
for (int j = 0; j < BLOCK_N; ++j) local_p[q_lane][j] = 0.0f;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
// Phase 3 V accumulate using broadcast probabilities.
{
const ACC_TYPE sp_block = local_softmax_scale[q_lane];
const int dv_off = split_idx * SPLIT_DV_VEC;
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_acc[i] *= sp_block;
}
for (int j = 0; j < BLOCK_N; ++j) {
const ACC_TYPE p = local_p[q_lane][j];
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_acc[i] = mad(p, CONVERT_KV_ACC4(l_v[j][dv_off + i]), o_acc[i]);
}
}
}
#else
// N_SPLIT==1: j+=4 unroll. Requires BLOCK_N % 4 == 0.
if (query_valid) {
for (int j = 0; j < BLOCK_N; j += 4) {
const int k_row0 = k_start + j;
const int k_row1 = k_start + j + 1;
const int k_row2 = k_start + j + 2;
const int k_row3 = k_start + j + 3;
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f);
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
const ACC_TYPE4 qk = q_priv[k];
dot_acc0 = mad(qk, CONVERT_KV_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(qk, CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1);
dot_acc2 = mad(qk, CONVERT_KV_ACC4(l_k[j+2][k]), dot_acc2);
dot_acc3 = mad(qk, CONVERT_KV_ACC4(l_k[j+3][k]), dot_acc3);
}
ACC_TYPE s0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE s1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
ACC_TYPE s2 = (dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3) * scale;
ACC_TYPE s3 = (dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3) * scale;
if (is_causal) {
const int causal_limit = n_kv - n_q + my_query_row;
if (k_row0 > causal_limit) s0 = FA_M_INIT;
if (k_row1 > causal_limit) s1 = FA_M_INIT;
if (k_row2 > causal_limit) s2 = FA_M_INIT;
if (k_row3 > causal_limit) s3 = FA_M_INIT;
}
if (k_row0 >= n_kv) s0 = FA_M_INIT;
if (k_row1 >= n_kv) s1 = FA_M_INIT;
if (k_row2 >= n_kv) s2 = FA_M_INIT;
if (k_row3 >= n_kv) s3 = FA_M_INIT;
if (mask_base != NULL && blk_cur != 2) {
if (use_kv_pad && mask_pad_base != NULL) {
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_pad_base + my_query_row * mask_pad_nb1);
s0 += slope * (ACC_TYPE)mask_ptr[j];
s1 += slope * (ACC_TYPE)mask_ptr[j + 1];
s2 += slope * (ACC_TYPE)mask_ptr[j + 2];
s3 += slope * (ACC_TYPE)mask_ptr[j + 3];
} else {
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) s0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) s1 += slope * (ACC_TYPE)mask_ptr[k_row1];
if (k_row2 < n_kv) s2 += slope * (ACC_TYPE)mask_ptr[k_row2];
if (k_row3 < n_kv) s3 += slope * (ACC_TYPE)mask_ptr[k_row3];
}
}
if (logit_softcap > 0.0f) {
s0 = logit_softcap * tanh(s0 / logit_softcap);
s1 = logit_softcap * tanh(s1 / logit_softcap);
s2 = logit_softcap * tanh(s2 / logit_softcap);
s3 = logit_softcap * tanh(s3 / logit_softcap);
}
const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3)));
// Whole tile masked (m_new == FA_M_INIT): force the exp() args
// far negative so the tile contributes 0, not exp(0)=1.
const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new;
const ACC_TYPE scale_prev = native_exp(m_i - m_exp);
const ACC_TYPE p0 = native_exp(s0 - m_exp);
const ACC_TYPE p1 = native_exp(s1 - m_exp);
const ACC_TYPE p2 = native_exp(s2 - m_exp);
const ACC_TYPE p3 = native_exp(s3 - m_exp);
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = mad(p3, CONVERT_KV_ACC4(l_v[j+3][i]),
mad(p2, CONVERT_KV_ACC4(l_v[j+2][i]),
mad(p1, CONVERT_KV_ACC4(l_v[j+1][i]),
mad(p0, CONVERT_KV_ACC4(l_v[j][i]),
o_acc[i] * scale_prev))));
}
l_i = l_i * scale_prev + p0 + p1 + p2 + p3;
m_i = m_new;
}
}
#endif
// End of tile: every thread must finish reading l_k/l_v before the
// next iteration's load overwrites them (WAR hazard on local memory).
barrier(CLK_LOCAL_MEM_FENCE);
}
if (my_query_row < n_q) {
// Write output.
#if N_SPLIT > 1 && defined(HAS_SUBGROUP_SHUFFLE)
if (query_valid) {
ACC_TYPE sinks_sp = 1.0f;
if (sinks_void != NULL) {
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
const ACC_TYPE m_sink = sinks_ptr[head_idx];
const ACC_TYPE m_final = max(m_i, m_sink);
sinks_sp = exp(m_i - m_final);
l_i = l_i * sinks_sp + exp(m_sink - m_final);
m_i = m_final;
}
const ACC_TYPE l_inv = (l_i > 0.0f) ? (1.0f / l_i) : 0.0f;
const int dv_off = split_idx * SPLIT_DV_VEC;
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
if (l_inv > 0.0f) {
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_row[dv_off + i] = CONVERT_O_DATA4(o_acc[i] * sinks_sp * l_inv);
}
} else {
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_row[dv_off + i] = (O_DATA_TYPE4)(0.0f);
}
}
}
#elif N_SPLIT > 1
if (split_idx == 0) {
ACC_TYPE sinks_sp = 1.0f;
if (query_valid && sinks_void != NULL) {
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
const ACC_TYPE m_sink = sinks_ptr[head_idx];
const ACC_TYPE m_final = max(m_i, m_sink);
sinks_sp = exp(m_i - m_final);
l_i = l_i * sinks_sp + exp(m_sink - m_final);
m_i = m_final;
}
local_softmax_scale[q_lane] = sinks_sp;
local_l_inv[q_lane] = (query_valid && l_i > 0.0f) ? (1.0f / l_i) : 0.0f;
}
barrier(CLK_LOCAL_MEM_FENCE);
if (query_valid) {
const ACC_TYPE sinks_sp = local_softmax_scale[q_lane];
const ACC_TYPE l_inv = local_l_inv[q_lane];
const int dv_off = split_idx * SPLIT_DV_VEC;
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
if (l_inv > 0.0f) {
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_row[dv_off + i] = CONVERT_O_DATA4(o_acc[i] * sinks_sp * l_inv);
}
} else {
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_row[dv_off + i] = (O_DATA_TYPE4)(0.0f);
}
}
}
#else
if (query_valid) {
if (sinks_void != NULL) {
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
const ACC_TYPE m_sink = sinks_ptr[head_idx];
const ACC_TYPE m_final = max(m_i, m_sink);
const ACC_TYPE scale_o = exp(m_i - m_final);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] *= scale_o;
}
@@ -194,17 +543,18 @@ __kernel void flash_attn_f32_f16(
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_i;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = CONVERT_O_DATA4(o_acc[i] * l_inv);
}
} else {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = (O_DATA_TYPE4)(0.0f);
}
}
}
#endif
}
__kernel void flash_attn_f32_f16_q1(
@@ -258,13 +608,16 @@ __kernel void flash_attn_f32_f16_q1(
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
}
ACC_TYPE4 q_priv[DK_VEC];
// Q is uniform across WG threads (n_q=1). Share via local memory to
// avoid per-thread q_priv[DK_VEC] dynamic-indexed private array that
// spills to DDR on Adreno.
__local ACC_TYPE4 q_shared[DK_VEC];
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);
for (int i = tid; i < DK_VEC; i += Q1_WG_SIZE) {
q_shared[i] = CONVERT_Q_ACC4(q_ptr[i]);
}
barrier(CLK_LOCAL_MEM_FENCE);
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
@@ -273,14 +626,14 @@ __kernel void flash_attn_f32_f16_q1(
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
}
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
@@ -296,7 +649,7 @@ __kernel void flash_attn_f32_f16_q1(
__local ACC_TYPE local_m[Q1_WG_SIZE];
local_m[tid] = m_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
barrier(CLK_LOCAL_MEM_FENCE);
@@ -304,7 +657,7 @@ __kernel void flash_attn_f32_f16_q1(
const ACC_TYPE m_final = local_m[0];
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
ACC_TYPE l_i = 0.0f;
@@ -314,9 +667,9 @@ __kernel void flash_attn_f32_f16_q1(
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
const global KV_DATA_TYPE4* v_ptr = (const global KV_DATA_TYPE4*)(v_base + v_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
@@ -328,7 +681,7 @@ __kernel void flash_attn_f32_f16_q1(
}
const ACC_TYPE p = exp(score - m_final);
l_i += p;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; i++) {
o_acc[i] = mad(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]);
}
@@ -338,7 +691,7 @@ __kernel void flash_attn_f32_f16_q1(
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
local_l[tid] = l_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_l[tid] += local_l[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -357,7 +710,7 @@ __kernel void flash_attn_f32_f16_q1(
for (int i = 0; i < DV_VEC; i++) {
local_o_comp[tid] = o_acc[i];
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -367,7 +720,257 @@ __kernel void flash_attn_f32_f16_q1(
}
}
} else if (tid == 0) {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (O_DATA_TYPE4)(0.0f);
}
}
// Flash-decoding split pass. gid(2) = q_idx * n_splits + split_idx.
// Partial record per split: [m, l, O[DV]]. Merge kernel applies sink + norm.
#define FA_PARTIAL_FLOATS (2 + DV)
__kernel void flash_attn_f32_f16_q1_split(
const global void * q_void, ulong q_offset,
const global void * k_void, ulong k_offset,
const global void * v_void, ulong v_offset,
const float scale,
const int n_q,
const int n_kv,
const int n_head,
const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
const float max_bias,
const float m0,
const float m1,
const int n_head_log2,
const float logit_softcap,
const int n_head_kv,
const global void * mask_void,
const ulong mask_offset,
const ulong mask_nb1,
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
const int mask_ne3,
global float * partial_void,
const int n_splits,
const int kv_per_split
) {
const int tid = get_local_id(0);
const int head_batch_idx = get_global_id(1);
const int split_q_idx = get_global_id(2);
const int split_idx = split_q_idx % n_splits;
const int q_idx = split_q_idx / n_splits;
const int batch_idx = head_batch_idx / n_head;
const int head_idx = head_batch_idx % n_head;
const int gqa_ratio = n_head / n_head_kv;
const int head_kv_idx = head_idx / gqa_ratio;
const int kv_start = split_idx * kv_per_split;
const int kv_end = min(kv_start + kv_per_split, n_kv);
const ulong record_stride = (ulong) FA_PARTIAL_FLOATS;
const ulong record_idx = ((((ulong) batch_idx * n_head + head_idx) * n_q + q_idx)
* n_splits + split_idx);
global float * rec = partial_void + record_idx * record_stride;
global float4 * rec_o = (global float4 *) (rec + 2);
if (kv_start >= kv_end) {
// Empty split: leave sentinel partial for merge.
if (tid == 0) {
rec[0] = FA_M_INIT;
rec[1] = 0.0f;
}
return;
}
const global char * q_base = (const global char *) q_void + q_offset;
const global char * k_base = (const global char *) k_void + k_offset;
const global char * v_base = (const global char *) v_void + v_offset;
const global char * mask_base = NULL;
if (mask_void != NULL) {
const int mask_head_idx = head_idx % mask_ne2;
const int mask_batch_idx = batch_idx % mask_ne3;
mask_base = (const global char *) mask_void + mask_offset +
mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2 +
(ulong) q_idx * mask_nb1;
}
// Share Q via local memory (n_q=1 per split -> uniform across WG).
__local ACC_TYPE4 q_shared[DK_VEC];
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + (ulong) q_idx * q_nb1;
const global Q_DATA_TYPE4 * q_ptr = (const global Q_DATA_TYPE4 *) (q_base + q_row_offset);
for (int i = tid; i < DK_VEC; i += Q1_WG_SIZE) {
q_shared[i] = CONVERT_Q_ACC4(q_ptr[i]);
}
barrier(CLK_LOCAL_MEM_FENCE);
const float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
// Pass 1a split-local max.
ACC_TYPE m_i = FA_M_INIT;
for (int k_idx = kv_start + tid; k_idx < kv_end; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global KV_DATA_TYPE4 * k_ptr = (const global KV_DATA_TYPE4 *) (k_base + k_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
for (int k = 0; k < DK_VEC; ++k) {
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
const global MASK_DATA_TYPE * mask_ptr = (const global MASK_DATA_TYPE *) (mask_base);
score += slope * (ACC_TYPE) mask_ptr[k_idx];
}
if (logit_softcap > 0.0f) {
score = logit_softcap * tanh(score / logit_softcap);
}
m_i = max(m_i, score);
}
__local ACC_TYPE local_m[Q1_WG_SIZE];
local_m[tid] = m_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
barrier(CLK_LOCAL_MEM_FENCE);
}
const ACC_TYPE m_c = local_m[0];
// Pass 1b softmax-weighted V accumulate.
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
ACC_TYPE l_i = 0.0f;
for (int k_idx = kv_start + tid; k_idx < kv_end; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1;
const global KV_DATA_TYPE4 * k_ptr = (const global KV_DATA_TYPE4 *) (k_base + k_row_offset);
const global KV_DATA_TYPE4 * v_ptr = (const global KV_DATA_TYPE4 *) (v_base + v_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
for (int k = 0; k < DK_VEC; ++k) {
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
const global MASK_DATA_TYPE * mask_ptr = (const global MASK_DATA_TYPE *) (mask_base);
score += slope * (ACC_TYPE) mask_ptr[k_idx];
}
if (logit_softcap > 0.0f) {
score = logit_softcap * tanh(score / logit_softcap);
}
const ACC_TYPE p = exp(score - m_c);
l_i += p;
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = mad(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]);
}
}
__local ACC_TYPE local_l[Q1_WG_SIZE];
__local ACC_TYPE4 local_o[Q1_WG_SIZE];
local_l[tid] = l_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_l[tid] += local_l[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
}
const ACC_TYPE l_c = local_l[0];
if (tid == 0) {
rec[0] = (float) m_c;
rec[1] = (float) l_c;
}
for (int i = 0; i < DV_VEC; ++i) {
local_o[tid] = o_acc[i];
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_o[tid] += local_o[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
}
if (tid == 0) {
rec_o[i] = local_o[0];
}
}
}
// FD Pass 2: merge per-split partials into final O. Empty splits drop via exp(-INF)=0.
__kernel void flash_attn_f32_merge(
const global float * partial_void,
global void * o_void,
const ulong o_offset,
const int n_head,
const int n_splits,
const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
const global void * sinks_void,
const ulong sinks_offset,
const int n_q
) {
const int lane = get_local_id(0); // 0..DV_VEC-1
const int head_batch_idx = get_global_id(1);
const int q_idx = get_global_id(2);
const int batch_idx = head_batch_idx / n_head;
const int head_idx = head_batch_idx % n_head;
const ulong record_stride = (ulong) FA_PARTIAL_FLOATS;
const ulong record_idx_0 = (((ulong) batch_idx * n_head + head_idx) * n_q + q_idx) * n_splits;
const global float * rec0 = partial_void + record_idx_0 * record_stride;
__local ACC_TYPE m_final_shared;
__local ACC_TYPE l_final_shared;
if (lane == 0) {
ACC_TYPE m = FA_M_INIT;
for (int c = 0; c < n_splits; ++c) {
const ACC_TYPE m_c = rec0[c * record_stride + 0];
m = max(m, m_c);
}
ACC_TYPE m_sink = 0.0f;
bool has_sink = false;
if (sinks_void != NULL) {
const global ACC_TYPE * sinks_ptr =
(const global ACC_TYPE *) ((const global char *) sinks_void + sinks_offset);
m_sink = sinks_ptr[head_idx];
has_sink = true;
m = max(m, m_sink);
}
ACC_TYPE l = 0.0f;
for (int c = 0; c < n_splits; ++c) {
const ACC_TYPE m_c = rec0[c * record_stride + 0];
const ACC_TYPE l_c = rec0[c * record_stride + 1];
if (m_c > FA_M_INIT) {
l += l_c * exp(m_c - m);
}
}
if (has_sink) {
l += exp(m_sink - m);
}
m_final_shared = m;
l_final_shared = l;
}
barrier(CLK_LOCAL_MEM_FENCE);
const ACC_TYPE m_final = m_final_shared;
const ACC_TYPE l_final = l_final_shared;
const ACC_TYPE l_inv = (l_final > 0.0f) ? (1.0f / l_final) : 0.0f;
ACC_TYPE4 o = (ACC_TYPE4)(0.0f);
for (int c = 0; c < n_splits; ++c) {
const global float * rec_c = rec0 + c * record_stride;
const ACC_TYPE m_c = rec_c[0];
if (m_c <= FA_M_INIT) continue;
const global float4 * rec_oc = (const global float4 *) (rec_c + 2);
const ACC_TYPE scale_c = exp(m_c - m_final);
o = mad((ACC_TYPE4)(scale_c), rec_oc[lane], o);
}
o = o * l_inv;
const ulong o_row_offset = (ulong) batch_idx * o_nb3 + (ulong) q_idx * o_nb2 + (ulong) head_idx * o_nb1;
global O_DATA_TYPE4 * o_row = (global O_DATA_TYPE4 *) ((global char *) o_void + o_offset + o_row_offset);
o_row[lane] = CONVERT_O_DATA4(o);
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,156 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void flash_attn_kv_pad_f16(
const global void * k_void, ulong k_offset,
const global void * v_void, ulong v_offset,
global void * k_pad_void,
global void * v_pad_void,
const int n_kv,
const int n_head_kv,
const int n_batch,
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3
) {
const int row_idx = get_global_id(0);
const int head_kv_idx = get_global_id(1);
const int batch_idx = get_global_id(2);
if (row_idx >= BLOCK_N || head_kv_idx >= n_head_kv || batch_idx >= n_batch) {
return;
}
const int tail_start = n_kv - (n_kv % BLOCK_N);
const int src_row_idx = tail_start + row_idx;
const global char * k_src = (const global char *) k_void + k_offset;
const global char * v_src = (const global char *) v_void + v_offset;
global char * k_pad = (global char *) k_pad_void;
global char * v_pad = (global char *) v_pad_void;
const ulong k_dst_offset = ((ulong) batch_idx * (ulong) n_head_kv + (ulong) head_kv_idx) * ((ulong) BLOCK_N * k_nb1) + (ulong) row_idx * k_nb1;
const ulong v_dst_offset = ((ulong) batch_idx * (ulong) n_head_kv + (ulong) head_kv_idx) * ((ulong) BLOCK_N * v_nb1) + (ulong) row_idx * v_nb1;
if (src_row_idx < n_kv) {
const ulong k_src_offset = (ulong) batch_idx * k_nb3 + (ulong) head_kv_idx * k_nb2 + (ulong) src_row_idx * k_nb1;
const ulong v_src_offset = (ulong) batch_idx * v_nb3 + (ulong) head_kv_idx * v_nb2 + (ulong) src_row_idx * v_nb1;
for (ulong i = 0; i < k_nb1; ++i) {
k_pad[k_dst_offset + i] = k_src[k_src_offset + i];
}
for (ulong i = 0; i < v_nb1; ++i) {
v_pad[v_dst_offset + i] = v_src[v_src_offset + i];
}
} else {
for (ulong i = 0; i < k_nb1; ++i) {
k_pad[k_dst_offset + i] = 0;
}
for (ulong i = 0; i < v_nb1; ++i) {
v_pad[v_dst_offset + i] = 0;
}
}
}
__kernel void flash_attn_mask_pad_f16(
const global void * mask_void, ulong mask_offset,
global void * mask_pad_void,
const int n_q,
const int n_kv,
const ulong mask_nb1,
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
const int mask_ne3
) {
const int col_idx = get_global_id(0);
const int q_row = get_global_id(1);
const int mask_slice = get_global_id(2);
if (col_idx >= BLOCK_N || q_row >= n_q || mask_slice >= mask_ne2 * mask_ne3) {
return;
}
const int tail_start = n_kv - (n_kv % BLOCK_N);
const int src_col_idx = tail_start + col_idx;
const int mask_head_idx = mask_slice % mask_ne2;
const int mask_batch_idx = mask_slice / mask_ne2;
const global char * mask_src_base = (const global char *) mask_void + mask_offset +
(ulong) mask_batch_idx * mask_nb3 +
(ulong) mask_head_idx * mask_nb2 +
(ulong) q_row * mask_nb1;
const global half * mask_src = (const global half *) mask_src_base;
global half * mask_pad = (global half *) mask_pad_void;
const ulong dst_idx =
(((ulong) mask_batch_idx * (ulong) mask_ne2 + (ulong) mask_head_idx) * (ulong) n_q + (ulong) q_row) * (ulong) BLOCK_N +
(ulong) col_idx;
mask_pad[dst_idx] = src_col_idx < n_kv ? mask_src[src_col_idx] : (half) (-INFINITY);
}
// Per-KV-tile mask class. 0=all -inf (skip tile), 1=mixed (apply mask),
// 2=all zero, no -inf (skip mask lookup). Causal diagonal tiles are class 1.
__kernel void flash_attn_blk_f16(
const global void * mask_void, ulong mask_offset,
global char * blk,
const int n_q,
const int n_kv,
const ulong mask_nb1,
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
const int mask_ne3
) {
const int kv_block_idx = get_global_id(0);
const int q_block_idx = get_global_id(1);
const int mask_slice = get_global_id(2);
const int n_q_blocks = (n_q + BLOCK_M - 1) / BLOCK_M;
const int n_kv_blocks = (n_kv + BLOCK_N - 1) / BLOCK_N;
if (kv_block_idx >= n_kv_blocks || q_block_idx >= n_q_blocks || mask_slice >= mask_ne2 * mask_ne3) {
return;
}
const int mask_head_idx = mask_slice % mask_ne2;
const int mask_batch_idx = mask_slice / mask_ne2;
const int q_start = q_block_idx * BLOCK_M;
const int k_start = kv_block_idx * BLOCK_N;
const int q_count = min(BLOCK_M, n_q - q_start);
const int k_count = min(BLOCK_N, n_kv - k_start);
const half neg_max_half = (half) (-65504.0f);
char has_unmasked = 0;
char has_masked = 0;
char has_nonzero = 0;
const global char * mask_base = (const global char *) mask_void + mask_offset +
(ulong) mask_batch_idx * mask_nb3 +
(ulong) mask_head_idx * mask_nb2;
for (int qi = 0; qi < q_count; ++qi) {
const global half * mask_row = (const global half *) (mask_base + (ulong) (q_start + qi) * mask_nb1) + k_start;
for (int ki = 0; ki < k_count; ++ki) {
const half v = mask_row[ki];
if (v <= neg_max_half) {
has_masked = 1;
} else {
has_unmasked = 1;
if (v != (half) 0.0f) {
has_nonzero = 1;
}
}
}
if (has_masked && has_unmasked) break; // mixed tile — short-circuit.
}
char res;
if (has_unmasked == 0) {
res = 0;
} else if (has_masked || has_nonzero) {
res = 1;
} else {
res = 2;
}
blk[((ulong) mask_slice * (ulong) n_q_blocks + (ulong) q_block_idx) * (ulong) n_kv_blocks + (ulong) kv_block_idx] = res;
}
+500
View File
@@ -158,6 +158,239 @@ kernel void kernel_set_rows_f32_i32(
}
}
// f32 -> q8_0 quantize set_rows. Block = half d + char qs[32].
#define QK8_0 32
inline void quantize_q8_0_block(global float * x, global char * qs, global half * d_out) {
float amax = 0.0f;
for (int j = 0; j < QK8_0; j++) {
amax = fmax(amax, fabs(x[j]));
}
float d = amax / 127.0f;
float id = (d != 0.0f) ? 127.0f / amax : 0.0f;
vstore_half(d, 0, d_out);
for (int j = 0; j < QK8_0; j++) {
qs[j] = (char)((int)round(x[j] * id));
}
}
kernel void kernel_set_rows_q8_0_i64(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK8_0;
global char * y = dst_row + blk * (2 + QK8_0);
quantize_q8_0_block(x, y + 2, (global half *)y);
}
}
kernel void kernel_set_rows_q8_0_i32(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK8_0;
global char * y = dst_row + blk * (2 + QK8_0);
quantize_q8_0_block(x, y + 2, (global half *)y);
}
}
// SoA q8_0 variants. dst_q: int8[QK8_0] per block; dst_d: fp16 scale per block.
// Layout matches kernel_convert_block_q8_0; block index follows dst element order.
kernel void kernel_set_rows_q8_0_soa_i64(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst_q,
ulong offset_q,
global char * dst_d,
ulong offset_d,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
int ne1_dst,
int ne2_dst,
int ne3_dst
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst_q = dst_q + offset_q;
dst_d = dst_d + offset_d;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0;
global half * d_row = (global half *)(dst_d) + row_blk_base;
global char * q_row = (global char *)(dst_q) + row_blk_base * QK8_0;
global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK8_0;
global char * q = q_row + blk * QK8_0;
quantize_q8_0_block(x, q, d_row + blk);
}
}
kernel void kernel_set_rows_q8_0_soa_i32(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst_q,
ulong offset_q,
global char * dst_d,
ulong offset_d,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
int ne1_dst,
int ne2_dst,
int ne3_dst
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst_q = dst_q + offset_q;
dst_d = dst_d + offset_d;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0;
global half * d_row = (global half *)(dst_d) + row_blk_base;
global char * q_row = (global char *)(dst_q) + row_blk_base * QK8_0;
global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK8_0;
global char * q = q_row + blk * QK8_0;
quantize_q8_0_block(x, q, d_row + blk);
}
}
kernel void kernel_set_rows_f16_i32(
global char * src0,
ulong offset0,
@@ -206,3 +439,270 @@ kernel void kernel_set_rows_f16_i32(
dst_row[ind] = src_row[ind];
}
}
// f32 -> q4_0 quantize set_rows. Block = half d + uchar qs[16] (shuffled
// nibbles: qs[j] low/high = elem j / j+16).
// Dequant: val[i] = d * (nibble_i - 8)
// nblk0 = number of q4_0 blocks per row = ne00 / 32.
#define QK4_0 32
#define Q4_0_BLOCK_SIZE 18
inline void quantize_q4_0_block(global float * x, global uchar * qs, global half * d_out) {
// Find the signed value with the largest absolute magnitude (matches ggml ref).
float max = 0.0f;
float amax = 0.0f;
for (int j = 0; j < QK4_0; j++) {
float v = x[j];
float a = fabs(v);
if (a > amax) {
amax = a;
max = v;
}
}
float d = max / -8.0f;
float id = (d != 0.0f) ? 1.0f / d : 0.0f;
vstore_half(d, 0, d_out);
for (int j = 0; j < QK4_0/2; j++) {
float x0 = x[j] * id;
float x1 = x[j + QK4_0/2] * id;
int i0 = (int)(x0 + 8.5f);
int i1 = (int)(x1 + 8.5f);
if (i0 < 0) i0 = 0;
if (i0 > 15) i0 = 15;
if (i1 < 0) i1 = 0;
if (i1 > 15) i1 = 15;
qs[j] = (uchar)i0 | ((uchar)i1 << 4);
}
}
kernel void kernel_set_rows_q4_0_i64(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK4_0;
global char * y = dst_row + blk * Q4_0_BLOCK_SIZE;
global half * yd = (global half *)(y);
global uchar * yqs = (global uchar *)(y + 2);
quantize_q4_0_block(x, yqs, yd);
}
}
kernel void kernel_set_rows_q4_0_i32(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK4_0;
global char * y = dst_row + blk * Q4_0_BLOCK_SIZE;
global half * yd = (global half *)(y);
global uchar * yqs = (global uchar *)(y + 2);
quantize_q4_0_block(x, yqs, yd);
}
}
// SoA variants for q4_0 dst. Used when the backend has split block_q4_0 records
// into separate quant (dst_q) and scale (dst_d) sub-buffers same pattern as
// the q8_0 SoA variants above.
//
// Layout (matches kernel_convert_block_q4_0, the "shuffled" variant):
// dst_q: contiguous 16 packed nibbles per block, block i at offset i * 16 bytes.
// dst_d: contiguous fp16 scales, block i at offset i * 2 bytes.
// Nibble layout inside each byte is unchanged from AoS: qs[j] low nibble = element j,
// qs[j] high nibble = element j+16. kernel_restore_block_q4_0 copies bytes as-is.
kernel void kernel_set_rows_q4_0_soa_i64(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst_q,
ulong offset_q,
global char * dst_d,
ulong offset_d,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
int ne1_dst,
int ne2_dst,
int ne3_dst
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst_q = dst_q + offset_q;
dst_d = dst_d + offset_d;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0;
global half * d_row = (global half *)(dst_d) + row_blk_base;
global uchar * q_row = (global uchar *)(dst_q) + row_blk_base * (QK4_0/2);
global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK4_0;
global uchar * qs = q_row + blk * (QK4_0/2);
global half * d_bk = d_row + blk;
quantize_q4_0_block(x, qs, d_bk);
}
}
kernel void kernel_set_rows_q4_0_soa_i32(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst_q,
ulong offset_q,
global char * dst_d,
ulong offset_d,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
int ne1_dst,
int ne2_dst,
int ne3_dst
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst_q = dst_q + offset_q;
dst_d = dst_d + offset_d;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0;
global half * d_row = (global half *)(dst_d) + row_blk_base;
global uchar * q_row = (global uchar *)(dst_q) + row_blk_base * (QK4_0/2);
global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK4_0;
global uchar * qs = q_row + blk * (QK4_0/2);
global half * d_bk = d_row + blk;
quantize_q4_0_block(x, qs, d_bk);
}
}
+4
View File
@@ -1053,6 +1053,10 @@ static bool is_op_unsupported_case(const ggml_tensor * op) {
(op->ne[0] == 2 && op->ne[1] == 4 && op->ne[2] == 3 && op->ne[3] == 2)) {
return true;
}
// CPY into a strided view of a larger buffer (recurrent-state snapshots) not supported
if (op->view_src && ggml_nbytes(op) != ggml_nbytes(op->view_src)) {
return true;
}
break;
}
case GGML_OP_MUL_MAT: {
+19
View File
@@ -156,6 +156,7 @@ class Keys:
DENSE_FEAT_OUT_SIZE = "{arch}.{dense}_feat_out"
TARGET_LAYERS = "{arch}.target_layers"
TARGET_HIDDEN_SIZE = "{arch}.target_hidden_size"
BLOCK_SIZE = "{arch}.block_size"
NORM_BEFORE_RESIDUAL = "{arch}.norm_before_residual"
class Attention:
@@ -517,6 +518,7 @@ class MODEL_ARCH(IntEnum):
PANGU_EMBED = auto()
MISTRAL3 = auto()
EAGLE3 = auto()
DFLASH = auto()
MISTRAL4 = auto()
PADDLEOCR = auto()
MIMO2 = auto()
@@ -1074,6 +1076,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
MODEL_ARCH.MISTRAL3: "mistral3",
MODEL_ARCH.EAGLE3: "eagle3",
MODEL_ARCH.DFLASH: "dflash",
MODEL_ARCH.MISTRAL4: "mistral4",
MODEL_ARCH.PADDLEOCR: "paddleocr",
MODEL_ARCH.MIMO2: "mimo2",
@@ -4086,6 +4089,22 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FC,
MODEL_TENSOR.D2T,
],
MODEL_ARCH.DFLASH: [
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FC,
MODEL_TENSOR.ENC_OUTPUT_NORM,
],
MODEL_ARCH.MISTRAL4: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
+12
View File
@@ -940,6 +940,18 @@ class GGUFWriter:
def add_sliding_window(self, value: int) -> None:
self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value)
def add_block_size(self, value: int) -> None:
self.add_uint32(Keys.LLM.BLOCK_SIZE.format(arch=self.arch), value)
def add_target_layers(self, value: Sequence[int]) -> None:
self.add_array(Keys.LLM.TARGET_LAYERS.format(arch=self.arch), value)
def add_target_hidden_size(self, value: int) -> None:
self.add_uint32(Keys.LLM.TARGET_HIDDEN_SIZE.format(arch=self.arch), value)
def add_norm_before_residual(self, value: bool) -> None:
self.add_bool(Keys.LLM.NORM_BEFORE_RESIDUAL.format(arch=self.arch), value)
def add_attention_scale(self, value: float) -> None:
self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value)
+5
View File
@@ -1283,6 +1283,11 @@ class TensorNameMap:
MODEL_TENSOR.ENC_OUTPUT_NORM: (
"encoder.final_layer_norm", # t5
"layer_norm", # neobert
"model.hidden_norm", # dflash
),
MODEL_TENSOR.FC: (
"model.fc", # dflash
),
MODEL_TENSOR.CLS: (
+179
View File
@@ -0,0 +1,179 @@
{{- bos_token }}{%- if tools %}
{%- set tool_definitions %}
{{- "# Tools\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
{%- for tool in tools %}
{{- "\n" }}
{{- tool | tojson(ensure_ascii=False) }}
{%- endfor %}
{{- '\n</tools>\n\nTool usage guidelines:\n- You may call zero or more functions. If no function calls are needed, just answer normally and do not include any <function ... </function>.\n- When calling a function, return an XML object within <function ... </function> using:\n<function name="function-name"><param name="param-name">param-value</param></function>\n- param-value may be multi-line. If it contains <, & or newline characters, wrap it in a CDATA block: <param name="param-name"><![CDATA[...multi-line value...]]></param>' }}
{%- endset %}
{{- '<|im_start|>system\n' }}
{%- if messages[0].role == 'system' %}
{%- if '<tool_def_sep>' in messages[0].content %}
{{- messages[0].content.replace('<tool_def_sep>', tool_definitions) }}
{%- else %}
{{- messages[0].content + '\n\n' + tool_definitions }}
{%- endif %}
{%- else %}
{{- tool_definitions.lstrip() }}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- else %}
{%- if messages[0].role == 'system' %}
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
{%- for message in messages[::-1] %}
{%- set index = (messages|length - 1) - loop.index0 %}
{%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
{%- set ns.multi_step_tool = false %}
{%- set ns.last_query_index = index %}
{%- endif %}
{%- endfor %}
{%- for message in messages %}
{%- if message.content is string %}
{%- set content = message.content %}
{%- else %}
{%- set content = '' %}
{%- endif %}
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
{%- elif message.role == "assistant" %}
{%- set reasoning_content = '' %}
{%- if message.reasoning_content is string %}
{%- set reasoning_content = message.reasoning_content %}
{%- else %}
{%- if '</think>' in content %}
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
{%- endif %}
{%- endif %}
{%- if message.tool_calls %}
{%- set content_parts = content.split('<tool_sep>') %}
{%- set processed_content = content_parts[0] %}
{%- set tool_calls_count = message.tool_calls|length %}
{%- set tool_sep_count = content_parts|length - 1 %}
{%- set min_count = [tool_calls_count, tool_sep_count]|min %}
{%- for i in range(1, content_parts|length) %}
{%- set tool_index = i - 1 %}
{%- if tool_index < tool_calls_count %}
{%- set tool_call = message.tool_calls[tool_index] %}
{%- if tool_call.function %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{%- set single_tool_xml %}
{{- '<function name="' ~ tool_call.name ~ '">' }}
{%- if tool_call.arguments %}
{%- set args_dict = tool_call.arguments %}
{%- for param_name, param_value in args_dict.items() %}
{{- '<param name="' ~ param_name ~ '">' }}
{%- if param_value is string and ('<' in param_value or '&' in param_value or '\n' in param_value) %}
{{- '<![CDATA[' + param_value + ']]>' }}
{%- else %}
{{- param_value }}
{%- endif %}
{{- '</param>' }}
{%- endfor %}
{%- endif %}
{{- '</function>' }}
{%- endset %}
{%- set processed_content = processed_content + single_tool_xml + content_parts[i] %}
{%- else %}
{%- set processed_content = processed_content + content_parts[i] %}
{%- endif %}
{%- endfor %}
{%- if tool_calls_count > tool_sep_count %}
{%- for remaining_index in range(tool_sep_count, tool_calls_count) %}
{%- set tool_call = message.tool_calls[remaining_index] %}
{%- if tool_call.function %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{%- set remaining_tool_xml %}
{{- '<function name="' ~ tool_call.name ~ '">' }}
{%- if tool_call.arguments %}
{%- set args_dict = tool_call.arguments %}
{%- for param_name, param_value in args_dict.items() %}
{{- '<param name="' ~ param_name ~ '">' }}
{%- if param_value is string and ('<' in param_value or '&' in param_value or '\n' in param_value) %}
{{- '<![CDATA[' + param_value + ']]>' }}
{%- else %}
{{- param_value }}
{%- endif %}
{{- '</param>' }}
{%- endfor %}
{%- endif %}
{{- '</function>' }}
{%- endset %}
{%- set processed_content = processed_content + remaining_tool_xml %}
{%- endfor %}
{%- endif %}
{%- set content = processed_content %}
{%- endif %}
{%- if loop.index0 > ns.last_query_index %}
{%- if reasoning_content %}
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{%- if message.tool_calls and not has_tool_sep %}
{%- for tool_call in message.tool_calls %}
{%- if (loop.first and content) or (not loop.first) %}
{{- '\n' }}
{%- endif %}
{%- if tool_call.function %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '<function name="' ~ tool_call.name ~ '">' }}
{%- if tool_call.arguments %}
{%- set args_dict = tool_call.arguments %}
{%- for param_name, param_value in args_dict.items() %}
{{- '<param name="' ~ param_name ~ '">' }}
{%- if param_value is string and ('<' in param_value or '&' in param_value or '\n' in param_value) %}
{{- '<![CDATA[' + param_value + ']]>' }}
{%- else %}
{{- param_value }}
{%- endif %}
{{- '</param>' }}
{%- endfor %}
{%- endif %}
{{- '</function>' }}
{%- endfor %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
{{- '<|im_start|>user' }}
{%- endif %}
{{- '\n<tool_response>\n' }}
{%- if message.content is string %}
{{- content }}
{%- else %}
{{- message.content | tojson(ensure_ascii=False) }}
{%- endif %}
{{- '\n</tool_response>' }}
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- if enable_thinking is defined %}
{%- if enable_thinking is false %}
{{- '<think>\n\n</think>\n\n' }}
{%- elif enable_thinking is true %}
{{- '<think>\n' }}
{%- endif %}
{%- endif %}
{%- endif %}
+1
View File
@@ -129,6 +129,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
{ LLM_ARCH_MISTRAL3, "mistral3" },
{ LLM_ARCH_EAGLE3, "eagle3" },
{ LLM_ARCH_DFLASH, "dflash" },
{ LLM_ARCH_MISTRAL4, "mistral4" },
{ LLM_ARCH_PADDLEOCR, "paddleocr" },
{ LLM_ARCH_MIMO2, "mimo2" },
+1
View File
@@ -143,6 +143,7 @@ enum llm_arch {
LLM_ARCH_TALKIE,
LLM_ARCH_MELLUM,
LLM_ARCH_EAGLE3,
LLM_ARCH_DFLASH,
LLM_ARCH_UNKNOWN,
};
+3 -3
View File
@@ -100,10 +100,10 @@ llama_context::llama_context(
cparams.ctx_other = params.ctx_other;
}
if (model.arch == LLM_ARCH_EAGLE3) {
if (model.arch == LLM_ARCH_EAGLE3 || model.arch == LLM_ARCH_DFLASH) {
if (model.tok_embd == nullptr || model.output == nullptr) {
if (params.ctx_other == nullptr) {
throw std::runtime_error("EAGLE3 requires ctx_other to be set (this warning is normal during memory fitting)");
throw std::runtime_error(model.arch_name() + " requires ctx_other to be set (this warning is normal during memory fitting)");
}
cparams.ctx_other = params.ctx_other;
}
@@ -256,7 +256,7 @@ llama_context::llama_context(
LLAMA_LOG_INFO("%s: n_outputs_max = %u\n", __func__, cparams.n_outputs_max);
if (cparams.n_ctx_seq < hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
LLAMA_LOG_INFO("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
}
+6 -1
View File
@@ -486,7 +486,11 @@ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
mctx->set_input_k_idxs(self_k_idxs, ubatch);
mctx->set_input_v_idxs(self_v_idxs, ubatch);
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
// the mask is left unallocated when the graph only stores K/V without attending
// (e.g. DFlash's KV-injection pass)
if (self_kq_mask && self_kq_mask->buffer) {
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
if (self_k_rot) {
mctx->set_input_k_rot(self_k_rot);
@@ -904,6 +908,7 @@ void llm_graph_result::reset() {
t_logits = nullptr;
t_embd = nullptr;
t_embd_pooled = nullptr;
t_h_nextn = nullptr;
t_layer_inp.resize(LLAMA_MAX_LAYERS);
std::fill(t_layer_inp.begin(), t_layer_inp.end(), nullptr);
+5 -1
View File
@@ -291,6 +291,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params
return new llama_model_mistral3(params);
case LLM_ARCH_EAGLE3:
return new llama_model_eagle3(params);
case LLM_ARCH_DFLASH:
return new llama_model_dflash(params);
case LLM_ARCH_MIMO2:
return new llama_model_mimo2(params);
case LLM_ARCH_KIMI_LINEAR:
@@ -2494,6 +2496,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_STEP35:
case LLM_ARCH_TALKIE:
case LLM_ARCH_MELLUM:
case LLM_ARCH_DFLASH:
return LLAMA_ROPE_TYPE_NEOX;
case LLM_ARCH_QWEN2VL:
@@ -2617,7 +2620,8 @@ bool llama_model_has_encoder(const llama_model * model) {
switch (model->arch) {
case LLM_ARCH_T5:
case LLM_ARCH_T5ENCODER:
case LLM_ARCH_EAGLE3: return true;
case LLM_ARCH_EAGLE3:
case LLM_ARCH_DFLASH: return true;
default: return false;
}
}
+276
View File
@@ -0,0 +1,276 @@
#include "models.h"
#include "llama-kv-cache.h"
#include "llama-kv-cache-iswa.h"
void llama_model_dflash::load_arch_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
if (!ml.get_arr(LLM_KV_TARGET_LAYERS, target_layer_ids, false)) {
throw std::runtime_error("DFlash model requires 'target_layers' in GGUF metadata");
}
hparams.n_embd_inp_enc_impl = (uint32_t) target_layer_ids.size() * hparams.n_embd;
LLAMA_LOG_INFO("%s: DFlash extract_layers = [", __func__);
for (size_t i = 0; i < target_layer_ids.size(); ++i) {
LLAMA_LOG_INFO("%d%s", target_layer_ids[i], i + 1 < target_layer_ids.size() ? ", " : "");
}
LLAMA_LOG_INFO("]\n");
// optional interleaved sliding-window attention with per-layer pattern array.
// DFlash has a single rope, so the SWA rope == main rope.
if (ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false) && hparams.n_swa > 0) {
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer());
hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train;
hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
}
type = LLM_TYPE_UNKNOWN;
}
void llama_model_dflash::load_arch_tensors(llama_model_loader &) {
LLAMA_LOAD_LOCALS;
const int64_t n_embd_inp = hparams.n_embd_inp_enc();
fc = create_tensor(tn(LLM_TENSOR_FC, "weight"), { n_embd_inp, n_embd }, 0);
output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), { n_embd }, 0); // encoder hidden_norm (after fc)
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); // decoder final norm
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
}
}
std::unique_ptr<llm_graph_context> llama_model_dflash::build_arch_graph(const llm_graph_params & params) const {
switch (params.gtype) {
case LLM_GRAPH_TYPE_ENCODER:
return std::make_unique<graph<true>>(*this, params);
case LLM_GRAPH_TYPE_DEFAULT:
case LLM_GRAPH_TYPE_DECODER:
return std::make_unique<graph<false>>(*this, params);
default:
GGML_ABORT("invalid graph type");
};
}
template <>
ggml_tensor * llama_model_dflash::graph<true>::build_inp_embd_enc() const {
auto inp_target = std::make_unique<llm_graph_input_embd>(hparams.n_embd_inp_enc());
inp_target->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd_inp_enc(), n_tokens);
ggml_set_input(inp_target->embd);
ggml_tensor * cur = inp_target->embd;
cb(cur, "inp_embd", -1);
res->add_input(std::move(inp_target));
return cur;
}
// DFlash Encoder: processes target model features through feature fusion layer
template <>
llama_model_dflash::graph<true>::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
ggml_tensor * cur = build_inp_embd_enc();
cur = build_lora_mm(model.fc, cur);
cb(cur, "fc_out", -1);
cur = build_norm(cur, model.output_norm_enc, NULL, LLM_NORM_RMS, -1);
cb(cur, "enc_norm_out", -1);
ggml_set_output(cur);
res->t_h_nextn = cur;
ggml_build_forward_expand(gf, cur);
}
// DFlash decoder, dual-mode by batch type:
// * embd batch -> fused target features: project + inject K/V into the cache.
// * token batch -> noise-block diffusion: attend over [committed, MASK...] to generate draft tokens
template <>
llama_model_dflash::graph<false>::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
ggml_tensor * inp_pos = build_inp_pos();
// optional iSWA: pick the matching attention input
const bool use_iswa = hparams.swa_type != LLAMA_SWA_TYPE_NONE;
llm_graph_input_attn_kv * inp_attn = nullptr;
llm_graph_input_attn_kv_iswa * inp_attn_iswa = nullptr;
if (use_iswa) {
inp_attn_iswa = build_attn_inp_kv_iswa();
} else {
inp_attn = build_attn_inp_kv();
}
const float kq_scale = 1.0f/sqrtf(float(n_embd_head));
// KV cache injection
if (ubatch.embd) {
auto inp = std::make_unique<llm_graph_input_embd>(n_embd);
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
ggml_set_input(inp->embd);
ggml_tensor * inp_g = inp->embd;
cb(inp_g, "inp_g_embeddings", -1);
res->add_input(std::move(inp));
for (int il = 0; il < n_layer; ++il) {
const auto & layer = model.layers[il];
ggml_tensor * Kcur = build_lora_mm(layer.wk, inp_g);
ggml_tensor * Vcur = build_lora_mm(layer.wv, inp_g);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
Kcur = build_norm(Kcur, layer.attn_k_norm, NULL, LLM_NORM_RMS, il);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur_injected", il);
cb(Vcur, "Vcur_injected", il);
if (use_iswa) {
// route each layer's K/V to its sub-cache: SWA layers -> sliding cache, full -> dense
const bool is_swa = hparams.is_swa(il);
const auto * kv = is_swa ? inp_attn_iswa->mctx->get_swa() : inp_attn_iswa->mctx->get_base();
ggml_tensor * k_idxs = is_swa ? inp_attn_iswa->get_k_idxs_swa() : inp_attn_iswa->get_k_idxs();
ggml_tensor * v_idxs = is_swa ? inp_attn_iswa->get_v_idxs_swa() : inp_attn_iswa->get_v_idxs();
ggml_build_forward_expand(gf, kv->cpy_k(ctx0, Kcur, k_idxs, il));
ggml_build_forward_expand(gf, kv->cpy_v(ctx0, Vcur, v_idxs, il));
} else {
ggml_build_forward_expand(gf, inp_attn->mctx->cpy_k(ctx0, Kcur, inp_attn->get_k_idxs(), il));
ggml_build_forward_expand(gf, inp_attn->mctx->cpy_v(ctx0, Vcur, inp_attn->get_v_idxs(), il));
}
}
res->t_embd = inp_g;
ggml_build_forward_expand(gf, inp_g);
return;
}
// tok_embd from the target model (shared via ctx_other)
auto * tok_embd = model.tok_embd;
if (tok_embd == nullptr) {
GGML_ASSERT(cparams.ctx_other != nullptr);
const auto * model_other = llama_get_model(cparams.ctx_other);
GGML_ASSERT(model_other->tok_embd != nullptr && "DFlash decoder requires the target model's token embeddings");
tok_embd = model_other->tok_embd;
}
auto inp = std::make_unique<llm_graph_input_embd>(n_embd);
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
ggml_set_input(inp->tokens);
ggml_tensor * inpL = ggml_get_rows(ctx0, tok_embd, inp->tokens);
cb(inpL, "inp_noise_embd", -1);
res->add_input(std::move(inp));
for (int il = 0; il < n_layer; ++il) {
const auto & layer = model.layers[il];
ggml_tensor * noise_norm = build_norm(inpL, layer.attn_norm, NULL, LLM_NORM_RMS, il);
cb(noise_norm, "noise_norm", il);
ggml_tensor * Qcur = build_lora_mm(layer.wq, noise_norm);
ggml_tensor * Kcur = build_lora_mm(layer.wk, noise_norm);
ggml_tensor * Vcur = build_lora_mm(layer.wv, noise_norm);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
Qcur = build_norm(Qcur, layer.attn_q_norm, NULL, LLM_NORM_RMS, il);
Kcur = build_norm(Kcur, layer.attn_k_norm, NULL, LLM_NORM_RMS, il);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
// cache-aware, non-causal attention
ggml_tensor * cur = use_iswa
? build_attn(inp_attn_iswa, layer.wo, NULL, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il)
: build_attn(inp_attn, layer.wo, NULL, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
cur = build_norm(ffn_inp, layer.ffn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_ffn(cur,
layer.ffn_up, NULL, NULL,
layer.ffn_gate, NULL, NULL,
layer.ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "l_out", il);
inpL = cur;
}
ggml_tensor * cur = build_norm(inpL, model.output_norm, NULL, LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head from the target model (shared via ctx_other)
auto * output = model.output;
if (output == nullptr) {
GGML_ASSERT(cparams.ctx_other != nullptr);
const auto * model_other = llama_get_model(cparams.ctx_other);
GGML_ASSERT(model_other->output != nullptr && "DFlash decoder requires the target model's output projection");
output = model_other->output;
}
cur = build_lora_mm(output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
+16
View File
@@ -1122,6 +1122,22 @@ struct llama_model_eagle3 : public llama_model_base {
};
struct llama_model_dflash : public llama_model_base {
llama_model_dflash(const struct llama_model_params & params) : llama_model_base(params) {}
void load_arch_hparams(llama_model_loader & ml) override;
void load_arch_tensors(llama_model_loader & ml) override;
template <bool is_enc>
struct graph : public llm_graph_context {
graph(const llama_model & model, const llm_graph_params & params);
ggml_tensor * build_inp_embd_enc() const;
};
std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override;
};
struct llama_model_mistral4 : public llama_model_deepseek2 {
llama_model_mistral4(const struct llama_model_params & params) : llama_model_deepseek2(params) {}
// reuse load_arch_hparams and load_arch_tensors from llama_model_deepseek2
+28 -7
View File
@@ -2890,12 +2890,17 @@ struct test_cpy : public test_case {
const std::array<int64_t, 4> ne_dst;
const std::array<int64_t, 4> permute_src;
const std::array<int64_t, 4> permute_dst;
const std::array<int64_t, 4> dst_alloc; // if set, dst is a view into a larger buffer (strided)
bool _src_use_permute;
bool _dst_use_permute;
bool _src_transpose;
bool _use_dst_shape;
bool _use_dst_alloc;
std::string vars() override {
if (_use_dst_alloc) {
return VARS_TO_STR8(type_src, type_dst, ne_src, ne_dst, permute_src, permute_dst, _src_transpose, dst_alloc);
}
if (_use_dst_shape) {
return VARS_TO_STR7(type_src, type_dst, ne_src, ne_dst, permute_src, permute_dst, _src_transpose);
}
@@ -2943,12 +2948,15 @@ struct test_cpy : public test_case {
std::array<int64_t, 4> ne_dst = {-1, -1, -1, -1},
std::array<int64_t, 4> permute_src = {0, 0, 0, 0},
std::array<int64_t, 4> permute_dst = {0, 0, 0, 0},
bool transpose_src = false)
bool transpose_src = false,
std::array<int64_t, 4> dst_alloc = {0, 0, 0, 0})
: type_src(type_src), type_dst(type_dst), ne_src(ne_src), ne_dst(ne_dst), permute_src(permute_src), permute_dst(permute_dst),
dst_alloc(dst_alloc),
_src_use_permute(permute_src[0] + permute_src[1] + permute_src[2] + permute_src[3] > 0),
_dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0),
_src_transpose(transpose_src),
_use_dst_shape(ne_dst[0] >= 0 && ne_dst[1] >= 0 && ne_dst[2] >= 0 && ne_dst[3] >= 0){}
_use_dst_shape(ne_dst[0] >= 0 && ne_dst[1] >= 0 && ne_dst[2] >= 0 && ne_dst[3] >= 0),
_use_dst_alloc(dst_alloc[0] > 0){}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne_src.data());
@@ -2966,12 +2974,23 @@ struct test_cpy : public test_case {
}
std::array<int64_t, 4> dst_ne = _use_dst_shape ? ne_dst : std::array<int64_t, 4>{src->ne[0], src->ne[1], src->ne[2], src->ne[3]};
ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, dst_ne.data());
ggml_set_name(dst, "dst");
ggml_tensor * dst;
if (_dst_use_permute) {
dst = ggml_permute(ctx, dst, permute_dst[0], permute_dst[1], permute_dst[2], permute_dst[3]);
ggml_set_name(dst, "dst_permuted");
if (_use_dst_alloc) {
// view a sub-block of a larger buffer -> strided dst
ggml_tensor * dst_buf = ggml_new_tensor(ctx, type_dst, 4, dst_alloc.data());
ggml_set_name(dst_buf, "dst_buf");
dst = ggml_view_4d(ctx, dst_buf, dst_ne[0], dst_ne[1], dst_ne[2], dst_ne[3],
dst_buf->nb[1], dst_buf->nb[2], dst_buf->nb[3], 0);
ggml_set_name(dst, "dst_view");
} else {
dst = ggml_new_tensor(ctx, type_dst, 4, dst_ne.data());
ggml_set_name(dst, "dst");
if (_dst_use_permute) {
dst = ggml_permute(ctx, dst, permute_dst[0], permute_dst[1], permute_dst[2], permute_dst[3]);
ggml_set_name(dst, "dst_permuted");
}
}
ggml_tensor * out = ggml_cpy(ctx, src, dst);
@@ -8181,6 +8200,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 1, 4, 1}, {-1,-1,-1,-1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {2, 2097121, 1, 1}, {-1,-1,-1,-1}, {1, 0, 2, 3}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {2, 2, 524281, 1}, {-1,-1,-1,-1}, {1, 0, 2, 3}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {128, 2, 3, 1}, {128, 2, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, false, {128, 4, 3, 1})); // strided dst
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {128, 2, 3, 1}, {128, 2, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, false, {128, 4, 3, 1})); // strided dst
// CPY - different src/dst shapes (reshaping via CPY)
// Use permutations of {3, 5, 7, 32}. Total elements: 3*5*7*32 = 3360.
+17 -3
View File
@@ -25,7 +25,7 @@ using json = nlohmann::ordered_json;
static int main_automated_tests(void);
static void run_multiple(const std::string& dir_path, bool stop_on_first_failure, const json& input, bool use_common = false);
static void run_single(const std::string& contents, json input, bool use_common = false, const std::string & output_path = "");
static void run_single(const std::string& contents, json input, bool use_common = false, bool dump_prog = false, const std::string & output_path = "");
static std::string HELP = R"(
Usage: test-chat-template [OPTIONS] PATH_TO_TEMPLATE
@@ -35,6 +35,7 @@ Options:
--json <path> Path to the JSON input file.
--stop-on-first-fail Stop testing on the first failure (default: false).
--no-common Use direct Jinja engine instead of common chat templates (default: use common).
--dump-prog Dump the parsed program for debugging (only for single template runs).
--output <path> Path to output results (only for single template runs).
If PATH_TO_TEMPLATE is a file, runs that single template.
If PATH_TO_TEMPLATE is a directory, runs all .jinja files in that directory.
@@ -118,6 +119,7 @@ int main(int argc, char ** argv) {
std::string & json_to_use = DEFAULT_JSON;
bool stop_on_first_fail = false;
bool use_common = true;
bool dump_prog = false;
for (size_t i = 1; i < args.size(); i++) {
if (args[i] == "--help" || args[i] == "-h") {
@@ -136,6 +138,8 @@ int main(int argc, char ** argv) {
i++;
} else if (args[i] == "--no-common") {
use_common = false;
} else if (args[i] == "--dump-prog") {
dump_prog = true;
} else if (tmpl_path.empty()) {
tmpl_path = args[i];
} else {
@@ -172,7 +176,7 @@ int main(int argc, char ** argv) {
std::string contents = std::string(
std::istreambuf_iterator<char>(infile),
std::istreambuf_iterator<char>());
run_single(contents, input_json, use_common, output_path);
run_single(contents, input_json, use_common, dump_prog, output_path);
} else {
std::cerr << "Error: PATH_TO_TEMPLATE is not a valid file or directory: " << tmpl_path << "\n";
return 1;
@@ -276,11 +280,21 @@ static jinja::value_string format_using_direct_engine(
}
void run_single(const std::string& contents, json input, bool use_common, const std::string & output_path) {
void run_single(const std::string& contents, json input, bool use_common, bool dump_prog, const std::string & output_path) {
jinja::enable_debug(true);
jinja::value_string output_parts;
if (dump_prog) {
jinja::lexer lexer;
auto lexer_res = lexer.tokenize(contents);
jinja::program ast = jinja::parse_from_tokens(lexer_res);
std::string prog_dump = jinja::runtime::debug_dump_program(ast, contents);
std::cout << "\n=== DUMPED PROGRAM ===\n";
std::cout << prog_dump << "\n";
return;
}
if (use_common) {
std::string bos_token = "<s>";
std::string eos_token = "</s>";
+78
View File
@@ -5593,6 +5593,77 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.expect_content("Hello, world!\nWhat's up?")
.run();
}
// MiniCPM5 - XML tool calls with <function name="..."><param name="...">...</param></function>
{
auto tst = peg_tester("models/templates/openbmb-MiniCPM5-1B.jinja", detailed_debug);
tst.test("Hello, world!\nWhat's up?")
.enable_thinking(false)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.expect(message_assist)
.run();
tst.test(R"(<function name="python"><param name="code">print('Hello, World!')</param></function>)")
.enable_thinking(false)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({ python_tool })
.expect_tool_calls({ { "python", R"#({"code": "print('Hello, World!')"})#", {} } })
.run();
tst.test(R"(<function name="empty_args"></function>)")
.enable_thinking(false)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({ empty_args_tool })
.expect(simple_assist_msg("", "", "empty_args", "{}"))
.run();
tst.test(R"(<function name="python"><param name="code">print('x')</param></function>)")
.enable_thinking(false)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.parallel_tool_calls(true)
.tools({ python_tool })
.expect_tool_calls({ { "python", R"#({"code": "print('x')"})#", {} } })
.run();
// CDATA lets a string value carry characters that would otherwise close the tag.
tst.test(R"(<function name="html"><param name="markup"><![CDATA[<a href="/x">hi</a> </param>]]></param></function>)")
.enable_thinking(false)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({ html_tool })
.expect_tool_calls({ { "html", R"#({"markup": "<a href=\"/x\">hi</a> </param>"})#", {} } })
.run();
tst.test(R"(I'm thinking</think><function name="python"><param name="code">print('hey')</param></function>)")
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({ python_tool })
.expect_reasoning("I'm thinking")
.expect_tool_calls({ { "python", R"#({"code": "print('hey')"})#", {} } })
.run();
tst.test(R"(<function name="python"><param name="code">print('x')</param></function>
<function name="python"><param name="code">print('y')</param></function>)")
.enable_thinking(false)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.parallel_tool_calls(true)
.tools({ python_tool })
.expect_tool_calls({
{ "python", R"#({"code": "print('x')"})#", {} },
{ "python", R"#({"code": "print('y')"})#", {} },
})
.run();
tst.test(" thinking</think>Hello, world!\nWhat's up?")
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.messages({ message_user, message_assist_prefill_reasoning })
.add_generation_prompt(false)
.continue_final_message(COMMON_CHAT_CONTINUATION_REASONING)
.expect_reasoning("I'm thinking")
.expect_content("Hello, world!\nWhat's up?")
.run();
}
}
static void test_template_generation_prompt() {
@@ -5740,6 +5811,13 @@ static void test_template_generation_prompt() {
check(tmpls, continuation_content(), "<Assistant><think>I'm thinking</think>Hello, ");
check(tmpls, continuation_reasoning(), "<Assistant><think>I'm");
}
{
auto tmpls = read_templates("models/templates/openbmb-MiniCPM5-1B.jinja");
check(tmpls, basic(), "<|im_start|>assistant\n<think>\n");
check(tmpls, continuation_content(), "<|im_start|>assistant\n<think>\nI'm thinking\n</think>\n\nHello, ");
check(tmpls, continuation_reasoning(), "<|im_start|>assistant\n<think>\nI'm");
}
}
// Test the developer role to system workaround with a simple mock template
+30
View File
@@ -1584,6 +1584,36 @@ static void test_array_methods(testing & t) {
"6"
);
test_template(t, "array|min",
"{{ [tool_calls_count, tool_sep_count]|min }}",
{{"tool_calls_count", 2}, {"tool_sep_count", 1}},
"1"
);
test_template(t, "array|max",
"{{ [tool_calls_count, tool_sep_count]|max }}",
{{"tool_calls_count", 2}, {"tool_sep_count", 1}},
"2"
);
test_template(t, "array|min attribute",
"{{ items|min(attribute='x') }}",
{{"items", json::array({
json({{"x", 2}}),
json({{"x", 1}}),
})}},
"{'x': 1}"
);
test_template(t, "array|max attribute",
"{{ items|max(attribute='x') }}",
{{"items", json::array({
json({{"x", 2}}),
json({{"x", 1}}),
})}},
"{'x': 2}"
);
// not used by any chat templates
// test_template(t, "array.insert()",
// "{% set _ = arr.insert(1, 'x') %}{{ arr|join(',') }}",
+2 -2
View File
@@ -451,7 +451,7 @@ static int save_models(const llm_arch target_arch, const size_t seed, const ggml
if (arch == LLM_ARCH_GEMMA4 || arch == LLM_ARCH_GEMMA4_ASSISTANT) {
continue; // FIXME: ISWA KV cache initialization needs more fixture params
}
if (arch == LLM_ARCH_EAGLE3) {
if (arch == LLM_ARCH_EAGLE3 || arch == LLM_ARCH_DFLASH) {
continue;
}
for (bool moe : {false, true}) {
@@ -557,7 +557,7 @@ static int test_backends(const llm_arch target_arch, const size_t seed, const gg
if (arch == LLM_ARCH_GEMMA4 || arch == LLM_ARCH_GEMMA4_ASSISTANT) {
continue; // FIXME: ISWA KV cache initialization needs more fixture params
}
if (arch == LLM_ARCH_EAGLE3) {
if (arch == LLM_ARCH_EAGLE3 || arch == LLM_ARCH_DFLASH) {
continue;
}
+46 -42
View File
@@ -106,7 +106,6 @@ struct server_batch {
if ((int32_t)tokens.size() >= n_tokens_alloc) {
return false;
}
// LOG_INF("adding token to batch: slot=%d, token=%d, pos=%d, output=%d\n", id_slot, token, pos, output);
tokens.push_back({ id_slot, token, pos, output });
return true;
}
@@ -228,7 +227,7 @@ struct server_slot {
const size_t cur_size = cur_size_tgt + cur_size_dft;
SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB (draft: %.3f MiB)\n",
SRV_TRC(" - saving prompt with length %d, total state size = %.3f MiB (draft: %.3f MiB)\n",
(int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0), cur_size_dft / (1024.0 * 1024.0));
auto * cur = prompt_cache.alloc(prompt, cur_size_tgt, cur_size_dft);
@@ -258,7 +257,7 @@ struct server_slot {
GGML_ASSERT(!is_processing());
}
SLT_INF(*this, "clearing prompt with %zu tokens\n", prompt.tokens.size());
SLT_TRC(*this, "clearing prompt with %zu tokens\n", prompt.tokens.size());
common_context_seq_rm(ctx_tgt, id, -1, -1);
if (ctx_dft) {
@@ -627,8 +626,10 @@ struct server_slot {
}
SLT_INF(*this,
"draft acceptance = %0.5f (%5d accepted / %5d generated), mean acceptance length = %5.2f, acceptance rate per position = (%s)\n",
draft_ratio, n_draft_accepted, n_draft_total, mean_acc_len, acceptance_rates_per_pos.c_str());
"draft acceptance = %0.5f (%5d accepted / %5d generated), mean len = %5.2f\n",
draft_ratio, n_draft_accepted, n_draft_total, mean_acc_len);
SLT_TRC(*this,
" acc per pos = (%s)\n", acceptance_rates_per_pos.c_str());
}
common_speculative_print_stats(spec);
@@ -771,7 +772,7 @@ struct server_slot {
}
// TODO @ngxson : move this log line to debug when it become more stable
SLT_INF(*this, "encoding mtmd batch from idx = %zu, n_chunks = %d\n", idx, n_added);
SLT_TRC(*this, "encoding mtmd batch from idx = %zu, n_chunks = %d\n", idx, n_added);
res = mtmd_batch_encode(mbatch.get());
if (res != 0) {
@@ -1032,7 +1033,8 @@ private:
}
SRV_INF("loading model '%s'\n", params.model.path.c_str());
SRV_INF("loading model '%s'\n", params.model.get_name().c_str());
SRV_TRC("local path '%s'\n", params.model.path.c_str());
std::string & mmproj_path = params_base.mmproj.path;
mtmd_context_params mparams = mtmd_context_params_default();
@@ -1061,7 +1063,7 @@ private:
for (auto & [dev, size] : mmproj_mem) {
total += size;
}
SRV_INF("[mtmd] estimated worst-case memory usage of mmproj is %.2f MiB (took %.2f ms)\n", total / (1024.0 * 1024.0), t_elapsed / 1000.0);
SRV_TRC("[mtmd] estimated worst-case memory usage of mmproj is %.2f MiB (took %.2f ms)\n", total / (1024.0 * 1024.0), t_elapsed / 1000.0);
GGML_ASSERT(!params_base.fit_params_target.empty());
for (auto & [dev, size] : mmproj_mem) {
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
@@ -1141,7 +1143,7 @@ private:
}
}
}
SRV_INF("[spec] estimated memory usage of %s is %.2f MiB\n",
SRV_TRC("[spec] estimated memory usage of %s is %.2f MiB\n",
has_draft ? "draft model" : "MTP context",
total / (1024.0 * 1024.0));
} catch (const std::exception & e) {
@@ -1177,7 +1179,7 @@ private:
// TODO speculative: move to common/speculative.cpp?
const auto & params_spec = params_base.speculative.draft;
SRV_INF("loading draft model '%s'\n", params_spec.mparams.path.c_str());
SRV_TRC("loading draft model '%s'\n", params_spec.mparams.path.c_str());
auto params_dft = params_base;
@@ -1229,7 +1231,7 @@ private:
// no new model load, so we simply report 0.0 and 1.0 progress
load_progress_callback(0.0f, &load_progress_spec);
SRV_INF("creating MTP draft context against the target model '%s'\n",
SRV_TRC("creating MTP draft context against the target model '%s'\n",
params_base.model.path.c_str());
auto cparams_mtp = common_context_params_to_llama(params_base);
@@ -1303,9 +1305,6 @@ private:
// Necessary similarity of prompt for slot selection
slot_prompt_similarity = params_base.slot_prompt_similarity;
// setup slots
SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
const int n_ctx_train = llama_model_n_ctx_train(model_tgt);
int n_ctx_slot = llama_n_ctx_seq(ctx_tgt);
@@ -1322,9 +1321,13 @@ private:
}
if (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
SRV_WRN("%s", "speculative decoding will use checkpoints\n");
SRV_TRC("%s", "speculative decoding will use checkpoints\n");
}
// setup slots
SRV_INF("initializing, n_slots = %d, n_ctx_slot = %d, kv_unified = '%s'\n",
params_base.n_parallel, n_ctx_slot, params_base.kv_unified ? "true" : "false");
// initialize slots
for (int i = 0; i < params_base.n_parallel; i++) {
slots.emplace_back();
@@ -1344,7 +1347,7 @@ private:
}
if (spec) {
SRV_INF("%s", "speculative decoding context initialized\n");
SRV_TRC("%s", "speculative decoding context initialized\n");
} else {
ctx_dft.reset();
}
@@ -1361,7 +1364,7 @@ private:
slot.mctx = mctx;
slot.prompt.tokens.has_mtmd = mctx != nullptr;
SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
SLT_TRC(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
slot.callback_on_release = [this](int id_slot) {
queue_tasks.pop_deferred_task(id_slot);
@@ -1397,23 +1400,23 @@ private:
if (params_base.cache_ram_mib != 0) {
if (params_base.cache_ram_mib < 0) {
SRV_INF("prompt cache is enabled, size limit: %s\n", "no limit");
SRV_TRC("prompt cache is enabled, size limit: %s\n", "no limit");
} else {
SRV_INF("prompt cache is enabled, size limit: %d MiB\n", params_base.cache_ram_mib);
SRV_TRC("prompt cache is enabled, size limit: %d MiB\n", params_base.cache_ram_mib);
}
SRV_INF("%s", "use `--cache-ram 0` to disable the prompt cache\n");
SRV_TRC("%s", "use `--cache-ram 0` to disable the prompt cache\n");
prompt_cache = std::make_unique<server_prompt_cache>(params_base.cache_ram_mib, n_ctx);
} else {
SRV_INF("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n");
SRV_TRC("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n");
}
SRV_INF("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n");
SRV_TRC("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n");
if (params_base.n_ctx_checkpoints > 0) {
SRV_INF("context checkpoints enabled, max = %d, min spacing = %d\n",
SRV_TRC("context checkpoints enabled, max = %d, min spacing = %d\n",
params_base.n_ctx_checkpoints, params_base.checkpoint_min_step);
} else {
SRV_INF("%s", "context checkpoints disabled\n");
SRV_TRC("%s", "context checkpoints disabled\n");
}
if (!params_base.model_alias.empty()) {
@@ -1470,11 +1473,11 @@ private:
params_base.cache_idle_slots = false;
} else {
if (params_base.kv_unified) {
SRV_INF("%s", "idle slots will be saved to prompt cache and cleared upon starting a new task\n");
SRV_TRC("%s", "idle slots will be saved to prompt cache and cleared upon starting a new task\n");
} else {
// without a unified KV cache, clearing a slot frees no reusable room, so we only
// publish a RAM-cache copy of idle slots (their KV stays in VRAM) [TAG_IDLE_SLOT_CLEAR]
SRV_INF("%s", "idle slots will be saved to prompt cache upon starting a new task\n");
SRV_TRC("%s", "idle slots will be saved to prompt cache upon starting a new task\n");
}
SRV_DBG("%s", "__TEST_TAG_CACHE_IDLE_SLOTS_ENABLED__\n");
}
@@ -1500,7 +1503,7 @@ private:
try {
chat_templates = common_chat_templates_init(model_tgt, params_base.chat_template);
LOG_INF("%s: chat template, example_format: '%s'\n", __func__,
SRV_TRC("%s: chat template, example_format: '%s'\n", __func__,
common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str());
} catch (const std::exception & e) {
@@ -1515,7 +1518,7 @@ private:
// 2. The chat template supports it
const bool template_supports_thinking = params_base.use_jinja && common_chat_templates_support_enable_thinking(chat_templates.get());
const bool enable_thinking = params_base.enable_reasoning != 0 && template_supports_thinking;
SRV_INF("%s: chat template, thinking = %d\n", __func__, enable_thinking);
SRV_TRC("%s: chat template, thinking = %d\n", __func__, enable_thinking);
// IMPORTANT: chat_params is reused across sleeping / resuming states,
// never store llama_context/llama_model pointers in chat_params,
@@ -1658,7 +1661,7 @@ private:
update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION;
if (update_cache) {
SRV_INF("%s", "updating prompt cache\n");
SRV_TRC("%s", "updating prompt cache\n");
const int64_t t_start = ggml_time_us();
@@ -1670,7 +1673,7 @@ private:
prompt_cache->update();
SRV_INF("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
SRV_TRC("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
}
}
@@ -2290,7 +2293,7 @@ private:
int id_parent = parent_task.id;
SRV_INF("launching slots for parent task id_task = %d with %zu child tasks\n", id_parent, parent_task.child_tasks.size());
SRV_TRC("launching slots for parent task id_task = %d with %zu child tasks\n", id_parent, parent_task.child_tasks.size());
// to be called in case of failure to release all launched slots
auto release_slots = [this, id_parent]() {
@@ -2351,7 +2354,7 @@ private:
// stash the draft's speculative state with the checkpoint
common_speculative_get_state(spec.get(), slot.id, cur.data_spec);
SLT_INF(slot,
SLT_TRC(slot,
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min,
cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024);
@@ -2415,7 +2418,7 @@ private:
if (params_base.cache_idle_slots) {
for (auto & slot : slots) {
if (!slot.is_processing()) {
SLT_INF(slot, "%s", "saving idle slot to prompt cache\n");
SLT_TRC(slot, "%s", "saving idle slot to prompt cache\n");
if (slot.prompt_save(*prompt_cache)) {
SLT_DBG(slot, "%s", "__TEST_TAG_CACHE_IDLE_SLOT__\n");
@@ -2447,6 +2450,8 @@ private:
server_slot * slot = get_slot_by_cmpl_id(task.params.control_cmpl_id);
if (slot == nullptr) {
SRV_WRN("control %s on unknown completion id=%s, no live slot\n",
task.params.control_action.c_str(), task.params.control_cmpl_id.c_str());
res->success = false;
res->message = "no active completion for this id";
queue_results.send(std::move(res));
@@ -2671,7 +2676,7 @@ private:
auto new_loras = construct_lora_list(task.set_lora);
// logging
for (size_t i = 0; i < new_loras.size(); ++i) {
SRV_INF("set lora adapter idx=%zu scale=%f\n", i, new_loras[i].scale);
SRV_TRC("set lora adapter idx=%zu scale=%f\n", i, new_loras[i].scale);
}
// TODO @ngxson : make lora_adapters a dedicated member of server_context
params_base.lora_adapters = new_loras;
@@ -2771,7 +2776,7 @@ private:
}
if (all_idle) {
SRV_INF("%s", "all slots are idle\n");
SRV_TRC("%s", "all slots are idle\n");
return; // skip further processing
} else {
@@ -3287,10 +3292,9 @@ private:
const auto it = std::find_if(
slot.prompt.checkpoints.rbegin(),
slot.prompt.checkpoints.rend(),
[&, func_name = __func__](const auto & cur) {
[&](const auto & cur) {
// guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
LOG_INF("slot %12.*s: id %2d | task %d | Checking checkpoint with [%d, %d] against %d...\n", 12,
func_name, (slot).id, ((slot).task ? (slot).task->id : -1), cur.pos_min, cur.pos_max, pos_min_thold);
SLT_TRC(slot, "checking checkpoint with [%d, %d] against %d...\n", cur.pos_min, cur.pos_max, pos_min_thold);
// workaround for [TAG_CHECKPOINTS_FIX_POS_MIN]
if (cur.pos_max > pos_next) {
return false;
@@ -3310,11 +3314,11 @@ private:
pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max));
n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens);
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) it->size() / 1024 / 1024);
SLT_TRC(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) it->size() / 1024 / 1024);
}
if (do_reset) {
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
SLT_TRC(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
pos_next = 0;
n_past = 0;
@@ -3327,7 +3331,7 @@ private:
for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) {
const auto & cur = *it;
if (cur.pos_max > pos_next) {
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.size() / 1024 / 1024);
SLT_TRC(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.size() / 1024 / 1024);
it = slot.prompt.checkpoints.erase(it);
} else {
++it;
@@ -3674,7 +3678,7 @@ private:
// all children slots should already launched by launch_slots_with_parent_task()
// copy state to the child slots
for (auto & child : children) {
SLT_INF(slot, " - copying state to child %d\n", child->id);
SLT_TRC(slot, " - copying state to child %d\n", child->id);
GGML_ASSERT(child->state == SLOT_STATE_WAIT_OTHER);
+8 -8
View File
@@ -83,7 +83,7 @@ bool server_http_context::init(const common_params & params) {
hostname = params.hostname;
if (gcp.enabled) {
SRV_INF("Google Cloud Platform compat: health route = %s, predict route = %s, port = %d\n", gcp.path_health.c_str(), gcp.path_predict.c_str(), gcp.port);
SRV_TRC("Google Cloud Platform compat: health route = %s, predict route = %s, port = %d\n", gcp.path_health.c_str(), gcp.path_predict.c_str(), gcp.port);
if (port != gcp.port) {
SRV_WRN("Google Cloud Platform compat: overriding server port %d with AIP_HTTP_PORT %d\n", port, gcp.port);
@@ -96,13 +96,13 @@ bool server_http_context::init(const common_params & params) {
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
if (!params.ssl_file_key.empty() && !params.ssl_file_cert.empty()) {
SRV_INF("running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str());
SRV_TRC("running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str());
srv = std::make_unique<httplib::SSLServer>(
params.ssl_file_cert.c_str(), params.ssl_file_key.c_str()
);
is_ssl = true;
} else {
SRV_INF("%s", "running without SSL\n");
SRV_TRC("%s", "running without SSL\n");
srv = std::make_unique<httplib::Server>();
}
#else
@@ -165,9 +165,9 @@ bool server_http_context::init(const common_params & params) {
if (params.api_keys.size() == 1) {
const auto key = params.api_keys[0];
const std::string substr = key.substr(std::max(static_cast<int>(key.length() - 4), 0));
SRV_INF("api_keys: ****%s\n", substr.c_str());
SRV_TRC("api_keys: ****%s\n", substr.c_str());
} else if (params.api_keys.size() > 1) {
SRV_INF("api_keys: %zu keys loaded\n", params.api_keys.size());
SRV_TRC("api_keys: %zu keys loaded\n", params.api_keys.size());
}
//
@@ -293,7 +293,7 @@ bool server_http_context::init(const common_params & params) {
// +4 threads for monitoring, health and some threads reserved for MCP and other tasks in the future
n_threads_http = std::max(params.n_parallel + 4, static_cast<int32_t>(std::thread::hardware_concurrency() - 1));
}
SRV_INF("using %d threads for HTTP server\n", n_threads_http);
SRV_TRC("using %d threads for HTTP server\n", n_threads_http);
srv->new_task_queue = [n_threads_http] {
// spawn n_threads_http fixed thread (always alive), while allow up to 1024 max possible additional threads
// when n_threads_http is used, server will create new "dynamic" threads that will be destroyed after processing each request
@@ -412,13 +412,13 @@ bool server_http_context::start() {
auto is_sock = false;
if (string_ends_with(std::string(hostname), ".sock")) {
is_sock = true;
SRV_INF("%s", "setting address family to AF_UNIX\n");
SRV_TRC("%s", "setting address family to AF_UNIX\n");
srv->set_address_family(AF_UNIX);
// bind_to_port requires a second arg, any value other than 0 should
// simply get ignored
was_bound = srv->bind_to_port(hostname, 8080);
} else {
SRV_INF("%s", "binding port with default address family\n");
SRV_TRC("%s", "binding port with default address family\n");
// bind HTTP listen port
if (port == 0) {
const auto bound_port = srv->bind_to_any_port(hostname);
+4 -1
View File
@@ -1983,7 +1983,10 @@ void server_models_routes::init_routes() {
cli.set_read_timeout(0, STREAM_LOOKUP_TIMEOUT_MS * 1000);
cli.set_write_timeout(0, STREAM_LOOKUP_TIMEOUT_MS * 1000);
auto resp = cli.Delete(child_path.c_str());
(void) resp; // best effort, 404 and network errors are equivalent to no op
(void) resp; // the child logs its own miss when the session is unknown there
} else {
SRV_WRN("router stop for unknown conv_id=%s, no owning child in the conv map\n",
conv_id.c_str());
}
// drop the tracking entry, the session is being torn down
models.conv_models.forget(conv_id);
+1 -1
View File
@@ -287,7 +287,7 @@ std::vector<std::unique_ptr<field>> make_llama_cmpl_schema(const common_params &
->set_desc("Chat format used internally by the server")
->set_handler([&](field_eval_context & ctx, const json & data) {
ctx.params.chat_parser_params.format = static_cast<common_chat_format>(data.at("chat_format").get<int>());
SRV_INF("Chat format: %s\n", common_chat_format_name(ctx.params.chat_parser_params.format));
SRV_TRC("chat format: %s\n", common_chat_format_name(ctx.params.chat_parser_params.format));
}));
add((new field_str("reasoning_format"))
+12 -6
View File
@@ -218,6 +218,13 @@ void stream_session_manager::evict_and_cancel(const std::string & conversation_i
std::unique_lock<std::shared_mutex> lock(map_mu);
auto it = sessions.find(conversation_id);
if (it == sessions.end()) {
std::string live;
for (const auto & kv : sessions) {
if (!live.empty()) live += ", ";
live += kv.first;
}
SRV_WRN("stop on unknown stream session, conv_id=%s matched nothing, %zu live: [%s]\n",
conversation_id.c_str(), sessions.size(), live.c_str());
return;
}
s = it->second;
@@ -339,11 +346,11 @@ void stream_pipe_producer::close() {
// httplib bails its content provider the moment is_peer_alive() goes false, so pump the rest
// of the generation into the ring buffer here. a DELETE flips is_cancelled and cuts it short
if (done_ || session_->is_cancelled()) {
SRV_INF("stream_pipe close: skip drain (done=%d cancelled=%d) conv=%s\n",
SRV_TRC("stream_pipe close: skip drain (done=%d cancelled=%d) conv=%s\n",
done_ ? 1 : 0, session_->is_cancelled() ? 1 : 0, session_->conversation_id.c_str());
return;
}
SRV_INF("stream_pipe close: draining conv=%s\n", session_->conversation_id.c_str());
SRV_TRC("stream_pipe close: draining conv=%s\n", session_->conversation_id.c_str());
size_t drained = 0;
std::string chunk;
while (true) {
@@ -357,7 +364,7 @@ void stream_pipe_producer::close() {
break;
}
}
SRV_INF("stream_pipe close: drain ended conv=%s bytes=%zu\n", session_->conversation_id.c_str(), drained);
SRV_TRC("stream_pipe close: drain ended conv=%s bytes=%zu\n", session_->conversation_id.c_str(), drained);
}
std::shared_ptr<stream_pipe_producer> stream_pipe_producer::create(stream_session_ptr session,
@@ -520,7 +527,7 @@ server_http_context::handler_t make_stream_delete_handler() {
if (conv_id.empty()) {
return make_error_response(400, "Missing conversation id in path", ERROR_TYPE_INVALID_REQUEST);
}
SRV_INF("DELETE /v1/stream/%s -> evict_and_cancel\n", conv_id.c_str());
SRV_TRC("DELETE /v1/stream/%s -> evict_and_cancel\n", conv_id.c_str());
g_stream_sessions.evict_and_cancel(conv_id);
auto res = std::make_unique<server_http_res>();
res->status = 204;
@@ -550,8 +557,7 @@ std::string stream_conv_id_from_headers(const std::map<std::string, std::string>
void stream_session_attach_pipe(server_http_res & res, const std::map<std::string, std::string> & headers) {
std::string conversation_id = stream_conv_id_from_headers(headers);
SRV_INF("stream_session_attach_pipe: conv_id=%s (empty=%d)\n",
conversation_id.c_str(), conversation_id.empty() ? 1 : 0);
SRV_TRC("conv_id=%s (empty=%d)\n", conversation_id.c_str(), conversation_id.empty() ? 1 : 0);
if (conversation_id.empty()) {
return;
}
+6 -6
View File
@@ -1626,7 +1626,7 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens);
if (cur_lcp_len == (int) prompt.tokens.size()) {
SRV_INF("%s", " - prompt is already in the cache, skipping\n");
SRV_TRC("%s", " - prompt is already in the cache, skipping\n");
return nullptr;
}
}
@@ -1636,7 +1636,7 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
const int len = it->tokens.get_common_prefix(prompt.tokens);
if (len == (int) it->tokens.size()) {
SRV_WRN(" - removing obsolete cached prompt with length %d\n", len);
SRV_TRC(" - removing obsolete cached prompt with length %d\n", len);
it = states.erase(it);
} else {
@@ -1681,7 +1681,7 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok
float f_keep_best = prompt.tokens.size() > 0 ? float(lcp_best) / prompt.tokens.size() : -1.0f; // empty slot: any cache entry wins
float sim_best = float(lcp_best) / tokens_new.size();
SRV_INF(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
SRV_TRC(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
auto it_best = states.end();
@@ -1706,7 +1706,7 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok
}
if (it_best != states.end()) {
SRV_INF(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
SRV_TRC(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
{
auto & data = it_best->data.main;
@@ -1783,11 +1783,11 @@ void server_prompt_cache::update() {
}
}
SRV_INF(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n",
SRV_TRC(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n",
states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur);
for (const auto & state : states) {
SRV_INF(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n",
SRV_TRC(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n",
(const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
}
}
+4 -8
View File
@@ -124,7 +124,7 @@ int llama_server(int argc, char ** argv) {
}
if (params.n_parallel < 0) {
SRV_INF("%s", "n_parallel is set to auto, using n_parallel = 4 and kv_unified = true\n");
SRV_TRC("%s", "n_parallel is set to auto, using n_parallel = 4 and kv_unified = true\n");
params.n_parallel = 4;
params.kv_unified = true;
@@ -338,7 +338,7 @@ int llama_server(int argc, char ** argv) {
std::function<void()> clean_up;
if (is_router_server) {
SRV_INF("%s", "starting router server, no model will be loaded in this process\n");
SRV_INF("%s", "starting server in router mode. models will be automatically loaded on-demand\n");
clean_up = [&models_routes]() {
SRV_INF("%s: cleaning up before exit...\n", __func__);
@@ -391,9 +391,6 @@ int llama_server(int argc, char ** argv) {
});
}
// load the model
SRV_INF("%s", "loading model\n");
if (!ctx_server.load_model(params)) {
clean_up();
if (ctx_http.thread.joinable()) {
@@ -429,8 +426,9 @@ int llama_server(int argc, char ** argv) {
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif
SRV_INF("listening on %s\n", ctx_http.listening_address.c_str());
if (is_router_server) {
SRV_INF("router server is listening on %s\n", ctx_http.listening_address.c_str());
SRV_WRN("%s", "NOTE: router mode is experimental\n");
SRV_WRN("%s", " it is not recommended to use this mode in untrusted environments\n");
@@ -446,8 +444,6 @@ int llama_server(int argc, char ** argv) {
// when the HTTP server stops, clean up and exit
clean_up();
} else {
SRV_INF("server is listening on %s\n", ctx_http.listening_address.c_str());
// optionally, notify router server that this instance is ready
std::thread monitor_thread;
if (child.is_child()) {
+28 -8
View File
@@ -154,7 +154,13 @@ class ChatStore {
});
if (convId === conversationsStore.activeConversation?.id) this.currentResponse = response;
}
private clearChatStreaming(convId: string): void {
private clearChatStreaming(convId: string, messageId?: string): void {
// session aware: a stale generation must not wipe a newer one's streaming state on the
// same conversation, that would drop the frozen stop identity and stop the wrong session
if (messageId !== undefined) {
const cur = this.chatStreamingStates.get(convId);
if (cur && cur.messageId !== messageId) return;
}
this.chatStreamingStates.delete(convId);
if (convId === conversationsStore.activeConversation?.id) this.currentResponse = '';
}
@@ -1055,11 +1061,14 @@ class ChatStore {
modelOverride?: string | null,
firstUserMessageContent?: string
): Promise<void> {
let effectiveModel = modelOverride;
// the ::model suffix in the stream identity is only for router mode, where it routes to the
// owning child. in single-model mode the identity stays the bare conv id so that attach, stop
// and reattach all agree, regardless of fresh send vs regenerate passing a resolved model
let effectiveModel: string | null | undefined = undefined;
if (isRouterMode() && !effectiveModel) {
if (isRouterMode()) {
const conversationModel = this.getConversationModel(allMessages);
effectiveModel = selectedModelName() || conversationModel;
effectiveModel = modelOverride || selectedModelName() || conversationModel;
}
if (isRouterMode() && effectiveModel) {
@@ -1074,6 +1083,9 @@ class ChatStore {
let resolvedModel: string | null = null;
let modelPersisted = false;
const convId = assistantMessage.convId;
// freeze the POST identity from t0 so a stop cancels with the exact session key,
// never a stale or empty model resolved later
this.setChatStreaming(convId, streamedContent, currentMessageId, effectiveModel);
const recordModel = (modelName: string | null | undefined, persistImmediately = true): void => {
if (!modelName) return;
@@ -1103,7 +1115,7 @@ class ChatStore {
};
const updateStreamingUI = () => {
this.setChatStreaming(convId, streamedContent, currentMessageId);
this.setChatStreaming(convId, streamedContent, currentMessageId, effectiveModel);
const idx = conversationsStore.findMessageIndex(currentMessageId);
conversationsStore.updateMessageAtIndex(idx, { content: streamedContent });
};
@@ -1111,7 +1123,7 @@ class ChatStore {
const cleanupStreamingState = () => {
this.setStreamingActive(false);
this.setChatLoading(convId, false);
this.clearChatStreaming(convId);
this.clearChatStreaming(convId, currentMessageId);
this.setProcessingState(convId, null);
};
@@ -1128,7 +1140,7 @@ class ChatStore {
onReasoningChunk: (chunk: string) => {
streamedReasoningContent += chunk;
// mark streaming state so a stop mid-thinking can persist the partial reasoning
this.setChatStreaming(convId, streamedContent, currentMessageId);
this.setChatStreaming(convId, streamedContent, currentMessageId, effectiveModel);
const idx = conversationsStore.findMessageIndex(currentMessageId);
conversationsStore.updateMessageAtIndex(idx, {
reasoningContent: streamedReasoningContent
@@ -1405,7 +1417,7 @@ class ChatStore {
// detached drain keeps producing tokens until eos or max_tokens. use the frozen identity
// captured when the session started, not the live dropdown
const streamStateForStop = this.chatStreamingStates.get(convId);
const modelForStop = streamStateForStop?.model ?? selectedModelName();
const modelForStop = streamStateForStop?.model;
void ChatService.cancelServerStream(convId, modelForStop);
this.abortRequest(convId);
this.setChatLoading(convId, false);
@@ -1846,6 +1858,14 @@ class ChatStore {
updateStreamingContent(originalContent + appendedContent);
this.setChatReasoning(msg.convId, false);
},
onCompletionId: (id: string) => {
if (!id) return;
// refresh the message id so a later skip targets the live slot after a continue
conversationsStore.updateMessageAtIndex(conversationsStore.findMessageIndex(msg.id), {
completionId: id
});
DatabaseService.updateMessage(msg.id, { completionId: id }).catch(() => {});
},
onReasoningChunk: (chunk: string) => {
appendedReasoning += chunk;
hasReceivedContent = true;