Compare commits

...

33 Commits

Author SHA1 Message Date
lhez 4fc4ec5541 opencl: allow loading precompiled binary kernels from library (#23042)
* opencl: allow loading binary kernel

* opencl: add libdl.h

* ggml-backend-dl is in ggml, which depends backend libs, thus
  ggml-opencl cannot depend on ggml-backend-dl
* add libdl.h to break cyclic dep

* opencl: allow loading bin kernel lib

* opencl: load `gemm_moe_mxfp4_f32_ns` from kernel lib if available

* opencl: load q8_0 gemm from kernel lib

* opencl: load q4_0 moe gemm from kernel lib

* opencl: load q4_1 moe gemm from kernel lib

* opencl: load q4_k moe gemm from kernel lib

* opencl: always declare `get_adreno_bin_kernel_func_t`

* opencl: rephrase message

* opencl: fix for rebase

* opencl: update doc
2026-07-01 10:29:22 -07:00
Adrien Gallouët a6647b1a32 common : use hf primary split as model path (#25194)
Fixes #25181
2026-07-01 18:33:00 +02:00
Max Krasnyansky 13e673863b hexagon: flash attention rework (optimizations, accuracy improvements, etc) (#25085)
* hex-mm: fold mm quant tasks into the main matmul threads

* hex-mm: minor formatting fixes

* hex-mm: cleanup is_quant checks in dma dispatch

* hex-mm: fix dst-spad alignment

* hex-mm: move fp kernels in the hvx-mm-kernels header

* hex-mm: fuse with ADD

* hex-fa: factor out ukernels into separate headers and unify the rest

* hex-fa: move kernel-params compute into the host

* hex-fa: refactor vtcm alloc for consistency

* hex-fa: add support for FA_SELECT

* hex-fa: update tracing insrumentation to cover all functions

* hex-fa: update hvx fallback thresholds to recover t/g regressions

* hex-fa: update tracing instrumentation

* hex-fa: improved tracing with additional events

* hex-fa: optimize mask processing (fastdiv, etc)

* hex-fa: improve mask dma caching

* hmx-fa: change loop order to maximize mask cache hits

* hex-fa: remove over instrumentation

* hex-fa: breakdown QKV prep trace events

* hmx-fa: further mask proc optimizations

* hex-fa: mask broadcast is the common case, optimize for that

* hex-fa: use aligned loads where possible

* hex-fa: update loops to use uint32_t indices

* hmx-fa: fold vtcm init into q prep task

* hex-fa: update rest of the hmx funcs to use uint32_t

* hmx-fa: fold build_d into the main softmax loop

* hmx-fa: start kv dmas earlier

* hmx-fa: start mask dma a bit earlier

* hex-fa: precompute rows per task to avoid divs

* hmx-fa: specialize fa_o_store for f16 and f32

* hmx-fa: prelim support for Sinks

* hmx-fa: keep softmax accumulators in fp32

* hex-fa: add tanh_f16 and exp2_f16 and use that in FA

* hex-fa: use fp16 math in the hvx kernel

* hex-fa: avoid expensive float -> __fp16 cast for slopes and softcap

* hex-fa: replace most vec_exp_f32 with vec_exp2_f16

* hmx-fa: vectorize sinks update

* hex-fa: minor formatting

* hmx-fa: fold softcap loop into the tile load

* hmx-fa: use vectoralias to populate sinks

* hex-fa: remove redudant check

* hex-fa: fix vtcm size compute to use fp32 for accumulators

* hex-mm: fix trailing spaces

* hmx-fa: dont use -inf to init mask to avoid conversion overflows

* hex-fa: no need to explicitly guard -inf in the f16->f32 converter now

* hmx-fa: cleanup fa sinks handling

* hex-mm: fixed src2 stride handling when mm is fused with add

* hex-fa: make lto happy
2026-07-01 06:59:19 -07:00
Johannes Gäßler b820cc8e6f CUDA: consistent use of __restrict__ + PDL for FA (#25185) 2026-07-01 10:55:14 +02:00
ragz4125 6dbc1174b8 ggml-cpu: add AVX2 optimization for nvfp4 dot product and use UE4M3 LUT (#23961) 2026-07-01 15:31:20 +08:00
Aleksander Grygier 9d88e7cedd ui Prevent tool messages from incorrectly appending to other conversations (#25177)
* fix: Prevent tool messages from incorrectly appending to other conversations

* ui: prevent agentic loop from poisoning another conv's currNode

* ui: make editedContent a  so background recompute does not wipe in-progress edits

---------

Co-authored-by: Pascal <admin@serveurperso.com>
2026-07-01 09:25:18 +02:00
Aleksander Grygier 7af4279f45 ui: Remove PWA navigate fallback to prevent caching API endpoint requests (#25174) 2026-07-01 07:32:55 +02:00
lhez fd1a05791d opencl: initial q1_0 support (#25160)
* opencl: general q1_0 support

* opencl: add Adreno GEMM/GEMV for q1_0
2026-06-30 21:43:20 -07:00
fairydreaming 0eca4d490e cuda : prevent integer truncation and overflow errors when using KQ mask strides in flash_attn_mask_to_KV_max kernel (#24945)
Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
2026-06-30 20:47:05 +02:00
Jürgen Schmied 4f31eedb0c model : register t_layer_inp for qwen3next (#25141)
* Fix input assignment in layer processing loop

Fix DFLASH for qwen-coder-next

* add line break

Added tensor for attention normalization in Qwen3 model.
2026-06-30 17:57:14 +02:00
Pascal 799fcc04a5 common,server: handle bracketed IPv6 literals in URL authority (#25140)
* common,server: handle bracketed IPv6 literals in URL authority

Parse the [host]:port form (RFC 3986) and bracket IPv6 hosts when
formatting a URL authority: listening log, proxy Host header, proxy
log, client rebuild. The per-request remote_addr stays bare.

* common: restore unsupported scheme throw in url parser

Address @ngxson review: keep the explicit reject in port resolution so
the block stays self-contained. Non-http(s) schemes still throw (also
gated at the top of common_http_parse_url).
2026-06-30 16:16:44 +02:00
Matt Jallo 931eb37f8c CUDA: fix get_rows_back for tables with more than 65535 rows (grid-y clamp + stride) (#25103) 2026-06-30 14:16:24 +02:00
Johannes Gäßler e495d1e748 CUDA: fix Gemma E4B MTP FlashAttention (#25148)
* CUDA: fix Gemma E4B MTP FlashAttention

* remove unused template declaration
2026-06-30 14:06:54 +02:00
Kevin Liu f708a5b2ca vulkan: roll bk loop in matmul for asahi linux (#24663)
* vulkan: roll bk loop in matmul for asahi linux

* vulkan: fix inline comment

* vulkan: revert BK-loop unroll change

* vulkan: edit spirv directly for asahi roll bk loop

* vulkan: remove trailing whitespace at the end of comments
2026-06-30 12:27:38 +02:00
zduford d9df11006f HIP: use hipBLAS for dense prefill on gfx900, keep MMQ for MoE (#24588)
* HIP: keep MMQ for gfx900 MoE and Q8_0, use hipBLAS for dense K-quants

Assisted-by: GitHub Copilot CLI

* HIP: tighten conditional block to be explicitly for gfx900

* HIP: Further simplified gfx900 conditional block

* removed unnecessary comment
2026-06-30 11:51:38 +02:00
Masashi Yoshimura 6c5de1cc83 ggml-webgpu: add support for NVFP4 (#25143) 2026-06-30 17:20:04 +09:00
Oliver Simons 86b94708f2 Revert "sched : reintroduce less synchronizations during split compute (#20793)" (#25138) 2026-06-30 08:41:45 +08:00
Adrien Gallouët 6f4f53f2b7 common : dedup preset and cached model entries in /v1/models (#25131)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-06-29 17:37:23 +02:00
Ruben Ortlam 25a1d63f43 vulkan: use flops instead of weight tensor size for submission heuristic (#25005)
* vulkan: extract flops calculation into function

* use flops instead of matmul src0 tensor size for submission threshold

* use unsigned ints
2026-06-29 15:24:44 +02:00
Aman Gupta 8c146a8366 DeepSeek V4 (#24162)
* convert: add dsv4 conversion

* add basic setup

* add llm_graph_input_dsv4

* add save-load state

* add sinkhorn eps - correction by @fairydreaming

* add rope fix

* cleanup dead code

* fix bugs

* support pro model: added by @fairydreaming

* remove redundant V cache

* Chat template

* remove debugging leftovers

* Add mechanism for inlining templates based on architecture

* s/deepseek-v4-flash/deepseek4/g

* s/deepseek-v4-flash/deepseek4/g continued

* enable graph reuse

* enable FA

* fix test llama archs

* rename

* compatibility with antirez ds4 GGUFs

* simplified set_gguf_parameters() by calling super class method, replaced moe.score_func with expert_gating_func.

* reserve worst-case kv-cache

* revert max split inputs

* address review comments

* add padding to enable FA

* pad only the final value of plan.n_kv to 256

* remove built-in cpp chat template

* cont: remove cpp built-in template

* rm outdated test

* replace ggml_view_3d() with ggml_reshape_3d()

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* only support n_seq=1 for now

* remove unused var

* cont: remove unused var

* use scale bias

* use correct ptr for can_reuse

* remove gen-chat-inline-templates.py

* simplify graph reuse

* cont: cleanup

* remove unused inputs

* enable partial checkpointing

* add correct shape for kq_mask + set llama_model_n_swa to 0 for dsv4

* precompute source_idx + add comment about dummy write

* support multi-seq

* remove restored_trim_pos

* use split_equal when possible

* fix indent

* address review comments

* use LLM_KV

* fix ci

---------

Co-authored-by: Piotr Wilkin <piotr.wilkin@syndatis.com>
Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
Co-authored-by: fairydreaming <166155368+fairydreaming@users.noreply.github.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-06-29 16:58:51 +08:00
seryogakovalyov 6cb18b2f2e tools/ui: restore Tailwind scanning in ignored worktrees (#24879) 2026-06-29 10:55:52 +02:00
o7si 277a105dc8 common : remove unused regex-partial (#25118) 2026-06-29 08:48:39 +02:00
Xuan-Son Nguyen b3fed31b99 jinja, chat: add --reasoning-preserve flag (#25105)
* jinja, chat: add --reasoning-preserve flag

* correct help message
2026-06-28 23:33:51 +02:00
Aleksander Grygier dbdaece23d Revert "ui: fix accessibility for hover-gated interactive elements assisted by claude(in debugging and tests) (#24727)" (#25098) 2026-06-28 21:30:03 +02:00
Pascal 7cb8576e7c ui: fix stop and reasoning skip in single-model mode (#25084) 2026-06-28 21:06:43 +02:00
Ruixiang Wang fa72bc6826 dflash: refactor draft model conversion (#25110)
* dflash: refactor draft model conversion

* apply fix for eagle3 convert
2026-06-28 20:31:48 +02:00
Aldehir Rojas c818263f2a chat : implement minicpm5 parser (#24889)
* Add minicpm5 tool call parser

* Refactor MiniCPM5 PEG parser per review feedback

* Fix jinja min/max API to match Jinja2

* modify by review

* MiniCPM5: use autoparser for XML tool calls and fix grammar preserved-token triggers

* MiniCPM5: fix streaming tool-arg placeholder and remove alt XML markers

* skip min/max attribute tests in -py mode

* test-jinja: use real expected output for min/max attribute tests

* MiniCPM5: revert shared mapper and history fallbacks per review

Drop streaming tool-arg placeholder workarounds from the generic PEG
mapper and restore strict tool-call argument JSON parsing so MiniCPM5
support stays limited to autoparser/diff-analyzer changes.

* chat : refactor minicpm5 back to dedicated parser

* cont : simplify grammar

* cont : refactor

* cont : fixes

* cont : rename template to openbmb-MiniCPM5-1B.jinja

* cont : add message delimiters

* cont : fix tests

---------

Co-authored-by: zhangtao <zhangtao2@modelbest.cn>
Co-authored-by: 张涛 <>
2026-06-28 16:53:32 +02:00
Xuan-Son Nguyen f68a788b0b jinja: add --dump-prog for debugging (#25086)
* jinja: add --dump-prog for debugging

* Update common/jinja/runtime.cpp

Co-authored-by: Sigbjørn Skjæret <1629204+CISC@users.noreply.github.com>

---------

Co-authored-by: Sigbjørn Skjæret <1629204+CISC@users.noreply.github.com>
2026-06-28 15:50:31 +02:00
Ruixiang Wang d1b34251bc spec : add DFlash support (#22105)
* spec: add DFlash v2 support

* dflash: support sliding window attention per layer_types

* docs: add dflash section

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2026-06-28 16:01:34 +03:00
Adrien Gallouët c1a1c8ee94 common : allow --offline in llama download (#25091)
Expose the existing --offline flag to `llama download` so a script can
run it to check whether a model is already cached and ready to be served
without touching the network.

Also fix a latent use-after-free in the URL-task on_done callback:
first_path is block-scoped and was captured by reference, but invoked
after the block ends.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-06-28 12:34:11 +02:00
Georgi Gerganov 27c8bb4f63 logs : reduce v2 (#25078)
* server : reduce logs

* cont : common

* cont : spec

* cont : CMN_ -> COM_
2026-06-28 08:52:15 +03:00
Hongqiang Wang ebd048fc5e opencl: flash attention improvement (#25069)
* opencl: rework FA kernel for f16 and f32

* opencl: flash-attention prefill prepass kernels

- flash_attn_kv_pad_f16    pads the tail KV tile to a BLOCK_N multiple
- flash_attn_mask_pad_f16  pads the matching mask tile
- flash_attn_blk_f16       classifies each KV tile per query block as
                           fully masked / mixed / fully unmasked, so
                           the main kernel can skip fully-masked tiles
                           and the mask lookup for fully-unmasked ones

* opencl: FA kernels for q4_0 and q8_0

* opencl: `set_rows` for f32 to q8_0/q4_0

* opencl: dequant kernels for q4_0 and q8_0

* opencl: add FA tile tuning table with override

* opencl: wire host side for FA

* opencl: q4_0 MoE tensors are also SOA'ed

* opencl: cosmetic fix

* opencl: refactor, also clarify some code paths in comments

* opencl: fix inifity for `-cl-finite-math-only`

---------

Co-authored-by: Li He <lih@qti.qualcomm.com>
2026-06-27 15:36:06 -07:00
Gaurav Garg 0ed235ea2c [CUDA] Added a cudaMemcpy2DAsync fast path to ggml_cuda_cpy (#25057)
* [CUDA] Added a cudaMemcpy2DAsync fast path to ggml_cuda_cpy

Add a CUDA ggml_cpy fast path for same-type, same-shape strided copies that are just 2D pitched block copies.
When tensors are not fully contiguous but each row is contiguous, it now uses cudaMemcpy2DAsync instead of the slow element-wise scalar copy kernel.

This fixes the GDN recurrent snapshot update with -np 4, where rollback slots are separated by cache stride gaps.

* Add new tests that execute the new optimized strided copy path

* Return unsupported for strided copy in OpenVINO, as new tests are failing
2026-06-27 17:46:21 +05:30
142 changed files with 18430 additions and 4527 deletions
-2
View File
@@ -94,10 +94,8 @@ add_library(${TARGET}
peg-parser.h
preset.cpp
preset.h
regex-partial.cpp
reasoning-budget.cpp
reasoning-budget.h
regex-partial.h
sampling.cpp
sampling.h
speculative.cpp
+26 -10
View File
@@ -467,7 +467,7 @@ void common_models_handler_apply(common_models_handler & handler, common_params
// the first part is what gets loaded, so point params.model.path at it
if (!url_tasks.empty()) {
std::string first_path = url_tasks.front().local_path;
url_tasks.front().on_done = [&]() { params.model.path = first_path; };
url_tasks.front().on_done = [&, first_path]() { params.model.path = first_path; };
}
for (auto & task : url_tasks) {
tasks.push_back(std::move(task));
@@ -496,13 +496,15 @@ void common_models_handler_apply(common_models_handler & handler, common_params
}
// handle hf_plan tasks
auto add_tasks = [&opts, &tasks](const hf_cache::hf_files & model_files, common_params_model & model) {
auto add_tasks = [&opts, &tasks](const hf_cache::hf_files & model_files,
const hf_cache::hf_file & primary,
common_params_model & model) {
for (size_t i = 0; i < model_files.size(); ++i) {
auto & model_file = model_files[i];
bool is_first = (i == 0);
tasks.emplace_back(model_file, opts, [&, is_first]() {
if (is_first) {
// only use first part as model path
bool is_primary = (model_file.path == primary.path);
tasks.emplace_back(model_file, opts, [&, is_primary]() {
if (is_primary) {
// the primary file is the first split (00001-of), use it as model path
model.path = hf_cache::finalize_file(model_file);
} else {
hf_cache::finalize_file(model_file);
@@ -511,7 +513,7 @@ void common_models_handler_apply(common_models_handler & handler, common_params
}
};
if (!plan.model_files.empty()) {
add_tasks(plan.model_files, params.model);
add_tasks(plan.model_files, plan.primary, params.model);
}
if (!plan.mmproj.local_path.empty()) {
tasks.emplace_back(plan.mmproj, opts, [&]() {
@@ -539,12 +541,12 @@ void common_models_handler_apply(common_models_handler & handler, common_params
// handle plan_spec (e.g. --spec-draft-hf)
if (!plan_spec.model_files.empty()) {
add_tasks(plan_spec.model_files, params.speculative.draft.mparams);
add_tasks(plan_spec.model_files, plan_spec.primary, params.speculative.draft.mparams);
}
// handle vocoder plan (e.g. --hf-repo-v)
if (!plan_voc.model_files.empty()) {
add_tasks(plan_voc.model_files, params.vocoder.model);
add_tasks(plan_voc.model_files, plan_voc.primary, params.vocoder.model);
}
// run all tasks in parallel
@@ -3296,6 +3298,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.sampling.reasoning_budget_message = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_MESSAGE"));
add_opt(common_arg(
{"--reasoning-preserve"},
{"--no-reasoning-preserve"},
"preserve reasoning trace in the full history, not just the last assistant message (default: template default)\n"
"compatible with certain templates having 'supports_preserve_reasoning' capability\n"
"example: https://docs.z.ai/guides/capabilities/thinking-mode#preserved-thinking",
[](common_params & params, bool value) {
if (value) {
params.default_template_kwargs["preserve_reasoning"] = "true";
} else {
params.default_template_kwargs["preserve_reasoning"] = "false";
}
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_REASONING_PRESERVE"));
add_opt(common_arg(
{"--chat-template"}, "JINJA_TEMPLATE",
string_format(
@@ -3471,7 +3487,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) {
params.offline = true;
}
).set_env("LLAMA_ARG_OFFLINE"));
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_OFFLINE"));
add_opt(common_arg(
{"-lv", "--verbosity", "--log-verbosity"}, "N",
string_format("Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:\n"
+155
View File
@@ -912,6 +912,10 @@ static std::string common_chat_template_direct_apply_impl(
if (inputs.add_generation_prompt) {
inp["add_generation_prompt"] = true;
}
if (inp.contains("preserve_reasoning") && inp["preserve_reasoning"].is_boolean()) {
bool enabled = inp["preserve_reasoning"].get<bool>();
jinja::caps_apply_preserve_reasoning(ctx, enabled);
}
jinja::global_from_json(ctx, inp, inputs.mark_input);
@@ -2376,6 +2380,149 @@ static void func_args_not_string(json & messages) {
}
// MiniCPM5 format:
// - Reasoning: <think>{reasoning}</think> (optional)
// - Tool calls: <function name="foo"><param name="bar">value</param></function>
static common_chat_params common_chat_params_init_minicpm5(const common_chat_template & tmpl,
const autoparser::generation_params & inputs) {
common_chat_params data;
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;
data.preserved_tokens = {
"<function",
"<param",
"</function>",
"</param>",
"<think>",
"</think>",
};
data.thinking_start_tag = "<think>";
data.thinking_end_tag = "</think>";
data.message_delimiters = {
{ COMMON_CHAT_ROLE_ASSISTANT, "<|im_start|>assistant" },
{ COMMON_CHAT_ROLE_TOOL, "<|im_start|>user\n<tool_response>" },
{ COMMON_CHAT_ROLE_USER, "<|im_start|>user" },
{ COMMON_CHAT_ROLE_SYSTEM, "<|im_start|>system" },
};
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto has_response_format = inputs.json_schema.is_object() && !inputs.json_schema.empty();
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE);
if (inputs.has_continuation()) {
const auto & msg = inputs.continue_msg;
data.generation_prompt = "<|im_start|>assistant\n<think>\n" + msg.reasoning_content;
if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) {
data.generation_prompt += "\n</think>\n\n" + msg.render_content();
}
data.prompt += data.generation_prompt;
}
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
auto generation_prompt = p.literal("<|im_start|>assistant\n");
auto reasoning = p.eps();
if (extract_reasoning) {
reasoning = ("<think>" << p.reasoning(p.until("</think>")) << "</think>") + p.space();
}
// Response format parser
if (has_response_format) {
return generation_prompt + reasoning + p.content(p.schema(p.json(), "response-format", inputs.json_schema));
}
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
// CDATA lets a value carry characters that would otherwise close the tag (e.g.
// </param>); capture the inner text only, excluding the CDATA markers.
auto string_value = p.choice({
p.literal("<![CDATA[") + p.ac(p.tool_arg_string_value(p.until("]]>")) + p.literal("]]>"), "]]>") + p.tool_arg_close(p.literal("</param>")),
p.negate(p.literal("<![CDATA[")) + p.ac(p.tool_arg_string_value(p.until("</param>")) + p.tool_arg_close(p.literal("</param>")), "</param>")
});
auto tool_choice = p.choice();
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
const std::string name = function.at("name");
auto params = function.contains("parameters") ? function.at("parameters") : json::object();
auto args = p.eps();
if (params.contains("properties") && params.at("properties").is_object() && !params.at("properties").empty()) {
auto schema_info = common_schema_info();
schema_info.resolve_refs(params);
auto arg_choice = p.choice();
for (const auto & [prop_name, prop_schema] : params.at("properties").items()) {
auto value_parser = p.eps();
if (schema_info.resolves_to_string(prop_schema)) {
value_parser = string_value;
} else {
value_parser = p.tool_arg_json_value(
p.schema(p.json(), "tool-" + name + "-arg-" + prop_name + "-schema", prop_schema, false)
) + p.tool_arg_close(p.literal("</param>"));
}
auto arg_rule = p.tool_arg(
p.tool_arg_open(p.literal("<param name=\"") + p.tool_arg_name(p.literal(prop_name)) + p.literal("\">")) +
value_parser
);
arg_choice |= arg_rule;
}
args = p.zero_or_more(arg_choice + p.space());
}
auto tool_parser = p.tool(
p.tool_open(p.literal("<function name=\"") + p.tool_name(p.literal(name)) + p.literal("\">"))
<< p.tool_args(args)
<< p.tool_close(p.literal("</function>")));
tool_choice |= p.rule("tool-" + name, tool_parser);
});
auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
auto tool_calls = p.trigger_rule("tool-call", p.repeat(tool_choice + p.space(), 1, max_calls));
auto content = p.content(p.until("<function"));
return generation_prompt + reasoning + content + tool_calls + p.end();
}
return generation_prompt + reasoning + p.content(p.rest()) + p.end();
});
data.parser = parser.save();
if (include_grammar) {
data.grammar_lazy = !(has_response_format || (has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED));
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
auto schema = function.contains("parameters") ? function.at("parameters") : json::object();
builder.resolve_refs(schema);
});
if (has_response_format) {
auto schema = inputs.json_schema;
builder.resolve_refs(schema);
}
parser.build_grammar(builder, data.grammar_lazy);
});
data.grammar_triggers = {
{ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function" },
};
}
return data;
}
static json common_chat_extra_context() {
json ctx = json::object();
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
@@ -2468,6 +2615,14 @@ std::optional<common_chat_params> common_chat_try_specialized_template(
return common_chat_params_init_gemma4(tmpl, params);
}
// MiniCPM5 - XML tool calls with <function name="..."><param name="...">...</param></function>
if (src.find("Tool usage guidelines:") != std::string::npos &&
src.find("<function name=\"") != std::string::npos &&
src.find("<param name=\"") != std::string::npos) {
LOG_DBG("Using specialized template: MiniCPM5\n");
return common_chat_params_init_minicpm5(tmpl, params);
}
return std::nullopt;
}
+47 -47
View File
@@ -225,7 +225,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
}
if (!SetPriorityClass(GetCurrentProcess(), p)) {
LOG_WRN("failed to set process priority class %d : (%d)\n", prio, (int) GetLastError());
COM_WRN("failed to set process priority class %d : (%d)\n", prio, (int) GetLastError());
return false;
}
@@ -251,7 +251,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
}
if (setpriority(PRIO_PROCESS, 0, p) != 0) {
LOG_WRN("failed to set process priority %d : %s (%d)\n", prio, strerror(errno), errno);
COM_WRN("failed to set process priority %d : %s (%d)\n", prio, strerror(errno), errno);
return false;
}
return true;
@@ -284,14 +284,14 @@ void postprocess_cpu_params(common_cpu_params & cpuparams, const common_cpu_para
if (n_set && n_set < cpuparams.n_threads) {
// Not enough set bits, may experience performance issues.
LOG_WRN("Not enough set bits in CPU mask (%d) to satisfy requested thread count: %d\n", n_set, cpuparams.n_threads);
COM_WRN("Not enough set bits in CPU mask (%d) to satisfy requested thread count: %d\n", n_set, cpuparams.n_threads);
}
}
bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THREADS]) {
size_t dash_loc = range.find('-');
if (dash_loc == std::string::npos) {
LOG_ERR("Format of CPU range is invalid! Expected [<start>]-[<end>].\n");
COM_ERR("%s", "Format of CPU range is invalid! Expected [<start>]-[<end>].\n");
return false;
}
@@ -303,7 +303,7 @@ bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THRE
} else {
start_i = std::stoull(range.substr(0, dash_loc));
if (start_i >= GGML_MAX_N_THREADS) {
LOG_ERR("Start index out of bounds!\n");
COM_ERR("%s", "Start index out of bounds!\n");
return false;
}
}
@@ -313,7 +313,7 @@ bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THRE
} else {
end_i = std::stoull(range.substr(dash_loc + 1));
if (end_i >= GGML_MAX_N_THREADS) {
LOG_ERR("End index out of bounds!\n");
COM_ERR("%s", "End index out of bounds!\n");
return false;
}
}
@@ -333,7 +333,7 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
}
size_t num_digits = mask.length() - start_i;
if (num_digits > 128) num_digits = 128;
num_digits = std::min<size_t>(num_digits, 128);
size_t end_i = num_digits + start_i;
@@ -348,7 +348,7 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
} else if (c >= 'A' && c <= 'F') {
id -= 'A' - 10;
} else {
LOG_ERR("Invalid hex character '%c' at position %d\n", c, int32_t(i));
COM_ERR("Invalid hex character '%c' at position %d\n", c, int32_t(i));
return false;
}
@@ -379,21 +379,21 @@ void common_params_print_info(const common_params & params, bool print_devices)
#else
const char * build_type = " (debug)";
#endif
LOG_TRC("%s: build %d (%s) with %s for %s%s\n", __func__, llama_build_number(), llama_commit(), llama_compiler(), llama_build_target(), build_type);
COM_TRC("%s: build %d (%s) with %s for %s%s\n", __func__, llama_build_number(), llama_commit(), llama_compiler(), llama_build_target(), build_type);
LOG_INF("log_info: verbosity = %d (adjust with the `-lv N` CLI arg)\n", common_log_get_verbosity_thold());
COM_INF("%s: verbosity = %d (adjust with the `-lv N` CLI arg)\n", __func__, common_log_get_verbosity_thold());
// device enumeration creates a primary context on CUDA backends, skip it when the caller does not own any device
if (print_devices) {
LOG_INF("device_info:\n");
COM_TRC("%s", "device_info:\n");
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
auto * dev = ggml_backend_dev_get(i);
size_t free, total;
ggml_backend_dev_memory(dev, &free, &total);
LOG_INF(" - %-8s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
COM_TRC(" - %-8s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
}
}
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
COM_TRC("%s\n", common_params_get_system_info(params).c_str());
}
std::string common_params_get_system_info(const common_params & params) {
@@ -660,7 +660,7 @@ void string_process_escapes(std::string & input) {
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides) {
const char * sep = strchr(data, '=');
if (sep == nullptr || sep - data >= 128) {
LOG_ERR("%s: malformed KV override '%s'\n", __func__, data);
COM_ERR("%s: malformed KV override '%s'\n", __func__, data);
return false;
}
llama_model_kv_override kvo;
@@ -683,20 +683,20 @@ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_over
} else if (std::strcmp(sep, "false") == 0) {
kvo.val_bool = false;
} else {
LOG_ERR("%s: invalid boolean value for KV override '%s'\n", __func__, data);
COM_ERR("%s: invalid boolean value for KV override '%s'\n", __func__, data);
return false;
}
} else if (strncmp(sep, "str:", 4) == 0) {
sep += 4;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR;
if (strlen(sep) > 127) {
LOG_ERR("%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data);
COM_ERR("%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data);
return false;
}
strncpy(kvo.val_str, sep, 127);
kvo.val_str[127] = '\0';
} else {
LOG_ERR("%s: invalid type for KV override '%s'\n", __func__, data);
COM_ERR("%s: invalid type for KV override '%s'\n", __func__, data);
return false;
}
overrides.emplace_back(std::move(kvo));
@@ -1199,8 +1199,8 @@ common_init_result::common_init_result(common_params & params, bool model_only)
auto cparams = common_context_params_to_llama(params);
if (params.fit_params) {
LOG_INF("%s: fitting params to device memory ...\n", __func__);
LOG_INF("%s: (for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on)\n", __func__);
COM_TRC("%s", "fitting params to device memory ...\n");
COM_TRC("%s", "(for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on)\n");
common_fit_params(params.model.path.c_str(), &mparams, &cparams,
params.tensor_split,
params.tensor_buft_overrides.data(),
@@ -1227,7 +1227,7 @@ common_init_result::common_init_result(common_params & params, bool model_only)
llama_adapter_lora_ptr lora;
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
if (lora == nullptr) {
LOG_ERR("%s: failed to load lora adapter '%s'\n", __func__, la.path.c_str());
COM_ERR("failed to load lora adapter '%s'\n", la.path.c_str());
pimpl->model.reset(model);
return;
}
@@ -1246,14 +1246,14 @@ common_init_result::common_init_result(common_params & params, bool model_only)
common_init_sampler_from_model(model, params.sampling);
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
COM_WRN("%s", "vocab does not have an EOS token, ignoring --ignore-eos\n");
params.sampling.ignore_eos = false;
}
// initialize once
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
if (llama_vocab_is_eog(vocab, i)) {
LOG_TRC("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(vocab, i).c_str(), -INFINITY);
COM_TRC("added %s logit bias = %f\n", common_token_to_piece(vocab, i).c_str(), -INFINITY);
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
}
}
@@ -1291,7 +1291,7 @@ common_init_result::common_init_result(common_params & params, bool model_only)
llama_context * lctx = llama_init_from_model(model, cparams);
if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
COM_ERR("failed to create context with model '%s'\n", params.model.path.c_str());
return;
}
@@ -1328,7 +1328,7 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
llama_model * model = res->model();
if (model == NULL) {
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
COM_ERR("failed to load model '%s'\n", params.model.path.c_str());
return res;
}
@@ -1338,14 +1338,14 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
llama_context * lctx = res->context();
if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
COM_ERR("failed to create context with model '%s'\n", params.model.path.c_str());
return res;
}
const llama_vocab * vocab = llama_model_get_vocab(model);
if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
COM_WRN("%s", "KV cache shifting is not supported for this context, disabling KV cache shifting\n");
params.ctx_shift = false;
}
@@ -1374,7 +1374,7 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
bool ok = true;
if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__);
COM_WRN("%s", "vocab does not have a BOS token, reranking will not work\n");
ok = false;
}
@@ -1383,10 +1383,10 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
bool has_rerank_prompt = llama_model_chat_template(model, "rerank") != NULL;
if (!has_eos && !has_sep && !has_rerank_prompt) {
LOG_WRN("%s: warning: vocab does not have an EOS token, SEP token, or rerank prompt. Reranking will not work\n", __func__);
COM_WRN("%s", "vocab does not have an EOS token, SEP token, or rerank prompt. Reranking will not work\n");
ok = false;
} else if (!has_eos) {
LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
COM_WRN("%s", "vocab does not have an EOS token, using SEP token as fallback\n");
}
if (!ok) {
@@ -1399,7 +1399,7 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
}
if (params.warmup) {
LOG_INF("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
COM_TRC("%s", "warming up the model with an empty run - please wait ... (--no-warmup to disable)\n");
std::vector<llama_token> tmp;
llama_token bos = llama_vocab_bos(vocab);
@@ -1473,20 +1473,20 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {
int ret = llama_decode(ctx, llama_batch_get_one(tmp.data(), tmp.size()));
if (ret != 0) {
LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
COM_ERR("llama_decode() failed: %d\n", ret);
res = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
goto done;
}
if (llama_n_rs_seq(ctx) > 0) {
LOG_INF("%s: the context supports bounded partial sequence removal\n", __func__);
COM_TRC("%s", "the context supports bounded partial sequence removal\n");
res = COMMON_CONTEXT_SEQ_RM_TYPE_RS;
goto done;
}
// try to remove the last tokens
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
LOG_TRC("%s: the context does not support partial sequence removal\n", __func__);
COM_TRC("%s", "the context does not support partial sequence removal\n");
res = COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
goto done;
}
@@ -1803,13 +1803,13 @@ static common_control_vector_data common_control_vector_load_one(const common_co
};
struct gguf_context * ctx_gguf = gguf_init_from_file(load_info.fname.c_str(), meta_gguf_params);
if (!ctx_gguf) {
LOG_ERR("%s: failed to load control vector file from %s\n", __func__, load_info.fname.c_str());
COM_ERR("failed to load control vector file from %s\n", load_info.fname.c_str());
return result;
}
int32_t n_tensors = gguf_get_n_tensors(ctx_gguf);
if (n_tensors == 0) {
LOG_WRN("%s: no direction tensors found in %s\n", __func__, load_info.fname.c_str());
COM_WRN("no direction tensors found in %s\n", load_info.fname.c_str());
}
for (int i = 0; i < n_tensors; i++) {
@@ -1827,23 +1827,23 @@ static common_control_vector_data common_control_vector_load_one(const common_co
}
}
if (layer_idx < 0) {
LOG_ERR("%s: invalid/unparsable direction tensor layer index in %s\n", __func__, load_info.fname.c_str());
COM_ERR("invalid/unparsable direction tensor layer index in %s\n", load_info.fname.c_str());
result.n_embd = -1;
break;
} else if (layer_idx == 0) {
LOG_ERR("%s: invalid (zero) direction tensor layer index in %s\n", __func__, load_info.fname.c_str());
COM_ERR("invalid (zero) direction tensor layer index in %s\n", load_info.fname.c_str());
result.n_embd = -1;
break;
}
struct ggml_tensor * tensor = ggml_get_tensor(ctx, name.c_str());
if (tensor->type != GGML_TYPE_F32) {
LOG_ERR("%s: invalid (non-F32) direction tensor type in %s\n", __func__, load_info.fname.c_str());
COM_ERR("invalid (non-F32) direction tensor type in %s\n", load_info.fname.c_str());
result.n_embd = -1;
break;
}
if (ggml_n_dims(tensor) != 1) {
LOG_ERR("%s: invalid (non-1D) direction tensor shape in %s\n", __func__, load_info.fname.c_str());
COM_ERR("invalid (non-1D) direction tensor shape in %s\n", load_info.fname.c_str());
result.n_embd = -1;
break;
}
@@ -1851,7 +1851,7 @@ static common_control_vector_data common_control_vector_load_one(const common_co
if (result.n_embd == -1) {
result.n_embd = ggml_nelements(tensor);
} else if (ggml_nelements(tensor) != result.n_embd) {
LOG_ERR("%s: direction tensor in %s does not match previous dimensions\n", __func__, load_info.fname.c_str());
COM_ERR("direction tensor in %s does not match previous dimensions\n", load_info.fname.c_str());
result.n_embd = -1;
break;
}
@@ -1868,7 +1868,7 @@ static common_control_vector_data common_control_vector_load_one(const common_co
}
if (result.n_embd == -1) {
LOG_WRN("%s: skipping %s due to invalid direction tensors\n", __func__, load_info.fname.c_str());
COM_WRN("skipping %s due to invalid direction tensors\n", load_info.fname.c_str());
result.data.clear();
}
@@ -1889,7 +1889,7 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
break;
}
if (result.n_embd != -1 && result.n_embd != cur.n_embd) {
LOG_ERR("%s: control vectors in %s does not match previous dimensions\n", __func__, info.fname.c_str());
COM_ERR("control vectors in %s does not match previous dimensions\n", info.fname.c_str());
result.n_embd = -1;
break;
}
@@ -1905,7 +1905,7 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
}
if (result.n_embd == -1) {
LOG_ERR("%s: no valid control vector files passed\n", __func__);
COM_ERR("%s", "no valid control vector files passed\n");
result.data.clear();
}
@@ -2016,13 +2016,13 @@ bool common_prompt_batch_decode(
// memory, so we can't just remove the last token from the memory and replay the last token which
// is the reason for this logic.
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(all_tokens.data() + offset), n_tokens_before_last))) {
LOG_ERR("%s : failed to eval\n", __func__);
COM_ERR("%s", "failed to eval\n");
return false;
}
n_past += n_tokens_before_last;
llama_state_save_file(ctx, state_path.data(), all_tokens.data(), all_tokens.size());
LOG_INF("saved session before last token to %s, n_new = %zu\n", state_path.data(), all_tokens.size());
COM_INF("saved session before last token to %s, n_new = %zu\n", state_path.data(), all_tokens.size());
llama_token last_token = all_tokens.back();
llama_batch batch = llama_batch_get_one(&last_token, 1);
@@ -2030,13 +2030,13 @@ bool common_prompt_batch_decode(
batch.pos = &pos;
if (llama_decode(ctx, batch)) {
LOG_ERR("%s : failed to eval last token\n", __func__);
COM_ERR("%s", "failed to eval last token\n");
return false;
}
n_past++;
} else {
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(all_tokens.data() + offset), n_new))) {
LOG_ERR("%s : failed to eval\n", __func__);
COM_ERR("%s", "failed to eval\n");
return false;
}
n_past += n_new;
+9 -1
View File
@@ -25,6 +25,13 @@
#define DIRECTORY_SEPARATOR '/'
#endif // _WIN32
#define COM_DBG(fmt, ...) LOG_DBG("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define COM_TRC(fmt, ...) LOG_TRC("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define COM_INF(fmt, ...) LOG_INF("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define COM_WRN(fmt, ...) LOG_WRN("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define COM_ERR(fmt, ...) LOG_ERR("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define COM_CNT(fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
#define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0)
#define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0)
@@ -162,6 +169,7 @@ enum common_speculative_type {
COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, // standalone draft model speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, // Eagle3 speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT_MTP, // Multi-token prediction
COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH, // DFlash speculative decoding
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding based on n-grams
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
@@ -377,7 +385,7 @@ struct common_params_speculative {
uint32_t need_n_rs_seq() const {
bool needs_rs_seq = std::any_of(types.begin(), types.end(), [&](auto t) {
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP || t == COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3;
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP || t == COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3 || t == COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH;
});
return needs_rs_seq ? draft.n_max : 0u;
+1 -1
View File
@@ -233,7 +233,7 @@ static void common_params_fit_impl(
sum_projected_used = dmds_full.back().mb.total();
sum_free = dmds_full.back().total;
sum_projected_free = sum_free - sum_projected_used;
LOG_INF("%s: projected to use %" PRId64 " MiB of host memory vs. %" PRId64 " MiB of total host memory\n",
LOG_TRC("%s: projected to use %" PRId64 " MiB of host memory vs. %" PRId64 " MiB of total host memory\n",
__func__, sum_projected_used/MiB, sum_free/MiB);
if (sum_projected_free >= margins[0]) {
LOG_TRC("%s: will leave %" PRId64 " >= %" PRId64 " MiB of system memory, no changes needed\n",
+28 -6
View File
@@ -11,6 +11,11 @@ struct common_http_url {
std::string path;
};
// bracket an IPv6 literal host for a URL authority (RFC 3986)
static std::string common_http_format_host(const std::string & host) {
return host.find(':') != std::string::npos ? "[" + host + "]" : host;
}
static common_http_url common_http_parse_url(const std::string & url) {
common_http_url parts;
auto scheme_end = url.find("://");
@@ -49,11 +54,28 @@ static common_http_url common_http_parse_url(const std::string & url) {
parts.path = "/";
}
auto colon_pos = parts.host.find(':');
// split the authority into host and optional port, a bracketed IPv6 literal keeps its inner colons (RFC 3986)
std::string port_str;
if (!parts.host.empty() && parts.host.front() == '[') {
auto close = parts.host.find(']');
if (close == std::string::npos) {
throw std::runtime_error("invalid IPv6 URL authority: " + parts.host);
}
auto after = parts.host.substr(close + 1);
if (!after.empty() && after.front() == ':') {
port_str = after.substr(1);
}
parts.host = parts.host.substr(1, close - 1);
} else {
auto colon_pos = parts.host.find(':');
if (colon_pos != std::string::npos) {
port_str = parts.host.substr(colon_pos + 1);
parts.host = parts.host.substr(0, colon_pos);
}
}
if (colon_pos != std::string::npos) {
parts.port = std::stoi(parts.host.substr(colon_pos + 1));
parts.host = parts.host.substr(0, colon_pos);
if (!port_str.empty()) {
parts.port = std::stoi(port_str);
} else if (parts.scheme == "http") {
parts.port = 80;
} else if (parts.scheme == "https") {
@@ -83,7 +105,7 @@ static std::pair<httplib::Client, common_http_url> common_http_client(const std:
}
#endif
httplib::Client cli(parts.scheme + "://" + parts.host + ":" + std::to_string(parts.port));
httplib::Client cli(parts.scheme + "://" + common_http_format_host(parts.host) + ":" + std::to_string(parts.port));
if (!parts.user.empty()) {
cli.set_basic_auth(parts.user, parts.password);
@@ -95,5 +117,5 @@ static std::pair<httplib::Client, common_http_url> common_http_client(const std:
}
static std::string common_http_show_masked_url(const common_http_url & parts) {
return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + parts.host + parts.path;
return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + common_http_format_host(parts.host) + parts.path;
}
+44 -23
View File
@@ -16,22 +16,34 @@ using json = nlohmann::ordered_json;
namespace jinja {
using caps_json_fn = std::function<json()>;
using caps_analyze_fn = std::function<void(bool, value &, value &)>;
using caps_ctx_fn = std::function<void(context &)>;
using caps_analyze_fn = std::function<void(bool, value &, value &, const std::string &)>;
void caps_apply_preserve_reasoning(jinja::context & ctx, bool enabled) {
ctx.set_val("preserve_thinking", mk_val<value_bool>(enabled));
ctx.set_val("clear_thinking", mk_val<value_bool>(!enabled));
ctx.set_val("truncate_history_thinking", mk_val<value_bool>(!enabled));
}
static void caps_try_execute(jinja::program & prog,
const caps_json_fn & messages_fn,
const caps_ctx_fn & ctx_fn,
const caps_json_fn & tools_fn,
const caps_analyze_fn & analyze_fn) {
context ctx;
ctx.is_get_stats = true;
jinja::global_from_json(ctx, json{
{"messages", messages_fn()},
{"tools", tools_fn()},
{"tools", tools_fn ? tools_fn() : json::array()},
{"bos_token", ""},
{"eos_token", ""},
{"add_generation_prompt", true}
}, true);
if (ctx_fn) {
ctx_fn(ctx);
}
auto messages = ctx.get_val("messages");
auto tools = ctx.get_val("tools");
@@ -49,7 +61,7 @@ static void caps_try_execute(jinja::program & prog,
// ignore exceptions during capability analysis
}
analyze_fn(success, messages, tools);
analyze_fn(success, messages, tools, result);
}
// for debugging only
@@ -109,11 +121,9 @@ caps caps_get(jinja::program & prog) {
}
});
},
[&]() {
// tools
return json{nullptr};
},
[&](bool success, value & messages, value &) {
nullptr, // ctx_fn
nullptr, // tools_fn
[&](bool success, value & messages, value &, const std::string &) {
auto & content = messages->at(0)->at("content");
caps_print_stats(content, "messages[0].content");
if (has_op(content, "selectattr") || has_op(content, "array_access")) {
@@ -145,11 +155,9 @@ caps caps_get(jinja::program & prog) {
},
});
},
[&]() {
// tools
return json::array();
},
[&](bool, value & messages, value &) {
nullptr, // ctx_fn
nullptr, // tools_fn
[&](bool, value & messages, value &, const std::string &) {
auto & content = messages->at(0)->at("content");
caps_print_stats(content, "messages[0].content");
if (!content->stats.used) {
@@ -201,6 +209,7 @@ caps caps_get(jinja::program & prog) {
},
});
},
nullptr, // ctx_fn
[&]() {
// tools
return json::array({
@@ -224,7 +233,7 @@ caps caps_get(jinja::program & prog) {
},
});
},
[&](bool success, value & messages, value & tools) {
[&](bool success, value & messages, value & tools, const std::string &) {
if (!success) {
return; // Nothing can be inferred
}
@@ -293,6 +302,7 @@ caps caps_get(jinja::program & prog) {
},
});
},
nullptr, // ctx_fn
[&]() {
// tools
return json::array({
@@ -316,7 +326,7 @@ caps caps_get(jinja::program & prog) {
},
});
},
[&](bool success, value & messages, value & tools) {
[&](bool success, value & messages, value & tools, const std::string &) {
if (!success) {
result.supports_tool_calls = false;
result.supports_tools = false;
@@ -394,6 +404,7 @@ caps caps_get(jinja::program & prog) {
},
});
},
nullptr, // ctx_fn
[&]() {
// tools
return json::array({
@@ -417,7 +428,7 @@ caps caps_get(jinja::program & prog) {
},
});
},
[&](bool success, value & messages, value & /*tools*/) {
[&](bool success, value & messages, value &, const std::string &) {
if (!success) {
result.supports_parallel_tool_calls = false;
return;
@@ -438,11 +449,22 @@ caps caps_get(jinja::program & prog) {
JJ_DEBUG("%s\n", ">>> Running capability check: preserve reasoning");
// case: preserve reasoning content in chat history
const std::string reasoning_placeholder = "<REASONING_CONTENT_PLACEHOLDER>";
caps_try_execute(
prog,
[&]() {
// messages
return json::array({
{
{"role", "user"},
{"content", "User message"}
},
{
{"role", "assistant"},
{"content", "Assistant message"},
// check of reasoning_content deeper in the history, not just the last assistant message
{"reasoning_content", reasoning_placeholder}
},
{
{"role", "user"},
{"content", "User message"}
@@ -458,14 +480,13 @@ caps caps_get(jinja::program & prog) {
},
});
},
[&]() {
// tools
return json::array();
[&](context & ctx) {
caps_apply_preserve_reasoning(ctx, true);
},
[&](bool, value & messages, value &) {
auto & content = messages->at(1)->at("reasoning_content");
caps_print_stats(content, "messages[1].reasoning_content");
if (content->stats.used) {
nullptr, // tools_fn
[&](bool, value &, value &, const std::string & output) {
// note: we cannot use stats here because the reasoning_content may be used for "if" condition test, but not actually outputted in the final result
if (output.find(reasoning_placeholder) != std::string::npos) {
result.supports_preserve_reasoning = true;
}
}
+5 -1
View File
@@ -12,7 +12,9 @@ struct caps {
bool supports_tool_calls = true;
bool supports_system_role = true;
bool supports_parallel_tool_calls = true;
bool supports_preserve_reasoning = false; // support assistant message with reasoning_content
// supports preserve reasoning trace in the full history, not just the last assistant message
bool supports_preserve_reasoning = false;
// one of the 2 content capabilities must be true
bool supports_string_content = true;
@@ -29,4 +31,6 @@ struct caps {
caps caps_get(jinja::program & prog);
void caps_apply_preserve_reasoning(jinja::context & ctx, bool enabled);
} // namespace jinja
+46
View File
@@ -954,4 +954,50 @@ value keyword_argument_expression::execute_impl(context & ctx) {
return mk_val<value_kwarg>(k, v);
}
std::string runtime::debug_dump_program(const program & prog, const std::string & src) {
std::ostringstream oss;
size_t lvl = 0;
context ctx;
ctx.src.reset(new std::string(src));
auto indent = [](size_t lvl) -> std::string {
return std::string(lvl * 2, ' ');
};
ctx.visitor = [&](bool is_leaf, statement * node, std::vector<visitor_pair> children) {
oss << indent(lvl) << node->type() << ":\n";
lvl++;
if (is_leaf) {
const auto & pos = node->pos;
oss << indent(lvl) << "(leaf) at " << get_line_col(src, pos) << " in source:\n";
std::string snippet = peak_source(src, pos);
string_replace_all(snippet, "\n", "\n" + indent(lvl));
oss << indent(lvl) << snippet << "\n";
} else {
for (auto & [label, children_vec] : children) {
oss << indent(lvl) << label << ":\n";
lvl++;
if (children_vec.empty()) {
oss << indent(lvl) << "<empty>\n\n";
} else {
for (auto * child : children_vec) {
if (!child) {
continue;
}
child->visit(ctx);
}
}
lvl--;
}
}
lvl--;
};
for (const auto & stmt : prog.body) {
stmt->visit(ctx);
}
return oss.str();
}
} // namespace jinja
+127
View File
@@ -47,12 +47,19 @@ const T * cast_stmt(const statement_ptr & ptr) {
// not thread-safe
void enable_debug(bool enable);
// for visiting AST nodes
// function signature: void(bool is_leaf, statement * node, pair of <label, children>)
using visitor_pair = std::pair<std::string, std::vector<statement *>>;
using visitor_fn = std::function<void(bool, statement *, std::vector<visitor_pair>)>;
struct context {
std::shared_ptr<std::string> src; // for debugging; use shared_ptr to avoid copying on scope creation
std::time_t current_time; // for functions that need current time
bool is_get_stats = false; // whether to collect stats
visitor_fn visitor;
// src is optional, used for error reporting
context(std::string src = "") : src(std::make_shared<std::string>(std::move(src))) {
env = mk_val<value_object>();
@@ -99,6 +106,15 @@ private:
value_object env;
};
// utils for visiting AST nodes
static std::vector<statement *> stmts_to_ptr(const statements & stmts) {
std::vector<statement *> children;
for (const auto & stmt : stmts) {
children.push_back(stmt.get());
}
return children;
}
/**
* Base class for all nodes in the AST.
*/
@@ -106,6 +122,7 @@ struct statement {
size_t pos; // position in source, for debugging
virtual ~statement() = default;
virtual std::string type() const { return "Statement"; }
virtual void visit(context & ctx) { ctx.visitor(true, this, {}); }
// execute_impl must be overridden by derived classes
virtual value execute_impl(context &) { throw_exec_error(); }
@@ -166,6 +183,13 @@ struct if_statement : public statement {
std::string type() const override { return "If"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"test", {test.get()}},
{"body", stmts_to_ptr(body)},
{"alternate", stmts_to_ptr(alternate)}
});
}
};
struct identifier;
@@ -190,6 +214,14 @@ struct for_statement : public statement {
std::string type() const override { return "For"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"loopvar", {loopvar.get()}},
{"iterable", {iterable.get()}},
{"body", stmts_to_ptr(body)},
{"default_block", stmts_to_ptr(default_block)}
});
}
};
struct break_statement : public statement {
@@ -241,6 +273,13 @@ struct set_statement : public statement {
std::string type() const override { return "Set"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"assignee", {assignee.get()}},
{"value", {val.get()}},
{"body", stmts_to_ptr(body)}
});
}
};
struct macro_statement : public statement {
@@ -256,6 +295,13 @@ struct macro_statement : public statement {
std::string type() const override { return "Macro"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"name", {name.get()}},
{"args", stmts_to_ptr(args)},
{"body", stmts_to_ptr(body)}
});
}
};
struct comment_statement : public statement {
@@ -289,6 +335,12 @@ struct member_expression : public expression {
}
std::string type() const override { return "MemberExpression"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"object", {object.get()}},
{"property", {property.get()}}
});
}
};
struct call_expression : public expression {
@@ -302,6 +354,12 @@ struct call_expression : public expression {
}
std::string type() const override { return "CallExpression"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"callee", {callee.get()}},
{"args", stmts_to_ptr(args)}
});
}
};
/**
@@ -405,6 +463,12 @@ struct binary_expression : public expression {
}
std::string type() const override { return "BinaryExpression"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"left", {left.get()}},
{"right", {right.get()}}
});
}
};
/**
@@ -431,6 +495,12 @@ struct filter_expression : public expression {
std::string type() const override { return "FilterExpression"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"operand", {operand.get()}},
{"filter", {filter.get()}}
});
}
};
struct filter_statement : public statement {
@@ -443,6 +513,12 @@ struct filter_statement : public statement {
}
std::string type() const override { return "FilterStatement"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"filter", {filter.get()}},
{"body", stmts_to_ptr(body)}
});
}
};
/**
@@ -468,6 +544,12 @@ struct select_expression : public expression {
}
return lhs->execute_impl(ctx);
}
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"lhs", {lhs.get()}},
{"test", {test.get()}}
});
}
};
/**
@@ -486,6 +568,12 @@ struct test_expression : public expression {
}
std::string type() const override { return "TestExpression"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"operand", {operand.get()}},
{"test", {test.get()}}
});
}
};
/**
@@ -501,6 +589,11 @@ struct unary_expression : public expression {
}
std::string type() const override { return "UnaryExpression"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"argument", {argument.get()}}
});
}
};
struct slice_expression : public expression {
@@ -518,6 +611,13 @@ struct slice_expression : public expression {
[[noreturn]] value execute_impl(context &) override {
throw std::runtime_error("must be handled by MemberExpression");
}
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"start_expr", {start_expr.get()}},
{"stop_expr", {stop_expr.get()}},
{"step_expr", {step_expr.get()}}
});
}
};
struct keyword_argument_expression : public expression {
@@ -531,6 +631,12 @@ struct keyword_argument_expression : public expression {
}
std::string type() const override { return "KeywordArgumentExpression"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"key", {key.get()}},
{"val", {val.get()}}
});
}
};
struct spread_expression : public expression {
@@ -539,6 +645,11 @@ struct spread_expression : public expression {
chk_type<expression>(this->argument);
}
std::string type() const override { return "SpreadExpression"; }
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"argument", {argument.get()}}
});
}
};
struct call_statement : public statement {
@@ -553,6 +664,13 @@ struct call_statement : public statement {
}
std::string type() const override { return "CallStatement"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"call", {call.get()}},
{"caller_args", stmts_to_ptr(caller_args)},
{"body", stmts_to_ptr(body)}
});
}
};
struct ternary_expression : public expression {
@@ -575,6 +693,13 @@ struct ternary_expression : public expression {
return false_expr->execute(ctx);
}
}
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"condition", {condition.get()}},
{"true_expr", {true_expr.get()}},
{"false_expr", {false_expr.get()}}
});
}
};
struct raised_exception : public std::exception {
@@ -648,6 +773,8 @@ struct runtime {
}
return parts;
}
static std::string debug_dump_program(const program & prog, const std::string & src);
};
} // namespace jinja
+44
View File
@@ -1108,6 +1108,50 @@ const func_builtins & value_array_t::get_builtins() const {
std::reverse(arr.begin(), arr.end());
return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(arr)) : mk_val<value_array>(std::move(arr));
}},
{"min", [](const func_args & args) -> value {
args.ensure_count(1, 4);
args.ensure_vals<value_array>();
value val_case = args.get_kwarg_or_pos("case_sensitive", 1);
value attribute = args.get_kwarg_or_pos("attribute", 2);
if (!attribute->is_undefined()) {
throw not_implemented_exception("min: attribute not implemented");
}
// FIXME: min is currently always case sensitive
(void) val_case;
const auto & arr = args.get_pos(0)->as_array();
if (arr.empty()) {
return mk_val<value_undefined>();
}
value result = arr[0];
for (size_t i = 1; i < arr.size(); ++i) {
if (value_compare(arr[i], result, value_compare_op::lt)) {
result = arr[i];
}
}
return result;
}},
{"max", [](const func_args & args) -> value {
args.ensure_count(1, 4);
args.ensure_vals<value_array>();
value val_case = args.get_kwarg_or_pos("case_sensitive", 1);
value attribute = args.get_kwarg_or_pos("attribute", 2);
if (!attribute->is_undefined()) {
throw not_implemented_exception("max: attribute not implemented");
}
// FIXME: max is currently always case sensitive
(void) val_case;
const auto & arr = args.get_pos(0)->as_array();
if (arr.empty()) {
return mk_val<value_undefined>();
}
value result = arr[0];
for (size_t i = 1; i < arr.size(); ++i) {
if (value_compare(arr[i], result, value_compare_op::gt)) {
result = arr[i];
}
}
return result;
}},
{"unique", array_unique_not_implemented},
};
return builtins;
+29 -4
View File
@@ -7,6 +7,7 @@
#include <fstream>
#include <sstream>
#include <filesystem>
#include <regex>
static std::string rm_leading_dashes(const std::string & str) {
size_t pos = 0;
@@ -16,6 +17,23 @@ static std::string rm_leading_dashes(const std::string & str) {
return str.substr(pos);
}
static std::string canonical_tag(const std::string & tag) {
static const std::regex re_tag("[-.]([A-Z0-9_]+)$", std::regex::icase);
std::smatch m;
if (std::regex_search(tag, m, re_tag)) {
std::string canon = m[1].str();
for (char & c : canon) {
c = (char) std::toupper((unsigned char) c);
}
return canon;
}
std::string upper = tag;
for (char & c : upper) {
c = (char) std::toupper((unsigned char) c);
}
return upper;
}
std::vector<std::string> common_preset::to_args(const std::string & bin_path) const {
std::vector<std::string> args;
@@ -270,11 +288,18 @@ common_presets common_preset_context::load_from_ini(const std::string & path, co
for (auto section : ini_data) {
common_preset preset;
if (section.first.empty()) {
preset.name = COMMON_PRESET_DEFAULT_NAME;
} else {
preset.name = section.first;
std::string section_name = section.first.empty() ? std::string(COMMON_PRESET_DEFAULT_NAME) : section.first;
if (section_name != "*" && section_name != COMMON_PRESET_DEFAULT_NAME) {
auto colon_idx = section_name.rfind(':');
if (colon_idx != std::string::npos) {
std::string tag = section_name.substr(colon_idx + 1);
std::string canon_tag = canonical_tag(tag);
if (canon_tag != tag) {
section_name = section_name.substr(0, colon_idx + 1) + canon_tag;
}
}
}
preset.name = section_name;
LOG_DBG("loading preset: %s\n", preset.name.c_str());
for (const auto & [key, value] : section.second) {
if (key == "version") {
+10 -10
View File
@@ -65,12 +65,12 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
if (ctx->start_matcher.advance(token)) {
ctx->state = REASONING_BUDGET_COUNTING;
ctx->remaining = ctx->budget;
LOG_INF("reasoning-budget: activated, budget=%d tokens\n", ctx->budget);
COM_TRC("activated, budget=%d tokens\n", ctx->budget);
if (ctx->remaining <= 0) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
COM_TRC("%s", "budget=0, forcing immediately\n");
}
}
break;
@@ -80,7 +80,7 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
{
if (ctx->end_matcher.advance(token)) {
ctx->state = REASONING_BUDGET_DONE;
LOG_INF("reasoning-budget: deactivated (natural end)\n");
COM_TRC("%s", "deactivated (natural end)\n");
break;
}
@@ -95,7 +95,7 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n");
COM_TRC("%s", "UTF-8 complete, now forcing end sequence\n");
}
} else if (ctx->state == REASONING_BUDGET_COUNTING) {
ctx->remaining--;
@@ -104,11 +104,11 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: budget exhausted, forcing end sequence\n");
COM_TRC("%s", "budget exhausted, forcing end sequence\n");
} else {
ctx->state = REASONING_BUDGET_WAITING_UTF8;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: budget exhausted, waiting for UTF-8 completion\n");
COM_TRC("%s", "budget exhausted, waiting for UTF-8 completion\n");
}
}
}
@@ -118,7 +118,7 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
ctx->force_pos++;
if (ctx->force_pos >= ctx->forced_tokens.size()) {
ctx->state = REASONING_BUDGET_DONE;
LOG_INF("reasoning-budget: forced sequence complete, done\n");
COM_TRC("%s", "forced sequence complete, done\n");
}
break;
case REASONING_BUDGET_DONE:
@@ -128,12 +128,12 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
ctx->state = REASONING_BUDGET_COUNTING;
ctx->remaining = ctx->budget;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: re-activated on new start tag, budget=%d tokens\n", ctx->budget);
COM_TRC("re-activated on new start tag, budget=%d tokens\n", ctx->budget);
if (ctx->remaining <= 0) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
COM_TRC("%s", "budget=0, forcing immediately\n");
}
}
break;
@@ -264,7 +264,7 @@ bool common_reasoning_budget_force(struct llama_sampler * smpl) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: forced into forcing state (manual transition)\n");
COM_TRC("%s", "forced into forcing state (manual transition)\n");
return true;
}
-204
View File
@@ -1,204 +0,0 @@
#include "regex-partial.h"
#include "common.h"
#include <functional>
#include <optional>
common_regex::common_regex(const std::string & pattern) :
pattern(pattern),
rx(pattern),
rx_reversed_partial(regex_to_reversed_partial_regex(pattern)) {}
common_regex_match common_regex::search(const std::string & input, size_t pos, bool as_match) const {
std::smatch match;
if (pos > input.size()) {
throw std::runtime_error("Position out of bounds");
}
auto start = input.begin() + pos;
auto found = as_match
? std::regex_match(start, input.end(), match, rx)
: std::regex_search(start, input.end(), match, rx);
if (found) {
common_regex_match res;
res.type = COMMON_REGEX_MATCH_TYPE_FULL;
for (size_t i = 0; i < match.size(); ++i) {
auto begin = pos + match.position(i);
res.groups.emplace_back(begin, begin + match.length(i));
}
return res;
}
std::match_results<std::string::const_reverse_iterator> srmatch;
if (std::regex_search(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial, std::regex_constants::match_continuous)) {
auto group = srmatch[1].str();
if (group.length() != 0) {
auto it = srmatch[1].second.base();
// auto position = static_cast<size_t>(std::distance(input.begin(), it));
if ((!as_match) || it == input.begin()) {
common_regex_match res;
res.type = COMMON_REGEX_MATCH_TYPE_PARTIAL;
const size_t begin = std::distance(input.begin(), it);
const size_t end = input.size();
if (begin == std::string::npos || end == std::string::npos || begin > end) {
throw std::runtime_error("Invalid range");
}
res.groups.push_back({begin, end});
return res;
}
}
}
return {};
}
/*
Transforms a regex pattern to a partial match pattern that operates on a reversed input string to find partial final matches of the original pattern.
Ideally we'd like to use boost::match_partial (https://beta.boost.org/doc/libs/1_59_0/libs/regex/doc/html/boost_regex/partial_matches.html)
to see if a string ends with a partial regex match, but but it's not in std::regex yet.
Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input.
- /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:(?:d)?c)?b)?a)
- /a|b/ -> ^(a|b)
- /a*?/ -> error, could match ""
- /a*b/ -> ^((?:b)?a*+) (final repetitions become eager)
- /.*?ab/ -> ^((?:b)?a) (omit .*)
- /a.*?b/ -> ^((?:b)?.*?a) (keep reluctant matches)
- /a(bc)d/ -> ^((?:(?:d)?(?:(?:c)?b))?a)
- /a(bc|de)/ -> ^((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a)
- /ab{2,4}c/ -> ^cbbb?b?a -> ^((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a)
The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern.
All other groups are turned into non-capturing groups, and reluctant quantifiers are ignored.
*/
std::string regex_to_reversed_partial_regex(const std::string & pattern) {
auto it = pattern.begin();
const auto end = pattern.end();
std::function<std::string()> process = [&]() {
std::vector<std::vector<std::string>> alternatives(1);
std::vector<std::string> * sequence = &alternatives.back();
while (it != end) {
if (*it == '[') {
auto start = it;
++it;
while (it != end) {
if ((*it == '\\') && (++it != end)) {
++it;
} else if ((it != end) && (*it == ']')) {
break;
} else {
++it;
}
}
if (it == end) {
throw std::runtime_error("Unmatched '[' in pattern");
}
++it;
sequence->push_back(std::string(start, it));
} else if (*it == '*' || *it == '?' || *it == '+') {
if (sequence->empty()) {
throw std::runtime_error("Quantifier without preceding element");
}
sequence->back() += *it;
auto is_star = *it == '*';
++it;
if (is_star) {
if (it != end && *it == '?') {
++it;
}
}
} else if (*it == '{') {
if (sequence->empty()) {
throw std::runtime_error("Repetition without preceding element");
}
++it;
auto start = it;
while (it != end && *it != '}') {
++it;
}
if (it == end) {
throw std::runtime_error("Unmatched '{' in pattern");
}
auto parts = string_split(std::string(start, it), ",");
++it;
if (parts.size() > 2) {
throw std::runtime_error("Invalid repetition range in pattern");
}
auto parseOptInt = [&](const std::string & s, const std::optional<int> & def = std::nullopt) -> std::optional<int> {
if (s.empty()) {
return def;
}
return std::stoi(s);
};
auto min = parseOptInt(parts[0], 0);
auto max = parts.size() == 1 ? min : parseOptInt(parts[1]);
if (min && max && *max < *min) {
throw std::runtime_error("Invalid repetition range in pattern");
}
// Brutal but... let's repeat at least min times, then ? for the delta between min & max (or * for unbounded)
auto part = sequence->back();
sequence->pop_back();
for (int i = 0; i < *min; i++) {
sequence->push_back(part);
}
if (max) {
for (int i = *min; i < *max; i++) {
sequence->push_back(part + "?");
}
} else {
sequence->push_back(part + "*");
}
} else if (*it == '(') {
++it;
if (it != end && *it == '?' && (it + 1 != end) && *(it + 1) == ':') {
it += 2;
}
auto sub = process();
if (*it != ')') {
throw std::runtime_error("Unmatched '(' in pattern");
}
++it;
auto & part = sequence->emplace_back("(?:");
part += sub;
part += ")";
} else if (*it == ')') {
break;
} else if (*it == '|') {
++it;
alternatives.emplace_back();
sequence = &alternatives.back();
} else if (*it == '\\' && (++it != end)) {
auto str = std::string("\\") + *it;
sequence->push_back(str);
++it;
} else if (it != end) {
sequence->push_back(std::string(1, *it));
++it;
}
}
// /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:d)?c)?b)?a)
// if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group
// We'll do the outermost capturing group and final .* in the enclosing function.
std::vector<std::string> res_alts;
for (const auto & parts : alternatives) {
auto & res = res_alts.emplace_back();
for (size_t i = 0; i < parts.size() - 1; i++) {
res += "(?:";
}
for (auto it = parts.rbegin(); it != parts.rend(); ++it) {
res += *it;
if (it != parts.rend() - 1) {
res += ")?";
}
}
}
return string_join(res_alts, "|");
};
auto res = process();
if (it != end) {
throw std::runtime_error("Unmatched '(' in pattern");
}
return "^(" + res + ")";
}
-56
View File
@@ -1,56 +0,0 @@
#pragma once
#include <regex>
#include <string>
enum common_regex_match_type {
COMMON_REGEX_MATCH_TYPE_NONE,
COMMON_REGEX_MATCH_TYPE_PARTIAL,
COMMON_REGEX_MATCH_TYPE_FULL,
};
struct common_string_range {
size_t begin;
size_t end;
common_string_range(size_t begin, size_t end) : begin(begin), end(end) {
if (begin > end) {
throw std::runtime_error("Invalid range");
}
}
// prevent default ctor
common_string_range() = delete;
bool empty() const {
return begin == end;
}
bool operator==(const common_string_range & other) const {
return begin == other.begin && end == other.end;
}
};
struct common_regex_match {
common_regex_match_type type = COMMON_REGEX_MATCH_TYPE_NONE;
std::vector<common_string_range> groups;
bool operator==(const common_regex_match & other) const {
return type == other.type && groups == other.groups;
}
bool operator!=(const common_regex_match & other) const {
return !(*this == other);
}
};
class common_regex {
std::string pattern;
std::regex rx;
std::regex rx_reversed_partial;
public:
explicit common_regex(const std::string & pattern);
common_regex_match search(const std::string & input, size_t pos, bool as_match = false) const;
const std::string & str() const { return pattern; }
};
// For testing only (pretty print of failures).
std::string regex_to_reversed_partial_regex(const std::string & pattern);
+371 -65
View File
@@ -18,6 +18,13 @@
#include <map>
#include <cinttypes>
#define SPC_DBG(fmt, ...) LOG_DBG("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SPC_TRC(fmt, ...) LOG_TRC("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SPC_INF(fmt, ...) LOG_INF("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SPC_WRN(fmt, ...) LOG_WRN("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SPC_ERR(fmt, ...) LOG_ERR("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SPC_CNT(fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
@@ -26,6 +33,7 @@ const std::map<std::string, common_speculative_type> common_speculative_type_fro
{"draft-simple", COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE},
{"draft-eagle3", COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3},
{"draft-mtp", COMMON_SPECULATIVE_TYPE_DRAFT_MTP},
{"draft-dflash", COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH},
{"ngram-simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
{"ngram-map-k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
{"ngram-map-k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
@@ -60,21 +68,20 @@ static bool common_speculative_are_compatible(
const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
const auto vocab_type_tgt = llama_vocab_type(vocab_tgt);
LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
SPC_DBG("vocab_type tgt: %d\n", vocab_type_tgt);
const auto vocab_type_dft = llama_vocab_type(vocab_dft);
LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
SPC_DBG("vocab_type dft: %d\n", vocab_type_dft);
if (vocab_type_tgt != vocab_type_dft) {
LOG_WRN("%s: draft model vocab type must match target model to use speculation but "
"vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt);
SPC_WRN("draft model vocab type must match target model to use speculation but "
"vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
return false;
}
if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
(llama_vocab_get_add_bos(vocab_tgt) && llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft))) {
LOG_WRN("%s: draft model bos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
__func__,
SPC_WRN("draft model bos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
llama_vocab_get_add_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_dft),
llama_vocab_bos(vocab_tgt), llama_vocab_bos(vocab_dft));
return false;
@@ -82,8 +89,7 @@ static bool common_speculative_are_compatible(
if (llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
(llama_vocab_get_add_eos(vocab_tgt) && llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft))) {
LOG_WRN("%s: draft model eos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
__func__,
SPC_WRN("draft model eos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
llama_vocab_get_add_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_dft),
llama_vocab_eos(vocab_tgt), llama_vocab_eos(vocab_dft));
return false;
@@ -97,8 +103,8 @@ static bool common_speculative_are_compatible(
: n_vocab_dft - n_vocab_tgt;
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
LOG_DBG("%s: draft model vocab must closely match target model to use speculation but ", __func__);
LOG_DBG("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
SPC_DBG("draft model vocab must closely match target model to use speculation but "
"target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
return false;
}
@@ -108,8 +114,8 @@ static bool common_speculative_are_compatible(
const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
LOG_DBG("%s: draft model vocab must match target model to use speculation but ", __func__);
LOG_DBG("token %d content differs - target '%s', draft '%s'\n", i,
SPC_DBG("draft model vocab must match target model to use speculation but "
"token %d content differs - target '%s', draft '%s'\n", i,
common_token_to_piece(vocab_tgt, i).c_str(),
common_token_to_piece(vocab_dft, i).c_str());
return false;
@@ -186,9 +192,9 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
auto * ctx_dft = this->params.ctx_dft;
auto * ctx_tgt = this->params.ctx_tgt;
LOG_INF("%s: adding speculative implementation 'draft-simple'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%f\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min);
LOG_INF("%s: - gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n", __func__,
SPC_TRC("%s", "adding speculative implementation 'draft-simple'\n");
SPC_TRC("- n_max=%d, n_min=%d, p_min=%f\n", this->params.n_max, this->params.n_min, this->params.p_min);
SPC_TRC("- gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n",
this->params.n_gpu_layers,
ggml_type_name(this->params.cache_type_k),
ggml_type_name(this->params.cache_type_v),
@@ -228,16 +234,16 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
}
const bool vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft));
LOG_DBG("%s: vocab_cmpt = %d\n", __func__, vocab_cmpt);
SPC_DBG("vocab_cmpt = %d\n", vocab_cmpt);
if (!vocab_cmpt) {
LOG_ERR("%s: the target and draft vocabs are not compatible\n", __func__);
SPC_ERR("%s", "the target and draft vocabs are not compatible\n");
throw std::runtime_error("draft model vocab type must match target model to use speculation");
}
if (n_seq != llama_n_seq_max(ctx_dft)) {
LOG_ERR("%s: n_seq mismatch: %d != %d\n", __func__, n_seq, llama_n_seq_max(ctx_dft));
SPC_ERR("n_seq mismatch: %d != %d\n", n_seq, llama_n_seq_max(ctx_dft));
throw std::runtime_error("the draft model number of sequences is incompatible with the speculative n_seq");
}
@@ -257,7 +263,7 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
const int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_ERR("%s: failed to decode draft batch, ret = %d\n", __func__, ret);
SPC_ERR("failed to decode draft batch, ret = %d\n", ret);
return false;
}
@@ -290,7 +296,7 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
SPC_ERR("llama_decode returned %d\n", ret);
return;
}
@@ -314,7 +320,7 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
const auto * cur_p = common_sampler_get_candidates(smpl, true);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
SPC_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p,
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
}
@@ -354,7 +360,7 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
// evaluate the drafted tokens on the draft model
ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
SPC_ERR("llama_decode[%d] returned %d\n", i, ret);
break;
}
@@ -449,8 +455,8 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, n_seq)
, params(params.draft)
{
LOG_INF("%s: adding speculative implementation 'draft-eagle3'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%f, backend_sampling=%d\n", __func__, params.draft.n_max, params.draft.n_min, params.draft.p_min, (int) params.draft.backend_sampling);
SPC_TRC("%s", "adding speculative implementation 'draft-eagle3'\n");
SPC_TRC("- n_max=%d, n_min=%d, p_min=%f, backend_sampling=%d\n", params.draft.n_max, params.draft.n_min, params.draft.p_min, (int) params.draft.backend_sampling);
auto * ctx_tgt = this->params.ctx_tgt;
auto * ctx_dft = this->params.ctx_dft;
@@ -493,7 +499,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
llama_sampler_chain_add(chain, llama_sampler_init_top_k(10));
if (!llama_set_sampler(ctx_dft, seq_id, chain)) {
LOG_WRN("%s: backend offload failed for seq_id=%d; using CPU sampler\n", __func__, (int) seq_id);
SPC_WRN("backend offload failed for seq_id=%d; using CPU sampler\n", (int) seq_id);
llama_sampler_free(chain);
chain = nullptr;
}
@@ -548,9 +554,9 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
auto * ctx_dft = this->params.ctx_dft;
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
if (pos_max < N - 2) {
LOG_WRN("%s: ctx_dft pos_max=%d < N-2=%d — process() did not run on every prefill ubatch. "
SPC_WRN("ctx_dft pos_max=%d < N-2=%d — process() did not run on every prefill ubatch. "
"Drafts may degrade.\n",
__func__, (int) pos_max, N - 2);
(int) pos_max, N - 2);
}
}
@@ -621,8 +627,8 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
};
const int32_t rc = llama_encode(ctx_dft, enc_batch);
if (rc != 0) {
LOG_ERR("%s: llama_encode(ctx_dft) failed rc=%d (n_tokens=%d, offset=%d)\n",
__func__, rc, (int) n_chunk, (int) i);
SPC_ERR("llama_encode(ctx_dft) failed rc=%d (n_tokens=%d, offset=%d)\n",
rc, (int) n_chunk, (int) i);
return false;
}
@@ -692,8 +698,8 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
if (batch.n_tokens > 0) {
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (n_tokens=%d, ubatch_pos[0]=%d)\n",
__func__, rc, (int) batch.n_tokens, (int) batch_in.pos[0]);
SPC_ERR("llama_decode(ctx_dft) failed rc=%d (n_tokens=%d, ubatch_pos[0]=%d)\n",
rc, (int) batch.n_tokens, (int) batch_in.pos[0]);
return false;
}
}
@@ -744,7 +750,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
SPC_ERR("llama_decode returned %d\n", ret);
return;
}
@@ -770,7 +776,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
const auto * cur_p = common_sampler_get_candidates(smpl, true);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
SPC_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p,
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
}
@@ -809,7 +815,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
SPC_ERR("llama_decode[%d] returned %d\n", i, ret);
break;
}
@@ -893,6 +899,296 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
}
};
// DFlash: block-diffusion drafting with a draft-side KV cache injection
struct common_speculative_impl_draft_dflash : public common_speculative_impl {
common_params_speculative_draft params;
llama_batch batch; // noise tokens
llama_batch batch_inject; // target features for KV cache injection
std::vector<common_sampler_ptr> smpls;
int32_t n_embd_dec = 0; // draft hidden size
int32_t n_embd_enc = 0; // target_layer_ids_n * target_hidden_size
int32_t n_embd_tgt = 0; // target model hidden size
int32_t block_size = 0;
llama_token mask_token_id = 0;
const int32_t * target_layer_ids = nullptr; // model_dft's extract layer indices
uint32_t target_layer_ids_n = 0;
// scratch buffer for concatenated target features [n_tokens, n_embd_enc]
std::vector<float> features_buf;
common_speculative_impl_draft_dflash(const common_params_speculative & params, uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH, n_seq)
, params(params.draft)
{
auto * ctx_tgt = this->params.ctx_tgt;
auto * ctx_dft = this->params.ctx_dft;
GGML_ASSERT(ctx_tgt && ctx_dft && "DFlash requires ctx_tgt and ctx_dft to be set");
const llama_model * model_dft = llama_get_model(ctx_dft);
const llama_model * model_tgt = llama_get_model(ctx_tgt);
target_layer_ids = llama_model_target_layer_ids (model_dft);
target_layer_ids_n = llama_model_target_layer_ids_n(model_dft);
GGML_ASSERT(target_layer_ids_n > 0 && "DFlash model has no target_layer_ids");
n_embd_tgt = llama_model_n_embd(model_tgt);
n_embd_dec = llama_model_n_embd(model_dft);
n_embd_enc = (int32_t) target_layer_ids_n * n_embd_tgt;
// read the trained block size from the dflash.block_size metadata key
block_size = 16;
{
char buf[32] = {};
if (llama_model_meta_val_str(model_dft, "dflash.block_size", buf, sizeof(buf)) >= 0) {
block_size = std::atoi(buf);
}
}
mask_token_id = llama_vocab_mask(llama_model_get_vocab(model_dft));
LOG_INF("%s: adding speculative implementation 'draft-dflash'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min);
LOG_INF("%s: - block_size=%d, mask_token_id=%d, n_extract=%u\n", __func__, block_size, mask_token_id, target_layer_ids_n);
// DFlash input is [id_last, <mask> * (block_size-1)], so it can draft at most block_size-1 tokens per step
if (this->params.n_max > block_size - 1) {
LOG_WRN("%s: requested draft size %d exceeds the trained DFlash block size %d -- clamping to %d draft tokens per step\n",
__func__, this->params.n_max, block_size - 1, block_size - 1);
this->params.n_max = block_size - 1;
}
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, n_seq);
batch_inject = llama_batch_init(llama_n_batch(ctx_dft), n_embd_dec, n_seq);
smpls.resize(n_seq);
for (auto & s : smpls) {
common_params_sampling sparams;
sparams.no_perf = false;
sparams.top_k = 1;
sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
s.reset(common_sampler_init(model_dft, sparams));
}
// turn on extraction of the target layers' input embeddings
for (uint32_t k = 0; k < target_layer_ids_n; ++k) {
llama_set_embeddings_layer_inp(ctx_tgt, (uint32_t) target_layer_ids[k], true);
}
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
llama_set_causal_attn(ctx_dft, false); // DFlash needs non-causal attention
}
~common_speculative_impl_draft_dflash() override {
llama_batch_free(batch);
llama_batch_free(batch_inject);
}
void begin(llama_seq_id seq_id, const llama_tokens & prompt) override {
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) {
return;
}
const int32_t N = (int32_t) prompt.size();
if (N <= 0) {
return;
}
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(params.ctx_dft), seq_id);
if (pos_max < N - 1) {
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - process() did not run on every prefill ubatch. "
"Drafts may degrade.\n",
__func__, (int) pos_max, N - 1);
}
}
bool process(const llama_batch & batch_in) override {
if (batch_in.n_tokens <= 0) {
return true;
}
if (batch_in.token == nullptr || batch_in.embd != nullptr) {
return true;
}
const int32_t n_tokens = batch_in.n_tokens;
// per-seq inclusive batch range (assumes each seq's tokens are contiguous in the batch)
std::vector<int32_t> i_batch_beg(n_seq, -1);
std::vector<int32_t> i_batch_end(n_seq, -1);
for (int32_t k = 0; k < n_tokens; ++k) {
GGML_ASSERT(batch_in.n_seq_id[k] == 1);
const llama_seq_id seq_id = batch_in.seq_id[k][0];
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) {
continue;
}
i_batch_end[seq_id] = k;
if (i_batch_beg[seq_id] < 0) {
i_batch_beg[seq_id] = k;
}
}
auto * ctx_tgt = this->params.ctx_tgt;
auto * ctx_dft = this->params.ctx_dft;
const int32_t n_ubatch = (int32_t) llama_n_ubatch(ctx_dft);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_batch_beg[seq_id] < 0) {
continue;
}
const int32_t n_rows = i_batch_end[seq_id] - i_batch_beg[seq_id] + 1;
for (int32_t offset = 0; offset < n_rows; offset += n_ubatch) {
const int32_t n_chunk = std::min(n_ubatch, n_rows - offset);
// gather this chunk's target features, interleaved by extract layer
features_buf.resize((size_t) n_chunk * n_embd_enc);
for (uint32_t k = 0; k < target_layer_ids_n; ++k) {
const float * layer = llama_get_embeddings_layer_inp(ctx_tgt, (uint32_t) target_layer_ids[k]);
if (!layer) {
GGML_ABORT("DFlash: target layer %d input not extracted.", target_layer_ids[k]);
}
for (int32_t i = 0; i < n_chunk; ++i) {
float * dst = features_buf.data() + (size_t) i * n_embd_enc + k * (size_t) n_embd_tgt;
const float * src = layer + (size_t) (i_batch_beg[seq_id] + offset + i) * n_embd_tgt;
std::memcpy(dst, src, (size_t) n_embd_tgt * sizeof(float));
}
}
// fuse extracted features through DFlash encoder
llama_batch enc_batch = {
/*.n_tokens =*/ n_chunk,
/*.token =*/ nullptr,
/*.embd =*/ features_buf.data(),
/*.pos =*/ nullptr,
/*.n_seq_id =*/ nullptr,
/*.seq_id =*/ nullptr,
/*.logits =*/ nullptr,
};
int32_t rc = llama_encode(ctx_dft, enc_batch);
if (rc != 0) {
LOG_ERR("%s: llama_encode(ctx_dft) failed rc=%d (n_tokens=%d, offset=%d)\n",
__func__, rc, (int) n_chunk, (int) offset);
return false;
}
const float * inp_g = llama_get_embeddings_nextn(ctx_dft);
GGML_ASSERT(inp_g && "DFlash encoder produced no output.");
// inject the DFlash decoder K/V cache at the tokens' target positions
batch_inject.n_tokens = n_chunk;
std::memcpy(batch_inject.embd, inp_g, (size_t) n_chunk * n_embd_dec * sizeof(float));
for (int32_t i = 0; i < n_chunk; ++i) {
batch_inject.pos[i] = batch_in.pos[i_batch_beg[seq_id] + offset + i];
batch_inject.n_seq_id[i] = 1;
batch_inject.seq_id[i][0] = seq_id;
batch_inject.logits[i] = false;
}
rc = llama_decode(ctx_dft, batch_inject);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (n_tokens=%d, offset=%d)\n",
__func__, rc, (int) n_chunk, (int) offset);
return false;
}
}
}
return true;
}
void draft(common_speculative_draft_params_vec & dparams) override {
auto & ctx_dft = params.ctx_dft;
common_batch_clear(batch);
// build one batch holding every drafting sequence's noise block into a single decode)
// record where each block starts and its size
std::vector<int32_t> i_block_beg(n_seq, -1);
std::vector<int32_t> n_block (n_seq, 0);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
auto & dp = dparams[seq_id];
if (!dp.drafting) {
continue;
}
common_sampler_reset(smpls[seq_id].get());
const int32_t n = (int32_t) dp.n_past;
int32_t n_draft = params.n_max;
if (dp.n_max > 0) {
n_draft = std::min(n_draft, dp.n_max);
}
const int32_t n_block_tokens = n_draft + 1; // id_last + n_draft * <mask>
i_block_beg[seq_id] = batch.n_tokens;
n_block [seq_id] = n_block_tokens;
for (int32_t i = 0; i < n_block_tokens; ++i) {
common_batch_add(batch, i == 0 ? dp.id_last : mask_token_id, n + i, { seq_id }, true);
}
}
if (batch.n_tokens == 0) {
return;
}
// decode all sequence's noise block in a single batch
int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
return;
}
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_block_beg[seq_id] < 0) {
continue;
}
auto & dp = dparams[seq_id];
const int32_t beg = i_block_beg[seq_id];
const int32_t n_block_tokens = n_block[seq_id];
auto * smpl = smpls[seq_id].get();
auto & result = *dp.result;
// greedily read the predicted block at this sequence's noise positions 1..n_block_tokens-1
for (int32_t i = 1; i < n_block_tokens; ++i) {
common_sampler_sample(smpl, ctx_dft, beg + i, true);
const auto * cur_p = common_sampler_get_candidates(smpl, true);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
seq_id, k, i - 1, cur_p->data[k].id, cur_p->data[k].p,
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
}
const llama_token id = cur_p->data[0].id;
common_sampler_accept(smpl, id, true);
result.push_back(id);
}
}
}
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/, bool /*is_other*/) override {
// noop
}
bool need_embd() const override {
return false;
}
};
struct common_speculative_impl_draft_mtp : public common_speculative_impl {
common_params_speculative_draft params; // reuses the draft-model params slot (ctx_tgt/ctx_dft)
@@ -942,9 +1238,9 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
"MTP input row width must match the target h_nextn width");
n_mtp_layers = std::max(1, (int) llama_model_n_layer_nextn(llama_get_model(ctx_dft)));
LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
LOG_INF("%s: - gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n", __func__,
SPC_TRC("%s", "adding speculative implementation 'draft-mtp'\n");
SPC_TRC("- n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
SPC_TRC("- gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n",
this->params.n_gpu_layers,
ggml_type_name(this->params.cache_type_k),
ggml_type_name(this->params.cache_type_v),
@@ -975,7 +1271,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
llama_sampler_chain_add(chain, llama_sampler_init_top_k(10));
if (!llama_set_sampler(ctx_dft, seq_id, chain)) {
LOG_WRN("%s: backend offload failed for seq_id=%d; using CPU sampler\n", __func__, (int) seq_id);
SPC_WRN("backend offload failed for seq_id=%d; using CPU sampler\n", (int) seq_id);
llama_sampler_free(chain);
chain = nullptr;
}
@@ -1038,11 +1334,11 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
if (pos_max < N - 1 && !is_mem_shared) {
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - "
SPC_WRN("ctx_dft pos_max=%d < N-1=%d - "
"process() hook may not have run on every prefill ubatch "
"(need_embd / logits=1 on every prompt position?). "
"Drafts may degrade.\n",
__func__, (int) pos_max, N - 1);
(int) pos_max, N - 1);
}
}
@@ -1128,8 +1424,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) head=%d failed rc=%d (pos=%d)\n",
__func__, head, (int) rc, (int) batch_in.pos[0]);
SPC_ERR("llama_decode(ctx_dft) head=%d failed rc=%d (pos=%d)\n",
head, (int) rc, (int) batch_in.pos[0]);
ok = false;
break;
}
@@ -1217,7 +1513,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
SPC_ERR("llama_decode[%d] returned %d\n", i, ret);
break;
}
@@ -1239,7 +1535,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
const auto * cur_p = common_sampler_get_candidates(smpl, true);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
SPC_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p,
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
}
@@ -1353,8 +1649,8 @@ struct common_speculative_impl_ngram_simple : public common_speculative_impl {
, params(params.ngram_simple)
, config(config)
{
LOG_INF("%s: adding speculative implementation 'ngram-simple'\n", __func__);
LOG_INF("%s: - size_n=%d, size_m=%d, min_hits=%d\n", __func__,
SPC_TRC("%s", "adding speculative implementation 'ngram-simple'\n");
SPC_TRC("- size_n=%d, size_m=%d, min_hits=%d\n",
this->params.size_n, this->params.size_m, this->params.min_hits);
}
@@ -1403,8 +1699,8 @@ struct common_speculative_impl_ngram_map_k : public common_speculative_impl {
this->config.push_back(config);
}
LOG_INF("%s: adding speculative implementation '%s'\n", __func__, common_speculative_type_to_str(this->type).c_str());
LOG_INF("%s: - size_key=%d, size_value=%d, key_only=%d, min_hits=%d\n", __func__,
SPC_TRC("adding speculative implementation '%s'\n", common_speculative_type_to_str(this->type).c_str());
SPC_TRC("- size_key=%d, size_value=%d, key_only=%d, min_hits=%d\n",
config.size_key, config.size_value, config.key_only, config.min_hits);
}
@@ -1478,15 +1774,15 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
, verbose(std::getenv("LLAMA_TRACE") != nullptr) {
static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t));
LOG_INF("%s: adding speculative implementation 'ngram-mod'\n", __func__);
LOG_INF("%s: - n_match=%d, n_max=%d, n_min=%d\n", __func__,
SPC_TRC("%s", "adding speculative implementation 'ngram-mod'\n");
SPC_TRC("- n_match=%d, n_max=%d, n_min=%d\n",
this->params.n_match, this->params.n_max, this->params.n_min);
LOG_INF("%s: - mod size=%zu (%.3f MB)\n", __func__,
SPC_TRC("- mod size=%zu (%.3f MB)\n",
mod.size(), (float)(mod.size_bytes())/1024/1024);
if (this->params.n_match < 16) {
LOG_WRN("%s: ngram_mod n_match=%d is too small - poor quality is possible, "
"see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, this->params.n_match);
SPC_WRN("ngram_mod n_match=%d is too small - poor quality is possible, "
"see: https://github.com/ggml-org/llama.cpp/pull/19164\n", this->params.n_match);
}
sinfos.resize(n_seq);
@@ -1510,11 +1806,11 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
sinfo.i_last = prompt.size() - n;
const double f = (double)mod.get_used() / (double)mod.size();
LOG_INF("%s: ngram_mod occupancy = %zu/%zu (%.2f)\n", __func__, mod.get_used(), mod.size(), f);
SPC_TRC("ngram_mod occupancy = %zu/%zu (%.2f)\n", mod.get_used(), mod.size(), f);
constexpr double f_thold = 0.25;
if (f > f_thold) {
LOG_WRN("%s: ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", __func__, f, f_thold);
SPC_WRN("ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", f, f_thold);
mod.reset();
}
@@ -1608,7 +1904,7 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
sinfo.n_low++;
if (sinfo.n_low >= 5) {
if (verbose) {
LOG_WRN("%s: low acceptance streak (%d) - resetting ngram_mod\n", __func__, sinfo.n_low);
SPC_TRC("low acceptance streak (%d) - resetting ngram_mod\n", sinfo.n_low);
}
mod.reset();
@@ -1658,8 +1954,8 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl {
, save_dynamic(save_dynamic)
, save_static(save_static)
{
LOG_INF("%s: adding speculative implementation 'ngram-cache'\n", __func__);
LOG_INF("%s: - n_draft=%d, cache_static=%s, cache_dynamic=%s\n", __func__,
SPC_TRC("%s", "adding speculative implementation 'ngram-cache'\n");
SPC_TRC("- n_draft=%d, cache_static=%s, cache_dynamic=%s\n",
n_draft,
path_static.empty() ? "none" : path_static.c_str(),
path_dynamic.empty() ? "none" : path_dynamic.c_str());
@@ -1674,7 +1970,7 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl {
sinfo.ngram_cache_static = ngram_cache_static;
}
} catch (...) {
LOG_ERR("failed to open static lookup cache: %s", path_static.c_str());
SPC_ERR("failed to open static lookup cache: %s", path_static.c_str());
GGML_ABORT("Couldn't read static lookup cache");
}
}
@@ -1687,7 +1983,7 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl {
sinfo.ngram_cache_dynamic = ngram_cache_dynamic;
}
} catch (...) {
LOG_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str());
SPC_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str());
GGML_ABORT("Couldn't read dynamic lookup cache");
}
}
@@ -1836,6 +2132,7 @@ std::string common_speculative_type_to_str(common_speculative_type type) {
case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE: return "draft-simple";
case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3: return "draft-eagle3";
case COMMON_SPECULATIVE_TYPE_DRAFT_MTP: return "draft-mtp";
case COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH: return "draft-dflash";
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram-simple";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram-map-k";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram-map-k4v";
@@ -1888,6 +2185,7 @@ int32_t common_speculative_n_max(const common_params_speculative * spec) {
case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE:
case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3:
case COMMON_SPECULATIVE_TYPE_DRAFT_MTP:
case COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH:
n_max = std::max(n_max, std::max(0, spec->draft.n_max));
break;
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE:
@@ -1925,6 +2223,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE));
bool has_draft_eagle3 = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3)) && params.draft.ctx_dft != nullptr;
bool has_draft_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
bool has_draft_dflash = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH)) && params.draft.ctx_dft != nullptr;
@@ -1935,7 +2234,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
bool has_ngram_mod = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MOD));
// when adding a new type - update here the logic above
static_assert(COMMON_SPECULATIVE_TYPE_COUNT == 9);
static_assert(COMMON_SPECULATIVE_TYPE_COUNT == 10);
// this list here defines the priority of the speculators
// the one with highest priority are listed first
@@ -1965,6 +2264,9 @@ common_speculative * common_speculative_init(common_params_speculative & params,
if (has_draft_mtp) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, params));
}
if (has_draft_dflash) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH, params));
}
}
std::vector<std::unique_ptr<common_speculative_impl>> impls = {};
@@ -1985,6 +2287,10 @@ common_speculative * common_speculative_init(common_params_speculative & params,
impls.push_back(std::make_unique<common_speculative_impl_draft_mtp>(config.params, n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH: {
impls.push_back(std::make_unique<common_speculative_impl_draft_dflash>(config.params, n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
common_ngram_map ngram_map = get_common_ngram_map(config.type, config.params.ngram_simple);
@@ -2034,7 +2340,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
}
if (impls.empty()) {
LOG_WRN("%s: no implementations specified for speculative decoding\n", __func__);
SPC_TRC("%s", "no implementations specified for speculative decoding\n");
return nullptr;
}
@@ -2161,13 +2467,13 @@ void common_speculative_draft(common_speculative * spec) {
if (dp.n_max > 0) {
if (!result.empty() && (int) result.size() > dp.n_max) {
LOG_DBG("%s: truncating draft to %d tokens\n", __func__, dp.n_max);
SPC_DBG("truncating draft to %d tokens\n", dp.n_max);
result.resize(dp.n_max);
}
}
if (!result.empty()) {
LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__,
SPC_DBG("called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n",
common_speculative_type_to_str(impl.get()->type).c_str(), dp.prompt->size(),
impl.get()->n_call_draft, result.size());
@@ -2291,7 +2597,7 @@ void common_speculative_print_stats(const common_speculative * spec) {
str_stats = ", #mean acc len = " + oss.str() + ", #acc rate/pos = (" + tmp.str() + ")";
}
LOG_INF("statistics %16s: #calls(b,g,a) = %4zu %6zu %6zu, #gen drafts = %6zu, #acc drafts = %5zu, #gen tokens = %6zu, #acc tokens = %5zu%s%s\n",
SPC_TRC("statistics %16s: #calls(b,g,a) = %4zu %6zu %6zu, #gen drafts = %6zu, #acc drafts = %5zu, #gen tokens = %6zu, #acc tokens = %5zu%s%s\n",
common_speculative_type_to_str(impl->type).c_str(),
impl->n_call_begin, impl->n_call_draft, impl->n_call_accept,
impl->n_gen_drafts,
+2
View File
@@ -50,6 +50,8 @@ TEXT_MODEL_MAP: dict[str, str] = {
"DeepseekV2ForCausalLM": "deepseek",
"DeepseekV3ForCausalLM": "deepseek",
"DeepseekV32ForCausalLM": "deepseek",
"DFlashDraftModel": "qwen",
"DeepseekV4ForCausalLM": "deepseek",
"DistilBertForMaskedLM": "bert",
"DistilBertForSequenceClassification": "bert",
"DistilBertModel": "bert",
+14 -1
View File
@@ -1273,7 +1273,7 @@ class TextModel(ModelBase):
if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None:
self.gguf_writer.add_layer_norm_eps(f_norm_eps)
logger.info(f"gguf: layer norm epsilon = {f_norm_eps}")
if (n_experts := self.find_hparam(["num_local_experts", "num_experts"], optional=True)) is not None:
if (n_experts := self.find_hparam(["num_local_experts", "num_experts", "n_routed_experts"], optional=True)) is not None:
self.gguf_writer.add_expert_count(n_experts)
logger.info(f"gguf: expert count = {n_experts}")
if (n_experts_used := self.find_hparam(["num_experts_per_tok", "num_experts_per_token", "top_k_experts"], optional=True)) is not None:
@@ -1291,6 +1291,8 @@ class TextModel(ModelBase):
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
elif score_func == "softmax":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
elif score_func == "sqrtsoftplus":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SQRTSOFTPLUS)
else:
raise ValueError(f"Unsupported expert score gating function value: {score_func}")
logger.info(f"gguf: expert score gating function = {score_func}")
@@ -2600,6 +2602,17 @@ class LazyTorchTensor(gguf.LazyBase):
return cls._wrap_fn(func)(*args, **kwargs)
if hasattr(torch, "float8_e8m0fnu"):
_torch_float8_e8m0 = torch.float8_e8m0fnu
LazyTorchTensor._dtype_map[_torch_float8_e8m0] = np.uint8
LazyTorchTensor._dtype_byteswap_map[_torch_float8_e8m0] = np.uint8
LazyTorchTensor._dtype_str_map["F8_E8M0"] = _torch_float8_e8m0
else:
# Older torch builds do not expose F8_E8M0. Keep the raw bytes so callers
# that know the format can decode them explicitly.
LazyTorchTensor._dtype_str_map["F8_E8M0"] = torch.uint8
def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> str:
# TODO @ngxson : this won't work correctly if the model has both audio & vision encoders
# maybe we should fallback to text model's arch in that case, since not many models have both
+308 -1
View File
@@ -1,15 +1,18 @@
from __future__ import annotations
import json
import re
from pathlib import Path
from typing import Any, Callable, Iterable, TYPE_CHECKING
import numpy as np
import torch
if TYPE_CHECKING:
from torch import Tensor
from .base import MmprojModel, ModelBase, TextModel, gguf, logger
from .base import LazyTorchTensor, MmprojModel, ModelBase, TextModel, gguf, logger
from .qwen import QwenModel
@@ -467,3 +470,307 @@ class DeepseekV32Model(DeepseekV2Model):
self.gguf_writer.add_indexer_head_count(self.hparams["index_n_heads"])
self.gguf_writer.add_indexer_key_length(self.hparams["index_head_dim"])
self.gguf_writer.add_indexer_top_k(self.hparams["index_topk"])
@ModelBase.register("DeepseekV4ForCausalLM")
class DeepseekV4Model(TextModel):
model_arch = gguf.MODEL_ARCH.DEEPSEEK4
_skipped_mtp_tensors = 0
def __init__(self, *args, **kwargs):
type(self)._skipped_mtp_tensors = 0
super().__init__(*args, **kwargs)
with open(self.dir_model / "config.json", "r", encoding="utf-8") as f:
raw_hparams = json.load(f)
for key, value in raw_hparams.items():
self.hparams.setdefault(key, value)
self.block_count = self.hparams["num_hidden_layers"]
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
self._dsv4_fp8_dequantized: set[str] = set()
self._dsv4_bf16_tensors: set[str] = set()
self._dsv4_f32_tensors: set[str] = set()
self._dsv4_mxfp4_generated = False
self._collect_source_dtypes()
if type(self)._skipped_mtp_tensors:
logger.info("Skipping %d DeepSeek-V4 MTP tensor(s) for conversion v0", type(self)._skipped_mtp_tensors)
# add a default chat template; if the model has a built-in template, it will be overridden later
template_path = Path(__file__).parent.parent / "models" / "templates" / "deepseek-ai-DeepSeek-V4.jinja"
if template_path.is_file():
with open(template_path, "r", encoding="utf-8") as f:
self.gguf_writer.add_chat_template(f.read())
@classmethod
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
name, _ = item
if name.startswith("mtp."):
cls._skipped_mtp_tensors += 1
return None
return super().filter_tensors(item)
@staticmethod
def _float8_dtypes() -> tuple[torch.dtype, ...]:
return tuple(
dtype for dtype in (
getattr(torch, "float8_e4m3fn", None),
getattr(torch, "float8_e5m2", None),
) if dtype is not None
)
@staticmethod
def _e8m0_to_float(scale: Tensor) -> Tensor:
torch_float8_e8m0 = getattr(torch, "float8_e8m0fnu", None)
if torch_float8_e8m0 is not None and scale.dtype == torch_float8_e8m0:
return scale.float()
bits = scale.view(torch.uint8).float()
return torch.exp2(bits - 127.0)
def _collect_source_dtypes(self) -> None:
for name, gen in self.model_tensors.items():
dtype = gen().dtype
if dtype == torch.bfloat16:
self._dsv4_bf16_tensors.add(name)
elif dtype == torch.float32:
self._dsv4_f32_tensors.add(name)
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
self.gguf_writer.add_swiglu_clamp_exp([hparams["swiglu_limit"]] * self.block_count)
self.gguf_writer.add_swiglu_clamp_shexp([hparams["swiglu_limit"]] * self.block_count)
self.gguf_writer.add_indexer_head_count(hparams["index_n_heads"])
self.gguf_writer.add_indexer_key_length(hparams["index_head_dim"])
self.gguf_writer.add_indexer_top_k(hparams["index_topk"])
self.gguf_writer.add_attention_output_group_count(hparams["o_groups"])
self.gguf_writer.add_attention_output_lora_rank(hparams["o_lora_rank"])
self.gguf_writer.add_attention_compress_ratios(hparams["compress_ratios"])
self.gguf_writer.add_attention_compress_rope_freq_base(hparams["compress_rope_theta"])
self.gguf_writer.add_hyper_connection_count(hparams["hc_mult"])
self.gguf_writer.add_hyper_connection_sinkhorn_iterations(hparams["hc_sinkhorn_iters"])
self.gguf_writer.add_hyper_connection_epsilon(hparams["hc_eps"])
self.gguf_writer.add_hash_layer_count(hparams["num_hash_layers"])
def dequant_model(self):
fp8_dtypes = self._float8_dtypes()
tensors_to_remove: list[str] = []
def dequant_fp8_weight(weight: Tensor, scale: Tensor) -> Tensor:
out_features, in_features = weight.shape
scale_f = self._e8m0_to_float(scale)
scale_f = scale_f.repeat_interleave(128, 0)[:out_features]
scale_f = scale_f.repeat_interleave(128, 1)[:, :in_features]
return weight.float() * scale_f
for name in list(self.model_tensors.keys()):
if not name.endswith(".scale"):
continue
weight_name = name.removesuffix(".scale") + ".weight"
if weight_name not in self.model_tensors:
continue
weight = self.model_tensors[weight_name]
scale = self.model_tensors[name]
if weight().dtype not in fp8_dtypes:
continue
self.model_tensors[weight_name] = lambda w=weight, s=scale: dequant_fp8_weight(w(), s())
self._dsv4_fp8_dequantized.add(weight_name)
tensors_to_remove.append(name)
for name in tensors_to_remove:
del self.model_tensors[name]
@staticmethod
def _pack_mxfp4_blocks(weight: Tensor, scale: Tensor) -> np.ndarray:
packed = weight.contiguous().view(torch.uint8)
scale_u8 = scale.contiguous().view(torch.uint8)
out_features, packed_cols = packed.shape
logical_cols = packed_cols * 2
if logical_cols % 32 != 0:
raise ValueError(f"MXFP4 source row has {logical_cols} values, expected a multiple of 32")
n_blocks = logical_cols // 32
if tuple(scale_u8.shape) != (out_features, n_blocks):
raise ValueError(f"MXFP4 scale shape {tuple(scale_u8.shape)} does not match {(out_features, n_blocks)}")
src = packed.reshape(out_features, n_blocks, 16)
low = src & 0x0F
high = (src >> 4) & 0x0F
# The safetensors bytes store adjacent values as low/high nibbles.
# ggml MXFP4 blocks store values 0..15 in low nibbles and 16..31 in high nibbles.
vals = torch.stack((low, high), dim=-1).reshape(out_features, n_blocks, 32)
qs = vals[:, :, :16] | (vals[:, :, 16:] << 4)
raw = torch.cat((scale_u8.unsqueeze(-1), qs.to(torch.uint8)), dim=-1)
return raw.reshape(out_features, n_blocks * 17).cpu().numpy()
def _write_mxfp4_expert_tensor(self, bid: int, proj: str, tensor_key: gguf.MODEL_TENSOR) -> list[str]:
n_experts = self.hparams["n_routed_experts"]
data: np.ndarray | None = None
consumed: list[str] = []
for eid in range(n_experts):
weight_name = f"layers.{bid}.ffn.experts.{eid}.{proj}.weight"
scale_name = f"layers.{bid}.ffn.experts.{eid}.{proj}.scale"
if weight_name not in self.model_tensors or scale_name not in self.model_tensors:
raise KeyError(f"Missing routed expert tensors for {weight_name}")
weight = LazyTorchTensor.to_eager(self.model_tensors[weight_name]())
scale = LazyTorchTensor.to_eager(self.model_tensors[scale_name]())
packed = self._pack_mxfp4_blocks(weight, scale)
if data is None:
data = np.empty((n_experts, *packed.shape), dtype=packed.dtype)
data[eid] = packed
consumed.extend((weight_name, scale_name))
assert data is not None
new_name = self.format_tensor_name(tensor_key, bid)
shape = gguf.quant_shape_from_byte_shape(data.shape, gguf.GGMLQuantizationType.MXFP4)
logger.info(f"{new_name}: repacked routed experts to MXFP4, shape = {{{', '.join(str(n) for n in reversed(shape))}}}")
self.gguf_writer.add_tensor(new_name, data, raw_dtype=gguf.GGMLQuantizationType.MXFP4)
return consumed
def _write_hash_routing_tensors(self) -> list[str]:
consumed: list[str] = []
for bid in range(self.hparams["num_hash_layers"]):
name = f"layers.{bid}.ffn.gate.tid2eid"
if name not in self.model_tensors:
raise KeyError(f"Missing hash routing tensor {name}")
data_torch = LazyTorchTensor.to_eager(self.model_tensors[name]())
data = data_torch.to(torch.int32).cpu().numpy()
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_TID2EID, bid, ".weight")
logger.info(f"{new_name}: converted hash routing table to I32, shape = {{{', '.join(str(n) for n in reversed(data.shape))}}}")
self.gguf_writer.add_tensor(new_name, data)
consumed.append(name)
return consumed
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
if self._dsv4_mxfp4_generated:
return ()
consumed: list[str] = self._write_hash_routing_tensors()
for bid in range(self.block_count):
consumed.extend(self._write_mxfp4_expert_tensor(bid, "w1", gguf.MODEL_TENSOR.FFN_GATE_EXP))
consumed.extend(self._write_mxfp4_expert_tensor(bid, "w2", gguf.MODEL_TENSOR.FFN_DOWN_EXP))
consumed.extend(self._write_mxfp4_expert_tensor(bid, "w3", gguf.MODEL_TENSOR.FFN_UP_EXP))
for name in consumed:
del self.model_tensors[name]
self._dsv4_mxfp4_generated = True
return ()
def _format_dsv4_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None, suffix: str = ".weight") -> str:
return self.format_tensor_name(key, bid, suffix)
def _map_dsv4_tensor_name(self, name: str, bid: int | None) -> tuple[gguf.MODEL_TENSOR, str]:
root_map: dict[str, tuple[gguf.MODEL_TENSOR, str]] = {
"embed.weight": (gguf.MODEL_TENSOR.TOKEN_EMBD, ".weight"),
"norm.weight": (gguf.MODEL_TENSOR.OUTPUT_NORM, ".weight"),
"head.weight": (gguf.MODEL_TENSOR.OUTPUT, ".weight"),
"hc_head_fn": (gguf.MODEL_TENSOR.HC_HEAD_FN, ".weight"),
"hc_head_base": (gguf.MODEL_TENSOR.HC_HEAD_BASE, ".weight"),
"hc_head_scale": (gguf.MODEL_TENSOR.HC_HEAD_SCALE, ".weight"),
}
if name in root_map:
return root_map[name]
match = re.match(r"layers\.(\d+)\.(.+)$", name)
if match is None:
raise ValueError(f"Unsupported DeepSeek-V4 tensor {name!r}")
layer = int(match.group(1))
if bid != layer:
raise ValueError(f"Tensor {name!r} parsed bid {bid} but layer name has {layer}")
layer_map: dict[str, tuple[gguf.MODEL_TENSOR, str]] = {
"hc_attn_fn": (gguf.MODEL_TENSOR.HC_ATTN_FN, ".weight"),
"hc_attn_base": (gguf.MODEL_TENSOR.HC_ATTN_BASE, ".weight"),
"hc_attn_scale": (gguf.MODEL_TENSOR.HC_ATTN_SCALE, ".weight"),
"hc_ffn_fn": (gguf.MODEL_TENSOR.HC_FFN_FN, ".weight"),
"hc_ffn_base": (gguf.MODEL_TENSOR.HC_FFN_BASE, ".weight"),
"hc_ffn_scale": (gguf.MODEL_TENSOR.HC_FFN_SCALE, ".weight"),
"attn.attn_sink": (gguf.MODEL_TENSOR.ATTN_SINKS, ".weight"),
"attn.wq_a.weight": (gguf.MODEL_TENSOR.ATTN_Q_A, ".weight"),
"attn.wq_b.weight": (gguf.MODEL_TENSOR.ATTN_Q_B, ".weight"),
"attn.q_norm.weight": (gguf.MODEL_TENSOR.ATTN_Q_A_NORM, ".weight"),
"attn.wkv.weight": (gguf.MODEL_TENSOR.ATTN_KV, ".weight"),
"attn.kv_norm.weight": (gguf.MODEL_TENSOR.ATTN_KV_NORM, ".weight"),
"attn.wo_a.weight": (gguf.MODEL_TENSOR.ATTN_OUT_A, ".weight"),
"attn.wo_b.weight": (gguf.MODEL_TENSOR.ATTN_OUT_B, ".weight"),
"attn.compressor.ape": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_APE, ".weight"),
"attn.compressor.wkv.weight": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_WKV, ".weight"),
"attn.compressor.wgate.weight": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_WGATE, ".weight"),
"attn.compressor.norm.weight": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_NORM, ".weight"),
"attn.indexer.wq_b.weight": (gguf.MODEL_TENSOR.INDEXER_ATTN_Q_B, ".weight"),
"attn.indexer.weights_proj.weight": (gguf.MODEL_TENSOR.INDEXER_PROJ, ".weight"),
"attn.indexer.compressor.ape": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_APE, ".weight"),
"attn.indexer.compressor.wkv.weight": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_WKV, ".weight"),
"attn.indexer.compressor.wgate.weight": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_WGATE, ".weight"),
"attn.indexer.compressor.norm.weight": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_NORM, ".weight"),
"attn_norm.weight": (gguf.MODEL_TENSOR.ATTN_NORM, ".weight"),
"ffn_norm.weight": (gguf.MODEL_TENSOR.FFN_NORM, ".weight"),
"ffn.gate.weight": (gguf.MODEL_TENSOR.FFN_GATE_INP, ".weight"),
"ffn.gate.bias": (gguf.MODEL_TENSOR.FFN_EXP_PROBS_B, ".bias"),
"ffn.gate.tid2eid": (gguf.MODEL_TENSOR.FFN_GATE_TID2EID, ".weight"),
"ffn.shared_experts.w1.weight": (gguf.MODEL_TENSOR.FFN_GATE_SHEXP, ".weight"),
"ffn.shared_experts.w2.weight": (gguf.MODEL_TENSOR.FFN_DOWN_SHEXP, ".weight"),
"ffn.shared_experts.w3.weight": (gguf.MODEL_TENSOR.FFN_UP_SHEXP, ".weight"),
}
tensor_name = match.group(2)
if tensor_name in layer_map:
return layer_map[tensor_name]
if re.match(r"ffn\.experts\.\d+\.w[123]\.(weight|scale)$", tensor_name):
return gguf.MODEL_TENSOR.FFN_GATE_EXP, ".weight"
raise ValueError(f"Unsupported DeepSeek-V4 tensor {name!r}")
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if re.match(r"layers\.\d+\.ffn\.experts\.\d+\.w[123]\.(weight|scale)$", name):
return []
tensor_key, suffix = self._map_dsv4_tensor_name(name, bid)
if tensor_key == gguf.MODEL_TENSOR.FFN_GATE_TID2EID:
return []
return [(self._format_dsv4_tensor_name(tensor_key, bid, suffix), data_torch)]
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
del new_name, bid # unused
if name in self._dsv4_fp8_dequantized and n_dims >= 2:
return gguf.GGMLQuantizationType.Q8_0
if name in self._dsv4_f32_tensors:
return gguf.GGMLQuantizationType.F32
if name in self._dsv4_bf16_tensors and n_dims >= 2:
return gguf.GGMLQuantizationType.BF16
return False
def prepare_tensors(self):
super().prepare_tensors()
self._is_mxfp4 = True
self.ftype = gguf.LlamaFileType.MOSTLY_MXFP4_MOE
+3 -3
View File
@@ -73,7 +73,7 @@ class LlamaModel(TextModel):
target_num_layers = target_config["num_hidden_layers"]
target_layers = [2, target_num_layers // 2, target_num_layers - 3]
logger.info(f"EAGLE-3: target_layers = {target_layers} (target model has {target_num_layers} layers)")
self.gguf_writer.add_array(f"{self.gguf_writer.arch}.target_layers", target_layers)
self.gguf_writer.add_target_layers(target_layers)
# target_hidden_size: prefer eagle3 config, fallback to target config
if eagle3_raw_config.get("target_hidden_size") is not None:
@@ -83,12 +83,12 @@ class LlamaModel(TextModel):
target_hidden_size = target_config["hidden_size"]
src = "target model config"
logger.info(f"EAGLE-3: target_hidden_size = {target_hidden_size} (from {src})")
self.gguf_writer.add_uint32(f"{self.gguf_writer.arch}.target_hidden_size", target_hidden_size)
self.gguf_writer.add_target_hidden_size(target_hidden_size)
# norm_before_residual (RedHat-style eagle3 specific)
norm_before_residual = eagle3_raw_config.get("norm_before_residual", False)
logger.info(f"EAGLE-3: norm_before_residual = {norm_before_residual}")
self.gguf_writer.add_bool(f"{self.gguf_writer.arch}.norm_before_residual", norm_before_residual)
self.gguf_writer.add_norm_before_residual(norm_before_residual)
def set_vocab(self):
# eagle3: use tokenizer from target model if provided
+48
View File
@@ -625,3 +625,51 @@ class Qwen3_5TextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReor
@ModelBase.register("Qwen3_5MoeForConditionalGeneration", "Qwen3_5MoeForCausalLM")
class Qwen3_5MoeTextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReorderBase):
model_arch = gguf.MODEL_ARCH.QWEN35MOE
@ModelBase.register("DFlashDraftModel")
class DFlashModel(Qwen3Model):
model_arch = gguf.MODEL_ARCH.DFLASH
def set_vocab(self):
if self.target_model_dir is None:
raise ValueError(
"DFlash draft model requires --target-model-dir to be specified. "
"Please provide the path to the target model directory containing the tokenizer."
)
logger.info(f"DFlash: Using tokenizer from target model: {self.target_model_dir}")
original_dir = self.dir_model
self.dir_model = self.target_model_dir
super().set_vocab()
self.dir_model = original_dir
mask_token_id = self.hparams.get("dflash_config", {}).get("mask_token_id")
if mask_token_id is not None:
self.gguf_writer.add_mask_token_id(mask_token_id)
def set_gguf_parameters(self):
super().set_gguf_parameters()
block_size = self.hparams.get("block_size", 16)
self.gguf_writer.add_block_size(block_size)
dflash_config = self.hparams.get("dflash_config", {})
target_layer_ids = dflash_config.get("target_layer_ids", [])
if target_layer_ids:
extract_layer_ids = [i + 1 for i in target_layer_ids]
self.gguf_writer.add_target_layers(extract_layer_ids)
use_sliding_window = self.hparams.get("use_sliding_window", False)
sliding_window = self.hparams.get("sliding_window")
layer_types = self.hparams.get("layer_types")
if use_sliding_window and sliding_window and layer_types:
is_swa = [lt == "sliding_attention" for lt in layer_types]
self.gguf_writer.add_sliding_window(sliding_window)
self.gguf_writer.add_sliding_window_pattern(is_swa)
@classmethod
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
name, gen = item
if not name.startswith("model."):
name = "model." + name
return super().filter_tensors((name, gen))
+51 -39
View File
@@ -1,16 +1,26 @@
# llama.cpp for OpenCL
- [Background](#background)
- [OS](#os)
- [Hardware](#hardware)
- [DataType Supports](#datatype-supports)
- [Model Preparation](#model-preparation)
- [CMake Options](#cmake-options)
- [Android](#android)
- [Windows 11 Arm64](#windows-11-arm64)
- [Linux](#Linux)
- [Known Issue](#known-issues)
- [TODO](#todo)
- [llama.cpp for OpenCL](#llamacpp-for-opencl)
- [Background](#background)
- [Llama.cpp + OpenCL](#llamacpp--opencl)
- [OS](#os)
- [Hardware](#hardware)
- [Adreno GPU](#adreno-gpu)
- [DataType Supports](#datatype-supports)
- [Model Preparation](#model-preparation)
- [Binary Kernel Library](#binary-kernel-library)
- [CMake Options](#cmake-options)
- [Android](#android)
- [I. Setup Environment](#i-setup-environment)
- [II. Build llama.cpp](#ii-build-llamacpp)
- [Windows 11 Arm64](#windows-11-arm64)
- [I. Setup Environment](#i-setup-environment-1)
- [II. Build llama.cpp](#ii-build-llamacpp-1)
- [Linux](#linux)
- [I. Setup Environment](#i-setup-environment-2)
- [II. Build llama.cpp](#ii-build-llamacpp-2)
- [Known Issues](#known-issues)
- [TODO](#todo)
## Background
@@ -34,11 +44,13 @@ The llama.cpp OpenCL backend is designed to enable llama.cpp on **Qualcomm Adren
**Verified devices**
| Adreno GPU | Status |
|:------------------------------------:|:-------:|
| Adreno 750 (Snapdragon 8 Gen 3) | Support |
| Adreno 830 (Snapdragon 8 Elite) | Support |
| Adreno X85 (Snapdragon X Elite) | Support |
| Adreno GPU | Status |
|:-------------------------------------:|:-------:|
| Adreno 750 (Snapdragon 8 Gen 3) | Support |
| Adreno 830 (Snapdragon 8 Elite) | Support |
| Adreno 840 (Snapdragon 8 Elite Gen 5) | Support |
| Adreno X1-85 (Snapdragon X Elite) | Support |
| Adreno X2-90 (Snapdragon X2 Elite) | Support |
> A6x GPUs with a recent driver and compiler are supported; they are usually found in IoT platforms.
However, A6x GPUs in phones are likely not supported due to the outdated driver and compiler.
@@ -47,42 +59,43 @@ However, A6x GPUs in phones are likely not supported due to the outdated driver
| DataType | Status |
|:----------------------:|:--------------------------:|
| Q1_0 | Support |
| Q4_0 | Support |
| Q6_K | Support, but not optimized |
| Q4_1 | Support |
| Q5_0 | Support |
| Q5_1 | Support |
| Q8_0 | Support |
| Q4_K | Support |
| Q5_K | Support |
| Q6_K | Support |
| MXFP4 | Support |
| IQ4_NL | Support |
## Model Preparation
You can refer to the general [llama-quantize tool](/tools/quantize/README.md) for steps to convert a model in Hugging Face safetensor format to GGUF with quantization.
Since common quantizations are supported now, it is recommanded to download GGUF models directly from Huggingface.
Currently we support `Q4_0` quantization and have optimized for it. To achieve best performance on Adreno GPU, add `--pure` to `llama-quantize` (i.e., make all weights in `Q4_0`). For example,
## Binary Kernel Library
```sh
./llama-quantize --pure ggml-model-qwen2.5-3b-f16.gguf ggml-model-qwen-3b-Q4_0.gguf Q4_0
```
A prebuilt binary kernel library has been introduced for Adreno GPUs.
It currently targets X2 GPUs (X2-90, X2-85 and X2-45) found in Snapdragon X2 SoC.
The library currently contains kernels for MUL_MAT_ID with Q4_0, Q4_1, Q4_K, MXFP4.
The library must be manually downloaded from https://softwarecenter.qualcomm.com/catalog/item/Adreno_Kernel_Library_GGML.
Since `Q6_K` is also supported, `Q4_0` quantization without `--pure` will also work. However, the performance will be worse compared to pure `Q4_0` quantization.
To allow using the kernel library, add `-DGGML_OPENCL_USE_ADRENO_BIN_KERNELS=ON` when configuring with CMake.
Then, extract `adreno-opencl-kernels.dll` from the zip file downloaded from the above URL and put it alongside the executables.
If kernels compatible with the current GPU are found in the library, they will be loaded and used.
### `MXFP4` MoE Models
OpenAI gpt-oss models are MoE models in `MXFP4`. The quantized model will be in `MXFP4_MOE`, a mixture of `MXFP4` and `Q8_0`.
For this quantization, there is no need to specify `--pure`.
For gpt-oss-20b model, you can directly [download](https://huggingface.co/ggml-org/gpt-oss-20b-GGUF) the quantized GGUF file in `MXFP4_MOE` from Hugging Face.
Although it is possible to quantize gpt-oss-20b model in pure `Q4_0` (all weights in `Q4_0`), it is not recommended since `MXFP4` has been optimized for MoE while `Q4_0` is not. In addition, accuracy should degrade with such pure `Q4_0` quantization.
Hence, using the default `MXFP4_MOE` quantization (see the link above) is recommended for this model.
> Note that the `Q4_0` model found [here](https://huggingface.co/unsloth/gpt-oss-20b-GGUF/blob/main/gpt-oss-20b-Q4_0.gguf) is a mixture of `Q4_0`, `Q8_0` and `MXFP4` and gives better performance than `MXFP4_MOE` quantization.
## CMake Options
The OpenCL backend has the following CMake options that control the behavior of the backend.
| CMake options | Default value | Description |
|:---------------------------------:|:--------------:|:------------------------------------------|
| `GGML_OPENCL_EMBED_KERNELS` | `ON` | Embed OpenCL kernels into the executable. |
| `GGML_OPENCL_USE_ADRENO_KERNELS` | `ON` | Use kernels optimized for Adreno. |
| CMake options | Default value | Description |
|:------------------------------------:|:--------------:|:------------------------------------------|
| `GGML_OPENCL_EMBED_KERNELS` | `ON` | Embed OpenCL kernels into the executable. |
| `GGML_OPENCL_USE_ADRENO_KERNELS` | `ON` | Use kernels optimized for Adreno. |
| `GGML_OPENCL_USE_ADRENO_BIN_KERNELS` | `OFF` | Allow using binary kernel lib for Adreno. |
## Android
@@ -277,6 +290,5 @@ ninja
## TODO
- Optimization for Q6_K
- Support and optimization for Q4_K
- Improve flash attention
- Improve OpenCL C kernels performance
+28 -1
View File
@@ -52,6 +52,32 @@ Supported EAGLE-3 draft models include:
For the full and up-to-date list of supported models, see #18039.
### DFlash (`draft-dflash`)
DFlash produces an entire block of draft tokens in a single forward pass (block diffusion) and
injects the target model's hidden states into the draft model's attention, instead of drafting one
token at a time. This keeps the draft model small while making drafting GPU-friendly. Unlike EAGLE-3
(a single-layer autoregressive draft), the DFlash draft uses several transformer layers but emits a
whole block per draft step.
The draft is a small block-diffusion model trained for a specific target (for example
`z-lab/Qwen3-4B-DFlash` for `Qwen/Qwen3-4B`). Convert it with `--target-model-dir` so it inherits the
target's tokenizer and token embeddings:
```bash
python convert_hf_to_gguf.py z-lab/Qwen3-4B-DFlash \
--target-model-dir Qwen/Qwen3-4B --outtype bf16 --outfile Qwen3-4B-DFlash.gguf
llama-server -m Qwen3-4B.gguf -md Qwen3-4B-DFlash.gguf \
--spec-type draft-dflash --spec-draft-n-max 15 -fa on --jinja
```
`--spec-draft-n-max` is clamped to the draft model's trained block size.
See:
- #22105
### n-gram Cache (`ngram-cache`)
An n-gram is a sequence of n tokens. The n-gram cache implementation maintains statistics about short n-gram sequences.
@@ -147,7 +173,7 @@ If a draft model is combined with a draftless decoding the draftless decoding ha
### General Speculative Parameters
```
--spec-type [none|draft-simple|draft-eagle3|draft-mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]
--spec-type [none|draft-simple|draft-eagle3|draft-dflash|draft-mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]
comma-separated list of types of speculative decoding to use
(default: none)
(env: LLAMA_ARG_SPEC_TYPE)
@@ -287,6 +313,7 @@ Specifies a comma-separated list of speculative decoding types to use.
| `none` | No speculative decoding (default) |
| `draft-simple` | Use a simple draft model for speculation |
| `draft-eagle3` | Use an EAGLE-3 draft model that reads the target's hidden states |
| `draft-dflash` | Use a DFlash block-diffusion draft model that emits a block per step |
| `draft-mtp` | Use Multi Token Prediction (MTP) heads from the main model |
| `ngram-cache` | Use n-gram cache lookup |
| `ngram-simple` | Use simple n-gram pattern matching |
+3 -7
View File
@@ -1551,8 +1551,6 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
int split_backend_id = split->backend_id;
ggml_backend_t split_backend = sched->backends[split_backend_id];
ggml_backend_synchronize(split_backend);
// copy the input tensors to the split backend
for (int input_id = 0; input_id < split->n_inputs; input_id++) {
ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[input_id]);
@@ -1563,15 +1561,15 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
// inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);
} else if (!split_backend->iface.cpy_tensor_async) {
} else {
ggml_backend_synchronize(split_backend);
}
ggml_backend_tensor_copy_async(input_backend, split_backend, input, input_cpy);
ggml_backend_tensor_copy(input, input_cpy);
} else {
// wait for the split backend to finish using the input before overwriting it
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]);
} else if (!split_backend->iface.cpy_tensor_async) {
} else {
ggml_backend_synchronize(split_backend);
}
@@ -1676,8 +1674,6 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
}
}
ggml_backend_synchronize(split_backend);
if (!sched->callback_eval) {
enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph);
if (ec != GGML_STATUS_SUCCESS) {
+3 -2
View File
@@ -1111,11 +1111,12 @@ GGML_TABLE_BEGIN(int8_t, kvalues_iq4nl, 16)
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
GGML_TABLE_END()
// e2m1 values (doubled)
// e2m1 values (doubled), shared by MXFP4 and NVFP4
// ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
GGML_TABLE_BEGIN(int8_t, kvalues_mxfp4, 16)
GGML_TABLE_BEGIN(int8_t, kvalues_fp4, 16)
0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12,
GGML_TABLE_END()
#define kvalues_mxfp4 kvalues_fp4
#define NGRID_IQ1S 2048
#define IQ1S_DELTA 0.125f
-1
View File
@@ -82,7 +82,6 @@
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
// quants.c
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
+142 -4
View File
@@ -934,7 +934,7 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
#if defined __AVX2__
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_fp4);
const __m128i m4b = _mm_set1_epi8(0x0f);
const __m256i mone = _mm256_set1_epi16(1);
@@ -963,7 +963,7 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
#elif defined __AVX__
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_fp4);
const __m128i m4b = _mm_set1_epi8(0x0f);
__m256 accum = _mm256_setzero_ps();
@@ -993,14 +993,152 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
int sumi1 = 0;
int sumi2 = 0;
for (int j = 0; j < QK_MXFP4/2; ++j) {
sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
sumi1 += y[ib].qs[j + 0] * kvalues_fp4[x[ib].qs[j] & 0xf];
sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_fp4[x[ib].qs[j] >> 4];
}
sumf += d * (sumi1 + sumi2);
}
*s = sumf;
}
void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
assert(n % QK_NVFP4 == 0);
const block_nvfp4 * GGML_RESTRICT x = vx;
const block_q8_0 * GGML_RESTRICT y = vy;
const int nb = n / QK_NVFP4;
int ib = 0;
float sumf = 0;
#if defined(__AVX2__)
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_fp4);
const __m128i m4b = _mm_set1_epi8(0x0f);
const __m256i mone = _mm256_set1_epi16(1);
__m256 accum = _mm256_setzero_ps();
for(; ib < nb; ib++){
const __m128i q4bits_01 = _mm_loadu_si128((const __m128i *)(x[ib].qs + 0));
const __m128i q4bits_23 = _mm_loadu_si128((const __m128i *)(x[ib].qs + 16));
const __m256i q8_01 = _mm256_loadu_si256((const __m256i *)y[2*ib + 0].qs);
const __m256i q8_23 = _mm256_loadu_si256((const __m256i *)y[2*ib + 1].qs);
const __m128i q4_01_lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_01, m4b));
const __m128i q4_01_hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_01, 4), m4b));
const __m128i q4_23_lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_23, m4b));
const __m128i q4_23_hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_23, 4), m4b));
//reordering
const __m256i q4_01 = MM256_SET_M128I(_mm_unpackhi_epi64(q4_01_lo,q4_01_hi), _mm_unpacklo_epi64(q4_01_lo,q4_01_hi));
const __m256i q4_23 = MM256_SET_M128I(_mm_unpackhi_epi64(q4_23_lo,q4_23_hi),_mm_unpacklo_epi64(q4_23_lo,q4_23_hi));
const __m256i p01 = mul_add_epi8(q4_01,q8_01);
const __m256i p_1 = _mm256_madd_epi16(p01, mone);
const __m256i p23 = mul_add_epi8(q4_23,q8_23);
const __m256i p_2 = _mm256_madd_epi16(p23, mone);
const float dy0 = GGML_CPU_FP16_TO_FP32(y[2*ib].d);
const float dy1 = GGML_CPU_FP16_TO_FP32(y[2*ib+1].d);
const float s0 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[0]) * dy0;
const float s1 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[1]) * dy0;
const float s2 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[2]) * dy1;
const float s3 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[3]) * dy1;
const __m256 scales01 = _mm256_set_m128(_mm_set1_ps(s1), _mm_set1_ps(s0));
const __m256 scales23 = _mm256_set_m128(_mm_set1_ps(s3), _mm_set1_ps(s2));
accum = _mm256_fmadd_ps(scales01, _mm256_cvtepi32_ps(p_1), accum);
accum = _mm256_fmadd_ps(scales23, _mm256_cvtepi32_ps(p_2), accum);
}
sumf = hsum_float_8(accum);
#elif defined(__AVX__)
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_fp4);
const __m128i m4b = _mm_set1_epi8(0x0f);
__m256 accum = _mm256_setzero_ps();
for(; ib < nb; ib++){
const __m128i q4bits_01 = _mm_loadu_si128((const __m128i *)(x[ib].qs + 0));
const __m128i q4bits_23 = _mm_loadu_si128((const __m128i *)(x[ib].qs + 16));
const __m128i q8_0 = _mm_loadu_si128((const __m128i *)(y[2*ib + 0].qs + 0));
const __m128i q8_1 = _mm_loadu_si128((const __m128i *)(y[2*ib + 0].qs + 16));
const __m128i q8_2 = _mm_loadu_si128((const __m128i *)(y[2*ib + 1].qs + 0));
const __m128i q8_3 = _mm_loadu_si128((const __m128i *)(y[2*ib + 1].qs + 16));
const __m128i q4_01_lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_01, m4b));
const __m128i q4_01_hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_01, 4), m4b));
const __m128i q4_23_lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_23, m4b));
const __m128i q4_23_hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_23, 4), m4b));
const __m128i q4_0 = _mm_unpacklo_epi64(q4_01_lo, q4_01_hi);
const __m128i q4_1 = _mm_unpackhi_epi64(q4_01_lo, q4_01_hi);
const __m128i q4_2 = _mm_unpacklo_epi64(q4_23_lo, q4_23_hi);
const __m128i q4_3 = _mm_unpackhi_epi64(q4_23_lo, q4_23_hi);
const __m128i p0_i32 = mul_sum_i8_pairs(q4_0, q8_0);
const __m128i p1_i32 = mul_sum_i8_pairs(q4_1, q8_1);
const __m128i p2_i32 = mul_sum_i8_pairs(q4_2, q8_2);
const __m128i p3_i32 = mul_sum_i8_pairs(q4_3, q8_3);
const __m128 p0 = _mm_cvtepi32_ps(p0_i32);
const __m128 p1 = _mm_cvtepi32_ps(p1_i32);
const __m128 p2 = _mm_cvtepi32_ps(p2_i32);
const __m128 p3 = _mm_cvtepi32_ps(p3_i32);
const __m256 p01 = _mm256_set_m128(p1, p0);
const __m256 p23 = _mm256_set_m128(p3, p2);
const float dy0 = GGML_CPU_FP16_TO_FP32(y[2*ib].d);
const float dy1 = GGML_CPU_FP16_TO_FP32(y[2*ib+1].d);
const float s0 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[0]) * dy0;
const float s1 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[1]) * dy0;
const float s2 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[2]) * dy1;
const float s3 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[3]) * dy1;
const __m256 scales01 = _mm256_set_m128(_mm_set1_ps(s1), _mm_set1_ps(s0));
const __m256 scales23 = _mm256_set_m128(_mm_set1_ps(s3), _mm_set1_ps(s2));
accum = _mm256_add_ps(accum, _mm256_mul_ps(p01, scales01));
accum = _mm256_add_ps(accum, _mm256_mul_ps(p23, scales23));
}
sumf = hsum_float_8(accum);
#endif
for (;ib < nb; ++ib) {
for (int s_idx = 0; s_idx < 4; ++s_idx) {
const float d = GGML_CPU_UE4M3_TO_FP32(x[ib].d[s_idx]);
const int q8_block = s_idx / 2;
const int q8_off = (s_idx % 2) * QK_NVFP4_SUB;
const float dy = GGML_CPU_FP16_TO_FP32(y[2*ib + q8_block].d);
int sumi_lo = 0, sumi_hi = 0;
for (int j = 0; j < QK_NVFP4_SUB/2; ++j) {
const uint8_t qv = x[ib].qs[s_idx*(QK_NVFP4_SUB/2) + j];
sumi_lo += y[2*ib + q8_block].qs[q8_off + j + 0] * kvalues_fp4[qv & 0xf];
sumi_hi += y[2*ib + q8_block].qs[q8_off + j + QK_NVFP4_SUB/2] * kvalues_fp4[qv >> 4];
}
sumf += dy * d * (sumi_lo + sumi_hi);
}
}
*s = sumf;
}
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
+8
View File
@@ -82,6 +82,9 @@ float ggml_table_f32_f16[1 << 16];
// precomputed f32 table for e8m0 half (1 KB) (simd-mappings.h)
float ggml_table_f32_e8m0_half[1 << 8];
// precomputed f32 table for ue4m3 (1 KB) (simd-mappings.h)
float ggml_table_f32_ue4m3[1 << 8];
#if defined(__ARM_ARCH)
struct ggml_arm_arch_features_type {
int sve_cnt;
@@ -3798,6 +3801,11 @@ void ggml_cpu_init(void) {
ggml_table_f32_e8m0_half[i] = GGML_E8M0_TO_FP32_HALF(i);
}
// initialize UE4M3 table (256 entries)
for (int i = 0; i < (1 << 8); ++i) {
ggml_table_f32_ue4m3[i] = ggml_ue4m3_to_fp32(i);
}
const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0);
+11
View File
@@ -120,6 +120,10 @@ extern float ggml_table_f32_f16[1 << 16];
// defined in ggml-cpu.c, initialized in ggml_cpu_init()
extern float ggml_table_f32_e8m0_half[1 << 8];
// precomputed f32 table for ue4m3 (1 KB)
// defined in ggml-cpu.c, initialized in ggml_cpu_init()
extern float ggml_table_f32_ue4m3[1 << 8];
// Use lookup table for E8M0 on x86 (faster than bit manipulation)
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
#define GGML_CPU_E8M0_TO_FP32_HALF(x) ggml_table_f32_e8m0_half[(uint8_t)(x)]
@@ -127,6 +131,13 @@ extern float ggml_table_f32_e8m0_half[1 << 8];
#define GGML_CPU_E8M0_TO_FP32_HALF(x) GGML_E8M0_TO_FP32_HALF(x)
#endif
// Use lookup table for UE4M3 on x86 (faster than bit manipulation)
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
#define GGML_CPU_UE4M3_TO_FP32(x) ggml_table_f32_ue4m3[(uint8_t)(x)]
#else
#define GGML_CPU_UE4M3_TO_FP32(x) ggml_ue4m3_to_fp32(x)
#endif
// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
// so we define GGML_CPU_FP16_TO_FP32 and GGML_CPU_FP32_TO_FP16 elsewhere for NEON.
// This is also true for POWER9.
+45
View File
@@ -386,6 +386,46 @@ static void ggml_cpy_f32_iq4_nl_cuda(
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
// check if a same-type copy reduces to a 2D strided copy (height rows of width
// contiguous bytes), so it can use cudaMemcpy2DAsync instead of the scalar kernel
static bool ggml_cuda_cpy_as_memcpy_2d(const ggml_tensor * src0, const ggml_tensor * src1,
size_t & width, size_t & height, size_t & spitch, size_t & dpitch) {
// require matching shape: a reshaped copy maps elements by flat order, which the
// prefix walk below does not handle
if (src0->type != src1->type || !ggml_are_same_shape(src0, src1)) {
return false;
}
// grow the contiguous prefix block shared by both tensors
size_t block_nb = ggml_element_size(src0);
int d = 0;
for (; d < GGML_MAX_DIMS; ++d) {
if (src0->nb[d] != block_nb || src1->nb[d] != block_nb) {
break;
}
block_nb *= src0->ne[d];
}
// d == 0: nothing contiguous; d == GGML_MAX_DIMS: fully contiguous (handled by memcpy)
if (d == 0 || d == GGML_MAX_DIMS) {
return false;
}
// dim d carries the rows; everything above it must be a single element
for (int i = d + 1; i < GGML_MAX_DIMS; ++i) {
if (src0->ne[i] != 1) {
return false;
}
}
width = block_nb;
height = src0->ne[d];
spitch = src0->nb[d];
dpitch = src1->nb[d];
return spitch >= width && dpitch >= width;
}
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1));
@@ -421,6 +461,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) &&
src0->ne[3] == 1 && nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0);
size_t mc_width = 0, mc_height = 0, mc_spitch = 0, mc_dpitch = 0;
if (src0->type == src1->type && contiguous_srcs) {
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
@@ -431,6 +473,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
{
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
}
} else if (ggml_cuda_cpy_as_memcpy_2d(src0, src1, mc_width, mc_height, mc_spitch, mc_dpitch)) {
CUDA_CHECK(cudaMemcpy2DAsync(src1_ddc, mc_dpitch, src0_ddc, mc_spitch,
mc_width, mc_height, cudaMemcpyDeviceToDevice, main_stream));
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
if (can_be_transposed) {
ggml_cpy_scalar_cuda<float, float, true>
+9 -5
View File
@@ -664,7 +664,10 @@ constexpr __device__ dequantize_V_t get_dequantize_V() {
template <int ncols1>
__launch_bounds__(FATTN_KQ_STRIDE/2, 1)
static __global__ void flash_attn_mask_to_KV_max(
const half2 * __restrict__ mask, int * __restrict__ KV_max, const int ne30, const int s31, const int s33) {
const half2 * mask_ptr, int * KV_max_ptr, const int ne30, const int64_t s31, const int64_t s33) {
const half2 * GGML_CUDA_RESTRICT mask = mask_ptr;
int * GGML_CUDA_RESTRICT KV_max = KV_max_ptr;
const int ne31 = gridDim.x;
const int tid = threadIdx.x;
const int sequence = blockIdx.y;
@@ -1089,8 +1092,8 @@ void launch_fattn(
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
// multiple sequences of possibly different lengths.
if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
const int s31 = mask->nb[1] / sizeof(half2);
const int s33 = mask->nb[3] / sizeof(half2);
const int64_t s31 = mask->nb[1] / sizeof(half2);
const int64_t s33 = mask->nb[3] / sizeof(half2);
const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1);
@@ -1099,8 +1102,9 @@ void launch_fattn(
const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
KV_max.alloc(ne_KV_max);
flash_attn_mask_to_KV_max<ncols1><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
((const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33);
ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num_KV_max, block_dim_KV_max, 0, main_stream);
ggml_cuda_kernel_launch(flash_attn_mask_to_KV_max<ncols1>, launch_params,
(const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33);
CUDA_CHECK(cudaGetLastError());
}
+4
View File
@@ -2003,6 +2003,10 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 2);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 2);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 16, 2);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 32, 2);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4);
+9 -5
View File
@@ -76,6 +76,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 2, 64, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
@@ -144,6 +145,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 16, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 2, 64, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 32, 64)
@@ -219,6 +221,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 512, 1, 128, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 2, 64, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
@@ -296,6 +299,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 128, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 2, 64, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 4, 64, 64)
@@ -1308,12 +1312,12 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
return;
}
if constexpr (DV <= 256) {
if (use_gqa_opt && gqa_ratio % 2 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio % 2 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
return;
}
if constexpr (DV <= 256) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst);
return;
}
+5 -5
View File
@@ -99,12 +99,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
return;
}
if constexpr (DKQ <= 256) {
if (use_gqa_opt && gqa_ratio > 1) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio > 1) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
return;
}
if constexpr (DKQ <= 256) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
} else {
GGML_ABORT("fatal error");
+17 -14
View File
@@ -78,26 +78,29 @@ static __global__ void k_get_rows_float(
template<typename grad_t, typename dst_t>
static __global__ void k_get_rows_back_float(
const grad_t * __restrict__ grad, const int32_t * __restrict__ rows, dst_t * __restrict__ dst, const int64_t ncols, const int64_t nrows_grad) {
const grad_t * __restrict__ grad, const int32_t * __restrict__ rows, dst_t * __restrict__ dst,
const int64_t ncols, const int64_t nrows_grad, const int64_t nrows_dst) {
const int col = blockIdx.x*blockDim.x + threadIdx.x;
if (col >= ncols) {
return;
}
const int dst_row = blockIdx.y*blockDim.y + threadIdx.y;
float sum = 0.0f;
ggml_cuda_pdl_sync();
for (int64_t i = 0; i < nrows_grad; ++i) {
if (rows[i] != dst_row) {
continue;
}
sum += grad[i*ncols + col];
}
dst[dst_row*ncols + col] = sum;
// grid.y is clamped to the CUDA grid limit, so stride over the destination rows
for (int64_t dst_row = blockIdx.y; dst_row < nrows_dst; dst_row += gridDim.y) {
float sum = 0.0f;
for (int64_t i = 0; i < nrows_grad; ++i) {
if (rows[i] != dst_row) {
continue;
}
sum += grad[i*ncols + col];
}
dst[dst_row*ncols + col] = sum;
}
}
template<int qk, int qr, dequantize_kernel_t dq, typename dst_t>
@@ -302,7 +305,7 @@ void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * d
const dim3 block_dims(CUDA_GET_ROWS_BACK_BLOCK_SIZE, 1, 1);
const int block_num_x = (ne00 + CUDA_GET_ROWS_BACK_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BACK_BLOCK_SIZE;
const dim3 block_nums(block_num_x, ne1, 1);
const dim3 block_nums(block_num_x, MIN(ne1, (int64_t)UINT16_MAX), 1);
k_get_rows_back_float<<<block_nums, block_dims, 0, stream>>>(src0_d, src1_d, dst_d, ne00, ne10);
k_get_rows_back_float<<<block_nums, block_dims, 0, stream>>>(src0_d, src1_d, dst_d, ne00, ne10, ne1);
}
+4 -20
View File
@@ -3192,24 +3192,11 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;
ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;
// Enables async copies from CPU to CUDA, instead of only CUDA-to-CUDA
// Excluding this path for HIP and MUSA as a precaution.
// According to the summary in https://github.com/ggml-org/llama.cpp/pull/20793#issuecomment-4275794315, this change is not beneficial for hip anyways.
// Additionally, there is a lot of anectodal evidence that hip/musa stream behavior might not always 1:1 match CUDA behavior.
// e.g. https://github.com/ROCm/rocm-systems/issues/5109
// It thus makes sense to exclude this path for HIP and MUSA. This PR was not aimed these backends, the majority of testing happened on CUDA.
// This can be revisited in the future if enabling copy_from_host benefits hip/MUSA, and if the PR author can extensively test on these backends.
#if defined(GGML_USE_HIP) || defined(GGML_USE_MUSA)
const bool copy_from_host = false;
#else
const bool copy_from_host = ggml_backend_buffer_is_host(buf_src) && ggml_backend_dev_type(backend_src->device) == GGML_BACKEND_DEVICE_TYPE_CPU;
#endif
if (!(copy_from_host || ggml_backend_is_cuda(backend_src)) || !ggml_backend_is_cuda(backend_dst)) {
if (!ggml_backend_is_cuda(backend_src) || !ggml_backend_is_cuda(backend_dst)) {
return false;
}
if (!(copy_from_host || ggml_backend_buffer_is_cuda(buf_src)) || !ggml_backend_buffer_is_cuda(buf_dst)) {
if (!ggml_backend_buffer_is_cuda(buf_src) || !ggml_backend_buffer_is_cuda(buf_dst)) {
return false;
}
@@ -3220,17 +3207,14 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *) buf_src->context;
ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *) buf_dst->context;
if ((copy_from_host && cuda_ctx_dst->device != buf_ctx_dst->device) ||
!copy_from_host && (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device)) {
if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) {
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__);
#endif // NDEBUG
return false;
}
if (copy_from_host) {
CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyHostToDevice, cuda_ctx_dst->stream()));
} else if (backend_src != backend_dst) {
if (backend_src != backend_dst) {
// copy on src stream
if (cuda_ctx_src->device == cuda_ctx_dst->device) {
CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));
+7
View File
@@ -368,5 +368,12 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
return true;
}
// gfx900 (Vega 10) lacks native dp4a, loses to dequant + hipBLAS
// for dense matrices; keep MMQ only for MoE, where the
// hipBLAS path is much slower.
if (cc == GGML_CUDA_CC_VEGA) {
return n_experts > 0;
}
return (!GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 2);
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 2);
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 2);
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 2);
DECL_FATTN_MMA_F16_CASE(512, 512, 16, 2);
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 32, 2);
DECL_FATTN_MMA_F16_CASE(112, 112, 32, 2);
DECL_FATTN_MMA_F16_CASE(128, 128, 32, 2);
DECL_FATTN_MMA_F16_CASE(256, 256, 32, 2);
DECL_FATTN_MMA_F16_CASE(512, 512, 32, 2);
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 2);
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 2);
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 2);
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 2);
DECL_FATTN_MMA_F16_CASE(512, 512, 4, 2);
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 2);
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 2);
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 2);
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 2);
DECL_FATTN_MMA_F16_CASE(512, 512, 8, 2);
@@ -92,7 +92,7 @@ for ncols in [8, 16, 32, 64]:
continue
if head_size_kq == 320 and ncols2 != 32: # Mistral Small 4
continue
if head_size_kq == 512 and ncols2 not in (4, 8): # Gemma 4
if head_size_kq == 512 and ncols2 not in (2, 4, 8): # Gemma 4 (+ MTP)
continue
if head_size_kq == 576 and ncols2 not in (4, 16, 32): # Deepseek, GLM 4.7 Flash
continue
-1
View File
@@ -23,7 +23,6 @@ include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
include(ExternalProject)
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
option(GGML_HEXAGON_FA_EXP2_HF "ggml-hexagon: use FP16 exp2 polynomial in FA softmax instead of F32 exp round-trip" OFF)
set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate")
add_library(htp_iface OBJECT
+221 -26
View File
@@ -43,6 +43,7 @@
#include "htp-opnode.h"
#include "htp-ops.h"
#include "htp/matmul-ops.h"
#include "htp/flash-attn-ops.h"
#include "htp_iface.h"
#include "htp-drv.h"
@@ -62,6 +63,7 @@ static int opt_profile = 0; // profiling mode (0-disabled, 1-basic, 2-pmu)
static int opt_hostbuf = 1; // hostbuf ON by default
static int opt_mm_select = 3; // 3 = HMX -> Tiled -> Flat -> CPU, 2 = Tiled -> Flat -> CPU, 1 = Flat -> CPU
static int opt_fa_select = 2; // 2 = HMX -> HVX -> CPU, 1 = HVX -> CPU, 0 = CPU (unsupported)
// Default PMU events, if profiling with PMU (mode=2) is enabled
// See https://docs.qualcomm.com/doc/80-N2040-60/topic/pmu-events.html
@@ -125,6 +127,11 @@ static const char * htp_event_name(uint16_t id) {
case HTP_TRACE_EVT_HVX_W_DEQUANT: return "HVX_W_DEQUANT";
case HTP_TRACE_EVT_HVX_W_PREP: return "HVX_W_PREP";
case HTP_TRACE_EVT_HVX_O_PROC: return "HVX_O_PROC";
case HTP_TRACE_EVT_HVX_FA_QK: return "HVX_QK_FA";
case HTP_TRACE_EVT_HVX_FA_SFM: return "HVX_SFM_FA";
case HTP_TRACE_EVT_HVX_FA_Q_PREP: return "HVX_Q_PREP";
case HTP_TRACE_EVT_HVX_FA_K_PREP: return "HVX_K_PREP";
case HTP_TRACE_EVT_HVX_FA_V_PREP: return "HVX_V_PREP";
case HTP_TRACE_EVT_HMX_COMP: return "HMX_COMP";
default: return "UNKNOWN";
}
@@ -1879,6 +1886,162 @@ ggml_hexagon_session::~ggml_hexagon_session() noexcept(true) {
// ** backend interface
static bool ggml_hexagon_flash_attn_is_hmx_eligible(
const struct ggml_hexagon_session * sess,
const struct ggml_tensor * q,
const struct ggml_tensor * k,
const struct ggml_tensor * v,
const struct ggml_tensor * sinks
) {
if (sess->n_hmx == 0) {
return false;
}
if (opt_fa_select < 2) {
return false;
}
if (k->type != GGML_TYPE_F16 || v->type != GGML_TYPE_F16) {
return false;
}
const uint32_t DK = q->ne[0];
const uint32_t DV = v->ne[0];
if (DK % 64 != 0 || DV % 64 != 0) {
return false;
}
// Fall back to HVX for small token counts if head dimension is small (DK <= 128)
const uint32_t neq1 = q->ne[1];
if (DK <= 128 && neq1 < 5) {
return false;
}
return true;
}
static bool ggml_hexagon_precompute_flash_attn_params(
const struct ggml_hexagon_session * sess,
const struct ggml_tensor * op,
struct htp_fa_kernel_params * kparams
) {
if (opt_fa_select < 1) {
return false;
}
memset(kparams, 0, sizeof(*kparams));
const struct ggml_tensor * q = op->src[0];
const struct ggml_tensor * k = op->src[1];
const struct ggml_tensor * v = op->src[2];
const struct ggml_tensor * mask = op->src[3];
const struct ggml_tensor * dst = op;
const uint32_t neq0 = q->ne[0]; // head_dim (DK)
const uint32_t neq1 = q->ne[1]; // n_tokens
const uint32_t neq2 = q->ne[2]; // n_heads
const uint32_t nek1 = k->ne[1]; // kv_len
const uint32_t nev0 = v->ne[0]; // head_dim (DV)
const uint32_t DK = neq0;
const uint32_t DV = nev0;
const uint32_t n_kv_heads = k->ne[2];
const uint32_t G = neq2 / n_kv_heads;
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
memcpy(&scale, &op->op_params[0], sizeof(float));
memcpy(&max_bias, &op->op_params[1], sizeof(float));
memcpy(&logit_softcap, &op->op_params[2], sizeof(float));
if (logit_softcap != 0.0f) {
scale /= logit_softcap;
}
kparams->scale = scale;
kparams->max_bias = max_bias;
kparams->logit_softcap = logit_softcap;
kparams->is_q_fp32 = (q->type == GGML_TYPE_F32) ? 1 : 0;
kparams->is_dst_fp32 = (dst->type == GGML_TYPE_F32) ? 1 : 0;
kparams->G = G;
const uint32_t n_head = q->ne[2];
kparams->n_head_log2 = 1u << (uint32_t) std::floor(std::log2(n_head));
kparams->m0 = std::pow(2.0f, -(max_bias) / kparams->n_head_log2);
kparams->m1 = std::pow(2.0f, -(max_bias / 2.0f) / kparams->n_head_log2);
// Check HMX eligibility
const struct ggml_tensor * sinks = op->src[4];
if (ggml_hexagon_flash_attn_is_hmx_eligible(sess, q, k, v, sinks)) {
size_t Br = 0, Bc = 0;
int ret = hmx_fa_find_chunk_size(&Br, &Bc, G, DK, DV, neq1, nek1, sess->vtcm_size, sess->n_threads);
if (ret == 0) {
kparams->kernel_type = HTP_FA_KERNEL_HMX;
kparams->Br = Br;
kparams->Bc = Bc;
kparams->n_kv_blocks = (nek1 + Bc - 1) / Bc;
kparams->n_threads = (kparams->n_kv_blocks >= 3 && sess->n_threads >= 2) ? sess->n_threads : 1;
kparams->u.hmx.g_br = hex_align_up(G * Br, 32);
kparams->u.hmx.pipeline = (kparams->n_kv_blocks >= 3 && sess->n_threads >= 2) ? 1 : 0;
kparams->vtcm_size = hmx_fa_compute_vtcm_usage(G, DK, DV, Br, Bc, kparams->n_threads, kparams->u.hmx.pipeline != 0);
const size_t row_vec_bytes = hex_align_up(Bc * sizeof(uint16_t), 256);
kparams->u.hmx.row_buf_stride = row_vec_bytes / 128; // HVX vector is 128 bytes
const size_t m_line_bytes = hex_align_up(Bc * sizeof(uint16_t), 128);
kparams->u.hmx.mask_buf_row_stride = m_line_bytes / sizeof(uint16_t);
kparams->u.hmx.mask_broadcast = (mask != nullptr && mask->ne[2] == 1) ? 1 : 0;
kparams->u.hmx.div_G = init_fastdiv_values(G);
if (mask) {
kparams->src3_div2 = init_fastdiv_values(mask->ne[2]);
kparams->src3_div3 = init_fastdiv_values(mask->ne[3]);
}
kparams->qrows = 0;
kparams->qrows_per_thread = 0;
return true;
}
}
// Fallback to HVX
kparams->kernel_type = HTP_FA_KERNEL_HVX;
kparams->Br = 1;
kparams->Bc = 64; // FLASH_ATTN_BLOCK_SIZE
kparams->n_kv_blocks = (k->ne[1] + 64 - 1) / 64;
kparams->n_threads = sess->n_threads;
const size_t size_q_row_padded = hex_round_up(q->ne[0] * (kparams->is_q_fp32 ? 4 : 2), 128);
const size_t size_k_row_padded = hex_round_up(k->ne[0] * 2, 128);
const size_t size_v_row_padded = hex_round_up(v->ne[0] * 2, 128);
kparams->vtcm_size = hvx_fa_compute_vtcm_usage(DK, DV, kparams->is_q_fp32 != 0, mask != nullptr, sess->n_threads);
kparams->u.hvx.size_q_row_padded = size_q_row_padded;
kparams->u.hvx.size_k_row_padded = size_k_row_padded;
kparams->u.hvx.size_v_row_padded = size_v_row_padded;
kparams->u.hvx.src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
kparams->u.hvx.src0_div1 = init_fastdiv_values(q->ne[1]);
kparams->u.hvx.broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
kparams->u.hvx.broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
kparams->u.hvx.broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
kparams->u.hvx.broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
if (mask) {
kparams->src3_div2 = init_fastdiv_values(mask->ne[2]);
kparams->src3_div3 = init_fastdiv_values(mask->ne[3]);
}
kparams->qrows = q->ne[1] * q->ne[2] * q->ne[3];
kparams->qrows_per_thread = (kparams->qrows + sess->n_threads - 1) / sess->n_threads;
return true;
}
static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const struct ggml_tensor * src0 = op->src[0];
@@ -1912,6 +2075,17 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess
return false;
}
struct htp_fa_kernel_params kparams;
if (!ggml_hexagon_precompute_flash_attn_params(sess, op, &kparams)) {
return false;
}
if ((size_t) kparams.vtcm_size > sess->vtcm_size) {
HEX_VERBOSE("ggml-hex: skip flash_attn_ext because VTCM needed (%d) > budget (%zu)\n",
kparams.vtcm_size, sess->vtcm_size);
return false;
}
return true;
}
@@ -2211,14 +2385,14 @@ static void ggml_hexagon_precompute_hvx_mm_params(
kparams->kernel_type = (src1_nrows < (int) sess->n_threads) ? HTP_MM_KERNEL_HVX_QUANT_BLOCK : HTP_MM_KERNEL_HVX_QUANT_ROW;
kparams->src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10);
size_t vtcm_src0_size = 0, vtcm_src1_size = 0;
size_t vtcm_src0_size = 0, vtcm_src1_size = 0, vtcm_dst_size = 0;
uint32_t max_prefetch = (src1_nrows > HTP_MM_HMX_MIN_NROWS) ? 2 : 16;
uint32_t best_n_prefetch = 2;
size_t total_size = 0;
for (uint32_t d = max_prefetch; d >= 2; d /= 2) {
total_size = htp_mm_hvx_id_get_vtcm_sizes(
wtype, ne10, src1_nrows, sess->n_threads, src0->nb[1], d,
&vtcm_src0_size, &vtcm_src1_size
&vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size
);
if (total_size <= vtcm_budget) {
best_n_prefetch = d;
@@ -2228,14 +2402,14 @@ static void ggml_hexagon_precompute_hvx_mm_params(
if (best_n_prefetch == 2 && total_size > vtcm_budget) {
total_size = htp_mm_hvx_id_get_vtcm_sizes(
wtype, ne10, src1_nrows, sess->n_threads, src0->nb[1], 2,
&vtcm_src0_size, &vtcm_src1_size
&vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size
);
}
kparams->n_prefetch = best_n_prefetch;
kparams->vtcm_size = total_size;
kparams->vtcm_src0_size = vtcm_src0_size;
kparams->vtcm_src1_size = vtcm_src1_size;
kparams->vtcm_dst_size = 0;
kparams->vtcm_dst_size = vtcm_dst_size;
} else {
bool try_tiled = (k_align && opt_mm_select >= 2);
if (try_tiled) {
@@ -2441,11 +2615,12 @@ static void ggml_hexagon_precompute_fused_qkv_params(
size_t src3_sz_per_thread = 0;
uint32_t best_n_prefetch = 16;
size_t quant_scratch_size = hex_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float)) * sess->n_threads;
if (is_repack) {
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
uint32_t n_k_tiles = hex_round_up(ne10, 32) / 32;
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float));
size_t src1_sz_per_thread = hex_round_up(src1_row_size * src1_nrows, 128);
size_t src1_sz = src1_sz_per_thread;
@@ -2453,13 +2628,10 @@ static void ggml_hexagon_precompute_fused_qkv_params(
best_n_prefetch = 2;
for (uint32_t d = max_prefetch; d >= 2; d /= 2) {
size_t repacked_vtcm_size = hex_round_up(d * tile_row_size, 128);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
size_t src0_sz = repacked_vtcm_size * sess->n_threads;
size_t src2_sz = hex_round_up(d * tile_row_size, 128) * sess->n_threads;
size_t src3_sz = hex_round_up(d * tile_row_size, 128) * sess->n_threads;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz + quant_scratch_size;
if (tiled_vtcm_size <= sess->vtcm_size) {
best_n_prefetch = d;
@@ -2471,9 +2643,6 @@ static void ggml_hexagon_precompute_fused_qkv_params(
}
if (best_n_prefetch == 2 && src0_sz_per_thread == 0) {
size_t repacked_vtcm_size = hex_round_up(2 * tile_row_size, 128);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
src0_sz_per_thread = repacked_vtcm_size;
src2_sz_per_thread = hex_round_up(2 * tile_row_size, 128);
src3_sz_per_thread = hex_round_up(2 * tile_row_size, 128);
@@ -2492,7 +2661,7 @@ static void ggml_hexagon_precompute_fused_qkv_params(
size_t src2_sz = src2_sz_per_thread * sess->n_threads;
size_t src3_sz = src3_sz_per_thread * sess->n_threads;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz + quant_scratch_size;
bool try_tiled = (opt_mm_select >= 2);
if (try_tiled && tiled_vtcm_size <= sess->vtcm_size) {
kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW;
@@ -2500,6 +2669,7 @@ static void ggml_hexagon_precompute_fused_qkv_params(
kparams->vtcm_src1_size = src1_sz;
kparams->vtcm_src2_size = src2_sz;
kparams->vtcm_src3_size = src3_sz;
kparams->vtcm_dst_size = quant_scratch_size;
kparams->vtcm_size = tiled_vtcm_size;
kparams->n_prefetch = best_n_prefetch;
} else {
@@ -2510,7 +2680,8 @@ static void ggml_hexagon_precompute_fused_qkv_params(
kparams->vtcm_src1_size = flat_src1_sz;
kparams->vtcm_src2_size = src2_sz;
kparams->vtcm_src3_size = src3_sz;
kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz + src3_sz;
kparams->vtcm_dst_size = quant_scratch_size;
kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz + src3_sz + quant_scratch_size;
kparams->n_prefetch = best_n_prefetch;
}
}
@@ -2536,11 +2707,12 @@ static void ggml_hexagon_precompute_fused_ffn_params(
size_t src2_sz_per_thread = 0;
uint32_t best_n_prefetch = 16;
size_t quant_scratch_size = hex_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float)) * sess->n_threads;
if (is_repack) {
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
uint32_t n_k_tiles = hex_round_up(ne10, 32) / 32;
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float));
size_t src1_sz_per_thread = hex_round_up(src1_row_size * src1_nrows, 128);
size_t src1_sz = src1_sz_per_thread;
@@ -2548,12 +2720,9 @@ static void ggml_hexagon_precompute_fused_ffn_params(
best_n_prefetch = 2;
for (uint32_t d = max_prefetch; d >= 2; d /= 2) {
size_t repacked_vtcm_size = hex_round_up(d * tile_row_size, 128);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
size_t src0_sz = repacked_vtcm_size * sess->n_threads;
size_t src2_sz = hex_round_up(d * tile_row_size, 128) * sess->n_threads;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + quant_scratch_size;
if (tiled_vtcm_size <= sess->vtcm_size) {
best_n_prefetch = d;
@@ -2564,9 +2733,6 @@ static void ggml_hexagon_precompute_fused_ffn_params(
}
if (best_n_prefetch == 2 && src0_sz_per_thread == 0) {
size_t repacked_vtcm_size = hex_round_up(2 * tile_row_size, 128);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
src0_sz_per_thread = repacked_vtcm_size;
src2_sz_per_thread = hex_round_up(2 * tile_row_size, 128);
}
@@ -2582,13 +2748,14 @@ static void ggml_hexagon_precompute_fused_ffn_params(
size_t src1_sz = src1_sz_per_thread;
size_t src2_sz = src2_sz_per_thread * sess->n_threads;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + quant_scratch_size;
bool try_tiled = (opt_mm_select >= 2);
if (try_tiled && tiled_vtcm_size <= sess->vtcm_size) {
kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW;
kparams->vtcm_src0_size = src0_sz;
kparams->vtcm_src1_size = src1_sz;
kparams->vtcm_src2_size = src2_sz;
kparams->vtcm_dst_size = quant_scratch_size;
kparams->vtcm_size = tiled_vtcm_size;
kparams->n_prefetch = best_n_prefetch;
} else {
@@ -2598,7 +2765,8 @@ static void ggml_hexagon_precompute_fused_ffn_params(
kparams->vtcm_src0_size = src0_sz;
kparams->vtcm_src1_size = flat_src1_sz;
kparams->vtcm_src2_size = src2_sz;
kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz;
kparams->vtcm_dst_size = quant_scratch_size;
kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz + quant_scratch_size;
kparams->n_prefetch = best_n_prefetch;
}
}
@@ -3243,7 +3411,7 @@ static inline bool op_is_compute(ggml_tensor *node)
return !ggml_op_is_empty(node->op) && !ggml_is_empty(node) && (node->flags & GGML_TENSOR_FLAG_COMPUTE);
}
static bool is_hmx_eligible(const ggml_tensor * t) {
static bool mm_is_hmx_eligible(const ggml_tensor * t) {
if (opt_nhmx == 0) { return false; }
const ggml_tensor * src0 = t->src[0];
@@ -3262,7 +3430,7 @@ static bool is_hmx_eligible(const ggml_tensor * t) {
static bool is_mergeable_mul_mat(const ggml_tensor * t) {
if (!t || t->op != GGML_OP_MUL_MAT) return false;
if (t->src[1]->type != GGML_TYPE_F32) return false;
return ggml_is_quantized(t->src[0]->type) && !is_hmx_eligible(t);
return ggml_is_quantized(t->src[0]->type) && !mm_is_hmx_eligible(t);
}
static bool is_mergeable_mul_mat_pair(const ggml_tensor * n1, const ggml_tensor * n2) {
@@ -3357,6 +3525,26 @@ static bool try_fuse_node(const ggml_hexagon_session * sess, const ggml_cgraph *
}
}
if (n->op == GGML_OP_MUL_MAT && next_node) {
if (next_node->op == GGML_OP_ADD && op_is_compute(next_node) && ggml_can_fuse(graph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) {
if (next_node->src[0] == n || next_node->src[1] == n) {
struct htp_mm_kernel_params kparams;
ggml_hexagon_precompute_matmul_params(sess, n->src[0], n->src[1], next_node, &kparams);
if ((size_t)kparams.vtcm_size <= sess->vtcm_size) {
htp_opnode node(n, {}, HTP_OP_MUL_MAT_ADD);
node.add_fused(next_node);
memcpy(node.kernel_params, &kparams, sizeof(kparams));
nodes.push_back(std::move(node));
i += 1;
return true;
} else {
HEX_VERBOSE("ggml-hex: skip MUL_MAT_ADD fusion because VTCM needed (%d) > budget (%zu)\n",
kparams.vtcm_size, sess->vtcm_size);
}
}
}
}
return false;
}
@@ -3393,6 +3581,11 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
node.node->src[0], node.node->src[1], node.node,
(struct htp_mm_kernel_params *)node.kernel_params
);
} else if (node.opcode == HTP_OP_FLASH_ATTN_EXT) {
ggml_hexagon_precompute_flash_attn_params(sess,
node.node,
(struct htp_fa_kernel_params *)node.kernel_params
);
}
computed_nodes.push_back(std::move(node));
}
@@ -4079,6 +4272,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX");
const char * str_nhmx = getenv("GGML_HEXAGON_NHMX");
const char * str_mm_select = getenv("GGML_HEXAGON_MM_SELECT");
const char * str_fa_select = getenv("GGML_HEXAGON_FA_SELECT");
const char * str_ndev = getenv("GGML_HEXAGON_NDEV");
const char * str_arch = getenv("GGML_HEXAGON_ARCH");
const char * str_vmem = getenv("GGML_HEXAGON_VMEM");
@@ -4120,6 +4314,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;
opt_nhmx = str_nhmx ? atoi(str_nhmx) : (str_use_hmx ? atoi(str_use_hmx) : opt_nhmx);
opt_mm_select = str_mm_select ? atoi(str_mm_select) : opt_mm_select;
opt_fa_select = str_fa_select ? atoi(str_fa_select) : opt_fa_select;
opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev;
opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf;
opt_mbuf = str_mbuf ? strtoul(str_mbuf, NULL, 0) * MiB : opt_mbuf;
+13 -1
View File
@@ -11,6 +11,7 @@
#include <stdio.h>
#include "htp-ops.h"
#include "htp/matmul-ops.h"
#include "htp/flash-attn-ops.h"
struct htp_opnode {
ggml_tensor * node = nullptr;
@@ -335,7 +336,8 @@ struct htp_opformat {
}
void format_kernel_params(char * str, size_t max_size, const htp_opnode & node) {
if (node.opcode == HTP_OP_MUL_MAT || node.opcode == HTP_OP_MUL_MAT_ID ||
node.opcode == HTP_OP_MUL_MAT_QKV || node.opcode == HTP_OP_MUL_MAT_FFN) {
node.opcode == HTP_OP_MUL_MAT_QKV || node.opcode == HTP_OP_MUL_MAT_FFN ||
node.opcode == HTP_OP_MUL_MAT_ADD) {
const auto * kparams = (const struct htp_mm_kernel_params *) node.kernel_params;
const char * path = "unknown";
int32_t type = kparams->kernel_type;
@@ -350,6 +352,16 @@ struct htp_opformat {
path = "hvx-flat";
}
snprintf(str, max_size, "%s vtcm %d", path, (int) kparams->vtcm_size);
} else if (node.opcode == HTP_OP_FLASH_ATTN_EXT) {
const auto * kparams = (const struct htp_fa_kernel_params *) node.kernel_params;
const char * path = "unknown";
int32_t type = kparams->kernel_type;
if (type == HTP_FA_KERNEL_HMX) {
path = kparams->u.hmx.pipeline ? "hmx-pipe" : "hmx-seq";
} else if (type == HTP_FA_KERNEL_HVX) {
path = "hvx";
}
snprintf(str, max_size, "%s vtcm %d", path, (int) kparams->vtcm_size);
} else {
snprintf(str, max_size, "----");
}
+2 -7
View File
@@ -20,9 +20,6 @@ add_library(${HTP_LIB} SHARED
worker-pool.c
hex-dma.c
hmx-queue.c
flash-attn-ops.c
hmx-flash-attn-ops.c
matmul-ops.c
binary-ops.c
unary-ops.c
sum-rows-ops.c
@@ -42,16 +39,14 @@ add_library(${HTP_LIB} SHARED
solve-tri-ops.c
gated-delta-net-ops.c
pad-ops.c
matmul-ops.c
flash-attn-ops.c
)
target_compile_definitions(${HTP_LIB} PRIVATE
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>)
if (GGML_HEXAGON_FA_EXP2_HF)
message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)")
target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1)
endif()
build_idl(htp_iface.idl ${HTP_LIB})
File diff suppressed because it is too large Load Diff
+253
View File
@@ -0,0 +1,253 @@
#ifndef HTP_FLASH_ATTN_OPS_H
#define HTP_FLASH_ATTN_OPS_H
#include <stdint.h>
#include <stddef.h>
#include <stdbool.h>
#include "hex-fastdiv.h"
#include "hex-common.h"
#ifdef __cplusplus
extern "C" {
#endif
// Tile constants (mirrored from hmx-utils.h for use on host side if needed)
#define HMX_FP16_TILE_N_ROWS 32
#define HMX_FP16_TILE_N_COLS 32
#define HMX_FP16_TILE_N_ELMS 1024
#define HMX_FP16_TILE_SIZE 2048
#define HVX_FA_DMA_CACHE_SIZE 128
#define HMX_FA_DMA_CACHE_SIZE 4
#define HTP_FA_M_INITIAL_VAL -10000.0f
enum htp_fa_kernel_type {
HTP_FA_KERNEL_UNSUPPORTED = 0,
HTP_FA_KERNEL_HVX,
HTP_FA_KERNEL_HMX
};
struct htp_fa_kernel_params {
uint8_t kernel_type; // enum htp_fa_kernel_type
uint8_t is_q_fp32; // 1 = Q type is F32, 0 = F16
uint8_t is_dst_fp32; // 1 = dst type is F32, 0 = F16
uint8_t n_threads; // Number of threads to run
// Common parameters
uint16_t Br;
uint16_t Bc;
uint16_t n_kv_blocks; // also HVX's n_blocks
uint16_t G; // GQA factor (n_heads / n_kv_heads)
float scale;
float max_bias;
float logit_softcap;
uint32_t vtcm_size;
uint32_t qrows;
uint32_t qrows_per_thread;
float m0;
float m1;
uint32_t n_head_log2;
struct fastdiv_values src3_div2;
struct fastdiv_values src3_div3;
union {
struct {
uint32_t g_br;
uint32_t row_buf_stride;
uint32_t mask_buf_row_stride;
int32_t mask_broadcast;
int32_t pipeline;
struct fastdiv_values div_G;
} hmx;
struct {
uint32_t size_q_row_padded;
uint32_t size_k_row_padded;
uint32_t size_v_row_padded;
struct fastdiv_values src0_div21;
struct fastdiv_values src0_div1;
struct fastdiv_values broadcast_rk2;
struct fastdiv_values broadcast_rk3;
struct fastdiv_values broadcast_rv2;
struct fastdiv_values broadcast_rv3;
} hvx;
} u;
};
#if defined(__cplusplus)
static_assert(sizeof(struct htp_fa_kernel_params) <= 128, "htp_fa_kernel_params is too large for kernel_params blob");
#endif
// Exact VTCM usage for a given (gqa_factor, DK, DV, Br, Bc) configuration.
// g_br = hex_align_up(gqa_factor * Br, 32) replaces Br for all Q/O/S/P/D dimensions.
// Layout: Q + O_ping + O_pong + K_dma*2 + V_dma*2 + K_tile + V_tile + S + P + D + vectors + scales
// Mask is DMA'd into a VTCM buffer (Br rows per KV block) to avoid DDR reads in softmax.
static inline size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool pipeline) {
const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS);
const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK]
const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong
const size_t k_dma_size = hex_align_up(Bc * hex_round_up(DK * sizeof(__fp16), 128), 4096); // K DMA: [Bc, DK] x2 double-buf
const size_t v_dma_size = hex_align_up(Bc * hex_round_up(DV * sizeof(__fp16), 128), 4096); // V DMA: [Bc, DV] x2 double-buf
const size_t k_tile_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); // K tiles: [Bc, DK] interleaved
const size_t v_tile_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); // V tiles: [Bc, DV] interleaved
const size_t s_tile_size = hex_align_up(g_br * Bc * sizeof(__fp16), 4096); // S/P:[g_br, Bc]
const size_t d_tile_size = hex_align_up(g_br * g_br * sizeof(__fp16), 4096); // D: [g_br, g_br]
const size_t col_vec_size = hex_align_up(g_br * sizeof(float), 256); // m, l, etc.
const size_t row_vec_size = hex_align_up(Bc * sizeof(__fp16), 256);
const size_t m_line_size = hex_align_up(Bc * sizeof(__fp16), 128);
const size_t m_buf_size = hex_align_up(Br * m_line_size, 4096) * HMX_FA_DMA_CACHE_SIZE;
const size_t slopes_size = hex_align_up(g_br * sizeof(__fp16), 128);
return q_tile_size * 1 // Q tiles
+ o_tile_size * 2 // O ping-pong
+ k_dma_size * 2 // K DMA x2
+ v_dma_size * 2 // V DMA x2
+ k_tile_size * 1 // K tiles
+ v_tile_size * (pipeline ? 2 : 1) // V tiles (double-buffered if pipelining)
+ s_tile_size * 2 // S + P
+ d_tile_size * 1 // D (diagonal matrix)
+ col_vec_size * 4 // m_vec, l_vec, s_rowmax, p_rowsum
+ row_vec_size * 2 * n_threads // per-thread softmax row scratch
+ m_buf_size * 1 // mask VTCM buffer [Br rows]
+ slopes_size // Slopes
+ 256 * 2; // HMX scales (id + qk)
}
#define FA_HVX_BLOCK_SIZE 64
static inline size_t hvx_fa_compute_vtcm_usage(size_t DK, size_t DV, bool is_q_fp32, bool has_mask, size_t n_threads) {
const size_t size_q_row_padded = hex_round_up(DK * (is_q_fp32 ? 4 : 2), 128);
const size_t size_k_row_padded = hex_round_up(DK * sizeof(__fp16), 128);
const size_t size_v_row_padded = hex_round_up(DV * sizeof(__fp16), 128);
const size_t size_q_block = size_q_row_padded * 1;
const size_t size_k_block = size_k_row_padded * FA_HVX_BLOCK_SIZE;
const size_t size_v_block = size_v_row_padded * FA_HVX_BLOCK_SIZE;
const size_t size_m_block = hex_round_up(FA_HVX_BLOCK_SIZE * sizeof(__fp16), 128);
const size_t size_vkq_acc = hex_round_up(DV * sizeof(float), 128);
const size_t size_per_thread = size_q_block * 1
+ size_k_block * 2
+ size_v_block * 2
+ (has_mask ? size_m_block * HVX_FA_DMA_CACHE_SIZE : 0)
+ size_vkq_acc;
return size_per_thread * n_threads;
}
#define FA_MIN_KV_BLOCKS 3
// Cost-based (Br, Bc) search for flash attention with pipeline constraint.
static inline int hmx_fa_find_chunk_size(size_t * Br_out,
size_t * Bc_out,
size_t gqa_factor,
size_t DK,
size_t DV,
size_t qo_len,
size_t kv_len,
size_t vtcm_budget,
size_t n_threads) {
const size_t T = HMX_FP16_TILE_N_ROWS; // 32
const size_t br_unit = hmx_ceil_div(T, gqa_factor);
const size_t bc_unit = HMX_FP16_TILE_N_COLS * 2; // 64
const size_t fp16 = sizeof(__fp16);
const bool can_pipeline = (kv_len >= FA_MIN_KV_BLOCKS * bc_unit && n_threads >= 2);
// Approximate per-unit VTCM costs (without per-buffer alignment padding).
const size_t per_gbr = (DK + 2 * DV) * fp16 + 4 * sizeof(float); // Q + O*2 + 4 col vectors
const size_t per_gbr2 = fp16; // D diagonal matrix
const size_t per_bc =
3 * DK * fp16 + (can_pipeline ? 4 : 3) * DV * fp16 + 2 * n_threads * fp16; // K/V DMA x2 + tiles + row bufs
const size_t per_gbr_bc = 2 * fp16; // S + P
const size_t overhead = 256 * 2 + 13 * 4096;
if (vtcm_budget <= overhead) {
return -1;
}
const size_t usable = vtcm_budget - overhead;
// Br_max: largest Br aligned to br_unit that does not exceed qo_len.
const size_t Br_max = qo_len >= br_unit ? hex_align_down(qo_len, br_unit) : br_unit;
// Pipeline constraint: cap Bc so n_kv_blocks >= FA_MIN_KV_BLOCKS.
// Only relax when kv_len is too short to form enough blocks.
const size_t Bc_limit = can_pipeline ? hex_align_down(kv_len / FA_MIN_KV_BLOCKS, bc_unit) :
(kv_len >= bc_unit ? hex_align_down(kv_len, bc_unit) : bc_unit);
// Cost coefficients calibrated from profiling
const size_t c_q_fixed = 1400; // per-Q-block: q_load + epilogue o_update + o_norm + o_store
const size_t c_iter_fixed = 200; // per-KV-iter: HMX queue push/pop + DMA pop + barriers
size_t best_cost = SIZE_MAX, best_mn = 0;
size_t best_Br = 0, best_Bc = 0;
for (size_t Br = Br_max; Br >= br_unit; Br -= br_unit) {
const size_t g_br = hex_align_up(gqa_factor * Br, T);
// g_br-dependent VTCM cost: g_br * per_gbr + g_br*g_br * per_gbr2
const size_t gbr_cost = g_br * per_gbr + g_br * g_br * per_gbr2;
if (gbr_cost >= usable) {
if (Br == br_unit) {
break;
}
continue;
}
// Analytically solve for max Bc:
// remain >= Bc * (per_bc + g_br * per_gbr_bc + Br * fp16 * HMX_FA_DMA_CACHE_SIZE)
// The Br * fp16 term accounts for the VTCM mask buffer [Br * Bc].
const size_t remain = usable - gbr_cost;
const size_t bc_denom = per_bc + g_br * per_gbr_bc + Br * fp16 * HMX_FA_DMA_CACHE_SIZE;
size_t Bc = hex_smin(hex_align_down(remain / bc_denom, bc_unit), Bc_limit);
if (Bc < bc_unit) {
if (Br == br_unit) {
break;
}
continue;
}
// Exact VTCM verification (alignment padding may push over budget)
while (Bc >= bc_unit && hmx_fa_compute_vtcm_usage(gqa_factor, DK, DV, Br, Bc, n_threads, can_pipeline) > vtcm_budget) {
Bc -= bc_unit;
}
if (Bc < bc_unit) {
if (Br == br_unit) {
break;
}
continue;
}
const size_t q_blocks = (qo_len + Br - 1) / Br;
const size_t kv_blocks = (kv_len + Bc - 1) / Bc;
const size_t cost = q_blocks * (c_q_fixed + kv_blocks * c_iter_fixed);
const size_t mn = Br * Bc;
if (cost < best_cost || (cost == best_cost && mn > best_mn)) {
best_cost = cost;
best_mn = mn;
best_Br = Br;
best_Bc = Bc;
}
if (Br == br_unit) {
break;
}
}
if (best_Br == 0) {
return -1;
}
*Br_out = best_Br;
*Bc_out = best_Bc;
return 0;
}
#ifdef __cplusplus
}
#endif
#endif /* HTP_FLASH_ATTN_OPS_H */
+15 -15
View File
@@ -138,27 +138,28 @@ static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t
}
dma_descriptor_1d * desc = (dma_descriptor_1d *) &q->desc[q->push_idx];
desc->next = NULL;
desc->desc_size = 0; // 1D mode
desc->src_bypass = dma_src_l2_bypass_on;
desc->dst_bypass = dma_dst_l2_bypass_on;
desc->order = 0;
desc->done = 0;
desc->src = (void *) dptr.src;
desc->dst = (void *) dptr.dst;
desc->size = size;
desc->src = (void *) dptr.src;
desc->dst = (void *) dptr.dst;
desc->size = size;
q->dptr[q->push_idx] = dptr;
if (size) {
desc->next = NULL;
desc->desc_size = 0; // 1D mode
desc->src_bypass = dma_src_l2_bypass_on;
desc->dst_bypass = dma_dst_l2_bypass_on;
desc->order = 0;
desc->done = 0;
htp_trace_event_start(q->trace, HTP_TRACE_EVT_DMA, q->push_idx);
dmlink(q->tail, desc);
q->tail = (dma_descriptor_2d *) desc;
} else {
desc->done = 1;
desc->desc_size = 0;
desc->done = 1;
}
// FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src);
q->push_idx = (q->push_idx + 1) & q->idx_mask;
return true;
}
@@ -320,7 +321,7 @@ static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q, dma_ptr dptr, size_
return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows);
}
#define DMA_CACHE_MAX_SIZE 64U
#define DMA_CACHE_MAX_SIZE 256U
typedef struct {
uint8_t *base;
@@ -352,20 +353,19 @@ static inline bool dma_cache_push(dma_queue *q, dma_cache *c, const uint8_t * sr
if (c->src[i] == (uint32_t) src) {
c->age[i] = 0;
dst = c->base + (i * c->line_size); nrows = 0; // dummy dma
// FARF(ERROR, "dma-cache: found %p", src);
} else {
c->age[i]++;
if (c->age[i] > o_age) { o_age = c->age[i]; o_idx = i; }
}
}
if (!dst) {
// FARF(ERROR, "dma-cache: replacing #%u : age %u %p -> %p", o_idx, c->age[o_idx], (void *) c->src[o_idx], src);
c->age[o_idx] = 0;
c->src[o_idx] = (uint32_t) src;
dst = c->base + o_idx * c->line_size; // normal nrows dma
return dma_queue_push(q, dma_make_ptr(dst, src), dst_stride, src_stride, row_size, nrows);
}
return dma_queue_push(q, dma_make_ptr(dst, src), dst_stride, src_stride, row_size, nrows);
return dma_queue_push_single_1d(q, dma_make_ptr(dst, src), 0);
}
#ifdef __cplusplus
@@ -0,0 +1,96 @@
#ifndef HMX_FA_KERNELS_H
#define HMX_FA_KERNELS_H
#include <stdint.h>
#include <stddef.h>
#include <stdbool.h>
#include "hvx-utils.h"
#include "hmx-utils.h"
// HMX-specific parameters, offsets and inner kernels for Flash Attention
// Scatter offsets for diagonal tile: entry[2i] = i*136, entry[2i+1] = i*136+6
// 136 = 4 * 32 + 8 = byte offset to diagonal in a 32x32 fp16 interleaved tile
static const int16_t d_tile_scatter_offsets[64] __attribute__((aligned(128))) = {
0 * 136, 0 * 136 + 6,
1 * 136, 1 * 136 + 6,
2 * 136, 2 * 136 + 6,
3 * 136, 3 * 136 + 6,
4 * 136, 4 * 136 + 6,
5 * 136, 5 * 136 + 6,
6 * 136, 6 * 136 + 6,
7 * 136, 7 * 136 + 6,
8 * 136, 8 * 136 + 6,
9 * 136, 9 * 136 + 6,
10 * 136, 10 * 136 + 6,
11 * 136, 11 * 136 + 6,
12 * 136, 12 * 136 + 6,
13 * 136, 13 * 136 + 6,
14 * 136, 14 * 136 + 6,
15 * 136, 15 * 136 + 6,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
};
// Inner HMX tile computation kernels
static inline void hmx_fa_qk_dot_tile(
const __fp16 * row_tiles,
const __fp16 * col_tiles,
__fp16 * out_tile,
size_t n_dot_tiles
) {
for (size_t k = 0; k < n_dot_tiles; ++k) {
Q6_activation_hf_mxmem_RR((unsigned int) row_tiles, 2047);
Q6_weight_hf_mxmem_RR((unsigned int) col_tiles, 2047);
row_tiles += HMX_FP16_TILE_N_ELMS;
col_tiles += HMX_FP16_TILE_N_ELMS;
}
Q6_mxmem_AR_after_hf(out_tile, 0);
}
static inline void hmx_fa_o_update_tile(
const __fp16 * d_diag,
const __fp16 * o_rc,
const __fp16 * p_tile_in,
const __fp16 * v_tile_in,
__fp16 * o_tile_out,
size_t n_col_tiles
) {
Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047);
Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047);
for (size_t k = 0; k < n_col_tiles; ++k) {
Q6_activation_hf_mxmem_RR((unsigned int) p_tile_in, 2047);
Q6_weight_hf_mxmem_RR((unsigned int) v_tile_in, 2047);
p_tile_in += HMX_FP16_TILE_N_ELMS;
v_tile_in += HMX_FP16_TILE_N_ELMS;
}
Q6_mxmem_AR_after_hf(o_tile_out, 0);
}
static inline void hmx_fa_o_norm_tile(
const __fp16 * d_diag,
const __fp16 * o_rc,
__fp16 * o_out
) {
Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047);
Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047);
Q6_mxmem_AR_after_hf(o_out, 0);
}
#endif /* HMX_FA_KERNELS_H */
File diff suppressed because it is too large Load Diff
@@ -712,7 +712,17 @@ static inline void hmx_matmul_job_init(hmx_matmul_job_t * job,
// output : fp16 -> f32p
static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, uint32_t start_row, uint32_t n_rows, uint32_t n_cols, uint32_t dst_stride, uint32_t dst_cols) {
static void transfer_output_chunk_fp16_to_fp32(
float *restrict dst,
const float *restrict src2,
const __fp16 *restrict vtcm_src,
uint32_t start_row,
uint32_t n_rows,
uint32_t n_cols,
uint32_t dst_stride,
uint32_t src2_stride,
uint32_t dst_cols
) {
assert(n_cols % HTP_MM_HMX_TILE_N_COLS == 0);
const size_t tile_row_stride = (n_cols / HTP_MM_HMX_TILE_N_COLS) * HTP_MM_HMX_TILE_N_ELMS;
@@ -727,6 +737,7 @@ static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16
const size_t r1 = (r_idx0 % HTP_MM_HMX_TILE_N_ROWS) / 2; // index of the row pair within the tile
const __fp16 *row_base = vtcm_src + r0 * tile_row_stride;
float *output_row_base = dst + r * dst_stride; // global memory row base for row r (and r+1)
const float *src2_row_base = src2 ? (src2 + r * src2_stride) : NULL;
#pragma unroll(4)
for (size_t c = 0; c < limit_c_aligned; c += HTP_MM_HMX_TILE_N_COLS) {
@@ -738,9 +749,20 @@ static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16
HVX_Vector *pv_out0 = (HVX_Vector *) (output_row_base + c + 0);
HVX_Vector *pv_out1 = (HVX_Vector *) (output_row_base + c + dst_stride);
*pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp));
HVX_Vector v_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp));
if (src2_row_base) {
HVX_Vector v_src2_0 = hvx_vmemu(src2_row_base + c + 0);
v_out0 = hvx_vec_add_f32_f32(v_out0, v_src2_0);
}
*pv_out0 = v_out0;
if (r + 1 < n_rows) {
*pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp));
HVX_Vector v_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp));
if (src2_row_base) {
HVX_Vector v_src2_1 = hvx_vmemu(src2_row_base + c + src2_stride);
v_out1 = hvx_vec_add_f32_f32(v_out1, v_src2_1);
}
*pv_out1 = v_out1;
}
}
@@ -752,9 +774,20 @@ static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16
HVX_Vector v = ((const HVX_Vector *) tile)[r1];
HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one);
hvx_vec_store_u(output_row_base + c, valid_c * sizeof(float), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)));
HVX_Vector v_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp));
if (src2_row_base) {
HVX_Vector v_src2_0 = hvx_vmemu(src2_row_base + c + 0);
v_out0 = hvx_vec_add_f32_f32(v_out0, v_src2_0);
}
hvx_vec_store_u(output_row_base + c, valid_c * sizeof(float), v_out0);
if (r + 1 < n_rows) {
hvx_vec_store_u(output_row_base + c + dst_stride, valid_c * sizeof(float), Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp)));
HVX_Vector v_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp));
if (src2_row_base) {
HVX_Vector v_src2_1 = hvx_vmemu(src2_row_base + c + src2_stride);
v_out1 = hvx_vec_add_f32_f32(v_out1, v_src2_1);
}
hvx_vec_store_u(output_row_base + c + dst_stride, valid_c * sizeof(float), v_out1);
}
}
}
@@ -763,11 +796,13 @@ static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16
typedef struct {
const __fp16 *vtcm_src;
float *dst;
const float *src2;
uint32_t n_tasks;
uint32_t n_tot_chunks;
uint32_t n_chunks_per_task;
uint32_t n_cols;
uint32_t dst_stride; // DDR row stride
uint32_t src2_stride; // DDR row stride for residual
uint32_t dst_cols; // Actual output columns
struct htp_thread_trace * traces;
} output_transfer_task_state_t;
+35 -35
View File
@@ -42,14 +42,14 @@ static const int32_t hmx_transpose_scatter_offsets[32] __attribute__((aligned(VL
// Full range: start_row=0, end_row=n_cols.
static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
const __fp16 * restrict vtcm_src,
int n_cols,
int k,
int src_stride,
int start_row,
int end_row) {
uint32_t n_cols,
uint32_t k,
size_t src_stride,
uint32_t start_row,
uint32_t end_row) {
assert(k % HMX_FP16_TILE_N_COLS == 0);
const int n_k_tiles = k / HMX_FP16_TILE_N_COLS;
const uint32_t n_k_tiles = k / HMX_FP16_TILE_N_COLS;
const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets);
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4);
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64);
@@ -65,14 +65,14 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
if (pair_scatter) {
// Step c by 64 fp16 (two K-tiles per scatter), advance dst by 2 tiles per iter.
const int c_step = 2 * HMX_FP16_TILE_N_COLS;
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
const size_t dst_step = 2 * (size_t) HMX_FP16_TILE_N_ELMS;
const int n_c_iters = k / c_step;
const uint32_t c_step = 2 * HMX_FP16_TILE_N_COLS;
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
const size_t dst_step = 2 * (size_t) HMX_FP16_TILE_N_ELMS;
const uint32_t n_c_iters = k / c_step;
for (int r = start_row; r < end_row; r += 2) {
const int ct = r / HMX_FP16_TILE_N_ROWS;
const int local_r = r % HMX_FP16_TILE_N_ROWS;
for (uint32_t r = start_row; r < end_row; r += 2) {
const uint32_t ct = r / HMX_FP16_TILE_N_ROWS;
const uint32_t local_r = r % HMX_FP16_TILE_N_ROWS;
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols;
const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
@@ -86,7 +86,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
assert(c_byte_step % 128 == 0);
if (p1) {
for (int i = 0; i < n_c_iters; ++i) {
for (uint32_t i = 0; i < n_c_iters; ++i) {
HVX_Vector v0 = hvx_vmem(p0); p0 += c_byte_step;
HVX_Vector v1 = hvx_vmem(p1); p1 += c_byte_step;
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0);
@@ -95,7 +95,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
}
} else {
const HVX_Vector vzero = Q6_V_vzero();
for (int i = 0; i < n_c_iters; ++i) {
for (uint32_t i = 0; i < n_c_iters; ++i) {
HVX_Vector v0 = hvx_vmem(p0); p0 += c_byte_step;
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0);
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, vzero);
@@ -105,14 +105,14 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
}
} else {
// Fallback: scatter one K-tile per call (region 2047, masked).
const int c_step = HMX_FP16_TILE_N_COLS;
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
const size_t dst_step = (size_t) HMX_FP16_TILE_N_ELMS;
const int n_c_iters = k / c_step;
const uint32_t c_step = HMX_FP16_TILE_N_COLS;
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
const size_t dst_step = (size_t) HMX_FP16_TILE_N_ELMS;
const uint32_t n_c_iters = k / c_step;
for (int r = start_row; r < end_row; r += 2) {
const int ct = r / HMX_FP16_TILE_N_ROWS;
const int local_r = r % HMX_FP16_TILE_N_ROWS;
for (uint32_t r = start_row; r < end_row; r += 2) {
const uint32_t ct = r / HMX_FP16_TILE_N_ROWS;
const uint32_t local_r = r % HMX_FP16_TILE_N_ROWS;
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols;
const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
@@ -122,7 +122,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL;
if (p1) {
for (int i = 0; i < n_c_iters; ++i) {
for (uint32_t i = 0; i < n_c_iters; ++i) {
HVX_Vector v0 = hvx_vmemu(p0); p0 += c_byte_step;
HVX_Vector v1 = hvx_vmemu(p1); p1 += c_byte_step;
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0);
@@ -131,7 +131,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
}
} else {
const HVX_Vector vzero = Q6_V_vzero();
for (int i = 0; i < n_c_iters; ++i) {
for (uint32_t i = 0; i < n_c_iters; ++i) {
HVX_Vector v0 = hvx_vmemu(p0); p0 += c_byte_step;
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0);
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, vzero);
@@ -148,24 +148,24 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
// Full range: start_row=0, end_row=n_rows.
static inline void hmx_interleave_cols_to_tiles(__fp16 * restrict tiles_out,
const __fp16 * restrict src,
int n_rows,
int head_dim,
int src_stride,
int n_row_tiles,
int start_row,
int end_row) {
uint32_t n_rows,
uint32_t head_dim,
size_t src_stride,
uint32_t n_row_tiles,
uint32_t start_row,
uint32_t end_row) {
__builtin_assume(head_dim > 0);
const size_t tile_stride_elms = (size_t) n_row_tiles * HMX_FP16_TILE_N_ELMS;
for (int r = start_row; r < end_row; r += 2) {
for (uint32_t r = start_row; r < end_row; r += 2) {
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_rows;
const HVX_Vector * pv_in0 = (const HVX_Vector *) (src + r * src_stride);
const HVX_Vector * pv_in1 = next_row_valid ? (const HVX_Vector *) (src + (r + 1) * src_stride) : NULL;
// Row-pair invariants hoisted out of the c loop.
const int r0 = r / HMX_FP16_TILE_N_ROWS;
const int r1_half = (r % HMX_FP16_TILE_N_ROWS) / 2;
const uint32_t r0 = r / HMX_FP16_TILE_N_ROWS;
const uint32_t r1_half = (r % HMX_FP16_TILE_N_ROWS) / 2;
// tb0 starts at tile (c0=0, r0); tb1 at the adjacent dim-tile (c0=1, r0).
// Each c step (+= 64) advances both by 2 dim-tiles worth of fp16.
@@ -174,7 +174,7 @@ static inline void hmx_interleave_cols_to_tiles(__fp16 * restrict tiles_out,
const size_t tb_step = 2 * tile_stride_elms;
if (pv_in1) {
for (int c = 0; c < head_dim; c += 64) {
for (uint32_t c = 0; c < head_dim; c += 64) {
HVX_Vector v0 = *pv_in0++;
HVX_Vector v1 = *pv_in1++;
HVX_VectorPair vp = Q6_W_vshuff_VVR(v1, v0, -2);
@@ -185,7 +185,7 @@ static inline void hmx_interleave_cols_to_tiles(__fp16 * restrict tiles_out,
}
} else {
const HVX_Vector vzero = Q6_V_vzero();
for (int c = 0; c < head_dim; c += 64) {
for (uint32_t c = 0; c < head_dim; c += 64) {
HVX_Vector v0 = *pv_in0++;
HVX_VectorPair vp = Q6_W_vshuff_VVR(vzero, v0, -2);
((HVX_Vector *) tb0)[r1_half] = Q6_V_lo_W(vp);
+6
View File
@@ -60,6 +60,7 @@ enum htp_op_code {
HTP_OP_MUL_MAT_ID,
HTP_OP_MUL_MAT_QKV,
HTP_OP_MUL_MAT_FFN,
HTP_OP_MUL_MAT_ADD,
HTP_OP_RMS_NORM,
HTP_OP_RMS_NORM_MUL,
HTP_OP_UNARY_SILU,
@@ -175,6 +176,11 @@ enum htp_trace_event_id {
HTP_TRACE_EVT_HVX_W_DEQUANT = 23,
HTP_TRACE_EVT_HVX_W_PREP = 24,
HTP_TRACE_EVT_HVX_O_PROC = 25,
HTP_TRACE_EVT_HVX_FA_QK = 26,
HTP_TRACE_EVT_HVX_FA_SFM = 27,
HTP_TRACE_EVT_HVX_FA_Q_PREP = 28,
HTP_TRACE_EVT_HVX_FA_K_PREP = 29,
HTP_TRACE_EVT_HVX_FA_V_PREP = 30,
HTP_TRACE_EVT_HMX_COMP = 40,
};
+1 -12
View File
@@ -134,16 +134,7 @@ static inline HVX_Vector hvx_vec_f32_to_f16_shuff(HVX_Vector v0, HVX_Vector v1)
}
static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) {
HVX_Vector v = Q6_Vh_vdeal_Vh(hvx_vec_f32_to_f16_shuff(v0, v1));
#if __HVX_ARCH__ < 79
// replace NaNs with -INF, older arches produce NaNs for (-INF + 0.0)
const HVX_Vector neg_inf = hvx_vec_splat_f16(-INFINITY);
HVX_VectorPred nan = hvx_vec_is_nan_f16(v);
v = Q6_V_vmux_QVV(nan, neg_inf, v);
#endif
return v;
return Q6_Vh_vdeal_Vh(hvx_vec_f32_to_f16_shuff(v0, v1));
}
#if __HVX_ARCH__ >= 79
@@ -170,8 +161,6 @@ static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) {
}
#endif
static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) {
// This looks complicated.
// Ideally should just be Q6_Vh_equals_Vhf(vin)
+39
View File
@@ -16,6 +16,7 @@
#define EXP_COEFF_0 (0x3F000000) // 0.5 = 1/(2!)
#define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805
#define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408
#define EXP_LOG2E_F 1.44269504f
#define EXP_ONE (0x3f800000) // 1.0
#define EXP_RANGE_R (0x42B17218) // ln(FLT_MAX) approx = 88.7228
#define EXP_RANGE_L (0xC2B00000) // -88.0 (approx log(FLT_MIN))
@@ -213,4 +214,42 @@ static inline void hvx_exp_f32(uint8_t * restrict dst, const uint8_t * restrict
}
}
static inline HVX_Vector hvx_vec_exp2_f16(HVX_Vector x_v) {
const HVX_Vector zero_v = Q6_V_vzero();
const HVX_Vector half_hf_v = Q6_Vh_vsplat_R(0x3800); // fp16 0.5
// Clamp input to prevent integer underflow in FP16-to-INT16 conversion
const HVX_Vector v_clamp_min = hvx_vec_splat_f16(-24.0f);
x_v = Q6_Vhf_vmax_VhfVhf(v_clamp_min, x_v);
// k = round_toward_neg_inf(x); f = (float)k; frac = x - f
HVX_Vector x_minus_half = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vsub_VhfVhf(x_v, half_hf_v));
HVX_Vector k_v = Q6_Vh_equals_Vhf(x_minus_half); // truncate to int16
HVX_Vector f_v = Q6_Vhf_equals_Vh(k_v); // back to fp16
HVX_Vector x_qf16 = Q6_Vqf16_vsub_VhfVhf(x_v, f_v); // fractional part in qf16
// Horner: y = ((((E5*x + E4)*x + E3)*x + E2)*x + E1)*x + E0
HVX_Vector y = Q6_Vqf16_vmpy_Vqf16Vqf16(Q6_Vh_vsplat_R(0x5082), x_qf16); // E5*x
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x157d)); // + E4
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16);
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x20ed)); // + E3
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16);
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x2b1b)); // + E2
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16);
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x33b0)); // + E1
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16);
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x398c)); // + E0
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); // y = y * x
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x3c00)); // + 1.0
// Combine polynomial (mantissa) with integer part (exponent): result = y * 2^k
y = Q6_Vhf_equals_Vqf16(y);
HVX_Vector y_exp = Q6_Vuh_vlsr_VuhR(Q6_Vh_vasl_VhR(y, 1), 11);
y_exp = Q6_Vh_vadd_VhVh(k_v, y_exp);
HVX_VectorPred q_underflow = Q6_Q_vcmp_gt_VhVh(zero_v, y_exp);
y = Q6_Vh_vaslacc_VhVhR(y, k_v, 10);
return Q6_V_vmux_QVV(q_underflow, zero_v, y);
}
#endif /* HVX_EXP_H */
+232
View File
@@ -0,0 +1,232 @@
#ifndef HVX_FA_KERNELS_H
#define HVX_FA_KERNELS_H
#include <assert.h>
#include <math.h>
#include "hvx-utils.h"
// Little inner kernels for HVX
#if __HVX_ARCH__ < 79
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
#define HVX_OP_SUB_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b))
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
#else
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
#define HVX_OP_SUB_F32(a, b) Q6_Vsf_vsub_VsfVsf(a, b)
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#endif
// This is a bit of a hack because the compiler is struggling to properly inline
// the default hvx_vec_f32_to_f16 with output into the local array.
static __attribute__((unused)) __attribute__((noinline)) void hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1)
{
*(HVX_Vector *) ptr = hvx_vec_f32_to_f16(v0, v1);
}
// Dot product of two F16 vectors, accumulating to float
static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_VectorPair rsum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, vx[i], vy[i]);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf);
}
HVX_Vector rsum = HVX_OP_ADD_F32(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p));
rsum = HVX_OP_MUL_F32(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
hvx_vec_store_u(r, 4, rsum);
}
static inline HVX_Vector hvx_dot_f16_f16_aa_rx4(const void * restrict y,
const uint8_t * restrict x,
const size_t stride_x,
const size_t nvec,
const size_t nloe) {
const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x; // fp16
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) (x + stride_x); // fp16
const HVX_Vector * restrict vx2 = (const HVX_Vector * restrict) (x + stride_x * 2); // fp16
const HVX_Vector * restrict vx3 = (const HVX_Vector * restrict) (x + stride_x * 3); // fp16
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
HVX_VectorPair rsum0_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
HVX_VectorPair rsum1_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
HVX_VectorPair rsum2_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
HVX_VectorPair rsum3_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
uint32_t i = 0;
for (i = 0; i < nvec; i++) {
HVX_Vector y_hf = vy[i];
HVX_Vector x0_hf = vx0[i];
HVX_Vector x1_hf = vx1[i];
HVX_Vector x2_hf = vx2[i];
HVX_Vector x3_hf = vx3[i];
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf);
rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf);
}
if (nloe) {
// Load x (fp16) and zero-out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
HVX_Vector x2_hf = Q6_V_vand_QV(bmask, vx2[i]);
HVX_Vector x3_hf = Q6_V_vand_QV(bmask, vx3[i]);
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf);
rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf);
}
HVX_Vector rsum0 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p));
HVX_Vector rsum1 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p));
HVX_Vector rsum2 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p));
HVX_Vector rsum3 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p));
HVX_Vector_x4 rsum0123 = { .v = { rsum0, rsum1, rsum2, rsum3 } };
return hvx_vec_reduce_sum_f32x4(rsum0123);
}
static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y,
const uint8_t * restrict x,
const size_t stride_x,
const size_t n,
float s) {
const size_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
const size_t nloe = n % VLEN_FP16; // leftover elements
HVX_Vector sums = Q6_V_vzero();
const size_t stride_x_4 = stride_x * 4;
for (uint32_t j = 0; j < VLEN_FP32; j += 4) {
HVX_Vector sums_x4 = hvx_dot_f16_f16_aa_rx4(y, x, stride_x, nvec, nloe);
HVX_VectorPred pred = Q6_Q_vsetq_R(j * SIZEOF_FP32);
sums = Q6_V_vmux_QVV(pred, sums, sums_x4);
x += stride_x_4;
}
return HVX_OP_MUL_F32(hvx_vec_splat_f32(s), sums);
}
// MAD: y (F32) += x (F16) * s (F16)
static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, const __fp16 * restrict s, uint32_t n) {
const HVX_Vector * restrict vx0 = (const HVX_Vector *) x;
HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y;
HVX_Vector * restrict vy = (HVX_Vector *) y;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_Vector S0 = hvx_vec_splat_f16(*s);
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; ++i) {
vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0);
}
if (nloe) {
HVX_VectorPair xy_p = vy_p[i];
xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0);
HVX_Vector xy = Q6_V_lo_W(xy_p);
i = 2 * i; // index for vy
if (nloe >= VLEN_FP32) {
vy[i] = xy;
nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p);
}
if (nloe) {
hvx_vec_store_a(&vy[i], nloe * 4, xy);
}
}
}
// MAD: y (F32) += x0 (F16) * s0 (F16) + x1 (F16) * s1 (F16)
static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y, const void * restrict x0, const void * restrict x1,
const __fp16 * restrict s0, const __fp16 * restrict s1, uint32_t n) {
const HVX_Vector * restrict vx0 = (const HVX_Vector *) x0;
const HVX_Vector * restrict vx1 = (const HVX_Vector *) x1;
HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y;
HVX_Vector * restrict vy = (HVX_Vector *) y;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_Vector S0 = hvx_vec_splat_f16(*s0);
HVX_Vector S1 = hvx_vec_splat_f16(*s1);
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; ++i) {
vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0);
vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx1[i]), S1);
}
if (nloe) {
HVX_VectorPair xy_p = vy_p[i];
xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0);
xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx1[i]), S1);
HVX_Vector xy = Q6_V_lo_W(xy_p);
i = 2 * i; // index for vy
if (nloe >= VLEN_FP32) {
vy[i] = xy;
nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p);
}
if (nloe) {
hvx_vec_store_a(&vy[i], nloe * 4, xy);
}
}
}
static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t n, HVX_Vector vs) {
assert((size_t) dst % 128 == 0);
assert((size_t) src % 128 == 0);
const HVX_Vector * restrict vsrc = (const HVX_Vector * restrict) src;
HVX_Vector * restrict vdst = (HVX_Vector * restrict) dst;
const uint32_t nvec = n / VLEN_FP32;
const uint32_t nloe = n % VLEN_FP32;
uint32_t i = 0;
#pragma unroll(4)
for (; i < nvec; ++i) {
vdst[i] = HVX_OP_MUL_F32(vsrc[i], vs);
}
if (nloe) {
hvx_vec_store_a(&vdst[i], nloe * sizeof(float), HVX_OP_MUL_F32(vsrc[i], vs));
}
}
#endif /* HVX_FA_KERNELS_H */
+512 -25
View File
@@ -256,7 +256,7 @@ static inline void quantize_f16_f16_flat_kernel(
// Dot kernels that consume flat (non-tiled) activations
static void flat_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void flat_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -312,10 +312,14 @@ static void flat_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const v
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void flat_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void flat_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -397,11 +401,19 @@ static void flat_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void flat_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void flat_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -464,10 +476,14 @@ static void flat_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const v
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void flat_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void flat_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -561,11 +577,19 @@ static void flat_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void flat_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void flat_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -620,10 +644,14 @@ static void flat_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const v
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void flat_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void flat_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -704,11 +732,19 @@ static void flat_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void flat_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void flat_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -765,10 +801,14 @@ static void flat_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void flat_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void flat_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -851,11 +891,19 @@ static void flat_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void flat_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void flat_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -921,10 +969,14 @@ static void flat_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const
v_sum_float = hvx_vec_mul_f32_f32(v_sum_float, hvx_vec_splat_f32(0.5f));
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void flat_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void flat_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -1019,6 +1071,441 @@ static void flat_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c0 = hvx_vec_mul_f32_f32(v_sum_float_c0, hvx_vec_splat_f32(0.5f));
v_sum_float_c1 = hvx_vec_mul_f32_f32(v_sum_float_c1, hvx_vec_splat_f32(0.5f));
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
#if __HVX_ARCH__ < 79
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
#else
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#endif
static inline void vec_dot_f32_f32_aa_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const HVX_Vector * restrict x = (const HVX_Vector *) vx;
const HVX_Vector * restrict y = (const HVX_Vector *) vy;
uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors
uint32_t nloe = n % VLEN_FP32; // leftover elements
HVX_Vector rsum = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
HVX_Vector prod = HVX_OP_MUL_F32(x[i], y[i]);
rsum = HVX_OP_ADD_F32(rsum, prod);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
HVX_Vector x_sf = Q6_V_vand_QV(bmask, x[i]);
HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]);
HVX_Vector prod = HVX_OP_MUL_F32(x_sf, y_sf);
rsum = HVX_OP_ADD_F32(rsum, prod);
}
*s = hvx_vec_get_f32(hvx_vec_reduce_sum_f32(rsum));
}
static inline void vec_dot_f32_f32_aa_2x1(const uint32_t n, float * restrict s0,
const void * restrict vx0, const void * restrict vx1,
const void * restrict vy0) {
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
const HVX_Vector * restrict y = (const HVX_Vector *) vy0;
uint32_t nvec = n / VLEN_FP32;
uint32_t nloe = n % VLEN_FP32;
HVX_Vector rsum0 = Q6_V_vzero();
HVX_Vector rsum1 = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector y_sf = y[i];
HVX_Vector prod0 = HVX_OP_MUL_F32(x0[i], y_sf);
HVX_Vector prod1 = HVX_OP_MUL_F32(x1[i], y_sf);
rsum0 = HVX_OP_ADD_F32(rsum0, prod0);
rsum1 = HVX_OP_ADD_F32(rsum1, prod1);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]);
HVX_Vector x0_sf = Q6_V_vand_QV(bmask, x0[i]);
HVX_Vector x1_sf = Q6_V_vand_QV(bmask, x1[i]);
HVX_Vector prod0 = HVX_OP_MUL_F32(x0_sf, y_sf);
HVX_Vector prod1 = HVX_OP_MUL_F32(x1_sf, y_sf);
rsum0 = HVX_OP_ADD_F32(rsum0, prod0);
rsum1 = HVX_OP_ADD_F32(rsum1, prod1);
}
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1);
hvx_vec_store_u(s0, 8, rsum);
}
static inline void vec_dot_f32_f32_aa_2x2(const uint32_t n, float * restrict s0, float * restrict s1,
const void * restrict vx0, const void * restrict vx1,
const void * restrict vy0, const void * restrict vy1) {
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0;
const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1;
uint32_t nvec = n / VLEN_FP32;
uint32_t nloe = n % VLEN_FP32;
HVX_Vector r0_c0_sum = Q6_V_vzero();
HVX_Vector r0_c1_sum = Q6_V_vzero();
HVX_Vector r1_c0_sum = Q6_V_vzero();
HVX_Vector r1_c1_sum = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector r0_sf = x0[i];
HVX_Vector r1_sf = x1[i];
HVX_Vector c0_sf = y0[i];
HVX_Vector c1_sf = y1[i];
r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf));
r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf));
r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf));
r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf));
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
HVX_Vector r0_sf = Q6_V_vand_QV(bmask, x0[i]);
HVX_Vector r1_sf = Q6_V_vand_QV(bmask, x1[i]);
HVX_Vector c0_sf = Q6_V_vand_QV(bmask, y0[i]);
HVX_Vector c1_sf = Q6_V_vand_QV(bmask, y1[i]);
r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf));
r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf));
r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf));
r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf));
}
// Reduce and store results
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
hvx_vec_store_u(s0, 8, r0_r1_c0_sum);
hvx_vec_store_u(s1, 8, r0_r1_c1_sum);
}
static inline void vec_dot_f32_f32_uu_1x1(const uint32_t n, float * restrict s, const void * restrict x, const void * restrict y) {
const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors
uint32_t nloe = n % VLEN_FP32; // leftover elements
HVX_Vector rsum = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector x_sf = vx[i];
HVX_Vector y_sf = vy[i];
rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf));
}
if (nloe) {
HVX_Vector x_sf = vx[i];
HVX_Vector y_sf = vy[i];
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
x_sf = Q6_V_vand_QV(bmask, x_sf);
y_sf = Q6_V_vand_QV(bmask, y_sf);
rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf));
}
rsum = hvx_vec_reduce_sum_f32(rsum);
hvx_vec_store_u(&s[0], 4, rsum);
}
#undef HVX_OP_ADD_F32
#undef HVX_OP_MUL_F32
static inline void vec_dot_f16_f16_aa_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const HVX_Vector * restrict x = (const HVX_Vector *) vx;
const HVX_Vector * restrict y = (const HVX_Vector *) vy;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_VectorPair rsum_p = Q6_W_vzero();
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x[i], y[i]);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf);
}
HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p)));
hvx_vec_store_u(s, 4, hvx_vec_reduce_sum_f32(rsum));
}
static inline void vec_dot_f16_f16_aa_2x1(const uint32_t n, float * restrict s0,
const void * restrict vx0, const void * restrict vx1,
const void * restrict vy0) {
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
const HVX_Vector * restrict y = (const HVX_Vector *) vy0;
uint32_t nvec = n / VLEN_FP16;
uint32_t nloe = n % VLEN_FP16;
HVX_VectorPair rsum0_p = Q6_W_vzero();
HVX_VectorPair rsum1_p = Q6_W_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector y_hf = y[i];
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0[i], y_hf);
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1[i], y_hf);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, x0[i]);
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, x1[i]);
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
}
HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p)));
HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p)));
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1);
hvx_vec_store_u(s0, 8, rsum);
}
static inline void vec_dot_f16_f16_aa_2x2(const uint32_t n, float * restrict s0, float * restrict s1,
const void * restrict vx0, const void * restrict vx1,
const void * restrict vy0, const void * restrict vy1) {
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0;
const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1;
uint32_t nvec = n / VLEN_FP16;
uint32_t nloe = n % VLEN_FP16;
// Row sums (sf) - 4 accumulators for 2x2 tile
HVX_VectorPair r0_c0_sum_p = Q6_W_vzero();
HVX_VectorPair r0_c1_sum_p = Q6_W_vzero();
HVX_VectorPair r1_c0_sum_p = Q6_W_vzero();
HVX_VectorPair r1_c1_sum_p = Q6_W_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector r0_hf = x0[i];
HVX_Vector r1_hf = x1[i];
HVX_Vector c0_hf = y0[i];
HVX_Vector c1_hf = y1[i];
// Compute 4 dot products: r0xc0, r0xc1, r1xc0, r1xc1
r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf);
r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf);
r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf);
r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector r0_hf = Q6_V_vand_QV(bmask, x0[i]);
HVX_Vector r1_hf = Q6_V_vand_QV(bmask, x1[i]);
HVX_Vector c0_hf = Q6_V_vand_QV(bmask, y0[i]);
HVX_Vector c1_hf = Q6_V_vand_QV(bmask, y1[i]);
r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf);
r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf);
r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf);
r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf);
}
HVX_Vector r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c0_sum_p), Q6_V_hi_W(r0_c0_sum_p)));
HVX_Vector r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c1_sum_p), Q6_V_hi_W(r0_c1_sum_p)));
HVX_Vector r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c0_sum_p), Q6_V_hi_W(r1_c0_sum_p)));
HVX_Vector r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c1_sum_p), Q6_V_hi_W(r1_c1_sum_p)));
// Reduce and store results
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0
hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
}
static inline void vec_dot_f16_f16_uu_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const HVX_UVector * restrict x = (const HVX_UVector *) vx;
const HVX_UVector * restrict y = (const HVX_UVector *) vy;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_Vector rsum = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
hvx_vec_store_u(&s[0], 4, rsum);
}
static inline void vec_dot_f16_f32_uu_1x1(const uint32_t n, float * restrict s, const void * restrict x, const void * restrict y) {
const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vzero();
HVX_Vector rsum = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
// Load y (fp32) and convert into fp16
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
// Load x (fp16)
HVX_Vector x_hf = vx[i];
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
if (nloe) {
// Load y (fp32) and convert into fp16
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
// Load x (fp16)
HVX_Vector x_hf = vx[i];
// Zero-out unused elements
// Note that we need to clear both x and y because they may contain NANs
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
x_hf = Q6_V_vand_QV(bmask, x_hf);
y_hf = Q6_V_vand_QV(bmask, y_hf);
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
// Convert into fp32 and reduce
rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
hvx_vec_store_u(&s[0], 4, rsum);
}
static inline void hvx_tensor_add_f32_grid(
const struct htp_tensor * restrict dst,
const struct htp_tensor * restrict src2,
uint32_t start_row,
uint32_t end_row,
uint32_t start_col,
uint32_t end_col,
const struct fastdiv_values * div_ne11_12,
const struct fastdiv_values * div_ne11
) {
if (start_row >= end_row || start_col >= end_col) return;
const uint32_t nb1 = dst->nb[1]; // row stride in bytes
const uint32_t ne11 = dst->ne[1];
const uint32_t ne12 = dst->ne[2];
const uint32_t ne11_12 = ne11 * ne12;
const bool is_broadcast1 = (src2->ne[1] == 1);
const bool is_broadcast2 = (src2->ne[2] == 1);
const bool is_broadcast3 = (src2->ne[3] == 1);
for (uint32_t r = start_row; r < end_row; r++) {
float * dst_row = (float *) ((uint8_t *) dst->data + r * nb1);
uint32_t i13 = fastdiv(r, div_ne11_12);
uint32_t i12 = fastdiv(r - i13 * ne11_12, div_ne11);
uint32_t i11 = r - i13 * ne11_12 - i12 * ne11;
uint32_t i23 = is_broadcast3 ? 0 : i13;
uint32_t i22 = is_broadcast2 ? 0 : i12;
uint32_t i21 = is_broadcast1 ? 0 : i11;
const float * src2_row = (const float *) ((const uint8_t *) src2->data +
i21 * src2->nb[1] + i22 * src2->nb[2] + i23 * src2->nb[3]);
float * dst_ptr = &dst_row[start_col];
const float * src2_ptr = &src2_row[start_col];
int remaining = end_col - start_col;
while (remaining >= 32) {
HVX_Vector v_out = hvx_vmemu(dst_ptr);
HVX_Vector v_z = hvx_vmemu(src2_ptr);
hvx_vmemu(dst_ptr) = hvx_vec_add_f32_f32(v_out, v_z);
dst_ptr += 32;
src2_ptr += 32;
remaining -= 32;
}
if (remaining > 0) {
HVX_Vector v_out = hvx_vmemu(dst_ptr);
HVX_Vector v_z = hvx_vmemu(src2_ptr);
hvx_vec_store_u(dst_ptr, remaining * sizeof(float), hvx_vec_add_f32_f32(v_out, v_z));
}
}
}
@@ -378,7 +378,7 @@ static inline HVX_VectorPair accum_q8_0_32x2(
return Q6_W_vcombine_VV(v_sum1, v_sum0);
}
static void tiled_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void tiled_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -401,10 +401,14 @@ static void tiled_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void tiled_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void tiled_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -484,11 +488,19 @@ static void tiled_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void tiled_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void tiled_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -519,10 +531,14 @@ static void tiled_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void tiled_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void tiled_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -637,11 +653,19 @@ static void tiled_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void tiled_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void tiled_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -663,10 +687,14 @@ static void tiled_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void tiled_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void tiled_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -745,11 +773,19 @@ static void tiled_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void tiled_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void tiled_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -773,10 +809,14 @@ static void tiled_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void tiled_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void tiled_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -857,11 +897,19 @@ static void tiled_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, floa
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void tiled_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void tiled_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -896,10 +944,14 @@ static void tiled_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const
v_sum_float = hvx_vec_mul_f32_f32(v_sum_float, hvx_vec_splat_f32(0.5f));
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void tiled_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void tiled_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -1013,8 +1065,16 @@ static void tiled_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, floa
v_sum_float_c0 = hvx_vec_mul_f32_f32(v_sum_float_c0, hvx_vec_splat_f32(0.5f));
v_sum_float_c1 = hvx_vec_mul_f32_f32(v_sum_float_c1, hvx_vec_splat_f32(0.5f));
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static inline void quantize_f32_q8_0_tiled_kernel(
+39
View File
@@ -3,6 +3,7 @@
#include "hvx-base.h"
#include "hvx-inverse.h"
#include "hvx-exp.h"
#define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022
#define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777
@@ -139,4 +140,42 @@ static inline void hvx_tanh_f32_aa(uint8_t * restrict dst, const uint8_t * restr
hvx_tanh_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
static inline HVX_Vector hvx_vec_fast_sigmoid_f16(HVX_Vector x_v) {
const HVX_Vector v_one = hvx_vec_splat_f16(1.0f);
const HVX_Vector v_neg_log2e = hvx_vec_splat_f16(-EXP_LOG2E_F);
const HVX_Vector em_mask = Q6_Vh_vsplat_R(0x7FFF);
// Compute absolute value of x_v
HVX_Vector abs_x = Q6_V_vand_VV(x_v, em_mask);
// Compute u = -abs_x * log2(e) <= 0.
HVX_Vector u = hvx_vec_mul_f16_f16(abs_x, v_neg_log2e);
// Clamp input to prevent underflow in exp2
const HVX_Vector v_clamp_min = hvx_vec_splat_f16(-24.0f);
u = Q6_Vhf_vmax_VhfVhf(v_clamp_min, u);
HVX_Vector exp_val = hvx_vec_exp2_f16(u);
HVX_Vector denom = hvx_vec_add_f16_f16(v_one, exp_val);
HVX_Vector sig_abs = hvx_vec_inverse_f16(denom);
// check if x_v < 0 (using integer comparison on absolute value)
HVX_VectorPred is_neg = Q6_Q_vcmp_gt_VhVh(abs_x, x_v);
// If x_v < 0, return 1.0f - sig_abs
HVX_Vector sig_neg = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vsub_VhfVhf(v_one, sig_abs));
return Q6_V_vmux_QVV(is_neg, sig_neg, sig_abs);
}
static inline HVX_Vector hvx_vec_tanh_f16(HVX_Vector x) {
// tanh(x) = 2 * sigmoid(2x) - 1
const HVX_Vector v_two = hvx_vec_splat_f16(2.0f);
HVX_Vector x2 = hvx_vec_mul_f16_f16(x, v_two);
HVX_Vector sig2x = hvx_vec_fast_sigmoid_f16(x2);
const HVX_Vector v_neg_one = hvx_vec_splat_f16(-1.0f);
return hvx_vec_add_f16_f16(hvx_vec_mul_f16_f16(sig2x, v_two), v_neg_one);
}
#endif /* HVX_SIGMOID_H */
+1
View File
@@ -575,6 +575,7 @@ static inline void profile_stop(uint32_t mode, struct profile_data * d) {
static int execute_op(struct htp_ops_context * octx) {
switch (octx->op) {
case HTP_OP_MUL_MAT:
case HTP_OP_MUL_MAT_ADD:
return op_matmul(octx);
case HTP_OP_MUL_MAT_ID:
File diff suppressed because it is too large Load Diff
+19 -32
View File
@@ -392,56 +392,49 @@ static inline size_t htp_mm_hvx_get_vtcm_sizes(
case HTP_MM_KERNEL_HVX_QUANT_ROW: {
size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10);
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256);
// src0 spad is also used in dynamic quantizer to store padded src1 rows
size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, QK_Q8_0_TILED * sizeof(float));
if (vtcm_src0_size < src1_row_size_padded) {
vtcm_src0_size = src1_row_size_padded;
}
vtcm_src0_size = vtcm_src0_size * n_threads;
vtcm_dst_size = vtcm_dst_size * n_threads;
if (is_repack) {
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
uint32_t n_k_tiles = ne10 / 32;
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
vtcm_src0_size = repacked_vtcm_size * n_threads;
}
size_t quant_scratch_size_per_thread = htp_mm_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float));
size_t dst_size_per_thread = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
if (dst_size_per_thread < quant_scratch_size_per_thread) {
dst_size_per_thread = quant_scratch_size_per_thread;
}
vtcm_dst_size = dst_size_per_thread * n_threads;
break;
}
case HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT: {
size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(ne10) : htp_mm_q8_0_flat_row_size(ne10);
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256);
size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, 256);
if (vtcm_src0_size < src1_row_size_padded) {
vtcm_src0_size = src1_row_size_padded;
}
vtcm_src0_size = vtcm_src0_size * n_threads;
vtcm_dst_size = vtcm_dst_size * n_threads;
if (is_repack) {
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
uint32_t n_k_tiles = ne10 / 32;
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
vtcm_src0_size = repacked_vtcm_size * n_threads;
}
size_t quant_scratch_size_per_thread = htp_mm_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float));
size_t dst_size_per_thread = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
if (dst_size_per_thread < quant_scratch_size_per_thread) {
dst_size_per_thread = quant_scratch_size_per_thread;
}
vtcm_dst_size = dst_size_per_thread * n_threads;
break;
}
default:
@@ -463,7 +456,8 @@ static inline size_t htp_mm_hvx_id_get_vtcm_sizes(
size_t src0_row_size, // nb01
uint32_t n_prefetch,
size_t * vtcm_src0_size_out,
size_t * vtcm_src1_size_out
size_t * vtcm_src1_size_out,
size_t * vtcm_dst_size_out
) {
const bool is_repack = (wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 ||
wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL ||
@@ -476,29 +470,22 @@ static inline size_t htp_mm_hvx_id_get_vtcm_sizes(
size_t src0_sz_per_thread = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
size_t src1_sz = htp_mm_round_up(src1_row_size * src1_nrows, 256);
// src0 spad also holds temporary transposed src1 columns during dynamic quantization.
const size_t src1_row_size_padded = htp_mm_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float));
if (src0_sz_per_thread < src1_row_size_padded) {
src0_sz_per_thread = src1_row_size_padded;
}
if (is_repack) {
const uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
const uint32_t n_k_tiles = ne10 / 32;
const uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
src0_sz_per_thread = repacked_vtcm_size;
}
const size_t vtcm_src0_size = src0_sz_per_thread * n_threads;
const size_t vtcm_dst_size = htp_mm_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float)) * n_threads;
*vtcm_src0_size_out = vtcm_src0_size;
*vtcm_src1_size_out = src1_sz;
*vtcm_dst_size_out = vtcm_dst_size;
return vtcm_src0_size + src1_sz;
return vtcm_src0_size + src1_sz + vtcm_dst_size;
}
#ifdef __cplusplus
+13
View File
@@ -31,6 +31,11 @@ if (GGML_OPENCL_EMBED_KERNELS)
target_include_directories(${TARGET_NAME} PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/autogenerated")
endif ()
if (GGML_OPENCL_USE_ADRENO_BIN_KERNELS)
message(STATUS "OpenCL will use precompiled binary kernels for Adreno (improved performance on some platforms)")
add_compile_definitions(GGML_OPENCL_USE_ADRENO_BIN_KERNELS)
endif ()
function(ggml_opencl_add_kernel KNAME)
set(KERN_HDR ${CMAKE_CURRENT_BINARY_DIR}/autogenerated/${KNAME}.cl.h)
set(KERN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/kernels/${KNAME}.cl)
@@ -78,6 +83,8 @@ set(GGML_OPENCL_KERNELS
mul_mv_f16_f32_l4
mul_mv_f16_f32
mul_mv_f32_f32
mul_mv_q1_0_f32
mul_mv_q1_0_f32_flat
mul_mv_q4_0_f32
mul_mv_q4_0_f32_v
mul_mv_q4_0_f32_8x_flat
@@ -128,6 +135,7 @@ set(GGML_OPENCL_KERNELS
moe_sort_by_expert
mul_mm_f32_f32_l4_lm
mul_mm_f16_f32_l4_lm
mul_mm_q1_0_f32_l4_lm
mul_mm_q4_0_f32_l4_lm
mul_mm_q4_1_f32_l4_lm
mul_mm_q5_0_f32_l4_lm
@@ -137,6 +145,8 @@ set(GGML_OPENCL_KERNELS
mul_mm_q4_k_f32_l4_lm
mul_mm_q5_k_f32_l4_lm
mul_mm_q6_k_f32_l4_lm
gemv_noshuffle_q1_0_f32
gemm_noshuffle_q1_0_f32
gemv_noshuffle_q4_0_f32
gemv_noshuffle_q4_0_f32_spec
gemm_noshuffle_q4_0_f32
@@ -192,7 +202,10 @@ set(GGML_OPENCL_KERNELS
mul_mm_f16_f32_kq_kqv
conv2d
conv2d_f16_f32
flash_attn_pre_f16
flash_attn_f32_f16
flash_attn_f32_q8_0
flash_attn_f32_q4_0
flash_attn_f16
flash_attn_f32
)
+91
View File
@@ -0,0 +1,91 @@
#pragma once
// Flash-attention per-(dk,dv) tile tuning for the Adreno OpenCL backend.
// Isolated from ggml-opencl.cpp so the tuning numbers are easy to find and
// edit; the FA dispatch and kernel-compile logic stay in the main file.
// This header is a file section — it is #included exactly once, at the point
// in ggml-opencl.cpp where the ggml logging macros are already in scope.
// Per-(dk, dv) FA config; shared by dispatch and supports_op.
struct ggml_opencl_fa_dim {
int dk; int dv; int bm; int bn; int n_split; int nkv_split_threshold;
};
// Split variant fires when n_kv >= threshold (threshold=0 -> always split).
// Default tuning covers Adreno 7xx/8xx mobile and X1-series laptop GPUs.
static const ggml_opencl_fa_dim g_fa_dims_adreno_default[] = {
{ 40, 40, 64, 32, 1, 0}, { 64, 64, 64, 32, 2, 64},
{ 80, 80, 64, 32, 2, 64}, { 96, 96, 64, 32, 2, 64},
{112, 112, 64, 32, 2, 64}, {128, 128, 64, 32, 2, 64},
{192, 128, 16, 16, 1, 0},
{192, 192, 16, 16, 1, 0},
{256, 256, 16, 16, 16, 0},
};
struct ggml_opencl_fa_dim_table {
const ggml_opencl_fa_dim * data;
size_t count;
const ggml_opencl_fa_dim * begin() const { return data; }
const ggml_opencl_fa_dim * end() const { return data + count; }
};
// Mutable copy of the active table; GGML_OPENCL_FA_TUNE patches entries here
// at backend init without touching the const source table.
static ggml_opencl_fa_dim g_fa_dims_runtime[
sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0])];
static ggml_opencl_fa_dim_table g_opencl_fa_dims = {
g_fa_dims_adreno_default,
sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0]),
};
// GGML_OPENCL_FA_TUNE=dk:dv:bm:bn:nsplit:thr[,…] — patches matching entries
// in the active table at backend init, before the first FA kernel compiles.
// Unmatched (dk,dv) pairs are warned and ignored.
static void ggml_opencl_fa_apply_env_overrides() {
const char * e = std::getenv("GGML_OPENCL_FA_TUNE");
if (!e || !e[0]) {
return;
}
std::string s = e;
size_t pos = 0;
while (pos < s.size()) {
size_t comma = s.find(',', pos);
std::string entry = s.substr(pos, comma == std::string::npos ? std::string::npos : comma - pos);
int dk, dv, bm, bn, nsplit, thr;
if (std::sscanf(entry.c_str(), "%d:%d:%d:%d:%d:%d", &dk, &dv, &bm, &bn, &nsplit, &thr) == 6) {
bool patched = false;
for (size_t i = 0; i < g_opencl_fa_dims.count; ++i) {
ggml_opencl_fa_dim & d = g_fa_dims_runtime[i];
if (d.dk == dk && d.dv == dv) {
d.bm = bm; d.bn = bn; d.n_split = nsplit; d.nkv_split_threshold = thr;
GGML_LOG_INFO("ggml_opencl: FA tune override DK=%d DV=%d -> bm=%d bn=%d n_split=%d thr=%d\n",
dk, dv, bm, bn, nsplit, thr);
patched = true;
break;
}
}
if (!patched) {
GGML_LOG_WARN("ggml_opencl: FA tune override DK=%d DV=%d ignored (no matching dim)\n", dk, dv);
}
} else {
GGML_LOG_WARN("ggml_opencl: FA tune override entry malformed: '%s'\n", entry.c_str());
}
if (comma == std::string::npos) break;
pos = comma + 1;
}
}
// Copy the default table into the mutable runtime buffer and apply any
// GGML_OPENCL_FA_TUNE overrides. A per-generation table can be added here
// once it has been tuned on hardware.
static void ggml_cl_init_fa_dims_table() {
const size_t count = sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0]);
for (size_t i = 0; i < count; ++i) {
g_fa_dims_runtime[i] = g_fa_dims_adreno_default[i];
}
g_opencl_fa_dims = { g_fa_dims_runtime, count };
ggml_opencl_fa_apply_env_overrides();
}
File diff suppressed because it is too large Load Diff
+198
View File
@@ -27,6 +27,8 @@
#define QR5_1 2
#define QK8_0 32
#define QR8_0 1
#define QK1_0 128
#define QR1_0 1
#define QK_K 256
#define K_SCALE_SIZE (3 * QK_K / 64)
#define K_QUANTS_PER_ITERATION 2
@@ -38,6 +40,14 @@ typedef ushort uint16_t;
typedef int int32_t;
typedef uint uint32_t;
//------------------------------------------------------------------------------
// block_q1_0
//------------------------------------------------------------------------------
typedef struct {
half d; // delta
uchar qs[QK1_0/8]; // 1-bit signs (16 bytes)
} block_q1_0;
//------------------------------------------------------------------------------
// block_q4_0
//------------------------------------------------------------------------------
@@ -159,6 +169,42 @@ kernel void kernel_convert_f16_to_bf16(
}
}
//------------------------------------------------------------------------------
// kernel_convert_block_q1_0
// Convert block_q1_0 (AOS) to 2 separate arrays (SOA): quant bytes + scales.
// q1_0 bits are stored in natural order (bit j of byte i -> weight 8*i + j)
//------------------------------------------------------------------------------
kernel void kernel_convert_block_q1_0(
global block_q1_0 * src0,
global uchar * dst_q,
global half * dst_d
) {
global block_q1_0 * b = (global block_q1_0 *) src0 + get_global_id(0);
global uchar * q = (global uchar *) dst_q + (QK1_0/8)*get_global_id(0);
global half * d = (global half *) dst_d + get_global_id(0);
*d = b->d;
for (int i = 0; i < QK1_0/8; ++i) {
q[i] = b->qs[i];
}
}
kernel void kernel_restore_block_q1_0(
global uchar * src_q,
global half * src_d,
global block_q1_0 * dst
) {
global block_q1_0 * b = (global block_q1_0 *) dst + get_global_id(0);
global uchar * q = (global uchar *) src_q + (QK1_0/8)*get_global_id(0);
global half * d = (global half *) src_d + get_global_id(0);
b->d = *d;
for (int i = 0; i < QK1_0/8; ++i) {
b->qs[i] = q[i];
}
}
//------------------------------------------------------------------------------
// kernel_convert_block_q4_0
// Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA).
@@ -1582,6 +1628,158 @@ kernel void kernel_restore_block_q8_0(
}
}
// View-aware AoS q8_0 -> f32 dequant (f32/f32 FA path).
kernel void kernel_dequant_q8_0_f32_view_aos(
global char * src,
ulong src_offset,
ulong src_nb1,
ulong src_nb2,
ulong src_nb3,
int nblk0,
int ne1,
int ne2,
int ne3,
global float * dst
) {
int blk_i0 = get_global_id(0);
int i1 = get_global_id(1);
int batch = get_global_id(2);
if (blk_i0 >= nblk0) return;
if (i1 >= ne1) return;
int i2 = batch % ne2;
int i3 = batch / ne2;
if (i3 >= ne3) return;
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK8_0);
float d = vload_half(0, (global half *)block);
global char * qs = block + 2;
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
global float * out = dst + (dst_row_base + blk_i0) * QK8_0;
for (int i = 0; i < QK8_0; ++i) {
out[i] = d * (float)qs[i];
}
}
// View-aware AoS q8_0 -> f16 dequant. Rows tight, batch strides may be gapped.
kernel void kernel_dequant_q8_0_f16_view_aos(
global char * src,
ulong src_offset,
ulong src_nb1,
ulong src_nb2,
ulong src_nb3,
int nblk0,
int ne1,
int ne2,
int ne3,
global half * dst
) {
int blk_i0 = get_global_id(0);
int i1 = get_global_id(1);
int batch = get_global_id(2);
if (blk_i0 >= nblk0) return;
if (i1 >= ne1) return;
int i2 = batch % ne2;
int i3 = batch / ne2;
if (i3 >= ne3) return;
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK8_0);
float d = vload_half(0, (global half *)block);
global char * qs = block + 2;
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
global half * out = dst + (dst_row_base + blk_i0) * QK8_0;
for (int i = 0; i < QK8_0; ++i) {
out[i] = (half)(d * (float)qs[i]);
}
}
// View-aware AoS q4_0 -> f32 dequant (mirrors the q8_0 view variant).
kernel void kernel_dequant_q4_0_f32_view_aos(
global char * src,
ulong src_offset,
ulong src_nb1,
ulong src_nb2,
ulong src_nb3,
int nblk0,
int ne1,
int ne2,
int ne3,
global float * dst
) {
int blk_i0 = get_global_id(0);
int i1 = get_global_id(1);
int batch = get_global_id(2);
if (blk_i0 >= nblk0) return;
if (i1 >= ne1) return;
int i2 = batch % ne2;
int i3 = batch / ne2;
if (i3 >= ne3) return;
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK4_0/2);
float d = vload_half(0, (global half *)block);
global uchar * qs = (global uchar *)(block + 2);
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
global float * out = dst + (dst_row_base + blk_i0) * QK4_0;
for (int i = 0; i < QK4_0/2; ++i) {
uchar byte = qs[i];
int q0 = (int)(byte & 0x0F) - 8;
int q1 = (int)(byte >> 4) - 8;
out[i] = d * (float)q0;
out[i + QK4_0/2] = d * (float)q1;
}
}
// View-aware AoS q4_0 -> f16 dequant (mirrors the q8_0 view variant).
kernel void kernel_dequant_q4_0_f16_view_aos(
global char * src,
ulong src_offset,
ulong src_nb1,
ulong src_nb2,
ulong src_nb3,
int nblk0,
int ne1,
int ne2,
int ne3,
global half * dst
) {
int blk_i0 = get_global_id(0);
int i1 = get_global_id(1);
int batch = get_global_id(2);
if (blk_i0 >= nblk0) return;
if (i1 >= ne1) return;
int i2 = batch % ne2;
int i3 = batch / ne2;
if (i3 >= ne3) return;
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK4_0/2);
float d = vload_half(0, (global half *)block);
global uchar * qs = (global uchar *)(block + 2);
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
global half * out = dst + (dst_row_base + blk_i0) * QK4_0;
for (int i = 0; i < QK4_0/2; ++i) {
uchar byte = qs[i];
int q0 = (int)(byte & 0x0F) - 8;
int q1 = (int)(byte >> 4) - 8;
out[i] = (half)(d * (float)q0);
out[i + QK4_0/2] = (half)(d * (float)q1);
}
}
kernel void kernel_restore_block_q8_0_trans(
global uchar * src_q,
global half * src_d,
+75 -40
View File
@@ -4,14 +4,26 @@
#define ACC_TYPE4 float4
#define DATA_TYPE half
#define DATA_TYPE4 half4
#define CONVERT_ACC4(x) convert_float4(x)
#define CONVERT_DATA4(x) convert_half4(x)
#define CONVERT_ACC4(x) ((float4)((float)(x).s0, (float)(x).s1, (float)(x).s2, (float)(x).s3))
#define CONVERT_DATA4(x) ((half4)((half)(x).s0, (half)(x).s1, (half)(x).s2, (half)(x).s3))
#define DK_VEC (DK/4)
#define DV_VEC (DV/4)
#define WG_SIZE (BLOCK_M)
#define Q1_WG_SIZE 64
// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs,
// infinite operand can cause undefined behavior and miscompilation for exp.
// Therefore, a large negative value is used instead.
#define FA_M_INIT (-3.0e38f)
// Drop full unroll at DK>=192 Adreno compiler host-memory budget.
#if DK >= 192
#define FA_UNROLL
#else
#define FA_UNROLL _Pragma("unroll")
#endif
inline float get_alibi_slope(
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
) {
@@ -81,18 +93,18 @@ __kernel void flash_attn_f16(
if (my_query_row < n_q) {
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
}
}
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = (ACC_TYPE4)(0.0f);
}
ACC_TYPE m_i = -INFINITY;
ACC_TYPE m_i = FA_M_INIT;
ACC_TYPE l_i = 0.0f;
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
@@ -125,49 +137,72 @@ __kernel void flash_attn_f16(
continue;
}
for (int j = 0; j < BLOCK_N; j += 2) {
for (int j = 0; j < BLOCK_N; j += 4) {
const int k_row0 = k_start + j;
const int k_row1 = k_start + j + 1;
const int k_row2 = k_start + j + 2;
const int k_row3 = k_start + j + 3;
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
#pragma unroll
ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f);
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
const ACC_TYPE4 qk = q_priv[k];
dot_acc0 = mad(qk, CONVERT_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(qk, CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
dot_acc2 = mad(qk, CONVERT_ACC4(l_k[j+2][k]), dot_acc2);
dot_acc3 = mad(qk, CONVERT_ACC4(l_k[j+3][k]), dot_acc3);
}
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
ACC_TYPE s0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE s1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
ACC_TYPE s2 = (dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3) * scale;
ACC_TYPE s3 = (dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3) * scale;
if (is_causal) {
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
const int causal_limit = n_kv - n_q + my_query_row;
if (k_row0 > causal_limit) s0 = FA_M_INIT;
if (k_row1 > causal_limit) s1 = FA_M_INIT;
if (k_row2 > causal_limit) s2 = FA_M_INIT;
if (k_row3 > causal_limit) s3 = FA_M_INIT;
}
if (k_row0 >= n_kv) score0 = -INFINITY;
if (k_row1 >= n_kv) score1 = -INFINITY;
if (k_row0 >= n_kv) s0 = FA_M_INIT;
if (k_row1 >= n_kv) s1 = FA_M_INIT;
if (k_row2 >= n_kv) s2 = FA_M_INIT;
if (k_row3 >= n_kv) s3 = FA_M_INIT;
if (mask_base != NULL) {
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
if (k_row0 < n_kv) s0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) s1 += slope * (ACC_TYPE)mask_ptr[k_row1];
if (k_row2 < n_kv) s2 += slope * (ACC_TYPE)mask_ptr[k_row2];
if (k_row3 < n_kv) s3 += slope * (ACC_TYPE)mask_ptr[k_row3];
}
if (logit_softcap > 0.0f) {
score0 = logit_softcap * tanh(score0 / logit_softcap);
score1 = logit_softcap * tanh(score1 / logit_softcap);
s0 = logit_softcap * tanh(s0 / logit_softcap);
s1 = logit_softcap * tanh(s1 / logit_softcap);
s2 = logit_softcap * tanh(s2 / logit_softcap);
s3 = logit_softcap * tanh(s3 / logit_softcap);
}
const ACC_TYPE m_new = max(m_i, max(score0, score1));
const ACC_TYPE p0 = exp(score0 - m_new);
const ACC_TYPE p1 = exp(score1 - m_new);
const ACC_TYPE scale_prev = exp(m_i - m_new);
const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3)));
const ACC_TYPE scale_prev = native_exp(m_i - m_new);
const ACC_TYPE p0 = native_exp(s0 - m_new);
const ACC_TYPE p1 = native_exp(s1 - m_new);
const ACC_TYPE p2 = native_exp(s2 - m_new);
const ACC_TYPE p3 = native_exp(s3 - m_new);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]);
o_acc[i] = mad(p3, CONVERT_ACC4(l_v[j+3][i]),
mad(p2, CONVERT_ACC4(l_v[j+2][i]),
mad(p1, CONVERT_ACC4(l_v[j+1][i]),
mad(p0, CONVERT_ACC4(l_v[j][i]),
o_acc[i] * scale_prev))));
}
l_i = l_i * scale_prev + p0 + p1;
l_i = l_i * scale_prev + p0 + p1 + p2 + p3;
m_i = m_new;
}
}
@@ -179,7 +214,7 @@ __kernel void flash_attn_f16(
const ACC_TYPE m_final = max(m_i, m_sink);
const ACC_TYPE scale_o = exp(m_i - m_final);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] *= scale_o;
}
@@ -191,12 +226,12 @@ __kernel void flash_attn_f16(
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_i;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv);
}
} else {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = (DATA_TYPE4)(0.0f);
}
@@ -258,7 +293,7 @@ __kernel void flash_attn_f16_q1(
ACC_TYPE4 q_priv[DK_VEC];
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
}
@@ -270,12 +305,12 @@ __kernel void flash_attn_f16_q1(
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
}
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
}
@@ -293,7 +328,7 @@ __kernel void flash_attn_f16_q1(
__local ACC_TYPE local_m[Q1_WG_SIZE];
local_m[tid] = m_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
barrier(CLK_LOCAL_MEM_FENCE);
@@ -301,7 +336,7 @@ __kernel void flash_attn_f16_q1(
const ACC_TYPE m_final = local_m[0];
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
ACC_TYPE l_i = 0.0f;
@@ -311,7 +346,7 @@ __kernel void flash_attn_f16_q1(
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
}
@@ -325,7 +360,7 @@ __kernel void flash_attn_f16_q1(
}
const ACC_TYPE p = exp(score - m_final);
l_i += p;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; i++) {
o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
}
@@ -335,7 +370,7 @@ __kernel void flash_attn_f16_q1(
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
local_l[tid] = l_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_l[tid] += local_l[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -354,7 +389,7 @@ __kernel void flash_attn_f16_q1(
for (int i = 0; i < DV_VEC; i++) {
local_o_comp[tid] = o_acc[i];
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -364,7 +399,7 @@ __kernel void flash_attn_f16_q1(
}
}
} else if (tid == 0) {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f);
}
}
+73 -38
View File
@@ -13,6 +13,18 @@
#define WG_SIZE (BLOCK_M)
#define Q1_WG_SIZE 64
// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs,
// infinite operand can cause undefined behavior and miscompilation for exp.
// Therefore, a large negative value is used instead.
#define FA_M_INIT (-3.0e38f)
// Drop full unroll at DK>=192 Adreno compiler host-memory budget.
#if DK >= 192
#define FA_UNROLL
#else
#define FA_UNROLL _Pragma("unroll")
#endif
inline float get_alibi_slope(
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
) {
@@ -82,18 +94,18 @@ __kernel void flash_attn_f32(
if (my_query_row < n_q) {
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
}
}
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = (ACC_TYPE4)(0.0f);
}
ACC_TYPE m_i = -INFINITY;
ACC_TYPE m_i = FA_M_INIT;
ACC_TYPE l_i = 0.0f;
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
@@ -126,49 +138,72 @@ __kernel void flash_attn_f32(
continue;
}
for (int j = 0; j < BLOCK_N; j += 2) {
for (int j = 0; j < BLOCK_N; j += 4) {
const int k_row0 = k_start + j;
const int k_row1 = k_start + j + 1;
const int k_row2 = k_start + j + 2;
const int k_row3 = k_start + j + 3;
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
#pragma unroll
ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f);
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
const ACC_TYPE4 qk = q_priv[k];
dot_acc0 = mad(qk, CONVERT_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(qk, CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
dot_acc2 = mad(qk, CONVERT_ACC4(l_k[j+2][k]), dot_acc2);
dot_acc3 = mad(qk, CONVERT_ACC4(l_k[j+3][k]), dot_acc3);
}
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
ACC_TYPE s0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE s1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
ACC_TYPE s2 = (dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3) * scale;
ACC_TYPE s3 = (dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3) * scale;
if (is_causal) {
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
const int causal_limit = n_kv - n_q + my_query_row;
if (k_row0 > causal_limit) s0 = FA_M_INIT;
if (k_row1 > causal_limit) s1 = FA_M_INIT;
if (k_row2 > causal_limit) s2 = FA_M_INIT;
if (k_row3 > causal_limit) s3 = FA_M_INIT;
}
if (k_row0 >= n_kv) score0 = -INFINITY;
if (k_row1 >= n_kv) score1 = -INFINITY;
if (k_row0 >= n_kv) s0 = FA_M_INIT;
if (k_row1 >= n_kv) s1 = FA_M_INIT;
if (k_row2 >= n_kv) s2 = FA_M_INIT;
if (k_row3 >= n_kv) s3 = FA_M_INIT;
if (mask_base != NULL) {
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
if (k_row0 < n_kv) s0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) s1 += slope * (ACC_TYPE)mask_ptr[k_row1];
if (k_row2 < n_kv) s2 += slope * (ACC_TYPE)mask_ptr[k_row2];
if (k_row3 < n_kv) s3 += slope * (ACC_TYPE)mask_ptr[k_row3];
}
if (logit_softcap > 0.0f) {
score0 = logit_softcap * tanh(score0 / logit_softcap);
score1 = logit_softcap * tanh(score1 / logit_softcap);
s0 = logit_softcap * tanh(s0 / logit_softcap);
s1 = logit_softcap * tanh(s1 / logit_softcap);
s2 = logit_softcap * tanh(s2 / logit_softcap);
s3 = logit_softcap * tanh(s3 / logit_softcap);
}
const ACC_TYPE m_new = max(m_i, max(score0, score1));
const ACC_TYPE p0 = exp(score0 - m_new);
const ACC_TYPE p1 = exp(score1 - m_new);
const ACC_TYPE scale_prev = exp(m_i - m_new);
const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3)));
const ACC_TYPE scale_prev = native_exp(m_i - m_new);
const ACC_TYPE p0 = native_exp(s0 - m_new);
const ACC_TYPE p1 = native_exp(s1 - m_new);
const ACC_TYPE p2 = native_exp(s2 - m_new);
const ACC_TYPE p3 = native_exp(s3 - m_new);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]);
o_acc[i] = mad(p3, CONVERT_ACC4(l_v[j+3][i]),
mad(p2, CONVERT_ACC4(l_v[j+2][i]),
mad(p1, CONVERT_ACC4(l_v[j+1][i]),
mad(p0, CONVERT_ACC4(l_v[j][i]),
o_acc[i] * scale_prev))));
}
l_i = l_i * scale_prev + p0 + p1;
l_i = l_i * scale_prev + p0 + p1 + p2 + p3;
m_i = m_new;
}
}
@@ -180,7 +215,7 @@ __kernel void flash_attn_f32(
const ACC_TYPE m_final = max(m_i, m_sink);
const ACC_TYPE scale_o = exp(m_i - m_final);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] *= scale_o;
}
@@ -192,12 +227,12 @@ __kernel void flash_attn_f32(
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_i;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv);
}
} else {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = (DATA_TYPE4)(0.0f);
}
@@ -259,7 +294,7 @@ __kernel void flash_attn_f32_q1(
ACC_TYPE4 q_priv[DK_VEC];
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
}
@@ -271,12 +306,12 @@ __kernel void flash_attn_f32_q1(
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
}
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
}
@@ -294,7 +329,7 @@ __kernel void flash_attn_f32_q1(
__local ACC_TYPE local_m[Q1_WG_SIZE];
local_m[tid] = m_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
barrier(CLK_LOCAL_MEM_FENCE);
@@ -302,7 +337,7 @@ __kernel void flash_attn_f32_q1(
const ACC_TYPE m_final = local_m[0];
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
ACC_TYPE l_i = 0.0f;
@@ -312,7 +347,7 @@ __kernel void flash_attn_f32_q1(
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
}
@@ -326,7 +361,7 @@ __kernel void flash_attn_f32_q1(
}
const ACC_TYPE p = exp(score - m_final);
l_i += p;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; i++) {
o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
}
@@ -336,7 +371,7 @@ __kernel void flash_attn_f32_q1(
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
local_l[tid] = l_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_l[tid] += local_l[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -355,7 +390,7 @@ __kernel void flash_attn_f32_q1(
for (int i = 0; i < DV_VEC; i++) {
local_o_comp[tid] = o_acc[i];
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -365,7 +400,7 @@ __kernel void flash_attn_f32_q1(
}
}
} else if (tid == 0) {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f);
}
}
@@ -1,5 +1,13 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_khr_subgroup_shuffle
#pragma OPENCL EXTENSION cl_khr_subgroup_shuffle : enable
#define HAS_SUBGROUP_SHUFFLE 1
#elif defined(cl_qcom_subgroup_shuffle)
#pragma OPENCL EXTENSION cl_qcom_subgroup_shuffle : enable
#define HAS_SUBGROUP_SHUFFLE 1
#endif
#define ACC_TYPE float
#define ACC_TYPE4 float4
#define Q_DATA_TYPE4 float4
@@ -12,9 +20,34 @@
#define DK_VEC (DK/4)
#define DV_VEC (DV/4)
#define WG_SIZE (BLOCK_M)
#define Q1_WG_SIZE 64
// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs,
// infinite operand can cause undefined behavior and miscompilation for exp.
// Therefore, a large negative value is used instead.
#define FA_M_INIT (-3.0e38f)
// Drop full unroll at DK>=192 Adreno compiler host-memory budget.
#if DK >= 192
#define FA_UNROLL
#else
#define FA_UNROLL _Pragma("unroll")
#endif
// N_SPLIT>1 splits DK/DV across threads to cut per-thread register use.
#ifndef N_SPLIT
#define N_SPLIT 1
#endif
#define SPLIT_DK_VEC (DK_VEC / N_SPLIT)
#define SPLIT_DV_VEC (DV_VEC / N_SPLIT)
#if N_SPLIT > 1
#define WG_SIZE (BLOCK_M * N_SPLIT)
#else
#define WG_SIZE (BLOCK_M)
#endif
inline float get_alibi_slope(
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
) {
@@ -54,19 +87,38 @@ __kernel void flash_attn_f32_f16(
const int mask_ne2,
const int mask_ne3,
const global void* sinks_void,
const ulong sinks_offset
const ulong sinks_offset,
const global void * k_pad_void,
const global void * v_pad_void,
const global void * mask_pad_void,
const global char * blk,
const int n_kv_blocks,
const ulong mask_pad_nb1,
const ulong mask_pad_nb2,
const ulong mask_pad_nb3
) {
const int tid = get_local_id(0);
const int block_q_idx = get_group_id(0);
const int head_batch_idx = get_global_id(1);
const int my_query_row = block_q_idx * BLOCK_M + tid;
#if N_SPLIT > 1
const int q_lane = tid / N_SPLIT;
const int split_idx = tid % N_SPLIT;
#else
const int q_lane = tid;
const int split_idx = 0;
#endif
const int my_query_row = block_q_idx * BLOCK_M + q_lane;
const int query_valid = my_query_row < n_q;
const int batch_idx = head_batch_idx / n_head;
const int head_idx = head_batch_idx % n_head;
const int gqa_ratio = n_head / n_head_kv;
const int head_kv_idx = head_idx / gqa_ratio;
const int mask_head_idx = mask_void != NULL ? head_idx % mask_ne2 : 0;
const int mask_batch_idx = mask_void != NULL ? batch_idx % mask_ne3 : 0;
const global char* q_base = (const global char*)q_void + q_offset;
const global char* k_base = (const global char*)k_void + k_offset;
@@ -75,27 +127,41 @@ __kernel void flash_attn_f32_f16(
const global char* mask_base = NULL;
if (mask_void != NULL) {
const int mask_head_idx = head_idx % mask_ne2;
const int mask_batch_idx = batch_idx % mask_ne3;
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
}
const global char* mask_pad_base = NULL;
if (mask_pad_void != NULL) {
mask_pad_base = (const global char*)mask_pad_void + mask_batch_idx * mask_pad_nb3 + mask_head_idx * mask_pad_nb2;
}
const global char* blk_base = NULL;
if (blk != NULL) {
const int n_q_blocks = (n_q + BLOCK_M - 1) / BLOCK_M;
blk_base = blk + (((mask_batch_idx * mask_ne2) + mask_head_idx) * n_q_blocks + block_q_idx) * n_kv_blocks;
}
ACC_TYPE4 q_priv[DK_VEC];
if (my_query_row < n_q) {
ACC_TYPE4 q_priv[SPLIT_DK_VEC];
const int dk_off = split_idx * SPLIT_DK_VEC;
if (query_valid) {
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);
FA_UNROLL
for (int i = 0; i < SPLIT_DK_VEC; ++i) {
q_priv[i] = CONVERT_Q_ACC4(q_ptr[dk_off + i]);
}
} else {
FA_UNROLL
for (int i = 0; i < SPLIT_DK_VEC; ++i) {
q_priv[i] = (ACC_TYPE4)(0.0f);
}
}
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
ACC_TYPE4 o_acc[SPLIT_DV_VEC];
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_acc[i] = (ACC_TYPE4)(0.0f);
}
ACC_TYPE m_i = -INFINITY;
ACC_TYPE m_i = FA_M_INIT;
ACC_TYPE l_i = 0.0f;
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
@@ -103,86 +169,369 @@ __kernel void flash_attn_f32_f16(
__local KV_DATA_TYPE4 l_k[BLOCK_N][DK_VEC];
__local KV_DATA_TYPE4 l_v[BLOCK_N][DV_VEC];
#if N_SPLIT > 1 && !defined(HAS_SUBGROUP_SHUFFLE)
__local ACC_TYPE local_partial[BLOCK_N][WG_SIZE];
__local ACC_TYPE local_p[BLOCK_M][BLOCK_N];
__local ACC_TYPE local_softmax_scale[BLOCK_M];
__local ACC_TYPE local_l_inv[BLOCK_M];
#endif
for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) {
char blk_cur = 1;
if (blk_base != NULL) {
blk_cur = blk_base[k_start / BLOCK_N];
if (blk_cur == 0) continue;
}
const int use_kv_pad = k_pad_void != NULL && k_start + BLOCK_N > n_kv;
const int k_tile_start = use_kv_pad ? 0 : k_start;
const ulong k_tile_nb2 = use_kv_pad ? (ulong) BLOCK_N * k_nb1 : k_nb2;
const ulong k_tile_nb3 = use_kv_pad ? (ulong) n_head_kv * k_tile_nb2 : k_nb3;
const ulong v_tile_nb2 = use_kv_pad ? (ulong) BLOCK_N * v_nb1 : v_nb2;
const ulong v_tile_nb3 = use_kv_pad ? (ulong) n_head_kv * v_tile_nb2 : v_nb3;
const global char* k_tile_base = use_kv_pad ? (const global char*) k_pad_void : k_base;
const global char* v_tile_base = use_kv_pad ? (const global char*) v_pad_void : v_base;
for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) {
const int row = i / DK_VEC;
const int col = i % DK_VEC;
const int k_row_idx = k_start + row;
if (k_row_idx < n_kv) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1;
l_k[row][col] = ((__global KV_DATA_TYPE4*)(k_base + k_row_offset))[col];
const int k_row_idx = k_tile_start + row;
if (use_kv_pad || k_row_idx < n_kv) {
const ulong k_row_offset = batch_idx * k_tile_nb3 + head_kv_idx * k_tile_nb2 + k_row_idx * k_nb1;
l_k[row][col] = ((__global KV_DATA_TYPE4*)(k_tile_base + k_row_offset))[col];
} else {
l_k[row][col] = (KV_DATA_TYPE4)(0.0h);
}
}
for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) {
const int row = i / DV_VEC;
const int col = i % DV_VEC;
const int v_row_idx = k_start + row;
if (v_row_idx < n_kv) {
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1;
l_v[row][col] = ((__global KV_DATA_TYPE4*)(v_base + v_row_offset))[col];
const int v_row_idx = k_tile_start + row;
if (use_kv_pad || v_row_idx < n_kv) {
const ulong v_row_offset = batch_idx * v_tile_nb3 + head_kv_idx * v_tile_nb2 + v_row_idx * v_nb1;
l_v[row][col] = ((__global KV_DATA_TYPE4*)(v_tile_base + v_row_offset))[col];
} else {
l_v[row][col] = (KV_DATA_TYPE4)(0.0h);
}
}
barrier(CLK_LOCAL_MEM_FENCE);
if (my_query_row >= n_q) {
continue;
#if N_SPLIT > 1 && defined(HAS_SUBGROUP_SHUFFLE)
{
const int dv_off = split_idx * SPLIT_DV_VEC;
for (int j = 0; j < BLOCK_N; j += 2) {
const int k_row0 = k_start + j;
const int k_row1 = k_start + j + 1;
ACC_TYPE partial0 = 0.0f;
ACC_TYPE partial1 = 0.0f;
FA_UNROLL
for (int k = 0; k < SPLIT_DK_VEC; k++) {
const ACC_TYPE4 qk = q_priv[k];
ACC_TYPE4 dot0 = qk * CONVERT_KV_ACC4(l_k[j ][dk_off + k]);
ACC_TYPE4 dot1 = qk * CONVERT_KV_ACC4(l_k[j+1][dk_off + k]);
partial0 += dot0.s0 + dot0.s1 + dot0.s2 + dot0.s3;
partial1 += dot1.s0 + dot1.s1 + dot1.s2 + dot1.s3;
}
FA_UNROLL
for (int step = 1; step < N_SPLIT; step <<= 1) {
partial0 += sub_group_shuffle_xor(partial0, step);
partial1 += sub_group_shuffle_xor(partial1, step);
}
ACC_TYPE score0 = partial0 * scale;
ACC_TYPE score1 = partial1 * scale;
if (!query_valid) { score0 = FA_M_INIT; score1 = FA_M_INIT; }
if (is_causal) {
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = FA_M_INIT;
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = FA_M_INIT;
}
if (k_row0 >= n_kv) score0 = FA_M_INIT;
if (k_row1 >= n_kv) score1 = FA_M_INIT;
if (query_valid && mask_base != NULL && blk_cur != 2) {
if (use_kv_pad && mask_pad_base != NULL) {
const global MASK_DATA_TYPE* mask_ptr =
(const global MASK_DATA_TYPE*)(mask_pad_base + my_query_row * mask_pad_nb1);
score0 += slope * (ACC_TYPE)mask_ptr[j];
score1 += slope * (ACC_TYPE)mask_ptr[j + 1];
} else {
const global MASK_DATA_TYPE* mask_ptr =
(const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
}
}
if (logit_softcap > 0.0f) {
score0 = logit_softcap * tanh(score0 / logit_softcap);
score1 = logit_softcap * tanh(score1 / logit_softcap);
}
const ACC_TYPE m_new = max(m_i, max(score0, score1));
// Whole tile masked (m_new == FA_M_INIT): force the exp() args
// far negative so the tile contributes 0, not exp(0)=1.
const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new;
const ACC_TYPE sp = native_exp(m_i - m_exp);
const ACC_TYPE p0 = native_exp(score0 - m_exp);
const ACC_TYPE p1 = native_exp(score1 - m_exp);
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_acc[i] = o_acc[i] * sp
+ p0 * CONVERT_KV_ACC4(l_v[j ][dv_off + i])
+ p1 * CONVERT_KV_ACC4(l_v[j+1][dv_off + i]);
}
l_i = l_i * sp + p0 + p1;
m_i = m_new;
}
}
for (int j = 0; j < BLOCK_N; j += 2) {
const int k_row0 = k_start + j;
const int k_row1 = k_start + j + 1;
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
#pragma unroll
for (int k = 0; k < DK_VEC; k++) {
dot_acc0 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1);
#elif N_SPLIT > 1
// N_SPLIT>1 fallback (no shuffle): 3-phase local-memory reduction.
// Phase 1 partial dots for all BLOCK_N tokens.
for (int j = 0; j < BLOCK_N; ++j) {
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
FA_UNROLL
for (int k = 0; k < SPLIT_DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j][dk_off + k]), dot_acc);
}
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
if (is_causal) {
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
}
if (k_row0 >= n_kv) score0 = -INFINITY;
if (k_row1 >= n_kv) score1 = -INFINITY;
if (mask_base != NULL) {
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
}
if (logit_softcap > 0.0f) {
score0 = logit_softcap * tanh(score0 / logit_softcap);
score1 = logit_softcap * tanh(score1 / logit_softcap);
}
const ACC_TYPE m_new = max(m_i, max(score0, score1));
const ACC_TYPE p0 = exp(score0 - m_new);
const ACC_TYPE p1 = exp(score1 - m_new);
const ACC_TYPE scale_prev = exp(m_i - m_new);
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_KV_ACC4(l_v[j][i]) + p1 * CONVERT_KV_ACC4(l_v[j+1][i]);
}
l_i = l_i * scale_prev + p0 + p1;
m_i = m_new;
local_partial[j][tid] =
dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3;
}
barrier(CLK_LOCAL_MEM_FENCE); // 1 barrier: partial dots visible
// Phase 2 split_idx==0 reduces partial sums and computes block softmax.
if (split_idx == 0) {
if (query_valid) {
ACC_TYPE m_new = m_i;
for (int j = 0; j < BLOCK_N; ++j) {
const int k_row = k_start + j;
ACC_TYPE score = 0.0f;
FA_UNROLL
for (int s = 0; s < N_SPLIT; s++) {
score += local_partial[j][q_lane * N_SPLIT + s];
}
score *= scale;
if (is_causal && k_row > (n_kv - n_q + my_query_row)) score = FA_M_INIT;
if (k_row >= n_kv) score = FA_M_INIT;
if (mask_base != NULL && blk_cur != 2) {
if (use_kv_pad && mask_pad_base != NULL) {
const global MASK_DATA_TYPE* mask_ptr =
(const global MASK_DATA_TYPE*)(mask_pad_base + my_query_row * mask_pad_nb1);
score += slope * (ACC_TYPE)mask_ptr[j];
} else {
const global MASK_DATA_TYPE* mask_ptr =
(const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row < n_kv) score += slope * (ACC_TYPE)mask_ptr[k_row];
}
}
if (logit_softcap > 0.0f) {
score = logit_softcap * tanh(score / logit_softcap);
}
m_new = max(m_new, score);
local_p[q_lane][j] = score;
}
const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new;
const ACC_TYPE sp = native_exp(m_i - m_exp);
ACC_TYPE l_new = l_i * sp;
for (int j = 0; j < BLOCK_N; ++j) {
const ACC_TYPE p = native_exp(local_p[q_lane][j] - m_exp);
local_p[q_lane][j] = p;
l_new += p;
}
local_softmax_scale[q_lane] = sp;
l_i = l_new;
m_i = m_new;
} else {
local_softmax_scale[q_lane] = 1.0f;
for (int j = 0; j < BLOCK_N; ++j) local_p[q_lane][j] = 0.0f;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
// Phase 3 V accumulate using broadcast probabilities.
{
const ACC_TYPE sp_block = local_softmax_scale[q_lane];
const int dv_off = split_idx * SPLIT_DV_VEC;
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_acc[i] *= sp_block;
}
for (int j = 0; j < BLOCK_N; ++j) {
const ACC_TYPE p = local_p[q_lane][j];
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_acc[i] = mad(p, CONVERT_KV_ACC4(l_v[j][dv_off + i]), o_acc[i]);
}
}
}
#else
// N_SPLIT==1: j+=4 unroll. Requires BLOCK_N % 4 == 0.
if (query_valid) {
for (int j = 0; j < BLOCK_N; j += 4) {
const int k_row0 = k_start + j;
const int k_row1 = k_start + j + 1;
const int k_row2 = k_start + j + 2;
const int k_row3 = k_start + j + 3;
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f);
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
const ACC_TYPE4 qk = q_priv[k];
dot_acc0 = mad(qk, CONVERT_KV_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(qk, CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1);
dot_acc2 = mad(qk, CONVERT_KV_ACC4(l_k[j+2][k]), dot_acc2);
dot_acc3 = mad(qk, CONVERT_KV_ACC4(l_k[j+3][k]), dot_acc3);
}
ACC_TYPE s0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE s1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
ACC_TYPE s2 = (dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3) * scale;
ACC_TYPE s3 = (dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3) * scale;
if (is_causal) {
const int causal_limit = n_kv - n_q + my_query_row;
if (k_row0 > causal_limit) s0 = FA_M_INIT;
if (k_row1 > causal_limit) s1 = FA_M_INIT;
if (k_row2 > causal_limit) s2 = FA_M_INIT;
if (k_row3 > causal_limit) s3 = FA_M_INIT;
}
if (k_row0 >= n_kv) s0 = FA_M_INIT;
if (k_row1 >= n_kv) s1 = FA_M_INIT;
if (k_row2 >= n_kv) s2 = FA_M_INIT;
if (k_row3 >= n_kv) s3 = FA_M_INIT;
if (mask_base != NULL && blk_cur != 2) {
if (use_kv_pad && mask_pad_base != NULL) {
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_pad_base + my_query_row * mask_pad_nb1);
s0 += slope * (ACC_TYPE)mask_ptr[j];
s1 += slope * (ACC_TYPE)mask_ptr[j + 1];
s2 += slope * (ACC_TYPE)mask_ptr[j + 2];
s3 += slope * (ACC_TYPE)mask_ptr[j + 3];
} else {
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) s0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) s1 += slope * (ACC_TYPE)mask_ptr[k_row1];
if (k_row2 < n_kv) s2 += slope * (ACC_TYPE)mask_ptr[k_row2];
if (k_row3 < n_kv) s3 += slope * (ACC_TYPE)mask_ptr[k_row3];
}
}
if (logit_softcap > 0.0f) {
s0 = logit_softcap * tanh(s0 / logit_softcap);
s1 = logit_softcap * tanh(s1 / logit_softcap);
s2 = logit_softcap * tanh(s2 / logit_softcap);
s3 = logit_softcap * tanh(s3 / logit_softcap);
}
const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3)));
// Whole tile masked (m_new == FA_M_INIT): force the exp() args
// far negative so the tile contributes 0, not exp(0)=1.
const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new;
const ACC_TYPE scale_prev = native_exp(m_i - m_exp);
const ACC_TYPE p0 = native_exp(s0 - m_exp);
const ACC_TYPE p1 = native_exp(s1 - m_exp);
const ACC_TYPE p2 = native_exp(s2 - m_exp);
const ACC_TYPE p3 = native_exp(s3 - m_exp);
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = mad(p3, CONVERT_KV_ACC4(l_v[j+3][i]),
mad(p2, CONVERT_KV_ACC4(l_v[j+2][i]),
mad(p1, CONVERT_KV_ACC4(l_v[j+1][i]),
mad(p0, CONVERT_KV_ACC4(l_v[j][i]),
o_acc[i] * scale_prev))));
}
l_i = l_i * scale_prev + p0 + p1 + p2 + p3;
m_i = m_new;
}
}
#endif
// End of tile: every thread must finish reading l_k/l_v before the
// next iteration's load overwrites them (WAR hazard on local memory).
barrier(CLK_LOCAL_MEM_FENCE);
}
if (my_query_row < n_q) {
// Write output.
#if N_SPLIT > 1 && defined(HAS_SUBGROUP_SHUFFLE)
if (query_valid) {
ACC_TYPE sinks_sp = 1.0f;
if (sinks_void != NULL) {
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
const ACC_TYPE m_sink = sinks_ptr[head_idx];
const ACC_TYPE m_final = max(m_i, m_sink);
sinks_sp = exp(m_i - m_final);
l_i = l_i * sinks_sp + exp(m_sink - m_final);
m_i = m_final;
}
const ACC_TYPE l_inv = (l_i > 0.0f) ? (1.0f / l_i) : 0.0f;
const int dv_off = split_idx * SPLIT_DV_VEC;
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
if (l_inv > 0.0f) {
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_row[dv_off + i] = CONVERT_O_DATA4(o_acc[i] * sinks_sp * l_inv);
}
} else {
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_row[dv_off + i] = (O_DATA_TYPE4)(0.0f);
}
}
}
#elif N_SPLIT > 1
if (split_idx == 0) {
ACC_TYPE sinks_sp = 1.0f;
if (query_valid && sinks_void != NULL) {
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
const ACC_TYPE m_sink = sinks_ptr[head_idx];
const ACC_TYPE m_final = max(m_i, m_sink);
sinks_sp = exp(m_i - m_final);
l_i = l_i * sinks_sp + exp(m_sink - m_final);
m_i = m_final;
}
local_softmax_scale[q_lane] = sinks_sp;
local_l_inv[q_lane] = (query_valid && l_i > 0.0f) ? (1.0f / l_i) : 0.0f;
}
barrier(CLK_LOCAL_MEM_FENCE);
if (query_valid) {
const ACC_TYPE sinks_sp = local_softmax_scale[q_lane];
const ACC_TYPE l_inv = local_l_inv[q_lane];
const int dv_off = split_idx * SPLIT_DV_VEC;
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
if (l_inv > 0.0f) {
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_row[dv_off + i] = CONVERT_O_DATA4(o_acc[i] * sinks_sp * l_inv);
}
} else {
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_row[dv_off + i] = (O_DATA_TYPE4)(0.0f);
}
}
}
#else
if (query_valid) {
if (sinks_void != NULL) {
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
const ACC_TYPE m_sink = sinks_ptr[head_idx];
const ACC_TYPE m_final = max(m_i, m_sink);
const ACC_TYPE scale_o = exp(m_i - m_final);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] *= scale_o;
}
@@ -194,17 +543,18 @@ __kernel void flash_attn_f32_f16(
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_i;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = CONVERT_O_DATA4(o_acc[i] * l_inv);
}
} else {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = (O_DATA_TYPE4)(0.0f);
}
}
}
#endif
}
__kernel void flash_attn_f32_f16_q1(
@@ -258,13 +608,16 @@ __kernel void flash_attn_f32_f16_q1(
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
}
ACC_TYPE4 q_priv[DK_VEC];
// Q is uniform across WG threads (n_q=1). Share via local memory to
// avoid per-thread q_priv[DK_VEC] dynamic-indexed private array that
// spills to DDR on Adreno.
__local ACC_TYPE4 q_shared[DK_VEC];
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);
for (int i = tid; i < DK_VEC; i += Q1_WG_SIZE) {
q_shared[i] = CONVERT_Q_ACC4(q_ptr[i]);
}
barrier(CLK_LOCAL_MEM_FENCE);
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
@@ -273,14 +626,14 @@ __kernel void flash_attn_f32_f16_q1(
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
}
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
@@ -296,7 +649,7 @@ __kernel void flash_attn_f32_f16_q1(
__local ACC_TYPE local_m[Q1_WG_SIZE];
local_m[tid] = m_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
barrier(CLK_LOCAL_MEM_FENCE);
@@ -304,7 +657,7 @@ __kernel void flash_attn_f32_f16_q1(
const ACC_TYPE m_final = local_m[0];
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
ACC_TYPE l_i = 0.0f;
@@ -314,9 +667,9 @@ __kernel void flash_attn_f32_f16_q1(
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
const global KV_DATA_TYPE4* v_ptr = (const global KV_DATA_TYPE4*)(v_base + v_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
@@ -328,7 +681,7 @@ __kernel void flash_attn_f32_f16_q1(
}
const ACC_TYPE p = exp(score - m_final);
l_i += p;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; i++) {
o_acc[i] = mad(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]);
}
@@ -338,7 +691,7 @@ __kernel void flash_attn_f32_f16_q1(
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
local_l[tid] = l_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_l[tid] += local_l[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -357,7 +710,7 @@ __kernel void flash_attn_f32_f16_q1(
for (int i = 0; i < DV_VEC; i++) {
local_o_comp[tid] = o_acc[i];
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -367,7 +720,257 @@ __kernel void flash_attn_f32_f16_q1(
}
}
} else if (tid == 0) {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (O_DATA_TYPE4)(0.0f);
}
}
// Flash-decoding split pass. gid(2) = q_idx * n_splits + split_idx.
// Partial record per split: [m, l, O[DV]]. Merge kernel applies sink + norm.
#define FA_PARTIAL_FLOATS (2 + DV)
__kernel void flash_attn_f32_f16_q1_split(
const global void * q_void, ulong q_offset,
const global void * k_void, ulong k_offset,
const global void * v_void, ulong v_offset,
const float scale,
const int n_q,
const int n_kv,
const int n_head,
const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
const float max_bias,
const float m0,
const float m1,
const int n_head_log2,
const float logit_softcap,
const int n_head_kv,
const global void * mask_void,
const ulong mask_offset,
const ulong mask_nb1,
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
const int mask_ne3,
global float * partial_void,
const int n_splits,
const int kv_per_split
) {
const int tid = get_local_id(0);
const int head_batch_idx = get_global_id(1);
const int split_q_idx = get_global_id(2);
const int split_idx = split_q_idx % n_splits;
const int q_idx = split_q_idx / n_splits;
const int batch_idx = head_batch_idx / n_head;
const int head_idx = head_batch_idx % n_head;
const int gqa_ratio = n_head / n_head_kv;
const int head_kv_idx = head_idx / gqa_ratio;
const int kv_start = split_idx * kv_per_split;
const int kv_end = min(kv_start + kv_per_split, n_kv);
const ulong record_stride = (ulong) FA_PARTIAL_FLOATS;
const ulong record_idx = ((((ulong) batch_idx * n_head + head_idx) * n_q + q_idx)
* n_splits + split_idx);
global float * rec = partial_void + record_idx * record_stride;
global float4 * rec_o = (global float4 *) (rec + 2);
if (kv_start >= kv_end) {
// Empty split: leave sentinel partial for merge.
if (tid == 0) {
rec[0] = FA_M_INIT;
rec[1] = 0.0f;
}
return;
}
const global char * q_base = (const global char *) q_void + q_offset;
const global char * k_base = (const global char *) k_void + k_offset;
const global char * v_base = (const global char *) v_void + v_offset;
const global char * mask_base = NULL;
if (mask_void != NULL) {
const int mask_head_idx = head_idx % mask_ne2;
const int mask_batch_idx = batch_idx % mask_ne3;
mask_base = (const global char *) mask_void + mask_offset +
mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2 +
(ulong) q_idx * mask_nb1;
}
// Share Q via local memory (n_q=1 per split -> uniform across WG).
__local ACC_TYPE4 q_shared[DK_VEC];
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + (ulong) q_idx * q_nb1;
const global Q_DATA_TYPE4 * q_ptr = (const global Q_DATA_TYPE4 *) (q_base + q_row_offset);
for (int i = tid; i < DK_VEC; i += Q1_WG_SIZE) {
q_shared[i] = CONVERT_Q_ACC4(q_ptr[i]);
}
barrier(CLK_LOCAL_MEM_FENCE);
const float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
// Pass 1a split-local max.
ACC_TYPE m_i = FA_M_INIT;
for (int k_idx = kv_start + tid; k_idx < kv_end; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global KV_DATA_TYPE4 * k_ptr = (const global KV_DATA_TYPE4 *) (k_base + k_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
for (int k = 0; k < DK_VEC; ++k) {
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
const global MASK_DATA_TYPE * mask_ptr = (const global MASK_DATA_TYPE *) (mask_base);
score += slope * (ACC_TYPE) mask_ptr[k_idx];
}
if (logit_softcap > 0.0f) {
score = logit_softcap * tanh(score / logit_softcap);
}
m_i = max(m_i, score);
}
__local ACC_TYPE local_m[Q1_WG_SIZE];
local_m[tid] = m_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
barrier(CLK_LOCAL_MEM_FENCE);
}
const ACC_TYPE m_c = local_m[0];
// Pass 1b softmax-weighted V accumulate.
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
ACC_TYPE l_i = 0.0f;
for (int k_idx = kv_start + tid; k_idx < kv_end; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1;
const global KV_DATA_TYPE4 * k_ptr = (const global KV_DATA_TYPE4 *) (k_base + k_row_offset);
const global KV_DATA_TYPE4 * v_ptr = (const global KV_DATA_TYPE4 *) (v_base + v_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
for (int k = 0; k < DK_VEC; ++k) {
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
const global MASK_DATA_TYPE * mask_ptr = (const global MASK_DATA_TYPE *) (mask_base);
score += slope * (ACC_TYPE) mask_ptr[k_idx];
}
if (logit_softcap > 0.0f) {
score = logit_softcap * tanh(score / logit_softcap);
}
const ACC_TYPE p = exp(score - m_c);
l_i += p;
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = mad(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]);
}
}
__local ACC_TYPE local_l[Q1_WG_SIZE];
__local ACC_TYPE4 local_o[Q1_WG_SIZE];
local_l[tid] = l_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_l[tid] += local_l[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
}
const ACC_TYPE l_c = local_l[0];
if (tid == 0) {
rec[0] = (float) m_c;
rec[1] = (float) l_c;
}
for (int i = 0; i < DV_VEC; ++i) {
local_o[tid] = o_acc[i];
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_o[tid] += local_o[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
}
if (tid == 0) {
rec_o[i] = local_o[0];
}
}
}
// FD Pass 2: merge per-split partials into final O. Empty splits drop via exp(-INF)=0.
__kernel void flash_attn_f32_merge(
const global float * partial_void,
global void * o_void,
const ulong o_offset,
const int n_head,
const int n_splits,
const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
const global void * sinks_void,
const ulong sinks_offset,
const int n_q
) {
const int lane = get_local_id(0); // 0..DV_VEC-1
const int head_batch_idx = get_global_id(1);
const int q_idx = get_global_id(2);
const int batch_idx = head_batch_idx / n_head;
const int head_idx = head_batch_idx % n_head;
const ulong record_stride = (ulong) FA_PARTIAL_FLOATS;
const ulong record_idx_0 = (((ulong) batch_idx * n_head + head_idx) * n_q + q_idx) * n_splits;
const global float * rec0 = partial_void + record_idx_0 * record_stride;
__local ACC_TYPE m_final_shared;
__local ACC_TYPE l_final_shared;
if (lane == 0) {
ACC_TYPE m = FA_M_INIT;
for (int c = 0; c < n_splits; ++c) {
const ACC_TYPE m_c = rec0[c * record_stride + 0];
m = max(m, m_c);
}
ACC_TYPE m_sink = 0.0f;
bool has_sink = false;
if (sinks_void != NULL) {
const global ACC_TYPE * sinks_ptr =
(const global ACC_TYPE *) ((const global char *) sinks_void + sinks_offset);
m_sink = sinks_ptr[head_idx];
has_sink = true;
m = max(m, m_sink);
}
ACC_TYPE l = 0.0f;
for (int c = 0; c < n_splits; ++c) {
const ACC_TYPE m_c = rec0[c * record_stride + 0];
const ACC_TYPE l_c = rec0[c * record_stride + 1];
if (m_c > FA_M_INIT) {
l += l_c * exp(m_c - m);
}
}
if (has_sink) {
l += exp(m_sink - m);
}
m_final_shared = m;
l_final_shared = l;
}
barrier(CLK_LOCAL_MEM_FENCE);
const ACC_TYPE m_final = m_final_shared;
const ACC_TYPE l_final = l_final_shared;
const ACC_TYPE l_inv = (l_final > 0.0f) ? (1.0f / l_final) : 0.0f;
ACC_TYPE4 o = (ACC_TYPE4)(0.0f);
for (int c = 0; c < n_splits; ++c) {
const global float * rec_c = rec0 + c * record_stride;
const ACC_TYPE m_c = rec_c[0];
if (m_c <= FA_M_INIT) continue;
const global float4 * rec_oc = (const global float4 *) (rec_c + 2);
const ACC_TYPE scale_c = exp(m_c - m_final);
o = mad((ACC_TYPE4)(scale_c), rec_oc[lane], o);
}
o = o * l_inv;
const ulong o_row_offset = (ulong) batch_idx * o_nb3 + (ulong) q_idx * o_nb2 + (ulong) head_idx * o_nb1;
global O_DATA_TYPE4 * o_row = (global O_DATA_TYPE4 *) ((global char *) o_void + o_offset + o_row_offset);
o_row[lane] = CONVERT_O_DATA4(o);
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,156 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void flash_attn_kv_pad_f16(
const global void * k_void, ulong k_offset,
const global void * v_void, ulong v_offset,
global void * k_pad_void,
global void * v_pad_void,
const int n_kv,
const int n_head_kv,
const int n_batch,
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3
) {
const int row_idx = get_global_id(0);
const int head_kv_idx = get_global_id(1);
const int batch_idx = get_global_id(2);
if (row_idx >= BLOCK_N || head_kv_idx >= n_head_kv || batch_idx >= n_batch) {
return;
}
const int tail_start = n_kv - (n_kv % BLOCK_N);
const int src_row_idx = tail_start + row_idx;
const global char * k_src = (const global char *) k_void + k_offset;
const global char * v_src = (const global char *) v_void + v_offset;
global char * k_pad = (global char *) k_pad_void;
global char * v_pad = (global char *) v_pad_void;
const ulong k_dst_offset = ((ulong) batch_idx * (ulong) n_head_kv + (ulong) head_kv_idx) * ((ulong) BLOCK_N * k_nb1) + (ulong) row_idx * k_nb1;
const ulong v_dst_offset = ((ulong) batch_idx * (ulong) n_head_kv + (ulong) head_kv_idx) * ((ulong) BLOCK_N * v_nb1) + (ulong) row_idx * v_nb1;
if (src_row_idx < n_kv) {
const ulong k_src_offset = (ulong) batch_idx * k_nb3 + (ulong) head_kv_idx * k_nb2 + (ulong) src_row_idx * k_nb1;
const ulong v_src_offset = (ulong) batch_idx * v_nb3 + (ulong) head_kv_idx * v_nb2 + (ulong) src_row_idx * v_nb1;
for (ulong i = 0; i < k_nb1; ++i) {
k_pad[k_dst_offset + i] = k_src[k_src_offset + i];
}
for (ulong i = 0; i < v_nb1; ++i) {
v_pad[v_dst_offset + i] = v_src[v_src_offset + i];
}
} else {
for (ulong i = 0; i < k_nb1; ++i) {
k_pad[k_dst_offset + i] = 0;
}
for (ulong i = 0; i < v_nb1; ++i) {
v_pad[v_dst_offset + i] = 0;
}
}
}
__kernel void flash_attn_mask_pad_f16(
const global void * mask_void, ulong mask_offset,
global void * mask_pad_void,
const int n_q,
const int n_kv,
const ulong mask_nb1,
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
const int mask_ne3
) {
const int col_idx = get_global_id(0);
const int q_row = get_global_id(1);
const int mask_slice = get_global_id(2);
if (col_idx >= BLOCK_N || q_row >= n_q || mask_slice >= mask_ne2 * mask_ne3) {
return;
}
const int tail_start = n_kv - (n_kv % BLOCK_N);
const int src_col_idx = tail_start + col_idx;
const int mask_head_idx = mask_slice % mask_ne2;
const int mask_batch_idx = mask_slice / mask_ne2;
const global char * mask_src_base = (const global char *) mask_void + mask_offset +
(ulong) mask_batch_idx * mask_nb3 +
(ulong) mask_head_idx * mask_nb2 +
(ulong) q_row * mask_nb1;
const global half * mask_src = (const global half *) mask_src_base;
global half * mask_pad = (global half *) mask_pad_void;
const ulong dst_idx =
(((ulong) mask_batch_idx * (ulong) mask_ne2 + (ulong) mask_head_idx) * (ulong) n_q + (ulong) q_row) * (ulong) BLOCK_N +
(ulong) col_idx;
mask_pad[dst_idx] = src_col_idx < n_kv ? mask_src[src_col_idx] : (half) (-INFINITY);
}
// Per-KV-tile mask class. 0=all -inf (skip tile), 1=mixed (apply mask),
// 2=all zero, no -inf (skip mask lookup). Causal diagonal tiles are class 1.
__kernel void flash_attn_blk_f16(
const global void * mask_void, ulong mask_offset,
global char * blk,
const int n_q,
const int n_kv,
const ulong mask_nb1,
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
const int mask_ne3
) {
const int kv_block_idx = get_global_id(0);
const int q_block_idx = get_global_id(1);
const int mask_slice = get_global_id(2);
const int n_q_blocks = (n_q + BLOCK_M - 1) / BLOCK_M;
const int n_kv_blocks = (n_kv + BLOCK_N - 1) / BLOCK_N;
if (kv_block_idx >= n_kv_blocks || q_block_idx >= n_q_blocks || mask_slice >= mask_ne2 * mask_ne3) {
return;
}
const int mask_head_idx = mask_slice % mask_ne2;
const int mask_batch_idx = mask_slice / mask_ne2;
const int q_start = q_block_idx * BLOCK_M;
const int k_start = kv_block_idx * BLOCK_N;
const int q_count = min(BLOCK_M, n_q - q_start);
const int k_count = min(BLOCK_N, n_kv - k_start);
const half neg_max_half = (half) (-65504.0f);
char has_unmasked = 0;
char has_masked = 0;
char has_nonzero = 0;
const global char * mask_base = (const global char *) mask_void + mask_offset +
(ulong) mask_batch_idx * mask_nb3 +
(ulong) mask_head_idx * mask_nb2;
for (int qi = 0; qi < q_count; ++qi) {
const global half * mask_row = (const global half *) (mask_base + (ulong) (q_start + qi) * mask_nb1) + k_start;
for (int ki = 0; ki < k_count; ++ki) {
const half v = mask_row[ki];
if (v <= neg_max_half) {
has_masked = 1;
} else {
has_unmasked = 1;
if (v != (half) 0.0f) {
has_nonzero = 1;
}
}
}
if (has_masked && has_unmasked) break; // mixed tile — short-circuit.
}
char res;
if (has_unmasked == 0) {
res = 0;
} else if (has_masked || has_nonzero) {
res = 1;
} else {
res = 2;
}
blk[((ulong) mask_slice * (ulong) n_q_blocks + (ulong) q_block_idx) * (ulong) n_kv_blocks + (ulong) kv_block_idx] = res;
}
@@ -0,0 +1,94 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#ifdef cl_qcom_reqd_sub_group_size
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
// each work-item computes a 4 (rows of A / m) x 8 (cols of B / n) output tile.
#ifdef ADRENO_GPU
REQD_SUBGROUP_SIZE_128
#endif
kernel void kernel_gemm_noshuffle_q1_0_f32(
global const uint * src0_q,
global const half * src0_d,
read_only image1d_buffer_t src1,
global float * dst,
int k,
int m,
int n,
int n_no_padding,
ulong offsetd
) {
int n_4 = n >> 2;
int gy = get_global_id(0);
int gx = get_global_id(1);
int gx_2 = gx << 2;
dst = (global float *)((global char*)dst + offsetd);
half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0;
half8 B;
global const uint* wptr = src0_q + gx_2;
global const half* sptr = src0_d + gx_2;
// 32 weights per uint32, 128 weights (one block / one scale) per 4 uint32.
for (int i = 0; i < k; i += 32) {
uint4 pack4 = vload4(0, wptr + (i / 32) * m); // 4 rows, 32 K-values each
half4 scale = vload4(0, sptr + (i / 128) * m); // 4 rows, one scale per 128
for (int j = 0; j < 32; ++j) {
B.s0123 = read_imageh(src1, gy * 2 + (i + j) * n_4);
B.s4567 = read_imageh(src1, gy * 2 + (i + j) * n_4 + 1);
// sign bit -> +-1 (half arithmetic avoids unsigned underflow)
half4 wj = (half4)(
2.0h * (half)((pack4.s0 >> j) & 1u) - 1.0h,
2.0h * (half)((pack4.s1 >> j) & 1u) - 1.0h,
2.0h * (half)((pack4.s2 >> j) & 1u) - 1.0h,
2.0h * (half)((pack4.s3 >> j) & 1u) - 1.0h) * scale;
c0 += B * wj.s0;
c1 += B * wj.s1;
c2 += B * wj.s2;
c3 += B * wj.s3;
}
}
int idx = (gy << 3) * m + (gx << 2);
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx);
}
}
@@ -0,0 +1,121 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#ifdef cl_qcom_reqd_sub_group_size
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#endif
#define QK1_0 128
#define N_SIMDGROUP 4
#define dequantizeBlockAccum_q1(total, bits, scale, regB, lb) \
total += (2.0f*(float)((bits >> 0) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s0, lb+0); \
total += (2.0f*(float)((bits >> 1) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s1, lb+0); \
total += (2.0f*(float)((bits >> 2) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s2, lb+0); \
total += (2.0f*(float)((bits >> 3) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s3, lb+0); \
total += (2.0f*(float)((bits >> 4) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s4, lb+0); \
total += (2.0f*(float)((bits >> 5) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s5, lb+0); \
total += (2.0f*(float)((bits >> 6) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s6, lb+0); \
total += (2.0f*(float)((bits >> 7) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s7, lb+0); \
total += (2.0f*(float)((bits >> 8) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s0, lb+1); \
total += (2.0f*(float)((bits >> 9) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s1, lb+1); \
total += (2.0f*(float)((bits >> 10) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s2, lb+1); \
total += (2.0f*(float)((bits >> 11) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s3, lb+1); \
total += (2.0f*(float)((bits >> 12) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s4, lb+1); \
total += (2.0f*(float)((bits >> 13) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s5, lb+1); \
total += (2.0f*(float)((bits >> 14) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s6, lb+1); \
total += (2.0f*(float)((bits >> 15) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s7, lb+1); \
total += (2.0f*(float)((bits >> 16) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s0, lb+2); \
total += (2.0f*(float)((bits >> 17) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s1, lb+2); \
total += (2.0f*(float)((bits >> 18) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s2, lb+2); \
total += (2.0f*(float)((bits >> 19) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s3, lb+2); \
total += (2.0f*(float)((bits >> 20) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s4, lb+2); \
total += (2.0f*(float)((bits >> 21) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s5, lb+2); \
total += (2.0f*(float)((bits >> 22) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s6, lb+2); \
total += (2.0f*(float)((bits >> 23) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s7, lb+2); \
total += (2.0f*(float)((bits >> 24) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s0, lb+3); \
total += (2.0f*(float)((bits >> 25) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s1, lb+3); \
total += (2.0f*(float)((bits >> 26) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s2, lb+3); \
total += (2.0f*(float)((bits >> 27) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s3, lb+3); \
total += (2.0f*(float)((bits >> 28) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s4, lb+3); \
total += (2.0f*(float)((bits >> 29) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s5, lb+3); \
total += (2.0f*(float)((bits >> 30) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s6, lb+3); \
total += (2.0f*(float)((bits >> 31) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s7, lb+3);
#ifdef ADRENO_GPU
REQD_SUBGROUP_SIZE_64
#endif
__kernel void kernel_gemv_noshuffle_q1_0_f32(
read_only image1d_buffer_t src0_q,
global half * src0_d,
read_only image1d_buffer_t src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3)
{
uint groupId = get_local_id(1);
uint gid = get_global_id(0);
ushort slid = get_sub_group_local_id();
uint K = ne00;
uint M = ne01;
uint LINE_STRIDE_A = M;
uint BLOCK_STRIDE_A = 4 * M;
uint4 regA;
half regS;
float8 regB;
float totalSum = 0.0f;
#pragma unroll 1
for (uint kb = groupId; kb < (K / QK1_0); kb += N_SIMDGROUP) {
regS = src0_d[gid + kb * LINE_STRIDE_A]; // each fiber loads its row's scale
// first 16 fibers load 8 B values each -> 128 activations for this block
if (slid < 16) {
regB.s0123 = read_imagef(src1, (slid * 2 + kb * 32));
regB.s4567 = read_imagef(src1, (1 + slid * 2 + kb * 32));
}
// load this row's 4 uint32 (128 sign bits)
regA.s0 = read_imageui(src0_q, (gid + kb * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;
regA.s1 = read_imageui(src0_q, (gid + kb * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;
regA.s2 = read_imageui(src0_q, (gid + kb * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;
regA.s3 = read_imageui(src0_q, (gid + kb * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;
float scale = (float)regS;
dequantizeBlockAccum_q1(totalSum, regA.s0, scale, regB, 0);
dequantizeBlockAccum_q1(totalSum, regA.s1, scale, regB, 4);
dequantizeBlockAccum_q1(totalSum, regA.s2, scale, regB, 8);
dequantizeBlockAccum_q1(totalSum, regA.s3, scale, regB, 12);
}
// reduction in local memory, assumes #wave = N_SIMDGROUP = 4
local float reduceLM[SIMDGROUP_WIDTH * 3];
if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum;
if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum;
if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum;
barrier(CLK_LOCAL_MEM_FENCE);
if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
if (groupId == 0) {
dst = (global float*)((global char*)dst + offsetd);
dst[gid] = totalSum;
}
}
@@ -0,0 +1,156 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
// LOAD_VEC_A is 8 because one q1_0 quant byte expands to 8 weights along K.
#define LOAD_VEC_A 8
#define LOAD_VEC_B 4
#define BM 64
#define BN 64
#define BK 32
#define TM 4
#define TN 8
kernel void kernel_mul_mm_q1_0_f32_l4_lm(
global uchar * src0_q,
global half * src0_d,
global float4 * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne11,
int ne12,
int stride_a,
int stride_b,
int stride_d,
int batch_stride_a,
int batch_stride_b,
int batch_stride_d,
int r2,
int r3
) {
src1 = (global float4*)((global char*)src1 + offset1);
dst = (global float *)((global char*)dst + offsetd);
local float buf_a[BM * BK];
local float buf_b[BN * BK];
const int batch_idx = get_global_id(2);
const int i13 = batch_idx / ne12;
const int i12 = batch_idx % ne12;
const int i03 = i13 / r3;
const int i02 = i12 / r2;
const int batch_idx_a = i03 * ne02 + i02;
const int ir = get_group_id(0);
const int ic = get_group_id(1);
const int tid = get_local_id(0);
const int th_r = tid % (BM / TM);
const int th_c = tid / (BM / TM);
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
float sums[TM * TN];
float cache_a[TM];
float cache_b[TN];
for (int i = 0; i < TM * TN; i++) {
sums[i] = 0.0f;
}
for (int block = 0; block < ne00; block += BK) {
for (int l = 0; l < BM; l += loadstride_a) {
if (ir*BM + loadc_a + l < ne01) {
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
int ib = idx / 16; // 16 quant bytes per q1_0 block
float d = (float)src0_d[ib];
uint bits = src0_q[idx];
// use float to avoid unsigned underflow of (2*0 - 1).
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 0) & 1) - 1.0f);
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 1) & 1) - 1.0f);
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 2) & 1) - 1.0f);
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 3) & 1) - 1.0f);
buf_a[(loadr_a * LOAD_VEC_A + 4) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 4) & 1) - 1.0f);
buf_a[(loadr_a * LOAD_VEC_A + 5) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 5) & 1) - 1.0f);
buf_a[(loadr_a * LOAD_VEC_A + 6) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 6) & 1) - 1.0f);
buf_a[(loadr_a * LOAD_VEC_A + 7) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 7) & 1) - 1.0f);
} else {
for (int b = 0; b < LOAD_VEC_A; ++b) {
buf_a[(loadr_a * LOAD_VEC_A + b) * BM + loadc_a + l] = 0.0f;
}
}
}
for (int l = 0; l < BN; l += loadstride_b) {
if (ic*BN + loadc_b + l < ne11) {
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
} else {
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B;
for (int i = 0; i < BK; i++) {
for (int j = 0; j < TM; j++) {
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
}
for (int j = 0; j < TN; j++) {
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
}
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
const int sums_idx = cc*TM + cr;
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
}
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
const int dr = ir * BM + th_r * TM;
const int dc = ic * BN + th_c * TN;
const int offsets = batch_idx * batch_stride_d;
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
if (dr + cr < ne01 && dc + cc < ne11) {
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
}
}
}
}
@@ -0,0 +1,141 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define QK1_0 128
typedef struct {
half d;
uchar qs[QK1_0/8];
} block_q1_0;
#define NB_Q1_0 16
#ifdef INTEL_GPU
#define N_R0_Q1_0 4 // number of rows each subgroup works on
#define N_SG_Q1_0 2 // number of subgroups in a work group
#define N_SIMDWIDTH 16 // subgroup size
#elif defined (ADRENO_GPU)
#define N_R0_Q1_0 4
#define N_SG_Q1_0 2
#define N_SIMDWIDTH 64
#endif
inline float block_q_1_0_dot_y(global block_q1_0 * qb, float sumy, float yl[NB_Q1_0], short il) {
global uchar * qs = qb->qs + il*2;
uint b0 = qs[0];
uint b1 = qs[1];
float acc = 0.f;
acc += yl[ 0]*(float)((b0 >> 0) & 1) + yl[ 1]*(float)((b0 >> 1) & 1);
acc += yl[ 2]*(float)((b0 >> 2) & 1) + yl[ 3]*(float)((b0 >> 3) & 1);
acc += yl[ 4]*(float)((b0 >> 4) & 1) + yl[ 5]*(float)((b0 >> 5) & 1);
acc += yl[ 6]*(float)((b0 >> 6) & 1) + yl[ 7]*(float)((b0 >> 7) & 1);
acc += yl[ 8]*(float)((b1 >> 0) & 1) + yl[ 9]*(float)((b1 >> 1) & 1);
acc += yl[10]*(float)((b1 >> 2) & 1) + yl[11]*(float)((b1 >> 3) & 1);
acc += yl[12]*(float)((b1 >> 4) & 1) + yl[13]*(float)((b1 >> 5) & 1);
acc += yl[14]*(float)((b1 >> 6) & 1) + yl[15]*(float)((b1 >> 7) & 1);
return qb->d * (2.0f*acc - sumy);
}
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mv_q1_0_f32(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
int ne12,
ulong nb11,
ulong nb12,
ulong nb13,
int ne0,
int ne1,
int r2,
int r3
) {
src0 = (global char*)((global char*)src0 + offset0);
src1 = (global char*)((global char*)src1 + offset1);
dst = (global char*)((global char*)dst + offsetd);
int nb = ne00/QK1_0;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
int first_row = (r0*N_SG_Q1_0 + get_sub_group_id()) * N_R0_Q1_0;
uint i12 = im%ne12;
uint i13 = im/ne12;
ulong offset_src1 = r1*nb11 + i12*nb12 + i13*nb13;
global float * y = (global float *) (src1 + offset_src1);
// pointers to src0 rows
global block_q1_0 * ax[N_R0_Q1_0];
for (int row = 0; row < N_R0_Q1_0; ++row) {
ulong offset_src0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
ax[row] = (global block_q1_0 *) ((global char *) src0 + offset_src0);
}
float yl[NB_Q1_0];
float sumf[N_R0_Q1_0] = { 0.f };
const short ix = get_sub_group_local_id()/8;
const short il = get_sub_group_local_id()%8;
global float * yb = y + ix*QK1_0 + il*NB_Q1_0;
// each thread handles NB_Q1_0 quants at a time
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) {
float sumy = 0.f;
for (short i = 0; i < NB_Q1_0; ++i) {
yl[i] = yb[i];
sumy += yb[i];
}
for (short row = 0; row < N_R0_Q1_0; row++) {
sumf[row] += block_q_1_0_dot_y(ax[row] + ib, sumy, yl, il);
}
yb += N_SIMDWIDTH*NB_Q1_0;
}
global float * dst_f32 = (global float *) dst + (ulong)im*ne0*ne1 + (ulong)r1*ne0;
for (int row = 0; row < N_R0_Q1_0; ++row) {
float tot = sub_group_reduce_add(sumf[row]);
if (get_sub_group_local_id() == 0 && first_row + row < ne01) {
dst_f32[first_row + row] = tot;
}
}
}
@@ -0,0 +1,190 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define QK1_0 128
#define QK1_0_BYTES (QK1_0/8) // 16 quant bytes per block
#define QK1_0_BLK_BYTES (QK1_0_BYTES + 2) // d + qs in original tensor = 18
#define NB_Q1_0 16 // quants handled per thread (two qs bytes)
#ifdef INTEL_GPU
#define N_R0_Q1_0 4 // number of rows each subgroup works on
#define N_SG_Q1_0 2 // number of subgroups in a work group
#define N_SIMDWIDTH 16 // subgroup size
#elif defined (ADRENO_GPU)
#define N_R0_Q1_0 4
#define N_SG_Q1_0 2
#define N_SIMDWIDTH 64
#endif
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mv_q1_0_f32_flat(
global char * src0_q,
global half * src0_d,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
int ne12,
ulong nb11,
ulong nb12,
ulong nb13,
int ne0,
int ne1,
int r2,
int r3
) {
src1 = (global char*)((global char*)src1 + offset1);
dst = (global char*)((global char*)dst + offsetd);
int nb = ne00/QK1_0;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
int first_row = (r0*N_SG_Q1_0 + get_sub_group_id()) * N_R0_Q1_0;
uint i12 = im%ne12;
uint i13 = im/ne12;
ulong offset_src1 = r1*nb11 + i12*nb12 + i13*nb13;
global float * y = (global float *) (src1 + offset_src1);
// pointers to src0 rows (flat: q bytes + scales)
uint offset_src0_base = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
global uchar * ax0, * ax1, * ax2, * ax3;
global half * ad0, * ad1, * ad2, * ad3;
uint offset_src0;
offset_src0 = (offset_src0_base + 0*nb01) / QK1_0_BLK_BYTES;
ax0 = (global uchar *) ((global char *) src0_q + offset_src0*QK1_0_BYTES);
ad0 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
offset_src0 = (offset_src0_base + 1*nb01) / QK1_0_BLK_BYTES;
ax1 = (global uchar *) ((global char *) src0_q + offset_src0*QK1_0_BYTES);
ad1 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
offset_src0 = (offset_src0_base + 2*nb01) / QK1_0_BLK_BYTES;
ax2 = (global uchar *) ((global char *) src0_q + offset_src0*QK1_0_BYTES);
ad2 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
offset_src0 = (offset_src0_base + 3*nb01) / QK1_0_BLK_BYTES;
ax3 = (global uchar *) ((global char *) src0_q + offset_src0*QK1_0_BYTES);
ad3 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
const short ix = get_sub_group_local_id()/8;
const short il = get_sub_group_local_id()%8;
global float * yb = y + ix*QK1_0 + il*NB_Q1_0;
float8 yl_lo;
float8 yl_hi;
float4 sumf = 0.f;
// each thread handles NB_Q1_0 = 16 quants (two qs bytes) at a time
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) {
yl_lo = vload8(0, yb);
yl_hi = vload8(0, yb + 8);
float sumy = yl_lo.s0 + yl_lo.s1 + yl_lo.s2 + yl_lo.s3
+ yl_lo.s4 + yl_lo.s5 + yl_lo.s6 + yl_lo.s7
+ yl_hi.s0 + yl_hi.s1 + yl_hi.s2 + yl_hi.s3
+ yl_hi.s4 + yl_hi.s5 + yl_hi.s6 + yl_hi.s7;
uint b0, b1;
float acc;
b0 = ax0[ib*QK1_0_BYTES + il*2 + 0];
b1 = ax0[ib*QK1_0_BYTES + il*2 + 1];
acc = yl_lo.s0*(float)((b0 >> 0) & 1) + yl_lo.s1*(float)((b0 >> 1) & 1)
+ yl_lo.s2*(float)((b0 >> 2) & 1) + yl_lo.s3*(float)((b0 >> 3) & 1)
+ yl_lo.s4*(float)((b0 >> 4) & 1) + yl_lo.s5*(float)((b0 >> 5) & 1)
+ yl_lo.s6*(float)((b0 >> 6) & 1) + yl_lo.s7*(float)((b0 >> 7) & 1)
+ yl_hi.s0*(float)((b1 >> 0) & 1) + yl_hi.s1*(float)((b1 >> 1) & 1)
+ yl_hi.s2*(float)((b1 >> 2) & 1) + yl_hi.s3*(float)((b1 >> 3) & 1)
+ yl_hi.s4*(float)((b1 >> 4) & 1) + yl_hi.s5*(float)((b1 >> 5) & 1)
+ yl_hi.s6*(float)((b1 >> 6) & 1) + yl_hi.s7*(float)((b1 >> 7) & 1);
sumf.s0 += (float)ad0[ib] * (2.0f*acc - sumy);
b0 = ax1[ib*QK1_0_BYTES + il*2 + 0];
b1 = ax1[ib*QK1_0_BYTES + il*2 + 1];
acc = yl_lo.s0*(float)((b0 >> 0) & 1) + yl_lo.s1*(float)((b0 >> 1) & 1)
+ yl_lo.s2*(float)((b0 >> 2) & 1) + yl_lo.s3*(float)((b0 >> 3) & 1)
+ yl_lo.s4*(float)((b0 >> 4) & 1) + yl_lo.s5*(float)((b0 >> 5) & 1)
+ yl_lo.s6*(float)((b0 >> 6) & 1) + yl_lo.s7*(float)((b0 >> 7) & 1)
+ yl_hi.s0*(float)((b1 >> 0) & 1) + yl_hi.s1*(float)((b1 >> 1) & 1)
+ yl_hi.s2*(float)((b1 >> 2) & 1) + yl_hi.s3*(float)((b1 >> 3) & 1)
+ yl_hi.s4*(float)((b1 >> 4) & 1) + yl_hi.s5*(float)((b1 >> 5) & 1)
+ yl_hi.s6*(float)((b1 >> 6) & 1) + yl_hi.s7*(float)((b1 >> 7) & 1);
sumf.s1 += (float)ad1[ib] * (2.0f*acc - sumy);
b0 = ax2[ib*QK1_0_BYTES + il*2 + 0];
b1 = ax2[ib*QK1_0_BYTES + il*2 + 1];
acc = yl_lo.s0*(float)((b0 >> 0) & 1) + yl_lo.s1*(float)((b0 >> 1) & 1)
+ yl_lo.s2*(float)((b0 >> 2) & 1) + yl_lo.s3*(float)((b0 >> 3) & 1)
+ yl_lo.s4*(float)((b0 >> 4) & 1) + yl_lo.s5*(float)((b0 >> 5) & 1)
+ yl_lo.s6*(float)((b0 >> 6) & 1) + yl_lo.s7*(float)((b0 >> 7) & 1)
+ yl_hi.s0*(float)((b1 >> 0) & 1) + yl_hi.s1*(float)((b1 >> 1) & 1)
+ yl_hi.s2*(float)((b1 >> 2) & 1) + yl_hi.s3*(float)((b1 >> 3) & 1)
+ yl_hi.s4*(float)((b1 >> 4) & 1) + yl_hi.s5*(float)((b1 >> 5) & 1)
+ yl_hi.s6*(float)((b1 >> 6) & 1) + yl_hi.s7*(float)((b1 >> 7) & 1);
sumf.s2 += (float)ad2[ib] * (2.0f*acc - sumy);
b0 = ax3[ib*QK1_0_BYTES + il*2 + 0];
b1 = ax3[ib*QK1_0_BYTES + il*2 + 1];
acc = yl_lo.s0*(float)((b0 >> 0) & 1) + yl_lo.s1*(float)((b0 >> 1) & 1)
+ yl_lo.s2*(float)((b0 >> 2) & 1) + yl_lo.s3*(float)((b0 >> 3) & 1)
+ yl_lo.s4*(float)((b0 >> 4) & 1) + yl_lo.s5*(float)((b0 >> 5) & 1)
+ yl_lo.s6*(float)((b0 >> 6) & 1) + yl_lo.s7*(float)((b0 >> 7) & 1)
+ yl_hi.s0*(float)((b1 >> 0) & 1) + yl_hi.s1*(float)((b1 >> 1) & 1)
+ yl_hi.s2*(float)((b1 >> 2) & 1) + yl_hi.s3*(float)((b1 >> 3) & 1)
+ yl_hi.s4*(float)((b1 >> 4) & 1) + yl_hi.s5*(float)((b1 >> 5) & 1)
+ yl_hi.s6*(float)((b1 >> 6) & 1) + yl_hi.s7*(float)((b1 >> 7) & 1);
sumf.s3 += (float)ad3[ib] * (2.0f*acc - sumy);
yb += N_SIMDWIDTH*NB_Q1_0;
}
global float * dst_f32 = (global float *) dst + (ulong)im*ne0*ne1 + (ulong)r1*ne0;
float4 tot = (float4)(
sub_group_reduce_add(sumf.s0),
sub_group_reduce_add(sumf.s1),
sub_group_reduce_add(sumf.s2),
sub_group_reduce_add(sumf.s3)
);
if (get_sub_group_local_id() == 0) {
if (first_row + 0 < ne01) dst_f32[first_row + 0] = tot.s0;
if (first_row + 1 < ne01) dst_f32[first_row + 1] = tot.s1;
if (first_row + 2 < ne01) dst_f32[first_row + 2] = tot.s2;
if (first_row + 3 < ne01) dst_f32[first_row + 3] = tot.s3;
}
}
+500
View File
@@ -158,6 +158,239 @@ kernel void kernel_set_rows_f32_i32(
}
}
// f32 -> q8_0 quantize set_rows. Block = half d + char qs[32].
#define QK8_0 32
inline void quantize_q8_0_block(global float * x, global char * qs, global half * d_out) {
float amax = 0.0f;
for (int j = 0; j < QK8_0; j++) {
amax = fmax(amax, fabs(x[j]));
}
float d = amax / 127.0f;
float id = (d != 0.0f) ? 127.0f / amax : 0.0f;
vstore_half(d, 0, d_out);
for (int j = 0; j < QK8_0; j++) {
qs[j] = (char)((int)round(x[j] * id));
}
}
kernel void kernel_set_rows_q8_0_i64(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK8_0;
global char * y = dst_row + blk * (2 + QK8_0);
quantize_q8_0_block(x, y + 2, (global half *)y);
}
}
kernel void kernel_set_rows_q8_0_i32(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK8_0;
global char * y = dst_row + blk * (2 + QK8_0);
quantize_q8_0_block(x, y + 2, (global half *)y);
}
}
// SoA q8_0 variants. dst_q: int8[QK8_0] per block; dst_d: fp16 scale per block.
// Layout matches kernel_convert_block_q8_0; block index follows dst element order.
kernel void kernel_set_rows_q8_0_soa_i64(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst_q,
ulong offset_q,
global char * dst_d,
ulong offset_d,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
int ne1_dst,
int ne2_dst,
int ne3_dst
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst_q = dst_q + offset_q;
dst_d = dst_d + offset_d;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0;
global half * d_row = (global half *)(dst_d) + row_blk_base;
global char * q_row = (global char *)(dst_q) + row_blk_base * QK8_0;
global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK8_0;
global char * q = q_row + blk * QK8_0;
quantize_q8_0_block(x, q, d_row + blk);
}
}
kernel void kernel_set_rows_q8_0_soa_i32(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst_q,
ulong offset_q,
global char * dst_d,
ulong offset_d,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
int ne1_dst,
int ne2_dst,
int ne3_dst
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst_q = dst_q + offset_q;
dst_d = dst_d + offset_d;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0;
global half * d_row = (global half *)(dst_d) + row_blk_base;
global char * q_row = (global char *)(dst_q) + row_blk_base * QK8_0;
global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK8_0;
global char * q = q_row + blk * QK8_0;
quantize_q8_0_block(x, q, d_row + blk);
}
}
kernel void kernel_set_rows_f16_i32(
global char * src0,
ulong offset0,
@@ -206,3 +439,270 @@ kernel void kernel_set_rows_f16_i32(
dst_row[ind] = src_row[ind];
}
}
// f32 -> q4_0 quantize set_rows. Block = half d + uchar qs[16] (shuffled
// nibbles: qs[j] low/high = elem j / j+16).
// Dequant: val[i] = d * (nibble_i - 8)
// nblk0 = number of q4_0 blocks per row = ne00 / 32.
#define QK4_0 32
#define Q4_0_BLOCK_SIZE 18
inline void quantize_q4_0_block(global float * x, global uchar * qs, global half * d_out) {
// Find the signed value with the largest absolute magnitude (matches ggml ref).
float max = 0.0f;
float amax = 0.0f;
for (int j = 0; j < QK4_0; j++) {
float v = x[j];
float a = fabs(v);
if (a > amax) {
amax = a;
max = v;
}
}
float d = max / -8.0f;
float id = (d != 0.0f) ? 1.0f / d : 0.0f;
vstore_half(d, 0, d_out);
for (int j = 0; j < QK4_0/2; j++) {
float x0 = x[j] * id;
float x1 = x[j + QK4_0/2] * id;
int i0 = (int)(x0 + 8.5f);
int i1 = (int)(x1 + 8.5f);
if (i0 < 0) i0 = 0;
if (i0 > 15) i0 = 15;
if (i1 < 0) i1 = 0;
if (i1 > 15) i1 = 15;
qs[j] = (uchar)i0 | ((uchar)i1 << 4);
}
}
kernel void kernel_set_rows_q4_0_i64(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK4_0;
global char * y = dst_row + blk * Q4_0_BLOCK_SIZE;
global half * yd = (global half *)(y);
global uchar * yqs = (global uchar *)(y + 2);
quantize_q4_0_block(x, yqs, yd);
}
}
kernel void kernel_set_rows_q4_0_i32(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK4_0;
global char * y = dst_row + blk * Q4_0_BLOCK_SIZE;
global half * yd = (global half *)(y);
global uchar * yqs = (global uchar *)(y + 2);
quantize_q4_0_block(x, yqs, yd);
}
}
// SoA variants for q4_0 dst. Used when the backend has split block_q4_0 records
// into separate quant (dst_q) and scale (dst_d) sub-buffers same pattern as
// the q8_0 SoA variants above.
//
// Layout (matches kernel_convert_block_q4_0, the "shuffled" variant):
// dst_q: contiguous 16 packed nibbles per block, block i at offset i * 16 bytes.
// dst_d: contiguous fp16 scales, block i at offset i * 2 bytes.
// Nibble layout inside each byte is unchanged from AoS: qs[j] low nibble = element j,
// qs[j] high nibble = element j+16. kernel_restore_block_q4_0 copies bytes as-is.
kernel void kernel_set_rows_q4_0_soa_i64(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst_q,
ulong offset_q,
global char * dst_d,
ulong offset_d,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
int ne1_dst,
int ne2_dst,
int ne3_dst
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst_q = dst_q + offset_q;
dst_d = dst_d + offset_d;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0;
global half * d_row = (global half *)(dst_d) + row_blk_base;
global uchar * q_row = (global uchar *)(dst_q) + row_blk_base * (QK4_0/2);
global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK4_0;
global uchar * qs = q_row + blk * (QK4_0/2);
global half * d_bk = d_row + blk;
quantize_q4_0_block(x, qs, d_bk);
}
}
kernel void kernel_set_rows_q4_0_soa_i32(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst_q,
ulong offset_q,
global char * dst_d,
ulong offset_d,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
int ne1_dst,
int ne2_dst,
int ne3_dst
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst_q = dst_q + offset_q;
dst_d = dst_d + offset_d;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0;
global half * d_row = (global half *)(dst_d) + row_blk_base;
global uchar * q_row = (global uchar *)(dst_q) + row_blk_base * (QK4_0/2);
global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK4_0;
global uchar * qs = q_row + blk * (QK4_0/2);
global half * d_bk = d_row + blk;
quantize_q4_0_block(x, qs, d_bk);
}
}
+79
View File
@@ -0,0 +1,79 @@
#pragma once
#ifdef _WIN32
# define WIN32_LEAN_AND_MEAN
# ifndef NOMINMAX
# define NOMINMAX
# endif
# include <windows.h>
# include <winevt.h>
#else
# include <dlfcn.h>
# include <unistd.h>
#endif
#include <filesystem>
namespace fs = std::filesystem;
#ifdef _WIN32
using dl_handle = std::remove_pointer_t<HMODULE>;
struct dl_handle_deleter {
void operator()(HMODULE handle) {
FreeLibrary(handle);
}
};
static inline dl_handle * dl_load_library(const fs::path & path) {
// suppress error dialogs for missing DLLs
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
HMODULE handle = LoadLibraryW(path.wstring().c_str());
SetErrorMode(old_mode);
return handle;
}
static inline void * dl_get_sym(dl_handle * handle, const char * name) {
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
void * p = (void *) GetProcAddress(handle, name);
SetErrorMode(old_mode);
return p;
}
static inline const char * dl_error() {
return "";
}
#else
using dl_handle = void;
struct dl_handle_deleter {
void operator()(void * handle) {
dlclose(handle);
}
};
static inline dl_handle * dl_load_library(const fs::path & path) {
dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
return handle;
}
static inline void * dl_get_sym(dl_handle * handle, const char * name) {
return dlsym(handle, name);
}
static inline const char * dl_error() {
const char *rslt = dlerror();
return rslt != nullptr ? rslt : "";
}
#endif
+4
View File
@@ -1053,6 +1053,10 @@ static bool is_op_unsupported_case(const ggml_tensor * op) {
(op->ne[0] == 2 && op->ne[1] == 4 && op->ne[2] == 3 && op->ne[3] == 2)) {
return true;
}
// CPY into a strided view of a larger buffer (recurrent-state snapshots) not supported
if (op->view_src && ggml_nbytes(op) != ggml_nbytes(op->view_src)) {
return true;
}
break;
}
case GGML_OP_MUL_MAT: {
+153 -39
View File
@@ -1907,6 +1907,38 @@ static bool vk_enable_sync_logger = false;
static uint32_t vk_perf_logger_frequency = 1;
static std::string vk_pipeline_stats_filter;
static uint64_t ggml_vk_get_node_flops(const ggml_tensor * node) {
if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
const uint64_t m = node->ne[0];
const uint64_t n = node->ne[1];
const uint64_t k = node->src[1]->ne[0];
const uint64_t batch = node->ne[2] * node->ne[3];
return m * n * (k + (k - 1)) * batch;
}
if (node->op == GGML_OP_CONV_2D || node->op == GGML_OP_CONV_TRANSPOSE_2D) {
const ggml_tensor * knl = node->src[0];
const uint64_t Cout = node->ne[2];
const uint64_t size_K = node->src[1]->ne[2] * knl->ne[0] * knl->ne[1];
const uint64_t size_N = node->ne[3] * node->ne[0] * node->ne[1];
return Cout * size_N * (size_K + (size_K - 1));
}
if (node->op == GGML_OP_CONV_3D) {
const ggml_tensor * knl = node->src[0];
const uint64_t OC = ggml_get_op_params_i32(node, 11);
const uint64_t IC = ggml_get_op_params_i32(node, 9);
const uint64_t size_K = IC * knl->ne[0] * knl->ne[1] * knl->ne[2];
const uint64_t size_N = node->ne[3] / OC * node->ne[0] * node->ne[1] * node->ne[2];
return OC * size_N * (size_K + (size_K - 1));
}
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
const ggml_tensor * q = node->src[0];
const ggml_tensor * k = node->src[1];
const ggml_tensor * v = node->src[2];
return 2ull * q->ne[1] * q->ne[2] * (k->ne[0] + v->ne[0]) * k->ne[1] * q->ne[3];
}
return 0;
}
class vk_perf_logger {
public:
void print_timings(bool force = false) {
@@ -1955,7 +1987,7 @@ class vk_perf_logger {
}
std::string get_node_fusion_name(const ggml_tensor * node, const char *fusion_name, uint64_t *n_flops) {
*n_flops = 0;
*n_flops = ggml_vk_get_node_flops(node);
std::string fusion_str;
if (fusion_name) {
fusion_str = fusion_name + std::string(" ");
@@ -1982,35 +2014,22 @@ class vk_perf_logger {
if (batch > 1) {
name += " batch=" + std::to_string(batch);
}
name = fusion_str + name;
*n_flops = m * n * (k + (k - 1)) * batch;
return name;
return fusion_str + name;
}
if (node->op == GGML_OP_CONV_2D || node->op == GGML_OP_CONV_TRANSPOSE_2D) {
std::string name = ggml_op_name(node->op);
ggml_tensor * knl = node->src[0];
uint64_t OW = node->ne[0];
uint64_t OH = node->ne[1];
uint64_t N = node->ne[3];
const ggml_tensor * knl = node->src[0];
uint64_t Cout = node->ne[2];
uint64_t KW = knl->ne[0];
uint64_t KH = knl->ne[1];
uint64_t Cin = node->src[1]->ne[2];
// KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ
uint64_t size_M = Cout;
uint64_t size_K = Cin * KW * KH;
uint64_t size_N = N * OW * OH;
*n_flops = size_M * size_N * (size_K + (size_K - 1));
name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) +
uint64_t size_K = node->src[1]->ne[2] * knl->ne[0] * knl->ne[1];
uint64_t size_N = node->ne[3] * node->ne[0] * node->ne[1];
name += " M=Cout=" + std::to_string(Cout) + ", K=Cin*KW*KH=" + std::to_string(size_K) +
", N=N*OW*OH=" + std::to_string(size_N);
name = fusion_str + name;
return name;
return fusion_str + name;
}
if (node->op == GGML_OP_RMS_NORM) {
std::string name = ggml_op_name(node->op);
name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")";
name = fusion_str + name;
return name;
return fusion_str + name;
}
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
const ggml_tensor * dst = node;
@@ -2026,7 +2045,6 @@ class vk_perf_logger {
" k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " <<
" v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " <<
" m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")";
*n_flops = 2ull * q->ne[1] * q->ne[2] * (k->ne[0] + v->ne[0]) * k->ne[1] * q->ne[3];
return name.str();
}
if (node->op == GGML_OP_TOP_K) {
@@ -2090,7 +2108,7 @@ struct ggml_backend_vk_context {
bool do_add_rms_partials_offset_calculation;
bool do_add_rms_partials;
uint64_t last_total_mul_mat_bytes {};
uint64_t last_total_flops {UINT64_MAX};
// Cache most recent tensor that was converted into prealloc_y, and what pipeline it used to convert.
vk_pipeline_struct * prealloc_y_last_pipeline_used {};
@@ -2457,6 +2475,85 @@ static bool ggml_vk_strip_decode_vector(const uint32_t * code, size_t word_count
return true;
}
// Remove the loop unrolling hint of the matmul shader's BK loop
// and replace it with the dont_unroll hint for better performance on
// hardware like Apple M1/M2.
// Assumes 1. code comes from mul_mm.comp 2. the K-tile loop has no loop
// control hint and 3. the BK loop is the last loop nested directly inside
// the K-tile loop.
// Returns true when the input was modified; returns false otherwise
// without touching `out`.
static bool ggml_vk_roll_bk_loop(const uint32_t * code, size_t word_count, std::vector<uint32_t> & out) {
if (word_count < 5) {
return false;
}
struct vk_spv_loop {
size_t header;
size_t end;
uint32_t control;
};
std::vector<vk_spv_loop> loops;
// Collect a list of all loops in the module.
for (size_t pos = 5; pos < word_count; ) {
const uint32_t wc = code[pos] >> spv::WordCountShift;
const uint32_t op = code[pos] & spv::OpCodeMask;
if (wc == 0 || pos + wc > word_count) {
return false;
}
if (op == spv::OpLoopMerge && wc >= 4) { loops.push_back({ pos, 0, code[pos + 3] }); }
if (op == spv::OpLabel && wc >= 2) {
for (auto & l : loops) {
if (l.end == 0 && code[l.header + 1] == code[pos + 1]) { l.end = pos; }
}
}
pos += wc;
}
auto encloses = [](const vk_spv_loop & a, const vk_spv_loop & b) {
return a.header < b.header && b.header < a.end;
};
// Find the BK loop.
const vk_spv_loop * bk = nullptr;
for (const auto & h : loops) {
if (h.control != spv::LoopControlUnrollMask) {
continue;
}
const vk_spv_loop * parent = nullptr;
bool has_child = false;
for (const auto & g : loops) {
if (encloses(g, h) && (!parent || g.header > parent->header)) {
parent = &g;
}
if (encloses(h, g)) {
has_child = true;
}
}
// BK loop should be the last loop nested inside the loop with no hint
// and have at least one child loop.
if (parent &&
parent->control == spv::LoopControlMaskNone &&
has_child &&
(!bk || h.header > bk->header)) {
bk = &h;
}
}
if (!bk) {
return false;
}
// set DontUnroll instead of Unroll
out.assign(code, code + word_count);
out[bk->header + 3] = spv::LoopControlDontUnrollMask;
return true;
}
static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, size_t spv_size, const void* spv_data, const std::string entrypoint,
uint32_t parameter_count, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants,
bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) {
@@ -2540,6 +2637,22 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
}
#endif
#if VK_HEADER_VERSION >= 287
// Roll the mul_mm BK loop on Asahi Linux. Skip bf16 and the mul_mmq pipelines.
if (device->driver_id == vk::DriverId::eMesaHoneykrisp &&
pipeline->name.rfind("matmul", 0) == 0 &&
pipeline->name.find("bf16") == std::string::npos &&
pipeline->name.find("q8_1") == std::string::npos) {
const uint32_t * src = spirv.empty() ? reinterpret_cast<const uint32_t *>(spv_data) : spirv.data();
size_t src_n = spirv.empty() ? spv_size / sizeof(uint32_t) : spirv.size();
std::vector<uint32_t> rolled;
if (ggml_vk_roll_bk_loop(src, src_n, rolled)) {
spirv = std::move(rolled);
shader_module_create_info = vk::ShaderModuleCreateInfo({}, spirv.size() * sizeof(uint32_t), spirv.data());
}
}
#endif
pipeline->shader_module = device->device.createShaderModule(shader_module_create_info);
vk::PushConstantRange pcr(
@@ -16188,22 +16301,23 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
}
// Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
// Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
// (and scaled down based on model size, so smaller models submit earlier).
int submitted_nodes = 0;
int submit_count = 0;
uint64_t mul_mat_bytes = 0;
uint64_t total_mul_mat_bytes = 0;
uint64_t mul_mat_bytes_per_submit = std::min(uint64_t(100*1000*1000), ctx->last_total_mul_mat_bytes / 40u);
// Estimate the amount of compute work using flops, and submit every 200 GFLOP
// (and scaled down based on total graph flops, so smaller models submit earlier).
// Also submit at least every 100 nodes, in case there are workloads without heavy compute.
uint32_t submitted_nodes = 0;
uint32_t submit_count = 0;
uint64_t batch_flops = 0;
uint64_t total_flops = 0;
uint64_t flops_per_submit = std::min(uint64_t(200'000'000'000), ctx->last_total_flops / 40u);
for (int i = 0; i < cgraph->n_nodes; i++) {
if (first_node_in_batch) {
submit_node_idx = i;
}
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
auto bytes = ggml_nbytes(cgraph->nodes[i]->src[0]);
mul_mat_bytes += bytes;
total_mul_mat_bytes += bytes;
{
auto node_flops = ggml_vk_get_node_flops(cgraph->nodes[i]);
batch_flops += node_flops;
total_flops += node_flops;
}
// op_srcs_fused_elementwise indicates whether an op's srcs all contribute to
@@ -16415,8 +16529,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
bool submit = ((uint32_t)submitted_nodes >= ctx->device->max_nodes_per_submit) ||
(mul_mat_bytes_per_submit != 0 && mul_mat_bytes >= mul_mat_bytes_per_submit) ||
bool submit = (submitted_nodes >= ctx->device->max_nodes_per_submit) ||
(flops_per_submit != 0 && batch_flops >= flops_per_submit) ||
(i + ctx->num_additional_fused_ops >= last_node) ||
(almost_ready && !ctx->almost_ready_fence_pending);
@@ -16450,9 +16564,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
if (submit && enqueued) {
first_node_in_batch = true;
submitted_nodes = 0;
mul_mat_bytes = 0;
batch_flops = 0;
if (submit_count < 3) {
mul_mat_bytes_per_submit *= 2;
flops_per_submit *= 2;
}
submit_count++;
}
@@ -16461,7 +16575,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->fused_ops_write_mask = 0;
}
ctx->last_total_mul_mat_bytes = total_mul_mat_bytes;
ctx->last_total_flops = total_flops;
if (vk_perf_logger_enabled) {
// End the command buffer and submit/wait
@@ -1563,6 +1563,7 @@ class ggml_webgpu_shader_lib {
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
{
// Quantized types using u32 buffers for portability.
defines.push_back("SRC_TYPE=u32");
@@ -1593,6 +1594,8 @@ class ggml_webgpu_shader_lib {
} else if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
key.src_type == GGML_TYPE_IQ4_NL || key.src_type == GGML_TYPE_MXFP4) {
defines.push_back("BLOCK_SIZE=32u");
} else if (key.src_type == GGML_TYPE_NVFP4) {
defines.push_back("BLOCK_SIZE=64u");
} else if (key.src_type >= GGML_TYPE_Q2_K) {
defines.push_back("BLOCK_SIZE=256u");
} else {
@@ -1960,6 +1963,7 @@ class ggml_webgpu_shader_lib {
defines.push_back(type_upper + "_TABLES");
break;
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
defines.push_back(type_upper + "_LUT");
break;
default:
@@ -2103,6 +2107,7 @@ class ggml_webgpu_shader_lib {
defines.push_back(type_upper + "_TABLES");
break;
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
defines.push_back(type_upper + "_LUT");
break;
default:
@@ -2274,6 +2279,7 @@ class ggml_webgpu_shader_lib {
defines.push_back(type_upper + "_TABLES");
break;
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
defines.push_back(type_upper + "_LUT");
break;
default:
@@ -2394,6 +2400,7 @@ class ggml_webgpu_shader_lib {
defines.push_back(type_upper + "_TABLES");
break;
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
defines.push_back(type_upper + "_LUT");
break;
default:
+3
View File
@@ -4056,6 +4056,7 @@ static bool ggml_webgpu_supported_qtype(ggml_type type) {
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
return true;
default:
return false;
@@ -4156,6 +4157,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
supports_op = true;
break;
default:
@@ -4196,6 +4198,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
supports_op = true;
break;
default:
@@ -896,9 +896,23 @@ const kvalues_iq4nl = array<i32, 16>(
#endif
#ifdef MXFP4_LUT
#if defined(MXFP4_LUT) || defined(NVFP4_LUT)
const kvalues_mxfp4 = array<i32, 16>(
0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12
);
#endif
#endif // MXFP4_LUT || NVFP4_LUT
#ifdef NVFP4_LUT
fn ue4m3_to_fp32(u: u32) -> f32 {
if (u == 0u || u == 127u) {
return 0.0;
}
let exp = (u >> 3u) & 15u;
let man = u & 7u;
if (exp == 0u) {
return f32(man) * (1.0 / 512.0);
}
let bits = ((exp + 120u) << 23u) | (man << 20u);
return bitcast<f32>(bits);
}
#endif // NVFP4_LUT
@@ -672,6 +672,27 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
#endif
#ifdef NVFP4
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 36;
let d_word = load_u32_at_src(block_byte_base);
for (var sub: u32 = 0u; sub < 4; sub++) {
let d = ue4m3_to_fp32(get_byte(d_word, sub)) * 0.5;
for (var j: u32 = 0u; j < 2; j++) {
let q_packed = load_u32_at_src(block_byte_base + 4 + sub * 8 + j * 4);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_lo = f32(kvalues_mxfp4[q_byte & 0xFu]) * d;
let q_hi = f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * d;
let dst_offset = dst_base + offset * 64 + sub * 16 + j * 4 + k;
dst[dst_offset] = q_lo;
dst[dst_offset + 8u] = q_hi;
}
}
}
}
#endif
@group(0) @binding(0)
var<storage, read_write> src: array<SRC_TYPE>;
@@ -241,7 +241,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#endif // INIT_SRC0_SHMEM_Q8_1
#if defined(INIT_SRC0_SHMEM_MXFP4)
let block_byte_base = src0_idx * 17u;
let block_byte_base = src0_idx * 17u; // BLOCK_SIZE_BYTES = 17u;
let eu8 = get_byte(load_u32_at_src0_aligned(block_byte_base), block_byte_base & 3u);
let e = ldexp(1.0, i32(eu8) - 128);
@@ -263,6 +263,47 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
}
#endif // legacy-quants
#if defined(INIT_SRC0_SHMEM_NVFP4)
const BLOCK_SIZE = 64u;
const BLOCK_SIZE_BYTES = 36u;
const SUB_BLOCK_SIZE = 16u; // elements sharing one UE4M3 scale
const NQ = 16u;
const BYTES_PER_THREAD = 8u;
const BYTES_PER_INNER_LOOP = 4u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let tile_m = i / TILE_K;
let tile_k_start = i % TILE_K;
let global_m = offset_m + tile_m;
let global_k_start = k_outer + tile_k_start;
if (global_m >= params.m) {
break;
}
let block_k = global_k_start / BLOCK_SIZE;
let sub_block = (global_k_start % BLOCK_SIZE) / SUB_BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d_byte_base = block_byte_base;
let qs_byte_base = block_byte_base + 4u;
let d = ue4m3_to_fp32(get_byte(load_u32_at_src0_aligned(d_byte_base), sub_block)) * 0.5;
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j++) {
let q_packed = load_u32_at_src0_aligned(qs_byte_base + sub_block * 8u + j * 4u);
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
let q_byte = get_byte(q_packed, k);
shmem[i + j * BYTES_PER_INNER_LOOP + k] = f16(f32(kvalues_mxfp4[q_byte & 0xF]) * d);
shmem[i + j * BYTES_PER_INNER_LOOP + k + 8u] = f16(f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * d);
}
}
}
}
#endif // INIT_SRC0_SHMEM_NVFP4
// k-quants
#if defined(INIT_SRC0_SHMEM_Q2_K) || defined(INIT_SRC0_SHMEM_Q3_K) || defined(INIT_SRC0_SHMEM_Q4_K) || defined(INIT_SRC0_SHMEM_Q5_K) || defined(INIT_SRC0_SHMEM_Q6_K)
const BLOCK_SIZE = 256u;
@@ -1505,3 +1505,49 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src
return acc;
}
#endif
#ifdef MUL_ACC_NVFP4
#define BLOCK_SIZE 64
#define BLOCK_SIZE_BYTES 36
#define THREADS_PER_BLOCK 4
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<array<f32, OUTPUTS_PER_WG>, NUM_COLS> {
var acc: array<array<f32, OUTPUTS_PER_WG>, NUM_COLS>;
let num_blocks = params.k / BLOCK_SIZE;
let sub = thread_id % THREADS_PER_BLOCK;
for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) {
let x_base = src1_idx_base + block * BLOCK_SIZE + sub * ELEMS_PER_THREAD;
var x_block: array<array<f32, ELEMS_PER_THREAD>, NUM_COLS>;
for (var col = 0u; col < NUM_COLS;col += 1) {
for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) {
x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]);
x_block[col][i + 8] = f32(src1[x_base + col * params.stride_11 + i + 8]);
}
}
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
let d = ue4m3_to_fp32(get_byte(load_u32_at_src0_aligned(block_byte_base), sub)) * 0.5;
let q_w0 = load_u32_at_src0_aligned(block_byte_base + 4u + 8u * sub);
let q_w1 = load_u32_at_src0_aligned(block_byte_base + 8u + 8u * sub);
for (var col = 0u;col < NUM_COLS;col += 1) {
var row_sum = 0.0;
for (var l = 0u; l < 8u; l++) {
let q_word = select(q_w0, q_w1, l >= 4u);
let q_byte = get_byte(q_word, l % 4u);
let q_lo = f32(kvalues_mxfp4[q_byte & 0xFu]) * d;
let q_hi = f32(kvalues_mxfp4[(q_byte >> 4u) & 0xFu]) * d;
row_sum += q_lo * x_block[col][l];
row_sum += q_hi * x_block[col][l + 8u];
}
acc[col][row] += row_sum;
}
}
}
}
return acc;
}
#endif
+121 -2
View File
@@ -145,6 +145,7 @@ class Keys:
TOKEN_SHIFT_COUNT = "{arch}.token_shift_count"
INTERLEAVE_MOE_LAYER_STEP = "{arch}.interleave_moe_layer_step"
FULL_ATTENTION_INTERVAL = "{arch}.full_attention_interval"
HASH_LAYER_COUNT = "{arch}.hash_layer_count"
ACTIVATION_SPARSITY_SCALE = "{arch}.activation_sparsity_scale"
ALTUP_ACTIVE_IDX = "{arch}.altup.active_idx"
ALTUP_NUM_INPUTS = "{arch}.altup.num_inputs"
@@ -156,6 +157,7 @@ class Keys:
DENSE_FEAT_OUT_SIZE = "{arch}.{dense}_feat_out"
TARGET_LAYERS = "{arch}.target_layers"
TARGET_HIDDEN_SIZE = "{arch}.target_hidden_size"
BLOCK_SIZE = "{arch}.block_size"
NORM_BEFORE_RESIDUAL = "{arch}.norm_before_residual"
class Attention:
@@ -179,8 +181,12 @@ class Keys:
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
SLIDING_WINDOW = "{arch}.attention.sliding_window"
SCALE = "{arch}.attention.scale"
OUTPUT_GROUP_COUNT = "{arch}.attention.output_group_count"
OUTPUT_LORA_RANK = "{arch}.attention.output_lora_rank"
OUTPUT_SCALE = "{arch}.attention.output_scale"
VALUE_SCALE = "{arch}.attention.value_scale"
COMPRESS_RATIOS = "{arch}.attention.compress_ratios"
COMPRESS_ROPE_FREQ_BASE = "{arch}.attention.compress_rope_freq_base"
TEMPERATURE_LENGTH = "{arch}.attention.temperature_length"
KEY_LENGTH_MLA = "{arch}.attention.key_length_mla"
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
@@ -195,6 +201,11 @@ class Keys:
KEY_LENGTH = "{arch}.attention.indexer.key_length"
TOP_K = "{arch}.attention.indexer.top_k"
class HyperConnection:
COUNT = "{arch}.hyper_connection.count"
SINKHORN_ITERATIONS = "{arch}.hyper_connection.sinkhorn_iterations"
EPSILON = "{arch}.hyper_connection.epsilon"
class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
DIMENSION_COUNT_SWA = "{arch}.rope.dimension_count_swa"
@@ -469,6 +480,7 @@ class MODEL_ARCH(IntEnum):
DEEPSEEK2 = auto()
DEEPSEEK2OCR = auto()
DEEPSEEK32 = auto()
DEEPSEEK4 = auto()
CHATGLM = auto()
GLM4 = auto()
GLM4_MOE = auto()
@@ -517,6 +529,7 @@ class MODEL_ARCH(IntEnum):
PANGU_EMBED = auto()
MISTRAL3 = auto()
EAGLE3 = auto()
DFLASH = auto()
MISTRAL4 = auto()
PADDLEOCR = auto()
MIMO2 = auto()
@@ -553,6 +566,9 @@ class MODEL_TENSOR(IntEnum):
DENSE_2_OUT = auto() # embeddinggemma 2_Dense
DENSE_3_OUT = auto() # embeddinggemma 3_Dense
OUTPUT_NORM = auto()
HC_HEAD_FN = auto()
HC_HEAD_BASE = auto()
HC_HEAD_SCALE = auto()
ROPE_FREQS = auto()
ROPE_FACTORS_LONG = auto()
ROPE_FACTORS_SHORT = auto()
@@ -592,6 +608,7 @@ class MODEL_TENSOR(IntEnum):
FFN_DOWN_CHEXP = auto()
FFN_UP_CHEXP = auto()
FFN_EXP_PROBS_B = auto()
FFN_GATE_TID2EID = auto()
MOE_LATENT_DOWN = auto() # nemotron 3 super
MOE_LATENT_UP = auto() # nemotron 3 super
ATTN_Q_NORM = auto()
@@ -679,6 +696,20 @@ class MODEL_TENSOR(IntEnum):
ATTN_V_B = auto()
ATTN_Q_A_NORM = auto()
ATTN_KV_A_NORM = auto()
ATTN_KV = auto()
ATTN_KV_NORM = auto()
ATTN_OUT_A = auto()
ATTN_OUT_B = auto()
HC_ATTN_FN = auto()
HC_ATTN_BASE = auto()
HC_ATTN_SCALE = auto()
HC_FFN_FN = auto()
HC_FFN_BASE = auto()
HC_FFN_SCALE = auto()
ATTN_COMPRESSOR_WKV = auto()
ATTN_COMPRESSOR_WGATE = auto()
ATTN_COMPRESSOR_APE = auto()
ATTN_COMPRESSOR_NORM = auto()
FFN_SUB_NORM = auto()
ATTN_SUB_NORM = auto()
DEC_ATTN_NORM = auto()
@@ -740,6 +771,10 @@ class MODEL_TENSOR(IntEnum):
INDEXER_PROJ = auto()
INDEXER_ATTN_K = auto()
INDEXER_ATTN_Q_B = auto()
INDEXER_COMPRESSOR_WKV = auto()
INDEXER_COMPRESSOR_WGATE = auto()
INDEXER_COMPRESSOR_APE = auto()
INDEXER_COMPRESSOR_NORM = auto()
# vision
V_MMPROJ = auto()
V_MMPROJ_FC = auto()
@@ -1025,6 +1060,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.DEEPSEEK2: "deepseek2",
MODEL_ARCH.DEEPSEEK2OCR: "deepseek2-ocr",
MODEL_ARCH.DEEPSEEK32: "deepseek32",
MODEL_ARCH.DEEPSEEK4: "deepseek4",
MODEL_ARCH.CHATGLM: "chatglm",
MODEL_ARCH.GLM4: "glm4",
MODEL_ARCH.GLM4_MOE: "glm4moe",
@@ -1074,6 +1110,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
MODEL_ARCH.MISTRAL3: "mistral3",
MODEL_ARCH.EAGLE3: "eagle3",
MODEL_ARCH.DFLASH: "dflash",
MODEL_ARCH.MISTRAL4: "mistral4",
MODEL_ARCH.PADDLEOCR: "paddleocr",
MODEL_ARCH.MIMO2: "mimo2",
@@ -1108,6 +1145,9 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.OUTPUT: "output",
MODEL_TENSOR.DENSE_2_OUT: "dense_2", # embeddinggemma 2_Dense
MODEL_TENSOR.DENSE_3_OUT: "dense_3", # embeddinggemma 2_Dense
MODEL_TENSOR.HC_HEAD_FN: "output_hc_fn",
MODEL_TENSOR.HC_HEAD_BASE: "output_hc_base",
MODEL_TENSOR.HC_HEAD_SCALE: "output_hc_scale",
MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long",
MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short",
@@ -1149,6 +1189,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
MODEL_TENSOR.FFN_GATE_UP_EXP: "blk.{bid}.ffn_gate_up_exps",
MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b",
MODEL_TENSOR.FFN_GATE_TID2EID: "blk.{bid}.ffn_gate_tid2eid",
MODEL_TENSOR.MOE_LATENT_DOWN: "blk.{bid}.ffn_latent_down", # nemotron 3 super
MODEL_TENSOR.MOE_LATENT_UP: "blk.{bid}.ffn_latent_up", # nemotron 3 super
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
@@ -1234,6 +1275,20 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b",
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
MODEL_TENSOR.ATTN_KV: "blk.{bid}.attn_kv",
MODEL_TENSOR.ATTN_KV_NORM: "blk.{bid}.attn_kv_a_norm",
MODEL_TENSOR.ATTN_OUT_A: "blk.{bid}.attn_output_a",
MODEL_TENSOR.ATTN_OUT_B: "blk.{bid}.attn_output_b",
MODEL_TENSOR.HC_ATTN_FN: "blk.{bid}.hc_attn_fn",
MODEL_TENSOR.HC_ATTN_BASE: "blk.{bid}.hc_attn_base",
MODEL_TENSOR.HC_ATTN_SCALE: "blk.{bid}.hc_attn_scale",
MODEL_TENSOR.HC_FFN_FN: "blk.{bid}.hc_ffn_fn",
MODEL_TENSOR.HC_FFN_BASE: "blk.{bid}.hc_ffn_base",
MODEL_TENSOR.HC_FFN_SCALE: "blk.{bid}.hc_ffn_scale",
MODEL_TENSOR.ATTN_COMPRESSOR_WKV: "blk.{bid}.attn_compressor_kv",
MODEL_TENSOR.ATTN_COMPRESSOR_WGATE: "blk.{bid}.attn_compressor_gate",
MODEL_TENSOR.ATTN_COMPRESSOR_APE: "blk.{bid}.attn_compressor_ape",
MODEL_TENSOR.ATTN_COMPRESSOR_NORM: "blk.{bid}.attn_compressor_norm",
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm",
MODEL_TENSOR.DEC_ATTN_NORM: "dec.blk.{bid}.attn_norm",
@@ -1295,6 +1350,10 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.INDEXER_PROJ: "blk.{bid}.indexer.proj",
MODEL_TENSOR.INDEXER_ATTN_K: "blk.{bid}.indexer.attn_k",
MODEL_TENSOR.INDEXER_ATTN_Q_B: "blk.{bid}.indexer.attn_q_b",
MODEL_TENSOR.INDEXER_COMPRESSOR_WKV: "blk.{bid}.indexer_compressor_kv",
MODEL_TENSOR.INDEXER_COMPRESSOR_WGATE: "blk.{bid}.indexer_compressor_gate",
MODEL_TENSOR.INDEXER_COMPRESSOR_APE: "blk.{bid}.indexer_compressor_ape",
MODEL_TENSOR.INDEXER_COMPRESSOR_NORM: "blk.{bid}.indexer_compressor_norm",
# vision
MODEL_TENSOR.V_MMPROJ: "mm.{bid}",
MODEL_TENSOR.V_MMPROJ_FC: "mm.model.fc",
@@ -3135,6 +3194,49 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
],
MODEL_ARCH.DEEPSEEK4: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.HC_HEAD_FN,
MODEL_TENSOR.HC_HEAD_BASE,
MODEL_TENSOR.HC_HEAD_SCALE,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_SINKS,
MODEL_TENSOR.ATTN_Q_A,
MODEL_TENSOR.ATTN_Q_B,
MODEL_TENSOR.ATTN_Q_A_NORM,
MODEL_TENSOR.ATTN_KV,
MODEL_TENSOR.ATTN_KV_NORM,
MODEL_TENSOR.ATTN_OUT_A,
MODEL_TENSOR.ATTN_OUT_B,
MODEL_TENSOR.HC_ATTN_FN,
MODEL_TENSOR.HC_ATTN_BASE,
MODEL_TENSOR.HC_ATTN_SCALE,
MODEL_TENSOR.HC_FFN_FN,
MODEL_TENSOR.HC_FFN_BASE,
MODEL_TENSOR.HC_FFN_SCALE,
MODEL_TENSOR.ATTN_COMPRESSOR_WKV,
MODEL_TENSOR.ATTN_COMPRESSOR_WGATE,
MODEL_TENSOR.ATTN_COMPRESSOR_APE,
MODEL_TENSOR.ATTN_COMPRESSOR_NORM,
MODEL_TENSOR.INDEXER_PROJ,
MODEL_TENSOR.INDEXER_ATTN_Q_B,
MODEL_TENSOR.INDEXER_COMPRESSOR_WKV,
MODEL_TENSOR.INDEXER_COMPRESSOR_WGATE,
MODEL_TENSOR.INDEXER_COMPRESSOR_APE,
MODEL_TENSOR.INDEXER_COMPRESSOR_NORM,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_TID2EID,
MODEL_TENSOR.FFN_EXP_PROBS_B,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
],
MODEL_ARCH.ERNIE4_5_MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
@@ -4086,6 +4188,22 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FC,
MODEL_TENSOR.D2T,
],
MODEL_ARCH.DFLASH: [
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FC,
MODEL_TENSOR.ENC_OUTPUT_NORM,
],
MODEL_ARCH.MISTRAL4: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
@@ -4418,8 +4536,9 @@ class GGMLQuantizationType(IntEnum):
class ExpertGatingFuncType(IntEnum):
SOFTMAX = 1
SIGMOID = 2
SOFTMAX = 1
SIGMOID = 2
SQRTSOFTPLUS = 4
# TODO: add GGMLFileType from ggml_ftype in ggml.h
+36
View File
@@ -715,6 +715,9 @@ class GGUFWriter:
def add_full_attention_interval(self, interval: int) -> None:
self.add_uint32(Keys.LLM.FULL_ATTENTION_INTERVAL.format(arch=self.arch), interval)
def add_hash_layer_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.HASH_LAYER_COUNT.format(arch=self.arch), count)
def add_feed_forward_length(self, length: int | Sequence[int]) -> None:
if isinstance(length, int):
self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
@@ -940,6 +943,39 @@ class GGUFWriter:
def add_sliding_window(self, value: int) -> None:
self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value)
def add_block_size(self, value: int) -> None:
self.add_uint32(Keys.LLM.BLOCK_SIZE.format(arch=self.arch), value)
def add_target_layers(self, value: Sequence[int]) -> None:
self.add_array(Keys.LLM.TARGET_LAYERS.format(arch=self.arch), value)
def add_target_hidden_size(self, value: int) -> None:
self.add_uint32(Keys.LLM.TARGET_HIDDEN_SIZE.format(arch=self.arch), value)
def add_norm_before_residual(self, value: bool) -> None:
self.add_bool(Keys.LLM.NORM_BEFORE_RESIDUAL.format(arch=self.arch), value)
def add_attention_output_group_count(self, count: int) -> None:
self.add_uint32(Keys.Attention.OUTPUT_GROUP_COUNT.format(arch=self.arch), count)
def add_attention_output_lora_rank(self, length: int) -> None:
self.add_uint32(Keys.Attention.OUTPUT_LORA_RANK.format(arch=self.arch), length)
def add_attention_compress_ratios(self, values: Sequence[int]) -> None:
self.add_array(Keys.Attention.COMPRESS_RATIOS.format(arch=self.arch), values)
def add_attention_compress_rope_freq_base(self, value: float) -> None:
self.add_float32(Keys.Attention.COMPRESS_ROPE_FREQ_BASE.format(arch=self.arch), value)
def add_hyper_connection_count(self, count: int) -> None:
self.add_uint32(Keys.HyperConnection.COUNT.format(arch=self.arch), count)
def add_hyper_connection_sinkhorn_iterations(self, count: int) -> None:
self.add_uint32(Keys.HyperConnection.SINKHORN_ITERATIONS.format(arch=self.arch), count)
def add_hyper_connection_epsilon(self, value: float) -> None:
self.add_float32(Keys.HyperConnection.EPSILON.format(arch=self.arch), value)
def add_attention_scale(self, value: float) -> None:
self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value)
+5
View File
@@ -1283,6 +1283,11 @@ class TensorNameMap:
MODEL_TENSOR.ENC_OUTPUT_NORM: (
"encoder.final_layer_norm", # t5
"layer_norm", # neobert
"model.hidden_norm", # dflash
),
MODEL_TENSOR.FC: (
"model.fc", # dflash
),
MODEL_TENSOR.CLS: (
@@ -0,0 +1,112 @@
{%- if not add_generation_prompt is defined -%}
{%- set add_generation_prompt = false -%}
{%- endif -%}
{%- if not thinking is defined -%}
{%- if enable_thinking is defined -%}
{%- set thinking = enable_thinking -%}
{%- else -%}
{%- set thinking = false -%}
{%- endif -%}
{%- endif -%}
{%- set dsml_token = 'DSML' -%}
{%- set thinking_start_token = '<think>' -%}
{%- set thinking_end_token = '</think>' -%}
{%- set tools_header = '## Tools\n\nYou have access to a set of tools to help answer the user\'s question. You can invoke tools by writing a "<' + dsml_token + 'tool_calls>" block like the following:\n\n<' + dsml_token + 'tool_calls>\n<' + dsml_token + 'invoke name="$TOOL_NAME">\n<' + dsml_token + 'parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</' + dsml_token + 'parameter>\n...\n</' + dsml_token + 'invoke>\n<' + dsml_token + 'invoke name="$TOOL_NAME2">\n...\n</' + dsml_token + 'invoke>\n</' + dsml_token + 'tool_calls>\n\nString parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.\n\nIf thinking_mode is enabled (triggered by ' + thinking_start_token + '), you MUST output your complete reasoning inside ' + thinking_start_token + '...' + thinking_end_token + ' BEFORE any tool calls or final response.\n\nOtherwise, output directly after ' + thinking_end_token + ' with tool calls or final response.\n\n### Available Tool Schemas\n\n' -%}
{%- set tools_footer = '\nYou MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.\n' -%}
{%- set ns = namespace(system_prompt='', is_first_sp=true) -%}
{%- for message in messages -%}
{%- if message['role'] == 'system' -%}
{%- if ns.is_first_sp -%}
{%- set ns.system_prompt = ns.system_prompt + (message['content'] or '') -%}
{%- set ns.is_first_sp = false -%}
{%- else -%}
{%- set ns.system_prompt = ns.system_prompt + '\n\n' + (message['content'] or '') -%}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
{%- if tools is defined and tools -%}
{%- set ts = namespace(schemas='') -%}
{%- for tool in tools -%}
{%- if tool['type'] == 'function' -%}
{%- set ts.schemas = ts.schemas + (tool['function'] | tojson) + '\n' -%}
{%- endif -%}
{%- endfor -%}
{%- if ns.system_prompt -%}
{%- set ns.system_prompt = ns.system_prompt + '\n\n' + tools_header + ts.schemas + tools_footer -%}
{%- else -%}
{%- set ns.system_prompt = tools_header + ts.schemas + tools_footer -%}
{%- endif -%}
{%- endif -%}
{{- bos_token -}}
{{- ns.system_prompt -}}
{%- set last_user_idx = namespace(value=-1) -%}
{%- for message in messages -%}
{%- if message['role'] == 'user' or message['role'] == 'developer' or message['role'] == 'tool' -%}
{%- set last_user_idx.value = loop.index0 -%}
{%- endif -%}
{%- endfor -%}
{%- set state = namespace(in_user=false) -%}
{%- for message in messages -%}
{%- if message['role'] == 'user' or message['role'] == 'developer' -%}
{%- if state.in_user -%}
{{- '\n\n' -}}
{%- else -%}
{{- '<User>' -}}
{%- set state.in_user = true -%}
{%- endif -%}
{{- message['content'] or '' -}}
{%- elif message['role'] == 'tool' -%}
{%- if state.in_user -%}
{{- '\n\n' -}}
{%- else -%}
{{- '<User>' -}}
{%- set state.in_user = true -%}
{%- endif -%}
{{- '<tool_result>' + (message['content'] or '') + '</tool_result>' -}}
{%- elif message['role'] == 'assistant' -%}
{%- set state.in_user = false -%}
{{- '<Assistant>' -}}
{%- set is_after_last_user = loop.index0 > last_user_idx.value -%}
{%- if is_after_last_user and thinking -%}
{{- thinking_start_token -}}
{%- if message['reasoning_content'] is defined and message['reasoning_content'] -%}
{{- message['reasoning_content'] -}}
{%- endif -%}
{{- thinking_end_token -}}
{%- else -%}
{{- thinking_end_token -}}
{%- endif -%}
{%- if message['content'] is defined and message['content'] -%}
{{- message['content'] -}}
{%- endif -%}
{%- if message['tool_calls'] -%}
{{- '\n\n<' + dsml_token + 'tool_calls>\n' -}}
{%- for tool in message['tool_calls'] -%}
{%- set func = tool['function'] -%}
{{- '<' + dsml_token + 'invoke name="' + func['name'] + '">\n' -}}
{%- set args = func['arguments'] -%}
{%- if args is string -%}
{%- set args = args | from_json -%}
{%- endif -%}
{%- for key, val in args.items() -%}
{%- if val is string -%}
{{- '<' + dsml_token + 'parameter name="' + key + '" string="true">' + val + '</' + dsml_token + 'parameter>\n' -}}
{%- else -%}
{{- '<' + dsml_token + 'parameter name="' + key + '" string="false">' + (val | tojson) + '</' + dsml_token + 'parameter>\n' -}}
{%- endif -%}
{%- endfor -%}
{{- '</' + dsml_token + 'invoke>\n' -}}
{%- endfor -%}
{{- '</' + dsml_token + 'tool_calls>' -}}
{%- endif -%}
{{- '<end▁of▁sentence>' -}}
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{- '<Assistant>' -}}
{%- if thinking -%}
{{- thinking_start_token -}}
{%- else -%}
{{- thinking_end_token -}}
{%- endif -%}
{%- endif -%}
+179
View File
@@ -0,0 +1,179 @@
{{- bos_token }}{%- if tools %}
{%- set tool_definitions %}
{{- "# Tools\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
{%- for tool in tools %}
{{- "\n" }}
{{- tool | tojson(ensure_ascii=False) }}
{%- endfor %}
{{- '\n</tools>\n\nTool usage guidelines:\n- You may call zero or more functions. If no function calls are needed, just answer normally and do not include any <function ... </function>.\n- When calling a function, return an XML object within <function ... </function> using:\n<function name="function-name"><param name="param-name">param-value</param></function>\n- param-value may be multi-line. If it contains <, & or newline characters, wrap it in a CDATA block: <param name="param-name"><![CDATA[...multi-line value...]]></param>' }}
{%- endset %}
{{- '<|im_start|>system\n' }}
{%- if messages[0].role == 'system' %}
{%- if '<tool_def_sep>' in messages[0].content %}
{{- messages[0].content.replace('<tool_def_sep>', tool_definitions) }}
{%- else %}
{{- messages[0].content + '\n\n' + tool_definitions }}
{%- endif %}
{%- else %}
{{- tool_definitions.lstrip() }}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- else %}
{%- if messages[0].role == 'system' %}
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
{%- for message in messages[::-1] %}
{%- set index = (messages|length - 1) - loop.index0 %}
{%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
{%- set ns.multi_step_tool = false %}
{%- set ns.last_query_index = index %}
{%- endif %}
{%- endfor %}
{%- for message in messages %}
{%- if message.content is string %}
{%- set content = message.content %}
{%- else %}
{%- set content = '' %}
{%- endif %}
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
{%- elif message.role == "assistant" %}
{%- set reasoning_content = '' %}
{%- if message.reasoning_content is string %}
{%- set reasoning_content = message.reasoning_content %}
{%- else %}
{%- if '</think>' in content %}
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
{%- endif %}
{%- endif %}
{%- if message.tool_calls %}
{%- set content_parts = content.split('<tool_sep>') %}
{%- set processed_content = content_parts[0] %}
{%- set tool_calls_count = message.tool_calls|length %}
{%- set tool_sep_count = content_parts|length - 1 %}
{%- set min_count = [tool_calls_count, tool_sep_count]|min %}
{%- for i in range(1, content_parts|length) %}
{%- set tool_index = i - 1 %}
{%- if tool_index < tool_calls_count %}
{%- set tool_call = message.tool_calls[tool_index] %}
{%- if tool_call.function %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{%- set single_tool_xml %}
{{- '<function name="' ~ tool_call.name ~ '">' }}
{%- if tool_call.arguments %}
{%- set args_dict = tool_call.arguments %}
{%- for param_name, param_value in args_dict.items() %}
{{- '<param name="' ~ param_name ~ '">' }}
{%- if param_value is string and ('<' in param_value or '&' in param_value or '\n' in param_value) %}
{{- '<![CDATA[' + param_value + ']]>' }}
{%- else %}
{{- param_value }}
{%- endif %}
{{- '</param>' }}
{%- endfor %}
{%- endif %}
{{- '</function>' }}
{%- endset %}
{%- set processed_content = processed_content + single_tool_xml + content_parts[i] %}
{%- else %}
{%- set processed_content = processed_content + content_parts[i] %}
{%- endif %}
{%- endfor %}
{%- if tool_calls_count > tool_sep_count %}
{%- for remaining_index in range(tool_sep_count, tool_calls_count) %}
{%- set tool_call = message.tool_calls[remaining_index] %}
{%- if tool_call.function %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{%- set remaining_tool_xml %}
{{- '<function name="' ~ tool_call.name ~ '">' }}
{%- if tool_call.arguments %}
{%- set args_dict = tool_call.arguments %}
{%- for param_name, param_value in args_dict.items() %}
{{- '<param name="' ~ param_name ~ '">' }}
{%- if param_value is string and ('<' in param_value or '&' in param_value or '\n' in param_value) %}
{{- '<![CDATA[' + param_value + ']]>' }}
{%- else %}
{{- param_value }}
{%- endif %}
{{- '</param>' }}
{%- endfor %}
{%- endif %}
{{- '</function>' }}
{%- endset %}
{%- set processed_content = processed_content + remaining_tool_xml %}
{%- endfor %}
{%- endif %}
{%- set content = processed_content %}
{%- endif %}
{%- if loop.index0 > ns.last_query_index %}
{%- if reasoning_content %}
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{%- if message.tool_calls and not has_tool_sep %}
{%- for tool_call in message.tool_calls %}
{%- if (loop.first and content) or (not loop.first) %}
{{- '\n' }}
{%- endif %}
{%- if tool_call.function %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '<function name="' ~ tool_call.name ~ '">' }}
{%- if tool_call.arguments %}
{%- set args_dict = tool_call.arguments %}
{%- for param_name, param_value in args_dict.items() %}
{{- '<param name="' ~ param_name ~ '">' }}
{%- if param_value is string and ('<' in param_value or '&' in param_value or '\n' in param_value) %}
{{- '<![CDATA[' + param_value + ']]>' }}
{%- else %}
{{- param_value }}
{%- endif %}
{{- '</param>' }}
{%- endfor %}
{%- endif %}
{{- '</function>' }}
{%- endfor %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
{{- '<|im_start|>user' }}
{%- endif %}
{{- '\n<tool_response>\n' }}
{%- if message.content is string %}
{{- content }}
{%- else %}
{{- message.content | tojson(ensure_ascii=False) }}
{%- endif %}
{{- '\n</tool_response>' }}
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- if enable_thinking is defined %}
{%- if enable_thinking is false %}
{{- '<think>\n\n</think>\n\n' }}
{%- elif enable_thinking is true %}
{{- '<think>\n' }}
{%- endif %}
{%- endif %}
{%- endif %}
+4 -1
View File
@@ -69,13 +69,16 @@ mbuf=
mmsel=
[ "$MM" != "" ] && mmsel="GGML_HEXAGON_MM_SELECT=$MM"
fasel=
[ "$FA" != "" ] && fasel="GGML_HEXAGON_FA_SELECT=$FA"
set -x
adb $adbserial $adbhost shell " \
cd $basedir; ulimit -c unlimited; \
LD_LIBRARY_PATH=$basedir/$branch/lib \
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opflt $opfuse $vmem $mbuf $mmsel \
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opflt $opfuse $vmem $mbuf $mmsel $fasel \
./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
--ctx-size 8192 --ubatch-size 1024 -fa on \
+4 -1
View File
@@ -57,6 +57,9 @@ opfuse=
mmsel=
[ "$MM" != "" ] && mmsel="GGML_HEXAGON_MM_SELECT=$MM"
fasel=
[ "$FA" != "" ] && fasel="GGML_HEXAGON_FA_SELECT=$FA"
set -x
tool=$1; shift
@@ -65,5 +68,5 @@ adb $adbserial $adbhost shell " \
cd $basedir; ulimit -c unlimited; \
LD_LIBRARY_PATH=$basedir/$branch/lib \
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opfuse $mmsel ./$branch/bin/$tool $@ \
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opfuse $mmsel $fasel ./$branch/bin/$tool $@ \
"
@@ -230,6 +230,12 @@ def print_ascii_timeline(op_name, dims, types, usec, cycles, events, evt_val=Non
char = 'Q'
elif norm_evt == 'A-PREP':
char = 'A'
elif norm_evt == 'Q-PREP':
char = 'q'
elif norm_evt == 'K-PREP':
char = 'k'
elif norm_evt == 'V-PREP':
char = 'v'
elif norm_evt == 'W-DEQUANT':
char = 'D'
elif norm_evt == 'O-PROC':
+1
View File
@@ -25,6 +25,7 @@ add_library(llama
llama-kv-cache.cpp
llama-kv-cache-iswa.cpp
llama-kv-cache-dsa.cpp
llama-kv-cache-dsv4.cpp
llama-memory.cpp
llama-memory-hybrid.cpp
llama-memory-hybrid-iswa.cpp
+57
View File
@@ -77,6 +77,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
{ LLM_ARCH_DEEPSEEK2OCR, "deepseek2-ocr" },
{ LLM_ARCH_DEEPSEEK32, "deepseek32" },
{ LLM_ARCH_DEEPSEEK4, "deepseek4" },
{ LLM_ARCH_CHATGLM, "chatglm" },
{ LLM_ARCH_GLM4, "glm4" },
{ LLM_ARCH_GLM4_MOE, "glm4moe" },
@@ -129,6 +130,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
{ LLM_ARCH_MISTRAL3, "mistral3" },
{ LLM_ARCH_EAGLE3, "eagle3" },
{ LLM_ARCH_DFLASH, "dflash" },
{ LLM_ARCH_MISTRAL4, "mistral4" },
{ LLM_ARCH_PADDLEOCR, "paddleocr" },
{ LLM_ARCH_MIMO2, "mimo2" },
@@ -249,9 +251,19 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, "%s.attention.indexer.head_count" },
{ LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" },
{ LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" },
{ LLM_KV_ATTENTION_OUTPUT_GROUP_COUNT, "%s.attention.output_group_count" },
{ LLM_KV_ATTENTION_OUTPUT_LORA_RANK, "%s.attention.output_lora_rank" },
{ LLM_KV_ATTENTION_COMPRESS_ROPE_FREQ_BASE, "%s.attention.compress_rope_freq_base" },
{ LLM_KV_ATTENTION_COMPRESS_RATIOS, "%s.attention.compress_ratios" },
{ LLM_KV_ATTENTION_SHARED_KV_LAYERS, "%s.attention.shared_kv_layers" },
{ LLM_KV_ATTENTION_RECURRENT_LAYERS, "%s.attention.recurrent_layers" },
{ LLM_KV_HYPER_CONNECTION_COUNT, "%s.hyper_connection.count" },
{ LLM_KV_HYPER_CONNECTION_SINKHORN_ITERATIONS, "%s.hyper_connection.sinkhorn_iterations" },
{ LLM_KV_HYPER_CONNECTION_EPSILON, "%s.hyper_connection.epsilon" },
{ LLM_KV_HASH_LAYER_COUNT, "%s.hash_layer_count" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_COUNT_SWA, "%s.rope.dimension_count_swa" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
@@ -439,6 +451,23 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
{ LLM_TENSOR_ATTN_KV, "blk.%d.attn_kv" },
{ LLM_TENSOR_ATTN_KV_NORM, "blk.%d.attn_kv_a_norm" },
{ LLM_TENSOR_ATTN_OUT_A, "blk.%d.attn_output_a" },
{ LLM_TENSOR_ATTN_OUT_B, "blk.%d.attn_output_b" },
{ LLM_TENSOR_HC_HEAD_FN, "output_hc_fn" },
{ LLM_TENSOR_HC_HEAD_BASE, "output_hc_base" },
{ LLM_TENSOR_HC_HEAD_SCALE, "output_hc_scale" },
{ LLM_TENSOR_HC_ATTN_FN, "blk.%d.hc_attn_fn" },
{ LLM_TENSOR_HC_ATTN_BASE, "blk.%d.hc_attn_base" },
{ LLM_TENSOR_HC_ATTN_SCALE, "blk.%d.hc_attn_scale" },
{ LLM_TENSOR_HC_FFN_FN, "blk.%d.hc_ffn_fn" },
{ LLM_TENSOR_HC_FFN_BASE, "blk.%d.hc_ffn_base" },
{ LLM_TENSOR_HC_FFN_SCALE, "blk.%d.hc_ffn_scale" },
{ LLM_TENSOR_ATTN_COMPRESSOR_WKV, "blk.%d.attn_compressor_kv" },
{ LLM_TENSOR_ATTN_COMPRESSOR_WGATE, "blk.%d.attn_compressor_gate" },
{ LLM_TENSOR_ATTN_COMPRESSOR_APE, "blk.%d.attn_compressor_ape" },
{ LLM_TENSOR_ATTN_COMPRESSOR_NORM, "blk.%d.attn_compressor_norm" },
{ LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "per_layer_token_embd" },
{ LLM_TENSOR_PER_LAYER_MODEL_PROJ, "per_layer_model_proj" },
{ LLM_TENSOR_PER_LAYER_PROJ_NORM, "per_layer_proj_norm" },
@@ -565,6 +594,11 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
{ LLM_TENSOR_INDEXER_PROJ, "blk.%d.indexer.proj" },
{ LLM_TENSOR_INDEXER_ATTN_K, "blk.%d.indexer.attn_k" },
{ LLM_TENSOR_INDEXER_ATTN_Q_B, "blk.%d.indexer.attn_q_b" },
{ LLM_TENSOR_INDEXER_COMPRESSOR_WKV, "blk.%d.indexer_compressor_kv" },
{ LLM_TENSOR_INDEXER_COMPRESSOR_WGATE, "blk.%d.indexer_compressor_gate" },
{ LLM_TENSOR_INDEXER_COMPRESSOR_APE, "blk.%d.indexer_compressor_ape" },
{ LLM_TENSOR_INDEXER_COMPRESSOR_NORM, "blk.%d.indexer_compressor_norm" },
{ LLM_TENSOR_FFN_GATE_TID2EID, "blk.%d.ffn_gate_tid2eid" },
{ LLM_TENSOR_MASKED_EMBD_CENTROIDS, "masked_embd_centroids" },
{ LLM_TENSOR_MASKED_EMBD_ORDERING, "masked_embd_ordering" },
{ LLM_TENSOR_FC, "fc" },
@@ -615,6 +649,23 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ATTN_OUT_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_OUT_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_HC_HEAD_FN, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_HC_HEAD_BASE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_ADD}},
{LLM_TENSOR_HC_HEAD_SCALE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_HC_ATTN_FN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_HC_ATTN_BASE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_HC_ATTN_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_HC_FFN_FN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_HC_FFN_BASE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_HC_FFN_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ATTN_COMPRESSOR_WKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_COMPRESSOR_WGATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_COMPRESSOR_APE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_ATTN_COMPRESSOR_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_SINKS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SCALE}},
@@ -778,6 +829,11 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_COMPRESSOR_WKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_COMPRESSOR_WGATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_COMPRESSOR_APE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_INDEXER_COMPRESSOR_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_FFN_GATE_TID2EID, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_PROJ_PRE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_PROJ_POST, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
// NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the
@@ -932,6 +988,7 @@ bool llm_arch_supports_sm_tensor(const llm_arch & arch) {
case LLM_ARCH_OLMOE:
case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_DEEPSEEK32:
case LLM_ARCH_DEEPSEEK4:
case LLM_ARCH_GLM_DSA:
case LLM_ARCH_BITNET:
case LLM_ARCH_T5:
+34
View File
@@ -82,6 +82,7 @@ enum llm_arch {
LLM_ARCH_DEEPSEEK2,
LLM_ARCH_DEEPSEEK2OCR,
LLM_ARCH_DEEPSEEK32,
LLM_ARCH_DEEPSEEK4,
LLM_ARCH_CHATGLM,
LLM_ARCH_GLM4,
LLM_ARCH_GLM4_MOE,
@@ -143,6 +144,7 @@ enum llm_arch {
LLM_ARCH_TALKIE,
LLM_ARCH_MELLUM,
LLM_ARCH_EAGLE3,
LLM_ARCH_DFLASH,
LLM_ARCH_UNKNOWN,
};
@@ -254,9 +256,19 @@ enum llm_kv {
LLM_KV_ATTENTION_INDEXER_HEAD_COUNT,
LLM_KV_ATTENTION_INDEXER_KEY_LENGTH,
LLM_KV_ATTENTION_INDEXER_TOP_K,
LLM_KV_ATTENTION_OUTPUT_GROUP_COUNT,
LLM_KV_ATTENTION_OUTPUT_LORA_RANK,
LLM_KV_ATTENTION_COMPRESS_ROPE_FREQ_BASE,
LLM_KV_ATTENTION_COMPRESS_RATIOS,
LLM_KV_ATTENTION_SHARED_KV_LAYERS,
LLM_KV_ATTENTION_RECURRENT_LAYERS,
LLM_KV_HYPER_CONNECTION_COUNT,
LLM_KV_HYPER_CONNECTION_SINKHORN_ITERATIONS,
LLM_KV_HYPER_CONNECTION_EPSILON,
LLM_KV_HASH_LAYER_COUNT,
LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_DIMENSION_COUNT_SWA,
LLM_KV_ROPE_DIMENSION_SECTIONS,
@@ -500,10 +512,27 @@ enum llm_tensor {
LLM_TENSOR_ATTN_Q_B,
LLM_TENSOR_ATTN_KV_A_MQA,
LLM_TENSOR_ATTN_KV_B,
LLM_TENSOR_ATTN_KV,
LLM_TENSOR_ATTN_KV_NORM,
LLM_TENSOR_ATTN_OUT_A,
LLM_TENSOR_ATTN_OUT_B,
LLM_TENSOR_ATTN_K_B,
LLM_TENSOR_ATTN_V_B,
LLM_TENSOR_ATTN_Q_A_NORM,
LLM_TENSOR_ATTN_KV_A_NORM,
LLM_TENSOR_HC_HEAD_FN,
LLM_TENSOR_HC_HEAD_BASE,
LLM_TENSOR_HC_HEAD_SCALE,
LLM_TENSOR_HC_ATTN_FN,
LLM_TENSOR_HC_ATTN_BASE,
LLM_TENSOR_HC_ATTN_SCALE,
LLM_TENSOR_HC_FFN_FN,
LLM_TENSOR_HC_FFN_BASE,
LLM_TENSOR_HC_FFN_SCALE,
LLM_TENSOR_ATTN_COMPRESSOR_WKV,
LLM_TENSOR_ATTN_COMPRESSOR_WGATE,
LLM_TENSOR_ATTN_COMPRESSOR_APE,
LLM_TENSOR_ATTN_COMPRESSOR_NORM,
LLM_TENSOR_ATTN_SUB_NORM,
LLM_TENSOR_FFN_SUB_NORM,
LLM_TENSOR_DEC_ATTN_NORM,
@@ -565,6 +594,11 @@ enum llm_tensor {
LLM_TENSOR_INDEXER_PROJ,
LLM_TENSOR_INDEXER_ATTN_K,
LLM_TENSOR_INDEXER_ATTN_Q_B,
LLM_TENSOR_INDEXER_COMPRESSOR_WKV,
LLM_TENSOR_INDEXER_COMPRESSOR_WGATE,
LLM_TENSOR_INDEXER_COMPRESSOR_APE,
LLM_TENSOR_INDEXER_COMPRESSOR_NORM,
LLM_TENSOR_FFN_GATE_TID2EID,
LLM_TENSOR_NEXTN_PROJ_PRE,
LLM_TENSOR_NEXTN_PROJ_POST,
LLM_TENSOR_NEXTN_EH_PROJ,

Some files were not shown because too many files have changed in this diff Show More