Compare commits

..

19 Commits

Author SHA1 Message Date
liminfei-amd 1191758c5d vulkan: fail the build when a shader fails to compile (#24450)
* vulkan-shaders-gen: fail the build when a shader fails to compile

vulkan-shaders-gen did not detect shader-compile subprocess failures, so a
broken libggml-vulkan could be produced while the build reported success and
the breakage only surfaced at run time. execute_command() discarded the child
exit code (POSIX waitpid passed nullptr for status; the Windows branch never
called GetExitCodeProcess) and string_to_spv decided success only from whether
stderr was empty, so a non-zero exit with empty stderr, or a subprocess that
failed to launch, was treated as success.

Return the child exit code from execute_command() (WEXITSTATUS on POSIX,
GetExitCodeProcess on Windows), treat a non-zero exit or non-empty stderr or a
launch exception as a failure, and record it in an atomic flag. main() checks
the flag after process_shaders() and returns EXIT_FAILURE before writing the
output files, so the build stops instead of emitting a broken backend.

Fixes #24393

Signed-off-by: liminfei-amd <91481003+liminfei-amd@users.noreply.github.com>

* vulkan-shaders-gen: simplify compile_failed access and drop unreachable return

Address review feedback on #24450:
- Access the std::atomic<bool> compile_failed directly (= / implicit bool)
  instead of .store()/.load(); the flag stays atomic because the worker
  threads in process_shaders() set it concurrently.
- Remove the unreachable trailing return -1 in execute_command(): on POSIX the
  child _exit()s after execvp and the parent returns (fork()<0 throws); on
  Windows the block returns the exit code.

Signed-off-by: liminfei-amd <91481003+liminfei-amd@users.noreply.github.com>

---------

Signed-off-by: liminfei-amd <91481003+liminfei-amd@users.noreply.github.com>
2026-06-24 11:42:03 +02:00
Pascal 00139b660b ui: loading bar below the model picker (#24931)
* ui: show model load progress on the selector trigger

Mirror the in-dropdown stage progress as a thin bar on the selector
trigger, so the active model's load percent stays visible when the menu
is closed. Same status gating and composite fraction as the dropdown
row, so both bars track the selected model in sync.

Suggested-by: Julien Chaumond <@julien-c>

* ui: show model load progress bar on the in-conversation model selector

* ui: tune model load indicator to a pulsing highlight (suggested by @ngxson)

Also wire the indicator onto the mobile sheet trigger, which was missing
it since mobile uses the sheet instead of the dropdown.

* ui: thin (@allozaur) pulsating (@ngxson) model load bar
2026-06-24 10:50:44 +02:00
Aleksander Grygier ef9c13d4c2 ui: New Logo + Navigation cleanup & Mobile UI/UX improvements (#24897)
* chore: `npm audit fix --force`

* feat: Update sidebar toggle to use Logo

* refactor: Clean up favicon SVG

* feat: Refactor logo component and implement theme-aware favicon generation

* feat: Add configurable padding to generated PWA assets

* test: Add unit tests for writeThemeFavicons

* refactor: Componentization

* feat: WIP

* feat: WIP

* feat: WIP

* feat: Mobile UI

* feat: add SEARCH route constant

* feat: create SidebarNavigationSearchResults component

* refactor: use SidebarNavigationSearchResults in conversation list

* feat: enable mobile search navigation in sidebar actions

* feat: add mobile search route and page

* fix: prevent sidebar overflow on mobile viewports

* fix: Mobile sidebar

* feat: Mobile Search WIP

* feat: Mobile WIP

* feat: Add PWA standalone detection and refine mobile UI

* feat: Improve mobile layout, sidebar handling, and chat scrolling

* feat: Improve mobile sidebar visibility and iOS Safari chat spacing

* fix: Disable auto-scroll on mobile

* chore: Linting

* fix: Wrong condition

* feat: Mobile chat scroll

* refactor: WIP

* fix: Desktop initial scroll always working again

* fix: Partial fix for mobile auto-scroll / initial scroll

* fix: Desktop auto-scroll on initial load and during streaming

* fix: Mobile scrolling logic

* refactor: Clean up

* feat: Improve start UI

* feat: Add `delay` to `fadeInView`

* feat: Auto-scroll button

* refactor: Cleanup

* refactor: Extract chat dialogs and alerts into dedicated component

* refactor: Reorganize ChatScreen component structure and initialization

* feat: Improve auto-scroll after sending message

* feat: UI improvements

* fix: Settings link

* feat: UI improvements

* fix: better UI spacing

* fix: Remove unneeded logic

* fix: Chat Processing Info UI rendering

* feat: Improve mobile UI

* feat: UI improvement

* fix: Conditional transition delay for Chat Messages based on route from

* fix: Delay mobile sidebar collapse for smoother transitions

* fix: Mobile scroll down button + sidebar pointer events

* fix: Mobile UI

* fix: Auto scrolling

* fix: Implement dynamic height calculations for chat auto-scroll positioning and UI elements

* fix: Retrieve `autofocus` for Chat Form textarea

* fix: Use proper class

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* refactor: extract scroll-to-bottom logic and fix message send flow

* fix: update viewport store usage and remove conflicting autofocus

* feat: add accessibility labels to scroll down button

* fix: correct HTML structure in sidebar empty states

* fix: dynamically toggle processing info visibility

* chore: remove commented exports and fix formatting

* fix

* fix: Mobile Chat Form Add Action Sheet interactions

---------

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-06-24 10:21:33 +02:00
Tarek Dakhran 88636e178f model : Add LFM2.5-ColBERT-350M and LFM2.5-Embedding-350M (#24913)
* model : Add LFM2.5-ColBERT-350M and LFM2.5-Embedding-350M

* Restore LFM2 models in README.md
2026-06-24 09:49:46 +03:00
Jeff Bolz ac4105d68b vulkan: Apply bias before softmax in FA, to avoid overflow (#24909) 2026-06-23 22:34:00 -05:00
kononnable be4a6a63eb server : check draft context creation error (#24922) 2026-06-23 16:56:50 +02:00
Jeff Bolz 72a9269172 vulkan: support all backend tests for SQR/SQRT/SIN/COS/CLAMP/LEAKY_RELU/NORM (#24582)
* vulkan: make SQR/SQRT/SIN/COS/CLAMP/LEAKY_RELU use unary.comp

* vulkan: make NORM support noncontig

* add noncontiguous row test cases for norm/l2_norm, handle this in the CPU backend and l2_norm.comp

* fix supports_op for cuda and webgpu
2026-06-23 09:48:24 -05:00
Jeff Bolz 92e854ab83 vulkan: Support GET_ROWS_BACK (#24883) 2026-06-23 15:39:37 +02:00
Jeff Bolz c5606364b2 vulkan: support CONV_3D (#24612)
* vulkan: support CONV_3D

This is a pretty direct port of conv2d_mm.comp to CONV_3D, done by codex
and cleaned up by me.

* disable slower perf tests
2026-06-23 15:39:20 +02:00
Jeff Bolz 0eb874d374 vulkan: make mul_mm ALIGNED a spec constant (#24689)
This trims down some of the shader variant explosion and reduces binary size.
2026-06-23 14:26:17 +02:00
Xuan-Son Nguyen 75ad0b23ed server: fix remote preset handling, add test (#24938)
* server: add test for remote preset

* fix remote preset handling

* fix

* fix test
2026-06-23 13:28:34 +02:00
Wyatt Caldwell c926ad0985 vulkan: link ggml-cpu when GGML_VULKAN_CHECK_RESULTS / RUN_TESTS are enabled (#24444)
The result-checking and test debug paths in ggml-vulkan.cpp call ggml_graph_compute_with_ctx() to compute a CPU reference graph, but that symbol is defined in ggml-cpu, which ggml-vulkan does not link. Enabling -DGGML_VULKAN_CHECK_RESULTS=ON (or -DGGML_VULKAN_RUN_TESTS=ON) therefore fails to link with an unresolved external (e.g. LNK2019 on MSVC, undefined reference on GCC/Clang). This regressed after ggml-cpu was split into its own library. Link ggml-cpu under those two options so the debug builds link again.

Signed-off-by: Wyatt Caldwell <218154709+Detensable@users.noreply.github.com>
2026-06-23 12:55:46 +02:00
Gabe Goodhart a3900a6694 model: Granite Speech Plus (#24818)
* feat: Add conversion support for Granite Speech Plus

Branch: GraniteSpeechPlus
AI-usage: full (Bob, OpenCode + Qwen3.6-35b)
Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Extend granite_speech to support plus multi-layer concatenation

Branch: GraniteSpeechPlus
AI-usage: draft (Bob, OpenCode + Qwen3.6-35b)
Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix(conversion): Fix plural naming for feature_layers for audio

Branch: GraniteSpeechPlus
AI-usage: none
Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix(mtmd): Align feature_layer usage and naming everywhere

Branch: GraniteSpeechPlus
AI-usage: none
Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* style: Use fstring for log

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>

---------

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
2026-06-23 12:03:31 +02:00
Masashi Yoshimura 7c908502ea ggml-webgpu: improve MTP inference by using mat-vec path for small batches (#24811)
* ggml-webgpu: improve small batches decoding

* Add barrier to the NUM_COLS loop in mul-mat-vec
2026-06-23 17:13:55 +09:00
Masashi Yoshimura 035cd8f9a6 codeowners: add yomaytk to ggml-webgpu (#24930) 2026-06-23 15:19:34 +09:00
Aldehir Rojas 73618f27a8 server: improve user message detection and create checkpoints at every user message (#24176)
* server : improve message span logic

* cont : cast size_t to int32_t in comparisons

* server : create checkpoints before every user msg

* chat : remove \n in gemma4 delimiters

* chat : merge msg delimiter structs into one

* cont : reword comment

* cont : initialize tokens in delimiter

* cont : add server_tokens::get_raw_tokens() for mtmd

* cont : move message finding to server_tokens and skip mtmd tokens

* cont : update cohere2moe parser

* cont : increase min-step to 8192 and always produce a chkpt for last user message
2026-06-23 08:27:28 +03:00
Shawn Gu 23ee8797e1 opencl: q8_0 gemv precision improvement (#24923) 2026-06-22 22:25:21 -07:00
Matt Thompson dec5ca5577 server : Add id to tool call responses api (#24882) 2026-06-22 23:03:12 +02:00
Mahdiou Diallo 9c0ac887f3 ui: Prioritize favorite models in model selection (#24766)
Updated model selection prioritization to include favorite models.
2026-06-22 21:00:21 +02:00
152 changed files with 4519 additions and 3332 deletions
+1 -1
View File
@@ -10,7 +10,7 @@
# ggml-org/ggml-rpc : rgerganov
# ggml-org/ggml-sycl : arthw
# ggml-org/ggml-vulkan : 0cc4m, jeffbolznv
# ggml-org/ggml-webgpu : reeselevine
# ggml-org/ggml-webgpu : reeselevine, yomaytk
# ggml-org/ggml-zdnn : taronaeo
# ggml-org/llama-common : ggerganov, aldehir, angt, danbev, ngxson, pwilkin
# ggml-org/llama-mtmd : ngxson
+3 -1
View File
@@ -142,7 +142,9 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct)
- [X] [Trillion-7B-preview](https://huggingface.co/trillionlabs/Trillion-7B-preview)
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38)
- [x] [Liquid LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2)
- [x] [Liquid LFM2.5 models](https://huggingface.co/collections/LiquidAI/lfm25)
- [x] [Liquid Nanos](https://huggingface.co/collections/LiquidAI/liquid-nanos)
- [x] [Hunyuan models](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7)
- [x] [BailingMoeV2 (Ring/Ling 2.0) models](https://huggingface.co/collections/inclusionAI/ling-v2-68bf1dd2fc34c306c1fa6f86)
- [x] [Mellum models](https://huggingface.co/JetBrains/models?search=mellum)
+7 -4
View File
@@ -301,6 +301,8 @@ static handle_model_result common_params_handle_model(struct common_params_model
const common_download_opts & opts) {
handle_model_result result;
// TODO @ngxson : refactor this into a new common_model_download_context
if (!model.docker_repo.empty()) {
model.path = common_docker_resolve_model(model.docker_repo);
} else if (!model.hf_repo.empty()) {
@@ -396,7 +398,7 @@ static bool parse_bool_value(const std::string & value) {
// CLI argument parsing functions
//
bool common_params_handle_models(common_params & params, llama_example curr_ex, common_download_callback * callback) {
bool common_params_handle_models(common_params & params, llama_example curr_ex, const common_params_handle_models_params & handle_params) {
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
params.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();
@@ -407,9 +409,10 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex,
opts.skip_download = params.skip_download;
opts.download_mtp = spec_type_draft_mtp;
opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty();
opts.preset_only = handle_params.preset_only;
if (callback) {
opts.callback = callback;
if (handle_params.callback) {
opts.callback = handle_params.callback;
}
// sub-models (draft, mmproj, vocoder) are explicitly specified by the user,
@@ -596,7 +599,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
if (!skip_model_download) {
// handle model and download
common_params_handle_models(params, ctx_arg.ex);
common_params_handle_models(params, ctx_arg.ex, {});
// model is required (except for server)
// TODO @ngxson : maybe show a list of available models in CLI in this case
+6 -1
View File
@@ -130,6 +130,11 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
// see: https://github.com/ggml-org/llama.cpp/issues/18163
void common_params_add_preset_options(std::vector<common_arg> & args);
struct common_params_handle_models_params {
common_download_callback * callback = nullptr;
bool preset_only = false; // if true, only check & download remote preset (for router mode)
};
// populate model paths (main model, mmproj, etc) from -hf if necessary
// return true if the model is ready to use
// throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc)
@@ -137,7 +142,7 @@ void common_params_add_preset_options(std::vector<common_arg> & args);
bool common_params_handle_models(
common_params & params,
llama_example curr_ex,
common_download_callback * callback = nullptr);
const common_params_handle_models_params & handle_params);
// initialize argument parser context - used by test-arg-parser and preset
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
+103 -53
View File
@@ -90,41 +90,93 @@ std::string common_chat_msg::render_content(const std::string & delimiter) const
return text;
}
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims) {
if (delims.empty() || prompt.empty()) {
return {};
common_chat_role common_chat_role_from_string(const std::string & role) {
if (role == "system") { return COMMON_CHAT_ROLE_SYSTEM; }
if (role == "assistant") { return COMMON_CHAT_ROLE_ASSISTANT; }
if (role == "user") { return COMMON_CHAT_ROLE_USER; }
if (role == "tool") { return COMMON_CHAT_ROLE_TOOL; }
return COMMON_CHAT_ROLE_UNKNOWN;
}
const char * common_chat_role_to_string(common_chat_role role) {
switch (role) {
case COMMON_CHAT_ROLE_SYSTEM: return "system";
case COMMON_CHAT_ROLE_ASSISTANT: return "assistant";
case COMMON_CHAT_ROLE_USER: return "user";
case COMMON_CHAT_ROLE_TOOL: return "tool";
case COMMON_CHAT_ROLE_UNKNOWN: return "";
}
return "";
}
json common_chat_msg_delimiters::to_json() const {
json result = json::array();
for (const auto & d : delimiters) {
result.push_back({
{ "role", common_chat_role_to_string(d.role) },
{ "delimiter", d.delimiter },
});
}
return result;
}
common_chat_msg_delimiters common_chat_msg_delimiters_parse(const json & delimiters) {
common_chat_msg_delimiters result;
if (!delimiters.is_array()) {
return result;
}
auto parser = build_peg_parser([&](common_peg_parser_builder & p) {
std::vector<std::string> all_delims;
std::vector<common_peg_parser> tagged_messages;
all_delims.reserve(delims.size());
tagged_messages.reserve(delims.size());
for (const auto & d : delims) {
all_delims.push_back(d.delimiter);
result.delimiters.reserve(delimiters.size());
for (const auto & d : delimiters) {
if (!d.is_object()) {
continue;
}
auto any_delim = p.until_one_of(all_delims);
for (const auto & d : delims) {
tagged_messages.push_back(p.tag(d.role, p.literal(d.delimiter) + any_delim));
}
return any_delim + p.zero_or_more(p.choice(tagged_messages)) + p.end();
});
common_peg_parse_context ctx(prompt);
const auto result = parser.parse(ctx);
if (!result.success()) {
return {};
result.delimiters.push_back({
common_chat_role_from_string(d.value("role", std::string())),
d.value("delimiter", std::string()),
});
}
std::vector<common_chat_msg_span> spans;
ctx.ast.visit(result, [&](const common_peg_ast_node & node) {
if (!node.tag.empty()) {
spans.push_back({ node.tag, node.start, node.end - node.start });
return result;
}
void common_chat_msg_delimiters::tokenize(const llama_vocab * vocab) {
for (auto & d : delimiters) {
d.tokens = common_tokenize(vocab, d.delimiter, false, true);
}
}
common_chat_msg_spans common_chat_msg_delimiters::split(const llama_tokens & tokens, const std::map<size_t, size_t> & skips) const {
std::vector<std::pair<common_chat_role, size_t>> matches;
auto skip = skips.begin();
for (size_t i = 0; i < tokens.size();) {
if (skip != skips.end() && i == skip->first) {
i += skip->second;
++skip;
continue;
}
});
for (const auto & d : delimiters) {
if (i + d.tokens.size() > tokens.size()) {
continue;
}
if (std::equal(d.tokens.begin(), d.tokens.end(), tokens.begin() + i)) {
matches.emplace_back(d.role, i);
break;
}
}
i++;
}
matches.emplace_back(COMMON_CHAT_ROLE_UNKNOWN, tokens.size());
common_chat_msg_spans spans;
for (size_t i = 0; i + 1 < matches.size(); i++) {
const auto & curr = matches[i];
const auto & next = matches[i + 1];
spans.add(curr.first, curr.second, next.second - curr.second);
}
return spans;
}
@@ -1081,13 +1133,13 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
data.prompt = prompt;
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override= */ adjusted_messages);
data.message_spans = common_chat_split_by_role(prompt, {
{ "assistant", "<|start|>assistant" },
{ "user", "<|start|>user" },
{ "system", "<|start|>developer" },
{ "system", "<|start|>system" },
{ "tool", "<|start|>functions" },
});
data.message_delimiters = {
{ COMMON_CHAT_ROLE_ASSISTANT, "<|start|>assistant" },
{ COMMON_CHAT_ROLE_USER, "<|start|>user" },
{ COMMON_CHAT_ROLE_SYSTEM, "<|start|>developer" },
{ COMMON_CHAT_ROLE_SYSTEM, "<|start|>system" },
{ COMMON_CHAT_ROLE_TOOL, "<|start|>functions" },
};
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;
@@ -1228,10 +1280,10 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
data.prompt += data.generation_prompt;
}
data.message_spans = common_chat_split_by_role(data.prompt, {
{ "user", "<|turn>user\n" },
{ "assistant", "<|turn>model\n" },
});
data.message_delimiters = {
{ COMMON_CHAT_ROLE_USER, "<|turn>user" },
{ COMMON_CHAT_ROLE_ASSISTANT, "<|turn>model" },
};
data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4;
data.supports_thinking = true;
@@ -2030,15 +2082,15 @@ static common_chat_params common_chat_params_init_cohere2moe(const common_chat_t
RESULT_START, RESULT_END,
};
// Split the rendered prompt into per-role message spans. Tool results are rendered with the
// Declare per-role message delimiters. Tool results are rendered with the
// system token followed by <|START_TOOL_RESULT|>, so the "tool" delimiter must be listed before
// the plain "system" one (it is a strict superset, and the role split tries delimiters in order).
data.message_spans = common_chat_split_by_role(data.prompt, {
{ "assistant", GEN_PREFIX },
{ "user", TURN_START + USER },
{ "tool", TURN_START + SYSTEM + RESULT_START },
{ "system", TURN_START + SYSTEM },
});
data.message_delimiters = {
{ COMMON_CHAT_ROLE_ASSISTANT, GEN_PREFIX },
{ COMMON_CHAT_ROLE_USER, TURN_START + USER },
{ COMMON_CHAT_ROLE_TOOL, TURN_START + SYSTEM + RESULT_START },
{ COMMON_CHAT_ROLE_SYSTEM, TURN_START + SYSTEM },
};
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
@@ -2526,17 +2578,15 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
autoparser.analyze_template(tmpl);
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
std::vector<common_chat_msg_delimiter> delimiters;
common_chat_msg_delimiters delimiters;
if (!autoparser.assistant_start.empty()) {
delimiters.push_back({ "assistant", autoparser.assistant_start });
delimiters.add(COMMON_CHAT_ROLE_ASSISTANT, autoparser.assistant_start);
}
if (!autoparser.user_start.empty()) {
delimiters.push_back({ "user", autoparser.user_start });
delimiters.add(COMMON_CHAT_ROLE_USER, autoparser.user_start);
}
if (!delimiters.empty()) {
auto_params.message_spans = common_chat_split_by_role(auto_params.prompt, delimiters);
}
auto_params.message_delimiters = std::move(delimiters);
auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
if (auto_params.supports_thinking) {
+65 -6
View File
@@ -143,15 +143,75 @@ struct common_chat_msg_diff {
}
};
enum common_chat_role {
COMMON_CHAT_ROLE_UNKNOWN,
COMMON_CHAT_ROLE_SYSTEM,
COMMON_CHAT_ROLE_ASSISTANT,
COMMON_CHAT_ROLE_USER,
COMMON_CHAT_ROLE_TOOL
};
common_chat_role common_chat_role_from_string(const std::string & role);
const char * common_chat_role_to_string(common_chat_role role);
struct common_chat_msg_span {
std::string role;
common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN;
std::size_t pos = 0;
std::size_t len = 0;
bool valid() const {
return role != COMMON_CHAT_ROLE_UNKNOWN;
}
};
struct common_chat_msg_spans {
std::vector<common_chat_msg_span> spans;
void add(common_chat_role role, size_t pos, size_t len) {
spans.push_back({ role, pos, len });
}
bool is_user_start(int32_t pos) const {
for (auto it = spans.begin(); it != spans.end(); ++it) {
if (it->role == COMMON_CHAT_ROLE_USER && pos == (int32_t) it->pos) {
return true;
}
}
return false;
}
int32_t last_user_message_pos() const {
for (auto it = spans.rbegin(); it != spans.rend(); ++it) {
if (it->role == COMMON_CHAT_ROLE_USER) {
return (int32_t) it->pos;
}
}
return -1;
}
};
struct common_chat_msg_delimiter {
std::string role;
std::string delimiter;
common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN;
std::string delimiter;
llama_tokens tokens = {};
};
struct common_chat_msg_delimiters {
std::vector<common_chat_msg_delimiter> delimiters;
common_chat_msg_delimiters() = default;
common_chat_msg_delimiters(std::initializer_list<common_chat_msg_delimiter> delims) : delimiters(delims) {}
void add(common_chat_role role, const std::string & delimiter) {
delimiters.push_back({ role, delimiter });
}
void tokenize(const llama_vocab * vocab);
// split tokens into message spans. skips maps a start index to a length of a region to jump over without matching
common_chat_msg_spans split(const llama_tokens & tokens, const std::map<size_t, size_t> & skips = {}) const;
nlohmann::ordered_json to_json() const;
};
struct common_chat_tool {
@@ -219,7 +279,7 @@ struct common_chat_params {
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;
std::string parser;
std::vector<common_chat_msg_span> message_spans;
common_chat_msg_delimiters message_delimiters;
};
// per-message parsing syntax
@@ -325,5 +385,4 @@ struct common_chat_prompt_preset {
common_chat_prompt_preset common_chat_get_asr_prompt(const common_chat_templates * chat_templates);
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims);
common_chat_msg_delimiters common_chat_msg_delimiters_parse(const nlohmann::ordered_json & delimiters);
+1 -1
View File
@@ -609,7 +609,7 @@ struct common_params {
bool cache_prompt = true; // whether to enable prompt caching
bool cache_idle_slots = true; // save and clear idle slots upon starting a new task
int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot
int32_t checkpoint_min_step = 256; // minimum spacing between context checkpoints
int32_t checkpoint_min_step = 8192; // minimum spacing between context checkpoints
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
std::string hostname = "127.0.0.1";
+3 -1
View File
@@ -799,6 +799,7 @@ common_download_model_result common_download_model(const common_params_model &
bool download_mmproj = opts.download_mmproj;
bool download_mtp = opts.download_mtp;
bool preset_only = opts.preset_only;
bool is_hf = !model.hf_repo.empty();
if (is_hf) {
@@ -806,7 +807,8 @@ common_download_model_result common_download_model(const common_params_model &
if (!hf.preset.path.empty()) {
// if preset.ini exists, only download that file alone
tasks.push_back({hf.preset.url, hf.preset.local_path});
} else {
} else if (!preset_only) {
// only add other files if we're NOT in preset-only mode (normal run, non-router)
for (const auto & f : hf.model_files) {
tasks.push_back({f.url, f.local_path});
}
+1
View File
@@ -55,6 +55,7 @@ struct common_download_opts {
bool skip_download = false; // if true, only validation is performed, common_skip_download_exception may be thrown if the file is missing or invalid
bool download_mmproj = false;
bool download_mtp = false;
bool preset_only = false; // if true, only check & download remote preset (for router mode)
common_download_callback * callback = nullptr;
};
+3
View File
@@ -96,6 +96,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
"GraniteMoeHybridForCausalLM": "granite",
"GraniteMoeSharedForCausalLM": "granite",
"GraniteSpeechForConditionalGeneration": "granite",
"GraniteSpeechPlusForConditionalGeneration": "granite",
"Grok1ForCausalLM": "grok",
"GrokForCausalLM": "grok",
"GroveMoeForCausalLM": "grovemoe",
@@ -123,6 +124,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
"LLaDAModelLM": "llada",
"LLaMAForCausalLM": "llama",
"Lfm25AudioTokenizer": "lfm2",
"Lfm2BidirectionalModel": "lfm2",
"Lfm2ForCausalLM": "lfm2",
"Lfm2Model": "lfm2",
"Lfm2MoeForCausalLM": "lfm2",
@@ -261,6 +263,7 @@ MMPROJ_MODEL_MAP: dict[str, str] = {
"GlmasrModel": "ultravox",
"Granite4VisionForConditionalGeneration": "granite",
"GraniteSpeechForConditionalGeneration": "granite",
"GraniteSpeechPlusForConditionalGeneration": "granite",
"HunYuanVLForConditionalGeneration": "hunyuan",
"Idefics3ForConditionalGeneration": "smolvlm",
"InternVisionModel": "internvl",
+28
View File
@@ -348,6 +348,34 @@ class GraniteSpeechMmprojModel(MmprojModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("GraniteSpeechPlusForConditionalGeneration")
class GraniteSpeechPlusMmprojModel(GraniteSpeechMmprojModel):
"""Conversion for GraniteSpeechPlus - extends GraniteSpeech with feature layer concatenation"""
has_vision_encoder = False
has_audio_encoder = True
def set_gguf_parameters(self):
assert self.hparams_audio is not None
super().set_gguf_parameters()
# Add feature_layer if present in encoder config
if feature_layers := self.hparams_audio.get("cat_hidden_layers"):
self.gguf_writer.add_audio_feature_layers(feature_layers)
logger.info(f"gguf: audio feature_layers = {feature_layers}")
# Validate projector dimension matches concatenated encoder output
hidden_dim = self.hparams_audio["hidden_dim"]
expected_dim = hidden_dim * (len(feature_layers) + 1)
projector_dim = self.global_config["projector_config"]["encoder_hidden_size"]
if projector_dim != expected_dim:
raise ValueError(
f"Projector encoder_hidden_size ({projector_dim}) does not match "
f"expected concatenated dimension ({expected_dim}). "
f"Expected: hidden_dim ({hidden_dim}) * (len(feature_layers) + 1) = {expected_dim}"
)
@ModelBase.register("Granite4VisionForConditionalGeneration")
class Granite4VisionMmprojModel(MmprojModel):
has_vision_encoder = True
+10 -3
View File
@@ -64,11 +64,17 @@ class LFM2Model(TextModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Lfm2Model")
@ModelBase.register("Lfm2Model", "Lfm2BidirectionalModel")
class LFM2ColBertModel(LFM2Model):
model_arch = gguf.MODEL_ARCH.LFM2
dense_tensor_name = "dense_2"
def set_gguf_parameters(self):
super().set_gguf_parameters()
if self.hf_arch == "Lfm2BidirectionalModel":
self.gguf_writer.add_causal_attention(False)
self._try_set_pooling_type()
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if not name.startswith(self.dense_tensor_name):
name = "model." + name
@@ -76,10 +82,11 @@ class LFM2ColBertModel(LFM2Model):
yield from super().modify_tensors(data_torch, name, bid)
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
# dense tensor is stored in a separate safetensors file
# optional dense tensor is stored in a separate safetensors file
from safetensors.torch import load_file
tensors_file = self.dir_model / "1_Dense" / "model.safetensors"
assert tensors_file.is_file()
if not tensors_file.is_file():
return
tensor = load_file(tensors_file)["linear.weight"]
self.gguf_writer.add_embedding_length_out(tensor.shape[0])
yield f"{self.dense_tensor_name}.weight", tensor.clone()
+50 -23
View File
@@ -3688,8 +3688,6 @@ static void ggml_compute_forward_norm_f32(
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
@@ -3703,25 +3701,49 @@ static void ggml_compute_forward_norm_f32(
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
const char * x = (const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
char * y = (char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
float sum = 0.0;
ggml_vec_sum_f32(ne00, &sum, x);
float mean = sum/ne00;
if (nb00 == sizeof(float) && nb0 == sizeof(float)) {
const float * xf = (const float *) x;
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
float variance = 0;
float sum = 0.0;
ggml_vec_sum_f32(ne00, &sum, xf);
float mean = sum/ne00;
float * yf = (float *) y;
float variance = 0;
#ifdef GGML_USE_ACCELERATE
mean = -mean;
vDSP_vsadd(x, 1, &mean, y, 1, ne00);
vDSP_measqv(y, 1, &variance, ne00);
mean = -mean;
vDSP_vsadd(xf, 1, &mean, yf, 1, ne00);
vDSP_measqv(yf, 1, &variance, ne00);
#else
variance = ggml_vec_cvar_f32(ne00, y, x, mean);
variance = ggml_vec_cvar_f32(ne00, yf, xf, mean);
#endif //GGML_USE_ACCELERATE
const float scale = 1.0f/sqrtf(variance + eps);
ggml_vec_scale_f32(ne00, y, scale);
const float scale = 1.0f/sqrtf(variance + eps);
ggml_vec_scale_f32(ne00, yf, scale);
} else {
float sum = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
sum += *(const float *) (x + i00*nb00);
}
const float mean = sum/ne00;
float variance = 0.0f;
for (int64_t i00 = 0; i00 < ne00; i00++) {
const float v = *(const float *) (x + i00*nb00) - mean;
*(float *) (y + i00*nb0) = v;
variance += v * v;
}
variance /= ne00;
const float scale = 1.0f/sqrtf(variance + eps);
for (int64_t i00 = 0; i00 < ne00; i00++) {
*(float *) (y + i00*nb0) *= scale;
}
}
}
}
}
@@ -4142,8 +4164,6 @@ static void ggml_compute_forward_l2_norm_f32(
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
@@ -4158,20 +4178,27 @@ static void ggml_compute_forward_l2_norm_f32(
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
const char * x = (const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
ggml_float sum = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
sum += (ggml_float)(x[i00] * x[i00]);
const float xi = *(const float *) (x + i00*nb00);
sum += (ggml_float)(xi * xi);
}
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
memcpy(y, x, ne00 * sizeof(float));
const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
ggml_vec_scale_f32(ne00, y, scale);
char * y = (char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
if (nb00 == sizeof(float) && nb0 == sizeof(float)) {
memcpy(y, x, ne00 * sizeof(float));
ggml_vec_scale_f32(ne00, (float *) y, scale);
} else {
for (int64_t i00 = 0; i00 < ne00; i00++) {
const float xi = *(const float *) (x + i00*nb00);
*(float *) (y + i00*nb0) = xi * scale;
}
}
}
}
}
+1 -1
View File
@@ -5334,7 +5334,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_L2_NORM:
return true;
return ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_RMS_NORM_BACK:
return ggml_is_contiguous(op->src[0]);
break;
@@ -174,7 +174,7 @@ __kernel void kernel_gemv_noshuffle_q8_0_f32(
regA.s6 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
regA.s7 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, regS, regB);
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, convert_float(regS), regB);
}
// reduction in local memory, assumes #wave=4
+5
View File
@@ -108,6 +108,9 @@ if (Vulkan_FOUND)
if (GGML_VULKAN_CHECK_RESULTS)
add_compile_definitions(GGML_VULKAN_CHECK_RESULTS)
# the result-checking path computes a CPU reference graph via
# ggml_graph_compute_with_ctx(), which is defined in ggml-cpu
target_link_libraries(ggml-vulkan PRIVATE ggml-cpu)
endif()
if (GGML_VULKAN_DEBUG)
@@ -129,6 +132,8 @@ if (Vulkan_FOUND)
if (GGML_VULKAN_RUN_TESTS)
add_compile_definitions(GGML_VULKAN_RUN_TESTS)
# the test path also calls ggml_graph_compute_with_ctx() (ggml-cpu)
target_link_libraries(ggml-vulkan PRIVATE ggml-cpu)
endif()
# Set up toolchain for host compilation whether cross-compiling or not
+362 -66
View File
@@ -493,6 +493,20 @@ struct vk_conv2d_pipeline_state {
}
};
struct vk_conv3d_pipeline_state {
vk_conv3d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t s2, uint32_t p0, uint32_t p1, uint32_t p2,
uint32_t d0, uint32_t d1, uint32_t d2, uint32_t KW, uint32_t KH, uint32_t KD, uint32_t aligned)
: s0(s0), s1(s1), s2(s2), p0(p0), p1(p1), p2(p2), d0(d0), d1(d1), d2(d2), KW(KW), KH(KH), KD(KD), aligned(aligned) {}
uint32_t s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD;
uint32_t aligned;
bool operator<(const vk_conv3d_pipeline_state &b) const {
return std::tie(s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD, aligned) <
std::tie(b.s0, b.s1, b.s2, b.p0, b.p1, b.p2, b.d0, b.d1, b.d2, b.KW, b.KH, b.KD, b.aligned);
}
};
struct vk_solve_tri_pipeline_state {
vk_solve_tri_pipeline_state(uint32_t N, uint32_t K)
: N(N), K(K) {}
@@ -777,6 +791,7 @@ struct vk_device_struct {
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
vk_pipeline pipeline_get_rows_back_f32;
vk_pipeline pipeline_acc_f32;
vk_pipeline pipeline_set_f32;
@@ -801,14 +816,10 @@ struct vk_device_struct {
vk_pipeline pipeline_concat_i8, pipeline_concat_i16, pipeline_concat_i32, pipeline_concat_i64;
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32, pipeline_upscale_bilinear_antialias_f32;
vk_pipeline pipeline_scale_f32;
vk_pipeline pipeline_sqr_f32;
vk_pipeline pipeline_sqrt_f32;
vk_pipeline pipeline_sin_f32;
vk_pipeline pipeline_cos_f32;
vk_pipeline pipeline_log[2];
vk_pipeline pipeline_tri[2];
vk_pipeline pipeline_diag[2];
vk_pipeline pipeline_clamp_f32;
vk_pipeline pipeline_clamp[2];
vk_pipeline pipeline_pad_f32;
vk_pipeline pipeline_roll_f32;
vk_pipeline pipeline_repeat_i32, pipeline_repeat_back_f32;
@@ -840,6 +851,10 @@ struct vk_device_struct {
vk_pipeline pipeline_gelu_quick[2];
vk_pipeline pipeline_silu[2];
vk_pipeline pipeline_relu[2];
vk_pipeline pipeline_sqr[2];
vk_pipeline pipeline_sqrt[2];
vk_pipeline pipeline_sin[2];
vk_pipeline pipeline_cos[2];
vk_pipeline pipeline_xielu[2];
vk_pipeline pipeline_neg[2];
vk_pipeline pipeline_tanh[2];
@@ -871,7 +886,7 @@ struct vk_device_struct {
vk_pipeline pipeline_geglu_erf[2];
vk_pipeline pipeline_geglu_quick[2];
vk_pipeline pipeline_leaky_relu_f32;
vk_pipeline pipeline_leaky_relu[2];
vk_pipeline pipeline_silu_back_f32;
vk_pipeline pipeline_diag_mask_inf_f32;
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
@@ -924,6 +939,8 @@ struct vk_device_struct {
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];
std::map<vk_conv3d_pipeline_state, vk_pipeline> pipeline_conv3d_f32[CONV_SHAPE_COUNT];
std::map<vk_conv3d_pipeline_state, vk_pipeline> pipeline_conv3d_f16_f32[CONV_SHAPE_COUNT];
vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
@@ -1669,6 +1686,41 @@ template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
}
struct vk_op_conv3d_push_constants {
uint32_t OC;
uint32_t IC;
uint32_t N;
uint32_t IW;
uint32_t IH;
uint32_t ID;
uint32_t OW;
uint32_t OH;
uint32_t OD;
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb1;
uint32_t nb2;
uint32_t nb3;
uint32_t OWmp; uint32_t OWL;
uint32_t OWOHmp; uint32_t OWOHL;
uint32_t OWOHODmp; uint32_t OWOHODL;
};
template <> void init_pushconst_fastdiv(vk_op_conv3d_push_constants &p) {
init_fastdiv_values(p.OW, p.OWmp, p.OWL);
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
init_fastdiv_values(p.OW*p.OH*p.OD, p.OWOHODmp, p.OWOHODL);
}
struct vk_op_conv2d_dw_push_constants {
uint32_t ne;
uint32_t batches;
@@ -4074,19 +4126,35 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
}
#endif
auto const &ggml_vk_mul_mm_spec = [](std::vector<uint32_t> spec, bool aligned) {
spec.push_back(aligned ? 1u : 0u);
return spec;
};
const int mul_mat_id_param_count = 5;
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (device->coopmat2) {
auto const &ggml_vk_mul_mm_cm2_spec = [](std::vector<uint32_t> spec, bool aligned, bool mul_mat_id) {
if (mul_mat_id && spec.size() > 5) {
spec.insert(spec.begin() + 5, aligned ? 1u : 0u);
} else {
spec.push_back(aligned ? 1u : 0u);
}
if (mul_mat_id && spec.size() == 6) {
spec.push_back(32);
}
return spec;
};
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(l_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(m_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(s_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(l_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), l_align, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(m_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), m_align, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(s_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), s_align, true); \
// Create 2 variants, {f16,f32} accumulator
#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
@@ -4161,17 +4229,17 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, true); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, true); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, true); \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, true); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, true); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, true); \
// Create 2 variants, {f16,f32} accumulator
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
@@ -4284,32 +4352,32 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
// Selects dot2 SPIR-V variant at runtime when device->dot2_f16 is true
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
// bf16 scalar path promotes to f32, no dot2 variant
#define CREATE_MM_NODOT2(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l_int[TYPE]) { \
@@ -4474,17 +4542,17 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l_int[TYPE]) \
@@ -4879,6 +4947,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_NVFP4], "get_rows_nvfp4_f32", get_rows_nvfp4_f32_len, get_rows_nvfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_back_f32, "get_rows_back_f32", get_rows_back_f32_len, get_rows_back_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {256, 1, 1}, {}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
@@ -4903,7 +4972,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
}
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_nc_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
@@ -5023,11 +5092,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -5037,8 +5101,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
ggml_vk_create_pipeline(device, device->pipeline_diag[0], "diag_f32", diag_f32_len, diag_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_diag[1], "diag_f16", diag_f16_len, diag_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -5058,6 +5120,12 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
CREATE_UNARY(gelu_quick)
CREATE_UNARY(silu)
CREATE_UNARY(relu)
CREATE_UNARY(sqr)
CREATE_UNARY(sqrt)
CREATE_UNARY(sin)
CREATE_UNARY(cos)
CREATE_UNARY(clamp)
CREATE_UNARY(leaky_relu)
CREATE_UNARY(xielu)
CREATE_UNARY(neg)
CREATE_UNARY(tanh)
@@ -5097,7 +5165,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
CREATE_GLU(geglu_quick)
#undef CREATE_GLU
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
@@ -5314,7 +5381,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
// conv2d, conv_transpose_2d
// conv2d, conv_transpose_2d, conv3d
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
// smaller WG for the small-tile fallback gives more concurrent WGs per SM
uint32_t conv2d_WG_SIZE = (s == CONV_SHAPE_64x32) ? 128 : 256;
@@ -5377,8 +5444,8 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
return (conv2d_BS.K * (conv2d_BS.CRS + pad) + conv2d_BS.CRS * (conv2d_BS.NPQ + pad) + csh_elems) * elem_size;
};
// coopmat1 needs to store the output through shared memory, so check up front
// whether it'll fit and disable it before applying coopmat1 parameters.
// 2D, transpose-2D, and 3D conv use the same KxCRS @ CRSxNPQ shmem
// layout. cm1 needs Csh for output, so check before applying cm1 params.
if (conv2d_use_cm1 && device->properties.limits.maxComputeSharedMemorySize < shmem_req(conv2d_cm1_shmem_pad, true, true)) {
conv2d_use_cm1 = false;
}
@@ -5470,6 +5537,53 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
}
#undef CREATE_CONV
#undef CREATE_CONVS
std::vector<uint32_t> conv3d_spec_constants = { conv2d_WG_SIZE, conv2d_BS.K, conv2d_BS.CRS, conv2d_BS.NPQ, conv2d_TS_K, conv2d_SHMEM_PAD };
#define CREATE_CONV3D(type_suffix, spv_suffix) \
for (auto &c : device->pipeline_conv3d##type_suffix[s]) { \
const vk_conv3d_pipeline_state &state = c.first; \
std::vector<uint32_t> spec_constants_cpy = conv3d_spec_constants; \
spec_constants_cpy.push_back(state.s0); \
spec_constants_cpy.push_back(state.s1); \
spec_constants_cpy.push_back(state.s2); \
spec_constants_cpy.push_back(state.p0); \
spec_constants_cpy.push_back(state.p1); \
spec_constants_cpy.push_back(state.p2); \
spec_constants_cpy.push_back(state.d0); \
spec_constants_cpy.push_back(state.d1); \
spec_constants_cpy.push_back(state.d2); \
spec_constants_cpy.push_back(state.KW); \
spec_constants_cpy.push_back(state.KH); \
spec_constants_cpy.push_back(state.KD); \
spec_constants_cpy.push_back(state.aligned); \
spec_constants_cpy.push_back(conv2d_csh_store); \
spec_constants_cpy.push_back(conv2d_WM); \
spec_constants_cpy.push_back(conv2d_WN); \
ggml_vk_create_pipeline( \
device, c.second, "conv3d" #type_suffix, \
conv3d##type_suffix##spv_suffix##_len, conv3d##type_suffix##spv_suffix##_data, "main", 3, \
sizeof(vk_op_conv3d_push_constants), wg_denoms, spec_constants_cpy, 1, true, conv2d_required_subgroup_size != 0, conv2d_required_subgroup_size); \
}
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (device->coopmat2) {
CREATE_CONV3D(_f32, _cm2)
CREATE_CONV3D(_f16_f32, _cm2)
} else
#endif
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (conv2d_use_cm1) {
CREATE_CONV3D(_f32, _cm1)
CREATE_CONV3D(_f16_f32, _cm1)
} else
#endif
if (conv2d_UNROLL) {
CREATE_CONV3D(_f32, _unroll)
CREATE_CONV3D(_f16_f32, _unroll)
} else {
CREATE_CONV3D(_f32, )
CREATE_CONV3D(_f16_f32, )
}
#undef CREATE_CONV3D
}
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
@@ -10294,6 +10408,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_get_rows_f32[src0->type];
}
return nullptr;
case GGML_OP_GET_ROWS_BACK:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_get_rows_back_f32;
}
return nullptr;
case GGML_OP_ACC:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_acc_f32;
@@ -10400,23 +10519,27 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
}
return nullptr;
case GGML_OP_SQR:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_sqr_f32;
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_sqr[dst->type == GGML_TYPE_F16];
}
return nullptr;
case GGML_OP_SQRT:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_sqrt_f32;
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_sqrt[dst->type == GGML_TYPE_F16];
}
return nullptr;
case GGML_OP_SIN:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_sin_f32;
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_sin[dst->type == GGML_TYPE_F16];
}
return nullptr;
case GGML_OP_COS:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_cos_f32;
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_cos[dst->type == GGML_TYPE_F16];
}
return nullptr;
case GGML_OP_LOG:
@@ -10438,8 +10561,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
}
return nullptr;
case GGML_OP_CLAMP:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_clamp_f32;
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_clamp[dst->type == GGML_TYPE_F16];
}
return nullptr;
case GGML_OP_PAD:
@@ -10807,8 +10931,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
}
return nullptr;
case GGML_OP_LEAKY_RELU:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_leaky_relu_f32;
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_leaky_relu[dst->type == GGML_TYPE_F16];
}
return nullptr;
case GGML_OP_CONV_2D:
@@ -10885,6 +11010,61 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
}
}
return nullptr;
case GGML_OP_CONV_3D:
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
const uint32_t OC = (uint32_t)ggml_get_op_params_i32(dst, 11);
const uint32_t IC = (uint32_t)ggml_get_op_params_i32(dst, 9);
const uint32_t N = (uint32_t)ggml_get_op_params_i32(dst, 10);
const uint32_t NPQ = N * dst->ne[2] * dst->ne[1] * dst->ne[0];
const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, OC, NPQ);
const uint32_t KW = (uint32_t)src0->ne[0];
const uint32_t KH = (uint32_t)src0->ne[1];
const uint32_t KD = (uint32_t)src0->ne[2];
const uint32_t s0 = (uint32_t)ggml_get_op_params_i32(dst, 0);
const uint32_t s1 = (uint32_t)ggml_get_op_params_i32(dst, 1);
const uint32_t s2 = (uint32_t)ggml_get_op_params_i32(dst, 2);
const uint32_t p0 = (uint32_t)ggml_get_op_params_i32(dst, 3);
const uint32_t p1 = (uint32_t)ggml_get_op_params_i32(dst, 4);
const uint32_t p2 = (uint32_t)ggml_get_op_params_i32(dst, 5);
const uint32_t d0 = (uint32_t)ggml_get_op_params_i32(dst, 6);
const uint32_t d1 = (uint32_t)ggml_get_op_params_i32(dst, 7);
const uint32_t d2 = (uint32_t)ggml_get_op_params_i32(dst, 8);
const uint32_t CRS = IC * KW * KH * KD;
const uint32_t BS_K = vk_conv_block_sizes[shape].K;
const uint32_t BS_CRS = vk_conv_block_sizes[shape].CRS;
const uint32_t BS_NPQ = vk_conv_block_sizes[shape].NPQ;
const uint32_t aligned = ((OC % BS_K == 0) &&
(CRS % BS_CRS == 0) &&
(NPQ % BS_NPQ == 0)) ? 1u : 0u;
vk_conv3d_pipeline_state conv3d_pipeline_state(s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD, aligned);
std::map<vk_conv3d_pipeline_state, vk_pipeline> *pipelines = nullptr;
if (src0->type == GGML_TYPE_F32) {
pipelines = &ctx->device->pipeline_conv3d_f32[shape];
} else if (src0->type == GGML_TYPE_F16) {
pipelines = &ctx->device->pipeline_conv3d_f16_f32[shape];
} else {
return nullptr;
}
vk_pipeline pipeline = nullptr;
{
std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
auto it = pipelines->find(conv3d_pipeline_state);
if (it != pipelines->end()) {
pipeline = it->second;
} else {
(*pipelines)[conv3d_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
}
}
return pipeline;
}
return nullptr;
case GGML_OP_ADD1:
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_add1_f16_f16;
@@ -11135,6 +11315,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
break;
case GGML_OP_GET_ROWS_BACK:
elements = { (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], 1 };
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
break;
case GGML_OP_ARGSORT:
GGML_ASSERT(0);
break;
@@ -11220,6 +11404,21 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
GGML_ABORT("invalid push constant type for CONV_2D");
}
break;
case GGML_OP_CONV_3D:
if constexpr (std::is_same_v<PC, vk_op_conv3d_push_constants>) {
const uint32_t NPQ = pc.N * pc.OD * pc.OH * pc.OW;
const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, pc.OC, NPQ);
const uint32_t NPQ_blocks = CEIL_DIV(NPQ, vk_conv_block_sizes[shape].NPQ);
elements = { pc.OC, NPQ_blocks, 1 };
if (elements[1] > 512) {
elements[2] = CEIL_DIV(elements[1], 512);
elements[1] = 512;
}
} else {
GGML_ABORT("invalid push constant type for CONV_3D");
}
break;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_DIV:
@@ -11236,6 +11435,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_TRI:
case GGML_OP_DIAG:
case GGML_OP_CLAMP:
case GGML_OP_LEAKY_RELU:
case GGML_OP_PAD:
case GGML_OP_ROLL:
case GGML_OP_REPEAT:
@@ -11380,6 +11580,21 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
});
}
static void ggml_vk_get_rows_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GET_ROWS_BACK, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2], (uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
0.0f, 0.0f, 0,
});
}
static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
@@ -12087,8 +12302,10 @@ static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx,
static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
float * op_params = (float *)dst->op_params;
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
p.param1 = op_params[0];
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, std::move(p));
}
static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@@ -13118,6 +13335,51 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx,
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, std::move(p));
}
static void ggml_vk_conv_3d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
GGML_ASSERT(nb10 == sizeof(float));
GGML_ASSERT(nb0 == sizeof(float));
vk_op_conv3d_push_constants p{};
p.IC = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 9));
p.N = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 10));
p.OC = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 11));
GGML_ASSERT(src0->ne[3] == (int64_t)p.IC * p.OC);
GGML_ASSERT(src1->ne[3] == (int64_t)p.IC * p.N);
GGML_ASSERT(dst->ne[3] == (int64_t)p.OC * p.N);
p.IW = static_cast<uint32_t>(ne10);
p.IH = static_cast<uint32_t>(ne11);
p.ID = static_cast<uint32_t>(ne12);
p.OW = static_cast<uint32_t>(ne0);
p.OH = static_cast<uint32_t>(ne1);
p.OD = static_cast<uint32_t>(ne2);
// the shader clamps src addresses to p.IC * p.N * p.IW * p.IH * p.ID - 1 in uint32, so the
// total input element count must fit in a uint32.
GGML_ASSERT((uint64_t)p.IC * p.N * p.IW * p.IH * p.ID <= 0xFFFFFFFFull);
p.nb01 = static_cast<uint32_t>(nb01 / nb00);
p.nb02 = static_cast<uint32_t>(nb02 / nb00);
p.nb03 = static_cast<uint32_t>(nb03 / nb00);
p.nb11 = static_cast<uint32_t>(nb11 / nb10);
p.nb12 = static_cast<uint32_t>(nb12 / nb10);
p.nb13 = static_cast<uint32_t>(nb13 / nb10);
p.nb1 = static_cast<uint32_t>(nb1 / nb0);
p.nb2 = static_cast<uint32_t>(nb2 / nb0);
p.nb3 = static_cast<uint32_t>(nb3 / nb0);
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_3D, std::move(p));
}
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
vk_op_conv2d_dw_push_constants p{};
p.ne = ggml_nelements(dst);
@@ -13144,7 +13406,10 @@ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx
static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
const float * op_params = (const float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f, 0.0f, 0.0f });
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
p.param1 = op_params[0];
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, std::move(p));
}
#ifdef GGML_VULKAN_RUN_TESTS
@@ -14247,6 +14512,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_GET_ROWS:
ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_GET_ROWS_BACK:
ggml_vk_get_rows_back(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_ADD:
if (ctx->num_additional_fused_ops) {
@@ -14515,6 +14784,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_CONV_TRANSPOSE_2D:
ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_CONV_3D:
ggml_vk_conv_3d(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_CONV_2D_DW:
ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node);
@@ -16964,6 +17237,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return false;
}
}
case GGML_OP_GET_ROWS_BACK:
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SET_ROWS:
{
switch (op->type) {
@@ -17060,12 +17335,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_TRANSPOSE:
case GGML_OP_RMS_NORM:
return true;
case GGML_OP_NORM:
case GGML_OP_GROUP_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_NORM:
case GGML_OP_L2_NORM:
return ggml_is_contiguous_rows(op->src[0]) &&
op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_MUL:
@@ -17084,8 +17358,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_CLAMP:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_LEAKY_RELU:
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
op->type == op->src[0]->type;
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
@@ -17285,6 +17560,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
ggml_is_contiguous(op->src[1]) &&
ggml_is_contiguous(op));
}
case GGML_OP_CONV_3D:
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
op->src[1]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32 &&
ggml_is_contiguous(op->src[0]) &&
ggml_is_contiguous(op->src[1]) &&
ggml_is_contiguous(op);
default:
return false;
}
@@ -18128,6 +18410,20 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
const int32_t d0 = tensor->op_params[4];
const int32_t d1 = tensor->op_params[5];
tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
} else if (tensor->op == GGML_OP_CONV_3D) {
const int32_t s0 = tensor->op_params[0];
const int32_t s1 = tensor->op_params[1];
const int32_t s2 = tensor->op_params[2];
const int32_t p0 = tensor->op_params[3];
const int32_t p1 = tensor->op_params[4];
const int32_t p2 = tensor->op_params[5];
const int32_t d0 = tensor->op_params[6];
const int32_t d1 = tensor->op_params[7];
const int32_t d2 = tensor->op_params[8];
const int32_t IC = tensor->op_params[9];
const int32_t N = tensor->op_params[10];
const int32_t OC = tensor->op_params[11];
tensor_clone = ggml_conv_3d_direct(ggml_ctx, src_clone[0], src_clone[1], s0, s1, s2, p0, p1, p2, d0, d1, d2, IC, N, OC);
} else if (tensor->op == GGML_OP_CONV_2D_DW) {
const int32_t s0 = tensor->op_params[0];
const int32_t s1 = tensor->op_params[1];
@@ -1,17 +0,0 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
}
@@ -0,0 +1,431 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#ifdef COOPMAT2
#extension GL_NV_cooperative_matrix2 : enable
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_KHR_memory_scope_semantics : enable
#endif
#ifdef COOPMAT
#extension GL_KHR_cooperative_matrix : enable
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_KHR_memory_scope_semantics : enable
#endif
#include "types.glsl"
// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
layout(binding = 0) readonly buffer A {
A_TYPE knl_data[];
}; // src0 - kernel: [KW, KH, KD, IC*OC]
layout(binding = 1) readonly buffer B {
B_TYPE src_data[];
}; // src1 - input: [IW, IH, ID, IC*N] -- channel_first format
layout(binding = 2) writeonly buffer D {
D_TYPE dst_data[];
}; // dst - result: [OW, OH, OD, OC*N]
layout(push_constant) uniform parameter {
// I/O channels, batch size
uint32_t OC;
uint32_t IC;
uint32_t N;
// Tensor spatial sizes: input, output
uint32_t IW;
uint32_t IH;
uint32_t ID;
uint32_t OW;
uint32_t OH;
uint32_t OD;
// Strides in elements
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb1;
uint32_t nb2;
uint32_t nb3;
// fastdiv helper values
uint32_t OWmp; uint32_t OWL;
uint32_t OWOHmp; uint32_t OWOHL;
uint32_t OWOHODmp; uint32_t OWOHODL;
}
p;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
// Blocktile sizes
layout(constant_id = 1) const uint BS_K = 128;
layout(constant_id = 2) const uint BS_CRS = 16;
layout(constant_id = 3) const uint BS_NPQ = 128;
// Thread-tile sizes
layout(constant_id = 4) const uint TS_K = 8;
layout(constant_id = 5) const uint SHMEM_PAD = 4;
// Stride, padding, dilation
layout(constant_id = 6) const uint s0 = 1;
layout(constant_id = 7) const uint s1 = 1;
layout(constant_id = 8) const uint s2 = 1;
layout(constant_id = 9) const uint p0 = 0;
layout(constant_id = 10) const uint p1 = 0;
layout(constant_id = 11) const uint p2 = 0;
layout(constant_id = 12) const uint d0 = 1;
layout(constant_id = 13) const uint d1 = 1;
layout(constant_id = 14) const uint d2 = 1;
// Kernel spatial sizes
layout(constant_id = 15) const uint KW = 1;
layout(constant_id = 16) const uint KH = 1;
layout(constant_id = 17) const uint KD = 1;
// when set, skip bounds checks and address clamps (K/CRS/NPQ are tile-aligned)
layout(constant_id = 18) const uint aligned = 0;
// stage cm2 result through shmem (Csh) for coalesced stores. cm1 always does this.
layout(constant_id = 19) const uint csh_store = 0;
#ifdef COOPMAT
// cm1 subgroup tile: each subgroup computes a WM x WN region as a grid of
// TM x TN x TK fragments. Requires WM%TM == WN%TN == BS_K%WM == BS_NPQ%WN ==
// BS_CRS%TK == 0, and WG_SIZE == (BS_K/WM) * (BS_NPQ/WN) * subgroup_size.
layout(constant_id = 20) const uint WM = 32;
layout(constant_id = 21) const uint WN = 32;
const uint TM = 16;
const uint TN = 16;
const uint TK = 16;
const uint cms_per_row = WM / TM;
const uint cms_per_col = WN / TN;
const uint warps_M = BS_K / WM;
const uint warps_N = BS_NPQ / WN;
#endif
// without padding, ID_idx/IH_idx/IW_idx are in bounds by construction
const bool dhw_in_bounds = (p0 == 0) && (p1 == 0) && (p2 == 0);
uint32_t tid = gl_LocalInvocationID.x;
const uint32_t WG_SIZE = gl_WorkGroupSize.x;
uint splitWork(uint work_size, uint block_size) {
return (block_size + work_size - 1) / block_size;
}
uint32_t K = p.OC;
uint32_t CRS = p.IC * KD * KH * KW;
uint32_t NPQ = p.N * p.OD * p.OH * p.OW;
// Number of blocktiles per input
uint32_t NB_CRS = splitWork(CRS, BS_CRS);
#if defined(COOPMAT2) || defined(COOPMAT)
#define SHMEM_TYPE float16_t
#else
#define SHMEM_TYPE float
#endif
const uint32_t Ash_stride = BS_CRS + SHMEM_PAD;
const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD;
const uint32_t Ash_len = BS_K * Ash_stride;
const uint32_t Bsh_len = BS_CRS * Bsh_stride;
shared SHMEM_TYPE Ash[Ash_len]; // K x CRS
shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ
#if defined(COOPMAT2) || defined(COOPMAT)
// stage matC through shmem so global stores are row-major (NPQ-contiguous)
const uint32_t Csh_stride = BS_NPQ;
#ifdef COOPMAT
const uint32_t Csh_len = BS_K * Csh_stride;
#else
const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 1;
#endif
shared SHMEM_TYPE Csh[Csh_len]; // K x NPQ
#endif
// Threadtile sizes
const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
// Number of threadtiles per blocktile
const uint32_t NT_NPQ = BS_NPQ / TS_NPQ;
/*
Compute
KxCRS @ CRSxNPQ = K x NPQ
K=OC
C=IC
D,R,S=KD,KH,KW
Z,P,Q=OD,OH,OW
*/
uint32_t B_idx_K = gl_WorkGroupID.x;
uint32_t B_idx_NPQ = gl_WorkGroupID.y + gl_WorkGroupID.z * 512;
uint32_t T_y = tid / NT_NPQ;
uint32_t T_x = tid % NT_NPQ;
uint32_t Ar = tid / BS_CRS;
uint32_t Ac = tid % BS_CRS;
const uint32_t ArpWg = WG_SIZE / BS_CRS;
uint32_t Br = tid / BS_NPQ;
uint32_t Bc = tid % BS_NPQ;
const uint32_t BrpWg = WG_SIZE / BS_NPQ;
// see init_fastdiv_values in ggml-vulkan.cpp
uint fastdiv(uint n, uint mp, uint L) {
uint msbs, lsbs;
// msbs = mulhi(n, mp)
umulExtended(n, mp, msbs, lsbs);
return (msbs + n) >> L;
}
void split_crs(uint32_t crs_idx, out uint32_t ic, out uint32_t kd, out uint32_t kh, out uint32_t kw) {
const uint32_t KHKW = KH * KW;
const uint32_t KDKHKW = KD * KHKW;
ic = crs_idx / KDKHKW;
uint32_t rem = crs_idx - ic * KDKHKW;
kd = rem / KHKW;
rem = rem - kd * KHKW;
kh = rem / KW;
kw = rem - kh * KW;
}
void split_npq(uint32_t npq_idx, out uint32_t n, out uint32_t od, out uint32_t oh, out uint32_t ow) {
const uint32_t OWOH = p.OW * p.OH;
n = fastdiv(npq_idx, p.OWOHODmp, p.OWOHODL);
uint32_t rem = npq_idx - n * p.OD * OWOH;
od = fastdiv(rem, p.OWOHmp, p.OWOHL);
rem = rem - od * OWOH;
oh = fastdiv(rem, p.OWmp, p.OWL);
ow = rem - oh * p.OW;
}
#ifdef COOPMAT2
#define ACC_TYPE float16_t
ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem)
{
uint32_t K_idx = B_idx_K * BS_K + r;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c;
uint32_t N_idx;
uint32_t OD_idx;
uint32_t OH_idx;
uint32_t OW_idx;
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
dst_data[dst_idx] = D_TYPE(elem);
}
return elem;
}
#endif
void main() {
if (B_idx_NPQ * BS_NPQ >= NPQ) {
return;
}
#ifdef COOPMAT2
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator> matC;
matC = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator>(0.0);
#elif defined(COOPMAT)
coopmat<float16_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
sums[i] = coopmat<float16_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0);
}
const uint warp_r = gl_SubgroupID / warps_N;
const uint warp_c = gl_SubgroupID % warps_N;
#else
float regC[TS_K][TS_NPQ];
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regC[T_ly][T_lx] = 0.0;
}
}
#endif
/* Advance block in CRS dim */
[[dont_unroll]] for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
uint32_t CRS_idx_a = B_idx_CRS * BS_CRS + Ac;
uint32_t IC_idx_a;
uint32_t KD_idx_a;
uint32_t KH_idx_a;
uint32_t KW_idx_a;
split_crs(CRS_idx_a, IC_idx_a, KD_idx_a, KH_idx_a, KW_idx_a);
/* Load kernel to A_block: (BS_K x BS_CRS)*/
UNROLL for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
uint32_t B_ly = r_offset + Ar;
uint32_t B_lx = Ac;
uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
uint32_t knl_idx = KW_idx_a + KH_idx_a * p.nb01 + KD_idx_a * p.nb02 + (K_idx * p.IC + IC_idx_a) * p.nb03;
if (aligned == 0) {
knl_idx = min(knl_idx, K * CRS - 1);
}
float val = knl_data[knl_idx];
if (aligned == 0 && (K_idx >= K || CRS_idx_a >= CRS)) {
val = 0.0;
}
Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val);
}
/* Load input to B_block: (BS_CRS x BS_NPQ) */
UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
uint32_t B_ly = r_offset + Br; /* Row index of B block */
uint32_t B_lx = Bc;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
uint32_t N_idx;
uint32_t OD_idx;
uint32_t OH_idx;
uint32_t OW_idx;
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
uint32_t CRS_idx_b = B_idx_CRS * BS_CRS + B_ly;
uint32_t IC_idx_b;
uint32_t KD_idx_b;
uint32_t KH_idx_b;
uint32_t KW_idx_b;
split_crs(CRS_idx_b, IC_idx_b, KD_idx_b, KH_idx_b, KW_idx_b);
uint32_t ID_idx = OD_idx * s2 + KD_idx_b * d2 - p2;
uint32_t IH_idx = OH_idx * s1 + KH_idx_b * d1 - p1;
uint32_t IW_idx = OW_idx * s0 + KW_idx_b * d0 - p0;
uint32_t src_idx = IW_idx + IH_idx * p.nb11 + ID_idx * p.nb12 + (N_idx * p.IC + IC_idx_b) * p.nb13;
// skip clamp when address can't go OOB
if (aligned == 0 || !dhw_in_bounds) {
src_idx = min(src_idx, p.IC * p.N * p.IW * p.IH * p.ID - 1);
}
float val = src_data[src_idx];
bool oob = false;
if (aligned == 0 && (CRS_idx_b >= CRS || NPQ_idx >= NPQ)) {
oob = true;
}
// also catches lower-bound underflow (idx wraps to 0x80000000+)
if (!dhw_in_bounds && (ID_idx >= p.ID || IH_idx >= p.IH || IW_idx >= p.IW)) {
oob = true;
}
if (oob) {
val = 0.0;
}
Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val);
}
barrier();
#ifdef COOPMAT2
coopmat<float16_t, gl_ScopeWorkgroup, BS_K, BS_CRS, gl_MatrixUseA> matA;
coopmat<float16_t, gl_ScopeWorkgroup, BS_CRS, BS_NPQ, gl_MatrixUseB> matB;
coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);
coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);
matC = coopMatMulAdd(matA, matB, matC);
#elif defined(COOPMAT)
// each subgroup multiplies its grid of fragments per TK-sized CRS chunk
[[unroll]] for (uint k_step = 0; k_step < BS_CRS / TK; k_step++) {
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a[cms_per_row];
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
const uint a_off = (warp_r * WM + cm_row * TM) * Ash_stride + k_step * TK;
coopMatLoad(cache_a[cm_row], Ash, a_off, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);
}
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
const uint b_off = k_step * TK * Bsh_stride + warp_c * WN + cm_col * TN;
coopMatLoad(cache_b, Bsh, b_off, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a[cm_row], cache_b, sums[cm_col * cms_per_row + cm_row]);
}
}
}
#else
if (T_y * TS_K < K) {
UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
float regA[TS_K];
float regB[TS_NPQ];
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
}
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
}
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
}
}
}
}
#endif
barrier();
}
/* Save C* */
#if defined(COOPMAT2) || defined(COOPMAT)
// stage matC into Csh, then write to dst with coalesced NPQ-contiguous stores
#ifdef COOPMAT
const bool use_staged_store = true;
#else
const bool use_staged_store = (csh_store != 0);
#endif
if (use_staged_store) {
#ifdef COOPMAT
// cm1: each subgroup stores its fragment grid into its Csh slot
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
const uint csh_off = (warp_r * WM + cm_row * TM) * Csh_stride + warp_c * WN + cm_col * TN;
coopMatStore(sums[cm_col * cms_per_row + cm_row], Csh, csh_off, Csh_stride, gl_CooperativeMatrixLayoutRowMajor);
}
}
#else
coopMatStore(matC, Csh, 0, Csh_stride, gl_CooperativeMatrixLayoutRowMajor);
#endif
barrier();
// cooperative shmem->global: WG threads spread across BS_NPQ (the
// contiguous direction of dst), each iter covers store_rows_per_iter K-rows
const uint32_t store_rows_per_iter = WG_SIZE / BS_NPQ;
const uint32_t store_iters = BS_K / store_rows_per_iter;
const uint32_t k_thread_offset = tid / BS_NPQ;
const uint32_t npq_thread = tid % BS_NPQ;
[[unroll]] for (uint32_t i = 0; i < store_iters; i++) {
uint32_t k_local = i * store_rows_per_iter + k_thread_offset;
uint32_t K_idx = B_idx_K * BS_K + k_local;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + npq_thread;
uint32_t N_idx;
uint32_t OD_idx;
uint32_t OH_idx;
uint32_t OW_idx;
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
dst_data[dst_idx] = D_TYPE(Csh[k_local * Csh_stride + npq_thread]);
}
}
}
#ifdef COOPMAT2
else {
coopMatPerElementNV(matC, matC, perElemOpStore);
}
#endif
#else
if (T_y * TS_K < K) {
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
uint32_t N_idx;
uint32_t OD_idx;
uint32_t OH_idx;
uint32_t OW_idx;
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
dst_data[dst_idx] = D_TYPE(regC[T_ly][T_lx]);
}
}
}
}
#endif
}
@@ -1,17 +0,0 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(cos(val));
}
@@ -463,6 +463,7 @@ void main() {
}
rowmaxf = max(rowmaxf, float(Sf[r][c]));
}
rowmaxf += FATTN_KQ_MAX_OFFSET;
float Moldf = Mf[r];
// M = max(rowmax, Mold)
@@ -352,6 +352,7 @@ void main() {
}
rowmaxf = max(rowmaxf, float(sfsh[r_vec + (c * cols_per_iter + col_tid) * sfshstride][r_comp]));
}
rowmaxf += FATTN_KQ_MAX_OFFSET;
float Moldf = Mf[r];
// Compute max across the row
@@ -0,0 +1,25 @@
#version 450
#include "types.glsl"
#include "generic_binary_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint col = gl_GlobalInvocationID.x;
if (col >= p.ne20) {
return;
}
for (uint row = gl_GlobalInvocationID.y; row < p.ne21; row += gl_WorkGroupSize.y * gl_NumWorkGroups.y) {
float sum = 0.0f;
for (uint i = 0; i < p.ne10; ++i) {
if (data_b[get_boffset() + i*p.nb10] == int(row)) {
sum += data_a[get_aoffset() + i*p.nb01 + col*p.nb00];
}
}
data_d[get_doffset() + row*p.nb21 + col*p.nb20] = sum;
}
}
@@ -14,16 +14,13 @@ void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
const uint i3 = row / (p.ne11 * p.ne12);
const uint i3_offset = i3 * p.ne12 * p.ne11;
const uint i2 = (row - i3_offset) / p.ne11;
const uint i2_offset = i2 * p.ne11;
const uint i1 = row - i3_offset - i2_offset;
const uint a_base = get_aoffset() + src0_idx(row * p.ne00);
const uint d_base = get_doffset() + dst_idx(row * p.ne10);
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]);
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_base + i0*p.nb00]);
sum[tid] += xi * xi;
}
@@ -39,6 +36,6 @@ void main() {
const FLOAT_TYPE scale = 1.0f / max(sqrt(sum[0]), FLOAT_TYPE(p.param1));
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
data_d[i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0] = D_TYPE(scale * FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]));
data_d[d_base + i0*p.nb10] = D_TYPE(scale * FLOAT_TYPE(data_a[a_base + i0*p.nb00]));
}
}
@@ -1,22 +0,0 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float val = float(data_a[i]);
data_d[i] = D_TYPE(max(val, 0.0f) + min(val, 0.0f) * p.param1);
}
+31 -23
View File
@@ -38,17 +38,7 @@
#define LOAD_VEC_B 1
#endif
// Load 2 values at once without affecting index calculations through LOAD_VEC
#if (defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)) && !defined(ALIGNED)
#define LOAD_VEC_BATCH_A 2
#else
#define LOAD_VEC_BATCH_A 1
#endif
#if !defined(ALIGNED)
#define LOAD_VEC_BATCH_B 2
#else
#define LOAD_VEC_BATCH_B 1
#endif
layout (constant_id = 11) const uint ALIGNED = 0;
#if !defined(TO_FLOAT_TYPE)
#define TO_FLOAT_TYPE FLOAT_TYPE
@@ -57,6 +47,13 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(DATA_A_F32)
layout (binding = 0) readonly buffer A_SCALAR {float data_a_scalar[];};
#elif defined(DATA_A_F16)
layout (binding = 0) readonly buffer A_SCALAR {float16_t data_a_scalar[];};
#elif defined(DATA_A_BF16)
layout (binding = 0) readonly buffer A_SCALAR {uint16_t data_a_scalar[];};
#endif
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
@@ -65,6 +62,7 @@ layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32
#endif
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 1) readonly buffer B_SCALAR {B_TYPE_SCALAR data_b_scalar[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID
@@ -194,13 +192,23 @@ void main() {
const uint warp_r = warp_i % (BM / WM);
const uint warp_c = warp_i / (BM / WM);
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
const uint LOAD_VEC_A_EFF = (ALIGNED != 0) ? LOAD_VEC_A : 1;
const uint LOAD_VEC_BATCH_A = (ALIGNED != 0) ? 1 : 2;
#else
const uint LOAD_VEC_A_EFF = LOAD_VEC_A;
const uint LOAD_VEC_BATCH_A = 1;
#endif
const uint LOAD_VEC_B_EFF = (ALIGNED != 0) ? LOAD_VEC_B : 1;
const uint LOAD_VEC_BATCH_B = (ALIGNED != 0) ? 1 : 2;
const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A * LOAD_VEC_BATCH_A / BK;
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B * LOAD_VEC_BATCH_B / BK;
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A_EFF / LOAD_VEC_BATCH_A);
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A_EFF / LOAD_VEC_BATCH_A);
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B_EFF / LOAD_VEC_BATCH_B);
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B_EFF / LOAD_VEC_BATCH_B);
const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A_EFF * LOAD_VEC_BATCH_A / BK;
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B_EFF * LOAD_VEC_BATCH_B / BK;
#ifdef MUL_MAT_ID
#ifdef MUL_MAT_ID_USE_SUBGROUPS
@@ -239,15 +247,15 @@ void main() {
uint pos_a =
#ifdef MUL_MAT_ID
expert_idx * (p.batch_stride_a / LOAD_VEC_A) +
expert_idx * (p.batch_stride_a / LOAD_VEC_A_EFF) +
#else
batch_idx_a * (p.batch_stride_a / LOAD_VEC_A) +
batch_idx_a * (p.batch_stride_a / LOAD_VEC_A_EFF) +
#endif
(ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
(ir * BM * p.stride_a + start_k) / LOAD_VEC_A_EFF;
#ifdef MUL_MAT_ID
uint pos_b = 0;
#else
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B_EFF;
#endif
#ifdef COOPMAT
@@ -287,8 +295,8 @@ void main() {
barrier();
pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B;
pos_a += BK / LOAD_VEC_A_EFF;
pos_b += BK / LOAD_VEC_B_EFF;
#ifdef COOPMAT
[[unroll]] for (uint i = 0; i < BK; i += TK) {
@@ -36,6 +36,7 @@ layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working wit
layout (constant_id = 4) const bool enable_smaller_matrices = false;
const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN;
const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN;
layout (constant_id = 5) const uint ALIGNED = 0;
layout (push_constant) uniform parameter
{
@@ -111,7 +112,7 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
};
uint _ne1;
layout (constant_id = 5) const uint subgroup_size = 32;
layout (constant_id = 6) const uint subgroup_size = 32;
shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size];
B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
@@ -297,12 +298,12 @@ void main() {
// Hint to the compiler that values are aligned (want 16B alignment).
// Quants are always block-aligned, no alignment needed.
#if ALIGNED
if (ALIGNED != 0) {
#if QUANT_K == 1
stride_a &= ~7;
#endif
stride_b &= ~7;
stride_a &= ~7;
#endif
stride_b &= ~7;
}
// Create layouts for both clamped and unclamped accesses
tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2);
@@ -1,50 +1,57 @@
void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uint idx_m, const uint block, const uint end_k) {
#if defined(DATA_A_F32) || defined(DATA_A_F16)
#if LOAD_VEC_A == 8
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]);
buf_a[buf_idx ] = aa[0].xy;
buf_a[buf_idx + 1] = aa[0].zw;
buf_a[buf_idx + 2] = aa[1].xy;
buf_a[buf_idx + 3] = aa[1].zw;
if (ALIGNED != 0) {
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]);
buf_a[buf_idx ] = aa[0].xy;
buf_a[buf_idx + 1] = aa[0].zw;
buf_a[buf_idx + 2] = aa[1].xy;
buf_a[buf_idx + 3] = aa[1].zw;
return;
}
#elif LOAD_VEC_A == 4
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]);
buf_a[buf_idx ] = aa.xy;
buf_a[buf_idx + 1] = aa.zw;
#else // LOAD_VEC_BATCH_A == 2
if (ALIGNED != 0) {
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]);
buf_a[buf_idx ] = aa.xy;
buf_a[buf_idx + 1] = aa.zw;
return;
}
#endif
const uint idx = pos_a + col * p.stride_a + row * 2;
const uint buf_idx = col * SHMEM_STRIDE + row;
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx],
data_a[idx + 1]);
buf_a[buf_idx] = FLOAT_TYPEV2(data_a_scalar[idx],
data_a_scalar[idx + 1]);
} else if (idx_m < p.M && block + row * 2 < end_k) {
buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx], 0.0f);
buf_a[buf_idx] = FLOAT_TYPEV2(data_a_scalar[idx], 0.0f);
} else {
buf_a[buf_idx] = FLOAT_TYPEV2(0.0f);
}
#endif
#elif defined(DATA_A_BF16)
#if LOAD_VEC_A == 4
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx]));
buf_a[buf_idx ] = aa.xy;
buf_a[buf_idx + 1] = aa.zw;
#else // LOAD_VEC_BATCH_A == 2
if (ALIGNED != 0) {
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx]));
buf_a[buf_idx ] = aa.xy;
buf_a[buf_idx + 1] = aa.zw;
return;
}
#endif
const uint idx = pos_a + col * p.stride_a + row * 2;
const uint buf_idx = col * SHMEM_STRIDE + row;
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]),
TO_FLOAT_TYPE(data_a[idx + 1]));
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a_scalar[idx]),
TO_FLOAT_TYPE(data_a_scalar[idx + 1]));
} else if (idx_m < p.M && block + row * 2 < end_k) {
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]), 0.0f);
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a_scalar[idx]), 0.0f);
} else {
buf_a[buf_idx] = FLOAT_TYPEV2(0.0f);
}
#endif
#elif defined(DATA_A_Q4_0)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
@@ -526,75 +533,85 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
#if !defined(MUL_MAT_ID)
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint idx_n, const uint block, const uint end_k) {
#if LOAD_VEC_B == 8
// Not supported for b_type bf16 because bf16mat2x4 does not exist
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
buf_b[buf_idx + 0] = bb[0].xy;
buf_b[buf_idx + 1] = bb[0].zw;
buf_b[buf_idx + 2] = bb[1].xy;
buf_b[buf_idx + 3] = bb[1].zw;
if (ALIGNED != 0) {
// Not supported for b_type bf16 because bf16mat2x4 does not exist
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
buf_b[buf_idx + 0] = bb[0].xy;
buf_b[buf_idx + 1] = bb[0].zw;
buf_b[buf_idx + 2] = bb[1].xy;
buf_b[buf_idx + 3] = bb[1].zw;
return;
}
#elif LOAD_VEC_B == 4
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
if (ALIGNED != 0) {
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
#if defined(DATA_B_BF16)
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
#else
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
#endif
buf_b[buf_idx + 0] = bb.xy;
buf_b[buf_idx + 1] = bb.zw;
return;
}
#endif
buf_b[buf_idx + 0] = bb.xy;
buf_b[buf_idx + 1] = bb.zw;
#else // LOAD_VEC_BATCH_B == 2
const uint idx = pos_b + col * p.stride_b + row * 2;
const uint buf_idx = col * SHMEM_STRIDE + row;
if (idx_n < p.N && block + row * 2 + 1 < end_k) {
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]),
TO_FLOAT_TYPE(data_b[idx + 1]));
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]),
TO_FLOAT_TYPE(data_b_scalar[idx + 1]));
} else if (idx_n < p.N && block + row * 2 < end_k) {
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), 0.0f);
} else {
buf_b[buf_idx] = FLOAT_TYPEV2(0.0f);
}
#endif
}
#else
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint ic, const uint _ne1, const uint block, const uint end_k) {
#if LOAD_VEC_B == 8
// Not supported for b_type bf16 because bf16mat2x4 does not exist
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
buf_b[buf_idx + 0] = bb[0].xy;
buf_b[buf_idx + 1] = bb[0].zw;
buf_b[buf_idx + 2] = bb[1].xy;
buf_b[buf_idx + 3] = bb[1].zw;
if (ALIGNED != 0) {
// Not supported for b_type bf16 because bf16mat2x4 does not exist
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
buf_b[buf_idx + 0] = bb[0].xy;
buf_b[buf_idx + 1] = bb[0].zw;
buf_b[buf_idx + 2] = bb[1].xy;
buf_b[buf_idx + 3] = bb[1].zw;
return;
}
#elif LOAD_VEC_B == 4
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
if (ALIGNED != 0) {
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
#if defined(DATA_B_BF16)
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
#else
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
#endif
buf_b[buf_idx + 0] = bb.xy;
buf_b[buf_idx + 1] = bb.zw;
return;
}
#endif
buf_b[buf_idx + 0] = bb.xy;
buf_b[buf_idx + 1] = bb.zw;
#else // LOAD_VEC_BATCH_B == 2
const uint row_i = ic * BN + col;
const uint buf_idx = col * SHMEM_STRIDE + row;
if (row_i < _ne1 && block + row * 2 + 1 < end_k) {
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]),
TO_FLOAT_TYPE(data_b[idx + 1]));
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]),
TO_FLOAT_TYPE(data_b_scalar[idx + 1]));
} else if (row_i < _ne1 && block + row * 2 < end_k) {
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), 0.0f);
} else {
buf_b[buf_idx] = FLOAT_TYPEV2(0.0f);
}
#endif
}
#endif
+10 -10
View File
@@ -1,26 +1,26 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#include "generic_unary_head.glsl"
#extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 512
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
shared vec2 sum[BLOCK_SIZE];
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
const uint a_base = get_aoffset() + src0_idx(row * p.ne00);
const uint d_base = get_doffset() + dst_idx(row * p.ne10);
sum[tid] = vec2(0.0f, 0.0f);
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
const float xi = float(data_a[row*p.KX + col]);
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
const float xi = float(data_a[a_base + i0*p.nb00]);
sum[tid].x += xi;
sum[tid].y += xi * xi;
}
@@ -34,11 +34,11 @@ void main() {
barrier();
}
const float mean = sum[0].x / p.KX;
const float var = sum[0].y / p.KX - mean * mean;
const float mean = sum[0].x / p.ne00;
const float var = sum[0].y / p.ne00 - mean * mean;
const float inv_std = inversesqrt(var + p.param1);
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
data_d[row*p.KX + col] = D_TYPE((float(data_a[row*p.KX + col]) - mean) * inv_std);
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
data_d[d_base + i0*p.nb10] = D_TYPE((float(data_a[a_base + i0*p.nb00]) - mean) * inv_std);
}
}
@@ -1,17 +0,0 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sin(val));
}
@@ -1,17 +0,0 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sqrt(val));
}
@@ -1,17 +0,0 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val * val);
}
@@ -17,6 +17,30 @@ float op_neg(float x) {
return -x;
}
float op_sqr(float x) {
return x * x;
}
float op_sqrt(float x) {
return sqrt(x);
}
float op_sin(float x) {
return sin(x);
}
float op_cos(float x) {
return cos(x);
}
float op_clamp(float x) {
return clamp(x, p.param1, p.param2);
}
float op_leaky_relu(float x) {
return max(x, 0.0f) + min(x, 0.0f) * p.param1;
}
float op_step(float x) {
return x >= 0.0f ? 1.0f : 0.0f;
}
@@ -11,6 +11,7 @@
#include <future>
#include <queue>
#include <condition_variable>
#include <atomic>
#include <cstdio>
#include <cstring>
#include <cstdlib>
@@ -34,6 +35,9 @@
std::mutex lock;
std::vector<std::pair<std::string, std::string>> shader_fnames;
// Set when any shader subprocess fails (non-zero exit / stderr / launch failure) so the
// build is stopped instead of silently producing a broken libggml-vulkan. (issue #24393)
static std::atomic<bool> compile_failed{false};
std::locale c_locale("C");
std::string GLSLC = "glslc";
@@ -78,7 +82,7 @@ enum MatMulIdType {
namespace {
void execute_command(std::vector<std::string>& command, std::string& stdout_str, std::string& stderr_str) {
int execute_command(std::vector<std::string>& command, std::string& stdout_str, std::string& stderr_str) {
#ifdef _WIN32
HANDLE stdout_read, stdout_write;
HANDLE stderr_read, stderr_write;
@@ -127,8 +131,11 @@ void execute_command(std::vector<std::string>& command, std::string& stdout_str,
CloseHandle(stdout_read);
CloseHandle(stderr_read);
WaitForSingleObject(pi.hProcess, INFINITE);
DWORD exit_code = 1;
GetExitCodeProcess(pi.hProcess, &exit_code);
CloseHandle(pi.hProcess);
CloseHandle(pi.hThread);
return (int)exit_code;
#else
int stdout_pipe[2];
int stderr_pipe[2];
@@ -175,7 +182,9 @@ void execute_command(std::vector<std::string>& command, std::string& stdout_str,
close(stdout_pipe[0]);
close(stderr_pipe[0]);
waitpid(pid, nullptr, 0);
int status = 0;
waitpid(pid, &status, 0);
return WIFEXITED(status) ? WEXITSTATUS(status) : -1;
}
#endif
}
@@ -372,13 +381,14 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
// }
// std::cout << std::endl;
execute_command(cmd, stdout_str, stderr_str);
if (!stderr_str.empty()) {
std::cerr << "cannot compile " << name << "\n\n";
int exit_code = execute_command(cmd, stdout_str, stderr_str);
if (exit_code != 0 || !stderr_str.empty()) {
std::cerr << "cannot compile " << name << " (exit code " << exit_code << ")\n\n";
for (const auto& part : cmd) {
std::cerr << part << " ";
}
std::cerr << "\n\n" << stderr_str << std::endl;
compile_failed = true;
return;
}
@@ -398,6 +408,7 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
shader_fnames.push_back(std::make_pair(name, out_path));
} catch (const std::exception& e) {
std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl;
compile_failed = true;
}
}
@@ -539,11 +550,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
};
// Shaders with f16 B_TYPE
string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
// bf16
{
@@ -565,8 +574,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
#endif
{
if (!dot2) {
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPE_SCALAR", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
}
}
}
@@ -583,8 +591,6 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
}
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
// For unaligned, load one at a time for f32/f16, or two at a time for quants
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant;
// For aligned matmul loads
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
@@ -597,13 +603,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
// don't generate f32 variants for coopmat2
if (!coopmat2) {
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPE_SCALAR", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
}
if (tname != "f16" && tname != "f32") {
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
}
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
@@ -850,21 +854,12 @@ void process_shaders() {
string_to_spv("repeat_i32", "repeat.comp", {{"A_TYPE", "int32_t"}, {"D_TYPE", "int32_t"}});
string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("get_rows_back_f32", "get_rows_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}});
string_to_spv("repeat_i16", "repeat.comp", {{"A_TYPE", "int16_t"}, {"D_TYPE", "int16_t"}});
string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("sqrt_f32", "sqrt.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("concat_i8", "concat.comp", {{"A_TYPE", "uint8_t"}, {"B_TYPE", "uint8_t"}, {"D_TYPE", "uint8_t"}});
@@ -891,6 +886,18 @@ void process_shaders() {
string_to_spv("silu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_silu"}});
string_to_spv("relu_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_relu"}});
string_to_spv("relu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_relu"}});
string_to_spv("sqr_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sqr"}});
string_to_spv("sqr_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sqr"}});
string_to_spv("sqrt_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sqrt"}});
string_to_spv("sqrt_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sqrt"}});
string_to_spv("sin_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sin"}});
string_to_spv("sin_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sin"}});
string_to_spv("cos_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_cos"}});
string_to_spv("cos_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_cos"}});
string_to_spv("clamp_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_clamp"}});
string_to_spv("clamp_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_clamp"}});
string_to_spv("leaky_relu_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_leaky_relu"}});
string_to_spv("leaky_relu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_leaky_relu"}});
string_to_spv("neg_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_neg"}});
string_to_spv("neg_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_neg"}});
string_to_spv("tanh_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_tanh"}});
@@ -948,7 +955,6 @@ void process_shaders() {
string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
@@ -1060,6 +1066,31 @@ void process_shaders() {
}
}
for (auto unroll : {false, true}) {
for (auto a_f16 : {false, true}) {
std::map<std::string, std::string> defines = {
{"A_TYPE", a_f16 ? "float16_t" : "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"},
{"UNROLL", unroll ? "[[unroll]]" : ""},
};
std::string name = std::string("conv3d") + (a_f16 ? "_f16" : "") + "_f32";
string_to_spv(name + (unroll ? "_unroll" : ""), "conv3d_mm.comp", defines);
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (unroll) {
auto cm2_defines = defines;
cm2_defines["COOPMAT2"] = "1";
string_to_spv(name, "conv3d_mm.comp", cm2_defines, true, false, true);
}
#endif
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (unroll) {
auto cm1_defines = defines;
cm1_defines["COOPMAT"] = "1";
string_to_spv(name, "conv3d_mm.comp", cm1_defines, true, true, false);
}
#endif
}
}
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
string_to_spv("conv2d_dw_whcn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
@@ -1251,6 +1282,11 @@ int main(int argc, char** argv) {
process_shaders();
if (compile_failed) {
std::cerr << "vulkan-shaders-gen: one or more shaders failed to compile" << std::endl;
return EXIT_FAILURE;
}
write_output_files();
return EXIT_SUCCESS;
@@ -905,11 +905,12 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key {
ggml_type src0_type;
ggml_type src1_type;
int vectorized;
uint32_t num_cols;
bool use_mmvq;
bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const {
return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&
use_mmvq == other.use_mmvq;
num_cols == other.num_cols && use_mmvq == other.use_mmvq;
}
};
@@ -919,6 +920,7 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {
ggml_webgpu_hash_combine(seed, key.src0_type);
ggml_webgpu_hash_combine(seed, key.src1_type);
ggml_webgpu_hash_combine(seed, key.vectorized);
ggml_webgpu_hash_combine(seed, key.num_cols);
ggml_webgpu_hash_combine(seed, key.use_mmvq);
return seed;
}
@@ -993,11 +995,12 @@ struct ggml_webgpu_mul_mat_id_pipeline_key {
ggml_type src0_type;
ggml_type src1_type;
uint32_t n_experts;
uint32_t num_cols;
int vectorized;
bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const {
return src0_type == other.src0_type && src1_type == other.src1_type && n_experts == other.n_experts &&
vectorized == other.vectorized;
num_cols == other.num_cols && vectorized == other.vectorized;
}
};
@@ -1007,6 +1010,7 @@ struct ggml_webgpu_mul_mat_id_pipeline_key_hash {
ggml_webgpu_hash_combine(seed, key.src0_type);
ggml_webgpu_hash_combine(seed, key.src1_type);
ggml_webgpu_hash_combine(seed, key.n_experts);
ggml_webgpu_hash_combine(seed, key.num_cols);
ggml_webgpu_hash_combine(seed, key.vectorized);
return seed;
}
@@ -1107,7 +1111,7 @@ inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0,
const ggml_tensor * src1,
bool supports_dot_product,
const std::string & vendor) {
if (src1->ne[1] == 1) {
if (src1->ne[1] <= 4) {
bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia";
if (supports_dp4a && supports_dot_product) {
switch (src1->type) {
@@ -1889,6 +1893,7 @@ class ggml_webgpu_shader_lib {
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0;
key.num_cols = context.dst->ne[1];
key.use_mmvq =
ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor);
@@ -2004,6 +2009,7 @@ class ggml_webgpu_shader_lib {
if (key.vectorized) {
variant += "_vectorized";
}
defines.push_back(std::string("NUM_COLS=") + std::to_string(key.num_cols));
auto processed = preprocessor.preprocess(shader_src, defines);
auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
@@ -2421,6 +2427,7 @@ class ggml_webgpu_shader_lib {
if (key.vectorized) {
variant += "_vectorized";
}
defines.push_back(std::string("NUM_COLS=1"));
defines.push_back(std::string("N_EXPERTS=") + std::to_string(key.n_experts));
+12 -10
View File
@@ -1418,15 +1418,17 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context &
const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
const size_t q8_src1_align_offset = ROUNDUP_POW2(
dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
const size_t q8_src1_binding_size =
ROUNDUP_POW2(src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)),
WEBGPU_STORAGE_BUF_BINDING_MULT);
const size_t q8_src1_binding_size = ROUNDUP_POW2(
src1->ne[3] * src1->ne[2] * src1->ne[1] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)),
WEBGPU_STORAGE_BUF_BINDING_MULT);
std::vector<uint32_t> q8_params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
(uint32_t) src1->ne[0],
(uint32_t) src1->ne[1],
(uint32_t) src1->ne[2],
(uint32_t) src1->ne[3],
};
@@ -1442,7 +1444,7 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context &
uint32_t q8_wg_x = 1;
uint32_t q8_wg_y = 1;
const uint32_t wg_per_vec = (src0->ne[0] / 4 + (q8_wg_size - 1)) / q8_wg_size;
const uint32_t q8_total_wg = src1->ne[2] * src1->ne[3] * wg_per_vec;
const uint32_t q8_total_wg = src1->ne[1] * src1->ne[2] * src1->ne[3] * wg_per_vec;
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
compute_2d_workgroups(q8_total_wg, max_wg_per_dim, q8_wg_x, q8_wg_y);
@@ -1456,7 +1458,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
ggml_tensor * src1,
ggml_tensor * dst) {
// Determine if this is a mat-vec operation
bool is_vec = (dst->ne[1] == 1);
bool use_mat_vec = (dst->ne[1] <= 4);
// use MMVQ path for mat-vec
bool use_mmvq = ggml_webgpu_can_use_mmvq(src0, src1, ctx->global_ctx->capabilities.supports_dot_product,
@@ -1482,7 +1484,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
webgpu_pipeline pipeline;
std::vector<webgpu_dispatch_desc> dispatches;
if (is_vec) {
if (use_mat_vec) {
if (use_mmvq) {
ggml_webgpu_quantize_q8_dispatch(ctx, src0, src1, dst, dispatches);
}
@@ -1529,7 +1531,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
uint32_t wg_y = 1;
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
if (is_vec) {
if (use_mat_vec) {
auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
uint32_t batches = dst->ne[2] * dst->ne[3];
@@ -3691,8 +3693,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
ggml_webgpu_can_use_mmvq(src0, src1, ctx->webgpu_global_ctx->capabilities.supports_dot_product,
ctx->webgpu_global_ctx->vendor);
if (use_mmvq) {
const size_t q8_src1_size =
src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32));
const size_t q8_src1_size = src1->ne[3] * src1->ne[2] * src1->ne[1] *
(36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32));
res = ROUNDUP_POW2(res + q8_src1_size +
ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
WEBGPU_STORAGE_BUF_BINDING_MULT);
@@ -4268,7 +4270,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_OP_RMS_NORM:
case GGML_OP_NORM:
case GGML_OP_L2_NORM:
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
supports_op = (op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32) && ggml_is_contiguous_rows(src0);
break;
case GGML_OP_ROPE:
supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
@@ -103,7 +103,7 @@ fn main(
#ifdef USE_SUBGROUP_REDUCTION
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let subgroup_total = subgroupAdd(acc[row]);
let subgroup_total = subgroupAdd(acc[0][row]);
if (subgroup_invocation_id == 0u) {
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
}
@@ -126,7 +126,7 @@ fn main(
#ifdef USE_WORKGROUP_REDUCTION
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
partial_sums[partial_index(row, thread_id)] = acc[row];
partial_sums[partial_index(row, thread_id)] = acc[0][row];
}
workgroupBarrier();
@@ -91,61 +91,67 @@ fn main(
let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base;
#ifdef MMVQ
let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * (params.k / 32u);
let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * params.n * (params.k / 32u);
let acc = accumulate_vec_q_dot(thread_id, row_base, src0_batch_offset, src1q_idx_base);
#else
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base);
#endif
for (var col = 0u;col < NUM_COLS;col += 1) {
#ifdef USE_SUBGROUP_REDUCTION
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let subgroup_total = subgroupAdd(acc[row]);
if (subgroup_invocation_id == 0u) {
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
}
}
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let subgroup_total = subgroupAdd(acc[col][row]);
if (subgroup_invocation_id == 0u) {
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
}
}
workgroupBarrier();
workgroupBarrier();
for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) {
let output_row = row_base + row;
var row_acc = 0.0f;
for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) {
row_acc += partial_sums[partial_index(row, k)];
}
let row_total = subgroupAdd(row_acc);
if (subgroup_invocation_id == 0) {
dst[dst_idx_base + row] = row_total;
}
}
for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) {
let output_row = row_base + row;
var row_acc = 0.0f;
for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) {
row_acc += partial_sums[partial_index(row, k)];
}
let row_total = subgroupAdd(row_acc);
if (subgroup_invocation_id == 0) {
dst[dst_idx_base + col * params.m + row] = row_total;
}
}
#endif
#ifdef USE_WORKGROUP_REDUCTION
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
partial_sums[partial_index(row, thread_id)] = acc[row];
}
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
partial_sums[partial_index(row, thread_id)] = acc[col][row];
}
workgroupBarrier();
var stride = WG_SIZE / 2u;
while (stride > 0) {
if (thread_id < stride) {
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)];
}
}
workgroupBarrier();
stride = stride / 2;
}
if (thread_id < OUTPUTS_PER_WG) {
let output_row = row_base + thread_id;
if (output_row < params.m) {
dst[dst_idx_base + col * params.m + thread_id] = partial_sums[partial_index(thread_id, 0)];
}
}
#endif
workgroupBarrier();
var stride = WG_SIZE / 2u;
while (stride > 0) {
if (thread_id < stride) {
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)];
}
}
workgroupBarrier();
stride = stride / 2;
}
if (thread_id < OUTPUTS_PER_WG) {
let output_row = row_base + thread_id;
if (output_row < params.m) {
dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)];
}
}
#endif
}
File diff suppressed because it is too large Load Diff
@@ -51,10 +51,7 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE {
fn get_dm(block_byte_base: u32) -> f32 {
return f32(load_f16_at_src0(block_byte_base));
}
fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 {
return f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK;
}
#endif
#endif // MUL_ACC_Q4_0
#ifdef MUL_ACC_Q4_1
#define BLOCK_SIZE_BYTES 20
@@ -85,10 +82,7 @@ fn get_dm(block_byte_base: u32) -> vec2<f32> {
f32(load_f16_at_src0(block_byte_base + 2u))
);
}
fn mul_q8_1(row_sum: i32, dma: vec2<f32>, b_ds: B_DS_TYPE) -> f32 {
return f32(row_sum) * (dma.x * b_ds.x) + dma.y * b_ds.y / THREADS_PER_BLOCK;
}
#endif
#endif // MUL_ACC_Q4_1
#ifdef MUL_ACC_Q8_0
#define BLOCK_SIZE_BYTES 34
@@ -111,46 +105,48 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE {
fn get_dm(block_byte_base: u32) -> f32 {
return f32(load_f16_at_src0(block_byte_base));
}
fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 {
return f32(row_sum) * (da * b_ds);
}
#endif
#endif // MUL_ACC_Q8_0
#ifdef LEGACY_QUANTS
fn mmvq_dot_product(a_byte_base: u32, b_inner_id: u32, b_repacked: vec2<u32>, b_ds: B_DS_TYPE) -> f32 {
var row_sum = 0;
let a_repacked = repack_a(a_byte_base, b_inner_id);
row_sum += dot4I8Packed(a_repacked[0], b_repacked[0]);
row_sum += dot4I8Packed(a_repacked[1], b_repacked[1]);
return mul_q8_1(row_sum, get_dm(a_byte_base), b_ds);
}
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
var acc: array<f32, OUTPUTS_PER_WG>;
#if defined(LEGACY_QUANTS)
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<array<f32, OUTPUTS_PER_WG>, NUM_COLS> {
var acc: array<array<f32, OUTPUTS_PER_WG>, NUM_COLS>;
let num_blocks = params.k / BLOCK_SIZE;
for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
let b_inner_id = thread_id % THREADS_PER_BLOCK;
let b_block_idx = src1q_idx_base + block;
let b_repacked = repack_b_qs(b_block_idx, b_inner_id);
let b_ds = repack_b_dm(b_block_idx);
let inner_id = thread_id % THREADS_PER_BLOCK;
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
acc[row] += mmvq_dot_product(block_byte_base, b_inner_id, b_repacked, b_ds);
let a_repacked = repack_a(block_byte_base, inner_id);
let da = get_dm(block_byte_base);
for (var col = 0u;col < NUM_COLS;col += 1) {
let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + block;
let b_repacked = repack_b_qs(src1q_idx, inner_id);
let b_ds = repack_b_dm(src1q_idx);
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]);
#if defined(MUL_ACC_Q4_0)
acc[col][row] += f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK;
#endif // MUL_ACC_Q4_0
#if defined(MUL_ACC_Q4_1)
acc[col][row] += f32(row_sum) * (da.x * b_ds.x) + da.y * b_ds.y / THREADS_PER_BLOCK;
#endif // MUL_ACC_Q4_1
#if defined(MUL_ACC_Q8_0)
acc[col][row] += f32(row_sum) * (da * b_ds);
#endif // MUL_ACC_Q8_0
}
}
}
}
return acc;
}
#endif
#endif // LEGACY_QUANTS
#ifdef MUL_ACC_Q2_K
#define BLOCK_SIZE_BYTES 84
@@ -191,22 +187,7 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> {
let scale = byte_of(load_u32_at_src0_aligned(scale_byte), scale_byte & 3u);
return vec2<f32>(f32(scale & 0xFu), f32(scale >> 4u));
}
fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 {
let a_repacked = repack_a(a_byte_base, tid);
let dm = get_dm(a_byte_base);
let scale_min = get_scale_min(a_byte_base, tid);
let scale_q = i32(scale_min.x);
let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u;
let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1])
+ dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q;
let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4)
+ dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4);
return b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m));
}
#endif
#endif // MUL_ACC_Q2_K
#ifdef MUL_ACC_Q4_K
#define BLOCK_SIZE_BYTES 144
@@ -265,39 +246,52 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> {
return vec2<f32>(scale, min_val);
}
fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 {
let a_repacked = repack_a(a_byte_base, tid);
let dm = get_dm(a_byte_base);
let scale_min = get_scale_min(a_byte_base, tid);
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1])
+ dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]);
// Each thread covers half of the Q8_1 block, so add only b_ds.y/2.
return b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD));
}
#endif
#endif // MUL_ACC_Q4_K
#ifdef K_QUANTS
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
var acc: array<f32, OUTPUTS_PER_WG>;
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<array<f32, OUTPUTS_PER_WG>, NUM_COLS> {
var acc: array<array<f32, OUTPUTS_PER_WG>, NUM_COLS>;
let tid = thread_id % THREADS_PER_BLOCK;
for (var block = thread_id / THREADS_PER_BLOCK; block < params.k / BLOCK_SIZE; block += WG_SIZE / THREADS_PER_BLOCK) {
let src1q_idx = src1q_idx_base + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE;
let b_repacked = repack_b_qs(src1q_idx, tid);
let b_ds = repack_b_dm(src1q_idx);
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
acc[row] += mmvq_dot_product(block_byte_base, tid, b_repacked, b_ds);
let a_repacked = repack_a(block_byte_base, tid);
let dm = get_dm(block_byte_base);
let scale_min = get_scale_min(block_byte_base, tid);
for (var col = 0u;col < NUM_COLS;col += 1) {
let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE;
let b_repacked = repack_b_qs(src1q_idx, tid);
let b_ds = repack_b_dm(src1q_idx);
#if defined(MUL_ACC_Q2_K)
let scale_q = i32(scale_min.x);
let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u;
let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1])
+ dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q;
let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4)
+ dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4);
acc[col][row] += b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m));
#endif // MUL_ACC_Q2_K
#if defined(MUL_ACC_Q4_K)
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1])
+ dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]);
// Each thread covers half of the Q8_1 block, so add only b_ds.y/2.
acc[col][row] += b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD));
#endif // MUL_ACC_Q4_K
}
}
}
}
return acc;
}
#endif
#endif // K_QUANTS
@@ -9,9 +9,11 @@ requires packed_4x8_integer_dot_product;
struct Params {
offset_src1: u32,
stride_11: u32,
stride_12: u32,
stride_13: u32,
ne0: u32,
ne1: u32,
ne2: u32,
ne3: u32,
};
@@ -57,25 +59,28 @@ fn main(
@builtin(num_workgroups) num_wg: vec3<u32>
) {
let thread_id = local_id.x;
let num_vec4 = params.ne0 / 4u;
let ne0_vec4 = params.ne0 / 4u;
let wg_per_vec = (num_vec4 + (WG_SIZE - 1u)) / WG_SIZE;
let total_batches = wg_per_vec * params.ne2 * params.ne3;
let wg_per_vec = (ne0_vec4 + (WG_SIZE - 1u)) / WG_SIZE;
let total_batches = wg_per_vec * params.ne1 * params.ne2 * params.ne3;
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
if (wg_linear >= total_batches) {
return;
}
let src13_idx = wg_linear / (params.ne2 * wg_per_vec);
let src12_idx = (wg_linear - src13_idx * (params.ne2 * wg_per_vec)) / wg_per_vec;
let src11_wg_idx = wg_linear % wg_per_vec;
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
let vec_idx = wg_linear / wg_per_vec;
let src13_idx = vec_idx / (params.ne2 * params.ne1);
let vec_ne12_num = vec_idx % (params.ne2 * params.ne1);
let src12_idx = vec_ne12_num / params.ne1;
let src11_idx = vec_ne12_num % params.ne1;
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + src11_idx * params.stride_11;
let src1_idx_vec4_base = src1_idx_base / 4u;
let blocks_per_row = params.ne0 / 32u;
let blocks_per_wg = (WG_SIZE * 4u) / 32u;
let src1q_idx_base = (src13_idx * params.ne2 + src12_idx) * blocks_per_row;
let src1q_idx_base = ((src13_idx * params.ne2 + src12_idx) * params.ne1 + src11_idx) * blocks_per_row;
let src11_wg_idx = wg_linear % wg_per_vec;
let src1q_idx = src1q_idx_base + src11_wg_idx * blocks_per_wg + thread_id / 8u;
let qs_idx = thread_id % 8u;
@@ -85,7 +90,7 @@ fn main(
var thread_amax = 0.0;
let src11_vec4_idx = src11_wg_idx * WG_SIZE + thread_id;
let is_valid = src11_vec4_idx < num_vec4;
let is_valid = src11_vec4_idx < ne0_vec4;
#ifdef USE_SUBGROUP_REDUCTION
+1
View File
@@ -359,6 +359,7 @@ class Keys:
CHUNK_SIZE = "clip.audio.chunk_size"
CONV_KERNEL_SIZE = "clip.audio.conv_kernel_size"
MAX_POS_EMB = "clip.audio.max_pos_emb"
FEATURE_LAYERS = "clip.audio.feature_layer" # Granite Speech Plus
class Attention:
HEAD_COUNT = "clip.audio.attention.head_count"
+3
View File
@@ -1310,6 +1310,9 @@ class GGUFWriter:
def add_audio_max_pos_emb(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.MAX_POS_EMB, value)
def add_audio_feature_layers(self, layers: Sequence[int]) -> None:
self.add_array(Keys.ClipAudio.FEATURE_LAYERS, layers)
def add_audio_projector_window_size(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.Projector.WINDOW_SIZE, value)
+14 -4
View File
@@ -190,7 +190,15 @@ llama_model_lfm2::graph<iswa>::graph(const llama_model & model, const llm_graph_
auto * conv_rs = build_rs(inp_recr, conv_state, hparams.n_embd_r(), n_seqs);
auto * conv = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs);
bx = ggml_concat(ctx0, conv, bx, 0);
// causal prepends the state, non-causal pads symmetrically for a centered window
if (hparams.causal_attn) {
bx = ggml_concat(ctx0, conv, bx, 0);
} else {
const int64_t pad = (hparams.n_shortconv_l_cache - 1) / 2;
auto * left = ggml_cont(ctx0,
ggml_view_3d(ctx0, conv, pad, hparams.n_embd, n_seqs, conv->nb[1], conv->nb[2], (d_conv - pad) * conv->nb[0]));
bx = ggml_pad_ext(ctx0, ggml_concat(ctx0, left, bx, 0), 0, pad, 0, 0, 0, 0, 0, 0);
}
GGML_ASSERT(bx->ne[0] > conv->ne[0]);
// last d_conv columns is a new conv state
@@ -266,10 +274,12 @@ llama_model_lfm2::graph<iswa>::graph(const llama_model & model, const llm_graph_
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
if (!cparams.embeddings) {
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;
res->t_logits = cur;
}
ggml_build_forward_expand(gf, cur);
}
+56 -8
View File
@@ -3298,21 +3298,29 @@ struct test_norm : public test_case {
const std::array<int64_t, 4> ne;
const bool v; // whether a is a non-contiguous view
const float eps;
const bool noncontig_rows;
std::string vars() override {
return VARS_TO_STR4(type, ne, v, eps);
return VARS_TO_STR5(type, ne, v, eps, noncontig_rows);
}
test_norm(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {64, 5, 4, 3},
bool v = false,
float eps = 1e-6f)
: type(type), ne(ne), v(v), eps(eps) {}
float eps = 1e-6f,
bool noncontig_rows = false)
: type(type), ne(ne), v(v), eps(eps), noncontig_rows(noncontig_rows) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
const std::array<int64_t, 4> ne_a = noncontig_rows ?
std::array<int64_t, 4>{ ne[1], ne[0], ne[2], ne[3] } : ne;
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
ggml_set_name(a, "a");
if (noncontig_rows) {
a = ggml_permute(ctx, a, 1, 0, 2, 3);
ggml_set_name(a, "permuted a");
}
if (v) {
a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);
ggml_set_name(a, "view of a");
@@ -6193,21 +6201,29 @@ struct test_l2_norm : public test_case {
const std::array<int64_t, 4> ne;
const float eps;
bool v;
bool noncontig_rows;
std::string vars() override {
return VARS_TO_STR4(type, ne, eps, v);
return VARS_TO_STR5(type, ne, eps, v, noncontig_rows);
}
test_l2_norm(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {64, 64, 320, 1},
float eps = 1e-12f,
bool v = false)
: type(type), ne(ne), eps(eps), v(v) {}
bool v = false,
bool noncontig_rows = false)
: type(type), ne(ne), eps(eps), v(v), noncontig_rows(noncontig_rows) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
const std::array<int64_t, 4> ne_a = noncontig_rows ?
std::array<int64_t, 4>{ ne[1], ne[0], ne[2], ne[3] } : ne;
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
ggml_set_name(a, "a");
if (noncontig_rows) {
a = ggml_permute(ctx, a, 1, 0, 2, 3);
ggml_set_name(a, "permuted a");
}
if (v) {
a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);
ggml_set_name(a, "view of a");
@@ -8282,9 +8298,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps));
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps));
}
test_cases.emplace_back(new test_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, false, eps, true));
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, { n, 5, 4, 3 }, eps));
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false, true));
}
}
@@ -8433,6 +8451,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 4, k, {3, 2}, {2, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {2, 1}));
@@ -8449,6 +8468,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 4, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
@@ -9270,6 +9290,34 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
}
}
struct conv3d_perf_case {
int N, IC, ID, IH, IW, OC, KD, KH, KW, s0, s1, s2, p0, p1, p2, d0, d1, d2;
};
const std::vector<conv3d_perf_case> conv3d_cases = {
{1, 320, 8, 38, 26, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
{1, 1280, 8, 38, 26, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
{1, 320, 8, 76, 52, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
{1, 1280, 8, 76, 52, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
{1, 320, 8, 152, 104, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
#if 0
// too slow on some devices
{1, 1280, 8, 152, 104, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
{1, 320, 4, 304, 208, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
{1, 640, 4, 304, 208, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
#endif
};
for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
for (const conv3d_perf_case & c : conv3d_cases) {
test_cases.emplace_back(new test_conv_3d(
c.N, c.IC, c.ID, c.IH, c.IW,
c.OC, c.KD, c.KH, c.KW,
c.s0, c.s1, c.s2, c.p0, c.p1, c.p2, c.d0, c.d1, c.d2,
kernel_type));
}
}
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1}));
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));
+99 -24
View File
@@ -1562,37 +1562,112 @@ static void test_msgs_oaicompat_json_conversion() {
}
}
static void test_split_by_role() {
static void test_msg_token_delimiters_split() {
LOG_DBG("%s\n", __func__);
// Delimiters that share a leading token, distinguished by the second token,
// to exercise the per-position token matching.
const common_chat_msg_delimiters delims = {
{ { COMMON_CHAT_ROLE_USER, "", { 10, 11 } },
{ COMMON_CHAT_ROLE_ASSISTANT, "", { 10, 12 } } }
};
// Empty inputs
assert_equals<size_t>(0, common_chat_split_by_role("", {}).size());
assert_equals<size_t>(0, common_chat_split_by_role("hello", {}).size());
assert_equals<size_t>(0, common_chat_split_by_role("", { { "user", "<|user|>" } }).size());
assert_equals<size_t>(0, common_chat_msg_delimiters{}.split({}).spans.size());
assert_equals<size_t>(0, common_chat_msg_delimiters{}.split({ 10, 11 }).spans.size());
assert_equals<size_t>(0, delims.split({}).spans.size());
// Multi-role conversation, no leading/trailing content
// No delimiters match -> no spans
assert_equals<size_t>(0, delims.split({ 100, 101, 102 }).spans.size());
// Multi-role conversation: <user>Hi<assistant>Hello<user>Bye
{
const std::string prompt = "<|user|>Hi<|assistant|>Hello<|user|>Bye";
const auto splits = common_chat_split_by_role(prompt, {
{ "user", "<|user|>" },
{ "assistant", "<|assistant|>" },
});
assert_equals<size_t>(3, splits.size());
const llama_tokens tokens = {
10, 11, // <user>
100, 101, // Hi
10, 12, // <assistant>
200, 201, 202, // Hello
10, 11, // <user>
300, 301, // Bye
};
assert_equals<std::string>("user", splits[0].role);
assert_equals<size_t>(0, splits[0].pos);
assert_equals<size_t>(10, splits[0].len);
assert_equals<std::string>("<|user|>Hi", prompt.substr(splits[0].pos, splits[0].len));
const auto result = delims.split(tokens);
const auto & spans = result.spans;
assert_equals<size_t>(3, spans.size());
assert_equals<std::string>("assistant", splits[1].role);
assert_equals<size_t>(10, splits[1].pos);
assert_equals<size_t>(18, splits[1].len);
assert_equals<std::string>("<|assistant|>Hello", prompt.substr(splits[1].pos, splits[1].len));
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
assert_equals<size_t>(0, spans[0].pos);
assert_equals<size_t>(4, spans[0].len);
assert_equals<std::string>("user", splits[2].role);
assert_equals<size_t>(28, splits[2].pos);
assert_equals<size_t>(11, splits[2].len);
assert_equals<std::string>("<|user|>Bye", prompt.substr(splits[2].pos, splits[2].len));
assert_equals(COMMON_CHAT_ROLE_ASSISTANT, spans[1].role);
assert_equals<size_t>(4, spans[1].pos);
assert_equals<size_t>(5, spans[1].len);
assert_equals(COMMON_CHAT_ROLE_USER, spans[2].role);
assert_equals<size_t>(9, spans[2].pos);
assert_equals<size_t>(4, spans[2].len);
// is_user_start() is true at the token position where a user span begins
assert_equals(true, result.is_user_start(0));
assert_equals(false, result.is_user_start(4)); // assistant span
assert_equals(true, result.is_user_start(9));
}
// Content before the first delimiter is not captured as a span
{
const llama_tokens tokens = {
500, 501, // leading content (dropped)
10, 11, // <user>
100, // Hi
};
const auto spans = delims.split(tokens).spans;
assert_equals<size_t>(1, spans.size());
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
assert_equals<size_t>(2, spans[0].pos);
assert_equals<size_t>(3, spans[0].len);
}
// Skipped regions (media chunks) are jumped over but still count as span content
{
const llama_tokens tokens = {
10, 11, // <user>
LLAMA_TOKEN_NULL, // media chunk (3 tokens)
LLAMA_TOKEN_NULL,
LLAMA_TOKEN_NULL,
100, // Hi
10, 12, // <assistant>
};
const std::map<size_t, size_t> skips = { { 2, 3 } };
const auto spans = delims.split(tokens, skips).spans;
assert_equals<size_t>(2, spans.size());
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
assert_equals<size_t>(0, spans[0].pos);
assert_equals<size_t>(6, spans[0].len);
assert_equals(COMMON_CHAT_ROLE_ASSISTANT, spans[1].role);
assert_equals<size_t>(6, spans[1].pos);
assert_equals<size_t>(2, spans[1].len);
}
// A delimiter sequence inside a skipped region is not matched
{
const llama_tokens tokens = {
10, 11, // <user>
10, 12, // skipped region that happens to contain delimiter tokens
100, // Hi
};
const std::map<size_t, size_t> skips = { { 2, 2 } };
const auto spans = delims.split(tokens, skips).spans;
assert_equals<size_t>(1, spans.size());
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
assert_equals<size_t>(0, spans[0].pos);
assert_equals<size_t>(5, spans[0].len);
}
}
@@ -5857,7 +5932,7 @@ int main(int argc, char ** argv) {
{
test_msg_diffs_compute();
test_msgs_oaicompat_json_conversion();
test_split_by_role();
test_msg_token_delimiters_split();
test_tools_oaicompat_json_conversion();
test_convert_responses_to_chatcmpl();
test_developer_role_to_system_workaround();
+1 -1
View File
@@ -42,6 +42,7 @@
#define KEY_N_HEAD "clip.%s.attention.head_count"
#define KEY_N_HEAD_KV "clip.%s.attention.head_count_kv"
#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon"
#define KEY_FEATURE_LAYERS "clip.%s.feature_layer"
// vision-specific
#define KEY_VISION_PROJ_TYPE "clip.vision.projector_type" // for models with mixed modalities
@@ -54,7 +55,6 @@
#define KEY_PATCH_SIZE "clip.vision.patch_size"
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
#define KEY_IMAGE_STD "clip.vision.image_std"
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
#define KEY_PROJ_SAMPLE_QUERY_SIDE "clip.vision.projector.query_side"
#define KEY_PROJ_SAMPLE_WINDOW_SIDE "clip.vision.projector.window_side"
+3 -3
View File
@@ -91,7 +91,7 @@ struct clip_hparams {
float eps = 1e-6;
float rope_theta = 0.0;
std::vector<int32_t> vision_feature_layer;
std::vector<int32_t> feature_layers;
int32_t attn_window_size = 0;
int32_t n_wa_pattern = 0;
std::unordered_set<int32_t> wa_layer_indexes; // explicit layer indexes that use full attention (for irregular patterns like YoutuVL)
@@ -165,8 +165,8 @@ struct clip_hparams {
return false;
}
bool is_vision_feature_layer(int32_t layer) const {
return std::find(vision_feature_layer.begin(), vision_feature_layer.end(), layer) != vision_feature_layer.end();
bool is_feature_layer(int32_t layer) const {
return std::find(feature_layers.begin(), feature_layers.end(), layer) != feature_layers.end();
}
};
+9 -10
View File
@@ -1264,12 +1264,10 @@ struct clip_model_loader {
}
}
// Load the vision feature layer indices if they are explicitly provided;
// if multiple vision feature layers are present, the values will be concatenated
// to form the final visual features.
// Load the vision/audio feature layer indices if they are explicitly provided
// NOTE: gguf conversions should standardize the values of the vision feature layer to
// be non-negative, since we use -1 to mark values as unset here.
get_arr_int(KEY_FEATURE_LAYER, hparams.vision_feature_layer, false);
get_arr_int(string_format(KEY_FEATURE_LAYERS, prefix), hparams.feature_layers, false);
// model-specific params
switch (model.proj_type) {
@@ -1651,6 +1649,7 @@ struct clip_model_loader {
get_u32(KEY_A_PROJ_WINDOW_SIZE, hparams.audio_proj_window_size);
get_u32(KEY_A_PROJ_DOWNSAMPLE_RATE, hparams.audio_proj_downsample_rate);
get_u32(KEY_A_PROJ_HEAD_COUNT, hparams.audio_proj_head_count);
// NOTE: feature layers loaded above in common path
} break;
case PROJECTOR_TYPE_JANUS_PRO:
{
@@ -1663,11 +1662,11 @@ struct clip_model_loader {
hparams.image_resize_algo = RESIZE_ALGO_BICUBIC_PILLOW;
hparams.image_resize_pad = PAD_CEIL;
get_arr_int(KEY_FEATURE_LAYER, hparams.vision_feature_layer);
// NOTE: feature_layers loaded in common path as optional
get_arr_int(KEY_PROJ_SPATIAL_OFFSETS, hparams.proj_spatial_offsets);
if (hparams.vision_feature_layer.size() != hparams.proj_spatial_offsets.size()) {
throw std::runtime_error(string_format("%s: vision_feature_layer.size() %d != proj_spatial_offsets.size() %d",
hparams.vision_feature_layer.size(), hparams.proj_spatial_offsets.size()));
if (hparams.feature_layers.size() != hparams.proj_spatial_offsets.size()) {
throw std::runtime_error(string_format("%s: feature_layers.size() %d != proj_spatial_offsets.size() %d",
hparams.feature_layers.size(), hparams.proj_spatial_offsets.size()));
}
get_u32(KEY_PROJ_SAMPLE_QUERY_SIDE, hparams.downsample_query_side);
@@ -2740,7 +2739,7 @@ struct clip_model_loader {
model.image_newline = get_tensor(TN_IMAGE_NEWLINE);
// Load separate layerwise and spatial projector tensors
const auto projector_count = hparams.vision_feature_layer.size();
const auto projector_count = hparams.feature_layers.size();
model.qf_proj_blocks.resize(projector_count);
for (size_t bid = 0; bid < projector_count; ++bid) {
auto & b = model.qf_proj_blocks[bid];
@@ -4388,7 +4387,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, int n_threads, const clip_image_f32
// Stage 1b only uses block 0's permutations; future stages
// will upload all blocks.
for (size_t bid = 0; bid < hparams.vision_feature_layer.size(); ++bid) {
for (size_t bid = 0; bid < hparams.feature_layers.size(); ++bid) {
const std::string prefix = "g4v_blk" + std::to_string(bid) + "_";
upload(prefix + "win_idx", make_win_idx(image_side, window_side));
upload(prefix + "qwin_idx", make_win_idx(new_side, query_side));
+35 -1
View File
@@ -1,5 +1,7 @@
#include "models.h"
#include <algorithm>
ggml_cgraph * clip_graph_granite_speech::build() {
const int n_frames = img.nx();
const int context_size = hparams.audio_chunk_size;
@@ -11,6 +13,10 @@ ggml_cgraph * clip_graph_granite_speech::build() {
const int padded_len = num_blocks * context_size;
const int remainder = n_frames % context_size;
// Calculate projector input dimension based on feature layers
const int proj_input_dim = n_embd * (hparams.feature_layers.size() + 1);
const bool use_feature_concat = !hparams.feature_layers.empty();
ggml_tensor * attn_dists = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, context_size * context_size);
ggml_set_name(attn_dists, "attn_dists");
ggml_set_input(attn_dists);
@@ -31,6 +37,15 @@ ggml_cgraph * clip_graph_granite_speech::build() {
cur = ggml_add(ctx0, cur, model.inp_proj_b);
cb(cur, "inp_linear", -1);
// Capture layer 0 if requested (after input_linear)
ggml_tensor * concat_result = nullptr;
if (use_feature_concat) {
if (std::find(hparams.feature_layers.begin(), hparams.feature_layers.end(), 0) != hparams.feature_layers.end()) {
concat_result = cur;
cb(concat_result, "feature_layer_0", -1);
}
}
for (int il = 0; il < n_layer; il++) {
const auto & layer = model.layers[il];
auto * residual = cur;
@@ -168,6 +183,18 @@ ggml_cgraph * clip_graph_granite_speech::build() {
NORM_TYPE_NORMAL, eps, il);
cb(cur, "layer_out", il);
// Capture intermediate layer (il + 1) if requested
if (use_feature_concat) {
if (hparams.is_feature_layer(il + 1)) {
if (concat_result == nullptr) {
concat_result = cur;
} else {
concat_result = ggml_concat(ctx0, concat_result, cur, 0);
}
cb(concat_result, string_format("feature_layer_%d", il + 1).c_str(), il);
}
}
// CTC branch
if (il + 1 == ctc_layer) {
auto * mid = build_mm(model.ctc_out_w, cur);
@@ -180,6 +207,13 @@ ggml_cgraph * clip_graph_granite_speech::build() {
}
}
// Append final output to concatenated features if using feature concatenation
if (use_feature_concat && concat_result != nullptr) {
concat_result = ggml_concat(ctx0, concat_result, cur, 0);
cb(concat_result, "concat_final", -1);
cur = concat_result;
}
cb(cur, "encoder_out", -1);
// QFormer projector
@@ -197,7 +231,7 @@ ggml_cgraph * clip_graph_granite_speech::build() {
cur = ggml_pad(ctx0, cur, 0, padded_proj - n_frames, 0, 0);
}
ggml_tensor * enc_windows = ggml_reshape_3d(ctx0, cur, n_embd, window_size, nblocks_proj);
ggml_tensor * enc_windows = ggml_reshape_3d(ctx0, cur, proj_input_dim, window_size, nblocks_proj);
ggml_tensor * queries = build_norm(model.qf_proj_blocks[0].qf_proj_query,
model.qf_proj_blocks[0].qf_proj_norm_w, model.qf_proj_blocks[0].qf_proj_norm_b,
+2 -2
View File
@@ -304,14 +304,14 @@ ggml_cgraph * clip_graph_granite4_vision::build() {
}
// --- Stage 1b/1c: WindowQFormer blocks ---
const int projector_count = hparams.vision_feature_layer.size();
const int projector_count = hparams.feature_layers.size();
const float qformer_eps = 1e-12f;
ggml_tensor * mmproj = nullptr;
for (int bid = 0; bid < projector_count; ++bid) {
const auto & blk = model.qf_proj_blocks[bid];
int vlayer = hparams.vision_feature_layer[bid];
int vlayer = hparams.feature_layers[bid];
GGML_ASSERT(vlayer >= 0 && vlayer < n_layer);
ggml_tensor * h = layer_outs[vlayer];
+3 -3
View File
@@ -21,7 +21,7 @@ ggml_cgraph * clip_graph_llava::build() {
// If we set explicit vision feature layers, only go up to the deepest one
// NOTE: only used by granite-vision models for now
for (const auto & feature_layer : hparams.vision_feature_layer) {
for (const auto & feature_layer : hparams.feature_layers) {
if (feature_layer > deepest_feature_layer) {
deepest_feature_layer = feature_layer;
}
@@ -59,7 +59,7 @@ ggml_cgraph * clip_graph_llava::build() {
// If this is an embedding feature layer, save the output.
// NOTE: 0 index here refers to the input to the encoder.
if (hparams.is_vision_feature_layer(il)) {
if (hparams.is_feature_layer(il)) {
embedding_stack.push_back(cur);
}
@@ -134,7 +134,7 @@ ggml_cgraph * clip_graph_llava::build() {
// process vision feature layers (used by granite)
{
// final layer is a vision feature layer
if (hparams.is_vision_feature_layer(max_feature_layer)) {
if (hparams.is_feature_layer(max_feature_layer)) {
embedding_stack.push_back(inpL);
}
+9 -9
View File
@@ -518,6 +518,14 @@ size_t server_tokens::get_common_prefix(const server_tokens & b) const {
return max_idx; // all tokens are equal
}
common_chat_msg_spans server_tokens::find_message_spans(const common_chat_msg_delimiters & delims) const {
std::map<size_t, size_t> skips;
for (const auto & it : map_idx_to_media) {
skips[it.first] = mtmd_input_chunk_get_n_tokens(it.second.get());
}
return delims.split(tokens, skips);
}
bool server_tokens::validate(const struct llama_context * ctx) const {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
@@ -1104,15 +1112,7 @@ json oaicompat_chat_params_parse(
llama_params["chat_parser"] = chat_params.parser;
}
llama_params["message_spans"] = json::array();
for (const auto & span : chat_params.message_spans) {
llama_params["message_spans"].push_back({
{ "role", span.role },
{ "pos", span.pos },
{ "len", span.len },
});
}
llama_params["message_delimiters"] = chat_params.message_delimiters.to_json();
// Reasoning budget: pass parameters through to sampling layer
{
+3
View File
@@ -218,6 +218,9 @@ public:
size_t get_common_prefix(const server_tokens & b) const;
// split the tokens into message spans, skipping over media chunks
common_chat_msg_spans find_message_spans(const common_chat_msg_delimiters & delims) const;
// make sure all text tokens are within the vocab range
bool validate(const struct llama_context * ctx) const;
+26 -92
View File
@@ -89,7 +89,9 @@ struct server_batch {
}
~server_batch() {
llama_batch_free(batch);
if (batch.token != nullptr) {
llama_batch_free(batch);
}
}
void init(int32_t n_tokens_alloc) {
@@ -1215,6 +1217,10 @@ private:
cparams.ctx_other = ctx_tgt;
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
if (ctx_dft == nullptr) {
SRV_ERR("%s", "failed to create draft context\n");
return false;
}
params_base.speculative.draft.ctx_tgt = ctx_tgt;
params_base.speculative.draft.ctx_dft = ctx_dft.get();
@@ -3436,8 +3442,8 @@ private:
has_mtmd = true;
}
const int32_t n_before_user = slot.task->params.n_before_user;
const bool n_before_user_known = n_before_user > 0;
const auto & spans = slot.task->params.message_spans;
const auto last_user_pos = spans.last_user_message_pos();
// add prompt tokens for processing in the current batch
while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.size() < n_batch) {
@@ -3466,10 +3472,8 @@ private:
slot.n_prompt_tokens_processed++;
// stop the prompt batch exactly before the latest user input, so a checkpoint
// can be created after the previous messages
if (n_before_user_known &&
slot.prompt.n_tokens() == n_before_user) {
// stop the prompt batch exactly before a user message
if (spans.is_user_start(slot.prompt.n_tokens())) {
break;
}
@@ -3498,8 +3502,13 @@ private:
// the number of tokens added to the batch for the current slot
const auto n_tokens_cur = batch.size() - n_tokens_prev;
const auto n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur;
const bool near_prompt_end = slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch;
const bool is_user_start = spans.is_user_start(n_tokens_start);
const bool is_last_user_message = n_tokens_start == last_user_pos;
// entire prompt has been processed
if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
slot.state = SLOT_STATE_DONE_PROMPT;
@@ -3514,8 +3523,9 @@ private:
slot.init_sampler();
} else {
// skip ordinary mid-prompt checkpoints
if (!n_before_user_known && !near_prompt_end) {
// skip ordinary mid-prompt checkpoints, unless the batch starts a user
// message or we are near the end of the prompt
if (!is_user_start && !near_prompt_end) {
do_checkpoint = false;
}
}
@@ -3523,29 +3533,6 @@ private:
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id);
// checkpoints are created before the current batch is decoded, so
// their token position is the batch start rather than the prompt end
const int32_t n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur;
{
const bool is_on_user =
n_before_user_known &&
n_tokens_start == n_before_user;
const bool is_after_user =
n_before_user_known &&
n_tokens_start > n_before_user;
const bool is_allowed =
!n_before_user_known ||
is_on_user ||
(is_after_user && near_prompt_end);
if (do_checkpoint && !is_allowed) {
do_checkpoint = false;
}
}
// nothing to checkpoint yet
// TODO: is this check needed?
if (do_checkpoint && pos_min < 0) {
@@ -3555,8 +3542,8 @@ private:
// do not checkpoint after mtmd chunks
do_checkpoint = do_checkpoint && !has_mtmd;
// no need to create checkpoints that are too close together
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || n_tokens_start > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step);
// no need to create checkpoints that are too close together, unless it's the last user message
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || is_last_user_message || n_tokens_start > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step);
SLT_DBG(slot, "main/do_checkpoint = %s, pos_min = %d, pos_max = %d\n", do_checkpoint ? "yes" : "no", pos_min, pos_max);
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
@@ -4055,54 +4042,6 @@ void server_context::set_state_callback(server_state_callback_t callback) {
});
}
// compute the number of tokens before the last user message in the prompt
static int32_t prompt_get_n_before_user(
const json & message_spans,
const std::string & prompt,
const std::vector<raw_buffer> & files,
const llama_vocab * vocab,
mtmd_context * mctx) {
int32_t result = -1;
int32_t byte_pos = -1;
for (const auto & span : message_spans) {
const std::string role = json_value(span, "role", std::string());
if (role == "user") {
byte_pos = json_value(span, "pos", -1);
}
}
if (byte_pos >= 0) {
GGML_ASSERT((size_t) byte_pos <= prompt.size());
const std::string prefix = prompt.substr(0, (size_t) byte_pos);
const std::string marker = get_media_marker();
size_t n_prefix_media = 0;
for (size_t pos = 0; (pos = prefix.find(marker, pos)) != std::string::npos; pos += marker.size()) {
n_prefix_media++;
}
GGML_ASSERT(n_prefix_media <= files.size());
if (mctx != nullptr && n_prefix_media > 0) {
// TODO: this makes a copy - avoid it
std::vector<raw_buffer> prefix_files(files.begin(), files.begin() + n_prefix_media);
result = (int32_t) process_mtmd_prompt(mctx, prefix, prefix_files).size();
} else {
result = (int32_t) tokenize_input_prompts(vocab, nullptr, prefix, true, true)[0].size();
}
SRV_TRC("message_spans: last user message: byte_pos=%d, media=%zu, n_before_user=%d\n",
byte_pos, n_prefix_media, result);
}
return result;
}
//
// server_routes
//
@@ -4150,6 +4089,10 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
// tasks.reserve(inputs.size()); // TODO: this is inaccurate due to child tasks
// message delimiters for checkpointing
auto delimiters = common_chat_msg_delimiters_parse(json_value(data, "message_delimiters", json::array()));
delimiters.tokenize(ctx_server.vocab);
for (size_t i = 0; i < inputs.size(); i++) {
server_task task = server_task(type);
@@ -4163,16 +4106,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
meta->logit_bias_eog,
data);
const auto message_spans = json_value(data, "message_spans", json::array());
if (prompt.is_string() && message_spans.is_array()) {
task.params.n_before_user =
prompt_get_n_before_user(
message_spans,
prompt.get<std::string>(),
files,
ctx_server.vocab,
ctx_server.mctx);
}
task.params.message_spans = task.tokens.find_message_spans(delimiters);
task.id_slot = json_value(data, "id_slot", -1);
+4 -2
View File
@@ -224,7 +224,7 @@ void server_model_meta::update_caps() {
});
params.offline = true;
// params.skip_download = true; // TODO: ideally, we should validate the model here, but it takes too much time
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER);
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, {});
if (params.mmproj.path.empty()) {
multimodal = { false, false };
} else {
@@ -1393,7 +1393,9 @@ struct server_download_state : public common_download_callback {
bool run(common_params & params) {
try {
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, this);
common_params_handle_models_params p;
p.callback = this;
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, p);
is_ok = true;
} catch (const std::exception & e) {
auto model_name = params.model.get_name();
+6 -3
View File
@@ -591,10 +591,11 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp() {
for (const common_chat_tool_call & tool_call : oaicompat_msg.tool_calls) {
output.push_back(json {
{"id", "fc_" + tool_call.id},
{"type", "function_call"},
{"status", "completed"},
{"arguments", tool_call.arguments},
{"call_id", "fc_" + tool_call.id},
{"call_id", "call_" + tool_call.id},
{"name", tool_call.name},
});
}
@@ -690,10 +691,11 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp_stream() {
for (const common_chat_tool_call & tool_call : oaicompat_msg.tool_calls) {
const json output_item = {
{"id", "fc_" + tool_call.id},
{"type", "function_call"},
{"status", "completed"},
{"arguments", tool_call.arguments},
{"call_id", "fc_" + tool_call.id},
{"call_id", "call_" + tool_call.id},
{"name", tool_call.name}
};
server_sent_events.push_back(json {
@@ -1277,8 +1279,9 @@ json server_task_result_cmpl_partial::to_json_oaicompat_resp() {
{"data", json {
{"type", "response.output_item.added"},
{"item", json {
{"id", "fc_" + diff.tool_call_delta.id},
{"arguments", ""},
{"call_id", "fc_" + diff.tool_call_delta.id},
{"call_id", "call_" + diff.tool_call_delta.id},
{"name", diff.tool_call_delta.name},
{"type", "function_call"},
{"status", "in_progress"},
+3 -3
View File
@@ -62,9 +62,6 @@ struct task_params {
int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled)
// number of prompt tokens before the latest user message
int32_t n_before_user = -1;
int64_t t_max_prompt_ms = -1; // TODO: implement
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
@@ -92,6 +89,9 @@ struct task_params {
// per-request parameters for chat parsing
common_chat_parser_params chat_parser_params;
// message spans for checkpointing
common_chat_msg_spans message_spans;
// Embeddings
int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
+12 -1
View File
@@ -89,6 +89,17 @@ int llama_server(int argc, char ** argv) {
llama_backend_init();
llama_numa_init(params.numa);
// note: router mode also accepts -hf remote-preset, so we need to check that first
if (!params.model.hf_repo.empty()) {
try {
common_params_handle_models_params handle_params;
handle_params.preset_only = true;
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, handle_params);
} catch (const std::exception & e) {
// ignored for now
}
}
// router server never loads a model and must not touch the GPU
const bool is_router_server = params.model.path.empty()
&& params.model.hf_repo.empty();
@@ -263,7 +274,7 @@ int llama_server(int argc, char ** argv) {
return child.run_download(params);
} else if (!is_router_server) {
// single-model mode (NOT spawned by router)
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER);
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, {});
}
//
+19
View File
@@ -256,6 +256,25 @@ def test_router_reload_models():
os.remove(preset_path)
def test_router_remote_preset():
global server
server.model_hf_repo = "ggml-org/test-preset-ci"
server.model_hf_file = None
server.offline = False
server.start()
# Should see preset models in GET /models
res = server.make_request("GET", "/models")
assert res.status_code == 200
ids = {item["id"] for item in res.body.get("data", [])}
assert "tinygemma3-preset" in ids
assert "stories260K-test" in ids
# Should be able to load a preset model
model_id = "tinygemma3-preset"
_load_model_and_wait(model_id)
MODEL_DOWNLOAD_ID = "ggml-org/test-model-router-download:F16"
MODEL_DOWNLOAD_TIMEOUT = 30
+1 -2
View File
@@ -28,10 +28,9 @@ vite.config.ts.timestamp-*
# PWA Artifacts
apple-splash-*.png
apple-touch-icon-*.png
favicon.ico
favicon-dark.ico
maskable-icon-*.png
pwa-*.png
static/favicon*
# Storybook
*storybook.log
+7 -7
View File
@@ -35,7 +35,7 @@
"bits-ui": "2.18.1",
"clsx": "2.1.1",
"dexie": "4.4.3",
"dompurify": "3.4.5",
"dompurify": "3.4.11",
"eslint": "9.39.4",
"eslint-config-prettier": "10.1.8",
"eslint-plugin-storybook": "10.4.2",
@@ -8653,9 +8653,9 @@
"peer": true
},
"node_modules/dompurify": {
"version": "3.4.5",
"resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.4.5.tgz",
"integrity": "sha512-OrwIBKsdNSVEeubdJ1HBv/wNENRM9ytAVCv7YXt//A3vPdVMNuACRqK9mXCGCBW2ln7BT/A4X0jXHo2Gu89miA==",
"version": "3.4.11",
"resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.4.11.tgz",
"integrity": "sha512-zhlUV12GsaRzMsf9q5M254YhA4+VuF0fG+QFqu6aYpoGlKtz+w8//jBcGVYBgQkR5GHjUomejY84AV+/uPbWdw==",
"dev": true,
"license": "(MPL-2.0 OR Apache-2.0)",
"optionalDependencies": {
@@ -10226,9 +10226,9 @@
}
},
"node_modules/hono": {
"version": "4.12.23",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.23.tgz",
"integrity": "sha512-eIaZ9qDgu7XV0pxOCrg7/WhnQ6Ivm22UcxhXx/A3dcbqbbYgBEkc6e/J/s7j2tS96zoB0S9VBdLwQNCWwUo4LA==",
"version": "4.12.26",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.26.tgz",
"integrity": "sha512-uyZtpnYxM9CmQ7QsQknM4zN8EftNqhON1qYeIKM0Se67CCEe2c44xyGURwB0axX2fBDu1dqHrHAc1hmNT8ITkw==",
"dev": true,
"license": "MIT",
"engines": {
+1 -1
View File
@@ -54,7 +54,7 @@
"bits-ui": "2.18.1",
"clsx": "2.1.1",
"dexie": "4.4.3",
"dompurify": "3.4.5",
"dompurify": "3.4.11",
"eslint": "9.39.4",
"eslint-config-prettier": "10.1.8",
"eslint-plugin-storybook": "10.4.2",
+8 -1
View File
@@ -1,4 +1,10 @@
import { defineConfig } from '@vite-pwa/assets-generator/config';
import { FAVICON_COLORS, PWA_ASSET_GENERATOR } from './src/lib/constants/pwa';
import { writeThemeFavicons } from './scripts/favicon-colorize';
writeThemeFavicons(FAVICON_COLORS.LIGHT, FAVICON_COLORS.DARK, {
padding: PWA_ASSET_GENERATOR.FAVICON_PADDING
});
export default defineConfig({
headLinkOptions: {
@@ -7,7 +13,8 @@ export default defineConfig({
preset: {
transparent: {
sizes: [],
favicons: [[48, 'favicon-dark.ico']]
favicons: [[48, 'favicon-dark.ico']],
padding: PWA_ASSET_GENERATOR.FAVICON_PADDING
},
maskable: {
sizes: []
+19 -2
View File
@@ -5,15 +5,32 @@ import {
} from '@vite-pwa/assets-generator/config';
import { readFileSync } from 'node:fs';
import { resolve } from 'node:path';
import { THEME_COLORS, PWA_GENERATOR_DEVICES, PWA_ASSET_GENERATOR } from './src/lib/constants/pwa';
import {
THEME_COLORS,
PWA_GENERATOR_DEVICES,
PWA_ASSET_GENERATOR,
FAVICON_COLORS
} from './src/lib/constants/pwa';
import { SplashOrientation } from './src/lib/enums/splash.enums';
import { writeThemeFavicons } from './scripts/favicon-colorize';
writeThemeFavicons(FAVICON_COLORS.LIGHT, FAVICON_COLORS.DARK, {
padding: PWA_ASSET_GENERATOR.FAVICON_PADDING
});
export default defineConfig({
headLinkOptions: {
preset: PWA_ASSET_GENERATOR.LINK_PRESET
},
preset: combinePresetAndAppleSplashScreens(
minimal2023Preset,
{
...minimal2023Preset,
// tiny margin so favicon.ico / pwa-*.png breathe inside the canvas
transparent: {
...minimal2023Preset.transparent,
padding: PWA_ASSET_GENERATOR.FAVICON_PADDING
}
},
{
padding: PWA_ASSET_GENERATOR.SPLASH_PADDING,
resizeOptions: {
+107
View File
@@ -0,0 +1,107 @@
import { mkdirSync, readFileSync, writeFileSync } from 'node:fs';
import { dirname, resolve } from 'node:path';
import { fileURLToPath } from 'node:url';
const HERE = dirname(fileURLToPath(import.meta.url));
const PROJECT_ROOT = resolve(HERE, '..');
const DEFAULT_LOGO = resolve(PROJECT_ROOT, 'src/lib/assets/logo.svg');
const DEFAULT_OUT_DIR = resolve(PROJECT_ROOT, 'static');
const DEFAULT_OUT_LIGHT = resolve(DEFAULT_OUT_DIR, 'favicon.svg');
const DEFAULT_OUT_DARK = resolve(DEFAULT_OUT_DIR, 'favicon-dark.svg');
const CURRENT_COLOR = 'currentColor';
export interface ColorizedFavicon {
light: string;
dark: string;
}
export interface WriteThemeFaviconsOptions {
sourcePath?: string;
lightOutPath?: string;
darkOutPath?: string;
/**
* Fraction of the icon (0..1) to leave as an even margin on each side.
* Applied by wrapping the inner content in a `<g transform="...">` so the
* source `src/lib/assets/logo.svg` is not modified. Pass 0 to disable.
*/
padding?: number;
}
/**
* Replace every `currentColor` occurrence in the SVG with the given color.
* Pure: no filesystem access, so it is straightforward to unit-test.
*/
export function colorizeFaviconSvg(
svg: string,
lightColor: string,
darkColor: string
): ColorizedFavicon {
return {
light: svg.replaceAll(CURRENT_COLOR, lightColor),
dark: svg.replaceAll(CURRENT_COLOR, darkColor)
};
}
/**
* Shrink the inner SVG content uniformly and re-center it so `padding` (a
* 0..1 fraction) is reserved as equal margin on each side. Returns the input
* unchanged for non-positive padding, missing/invalid `viewBox`, or unexpected
* markup so the caller always gets a renderable SVG.
*/
export function padFaviconSvg(svg: string, padding: number): string {
if (!(padding > 0) || padding >= 1) return svg;
const viewBoxMatch = svg.match(/viewBox\s*=\s*["']([^"']+)["']/i);
if (!viewBoxMatch) return svg;
const parts = viewBoxMatch[1]
.trim()
.split(/[\s,]+/)
.map(Number);
if (parts.length !== 4 || parts.some((n) => !Number.isFinite(n))) return svg;
const [, , width, height] = parts;
if (width <= 0 || height <= 0) return svg;
const scale = 1 - padding;
const translateX = (padding * width) / 2;
const translateY = (padding * height) / 2;
const openTagStart = svg.search(/<svg\b/i);
if (openTagStart === -1) return svg;
const openTagEnd = svg.indexOf('>', openTagStart);
if (openTagEnd === -1) return svg;
const closeStart = svg.lastIndexOf('</svg');
if (closeStart === -1 || closeStart <= openTagEnd) return svg;
const openTag = svg.slice(0, openTagEnd + 1);
const inner = svg.slice(openTagEnd + 1, closeStart);
const closeTag = svg.slice(closeStart);
const group = `<g transform="translate(${translateX} ${translateY}) scale(${scale})">`;
return `${openTag}${group}${inner}</g>${closeTag}`;
}
/**
* Read `src/lib/assets/logo.svg`, colorize it for both themes, and write
* the results to the static directory so the PWA asset generator can consume
* them. Paths can be overridden for tests.
*/
export function writeThemeFavicons(
lightColor: string,
darkColor: string,
{
sourcePath = DEFAULT_LOGO,
lightOutPath = DEFAULT_OUT_LIGHT,
darkOutPath = DEFAULT_OUT_DARK,
padding = 0
}: WriteThemeFaviconsOptions = {}
): void {
const source = readFileSync(sourcePath, 'utf-8');
const { light, dark } = colorizeFaviconSvg(source, lightColor, darkColor);
mkdirSync(dirname(lightOutPath), { recursive: true });
writeFileSync(lightOutPath, padFaviconSvg(light, padding));
writeFileSync(darkOutPath, padFaviconSvg(dark, padding));
}
+6 -1
View File
@@ -48,6 +48,7 @@
--chat-form-area-height: 8rem;
--chat-form-area-offset: 2rem;
--chat-form-padding-top: 6rem;
--max-message-height: max(24rem, min(80dvh, calc(100dvh - var(--chat-form-area-height) - 12rem)));
}
@@ -55,6 +56,7 @@
:root {
--chat-form-area-height: 24rem;
--chat-form-area-offset: 12rem;
--chat-form-padding-top: 6rem;
}
}
@@ -141,7 +143,6 @@
@apply bg-background text-foreground;
scrollbar-width: thin;
scrollbar-gutter: stable;
overflow: hidden; /* Added due to Mermaid rendering somehow causing the double scrollbar */
}
/* Global scrollbar styling - visible only on hover */
@@ -193,3 +194,7 @@
scrollbar-width: none;
}
}
.mermaidTooltip {
display: none !important;
}
@@ -10,9 +10,9 @@ import { isElementInViewport } from '$lib/utils/viewport';
*/
export function fadeInView(
node: HTMLElement,
options: { duration?: number; y?: number; skipIfVisible?: boolean } = {}
options: { duration?: number; y?: number; delay?: number; skipIfVisible?: boolean } = {}
) {
const { duration = 300, y = 0, skipIfVisible = false } = options;
const { duration = 300, y = 0, delay = 0, skipIfVisible = false } = options;
if (skipIfVisible && isElementInViewport(node)) {
return;
@@ -27,10 +27,12 @@ export function fadeInView(
(entries) => {
for (const entry of entries) {
if (entry.isIntersecting) {
requestAnimationFrame(() => {
node.style.opacity = '1';
node.style.transform = 'translateY(0)';
});
setTimeout(() => {
requestAnimationFrame(() => {
node.style.opacity = '1';
node.style.transform = 'translateY(0)';
});
}, delay);
observer.disconnect();
}
}
+7
View File
@@ -0,0 +1,7 @@
<svg width="512" height="512" viewBox="0 0 512 512" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M244.95 8C215.233 8 187.774 23.8591 172.923 49.5999L95.6009 183.625C60.2162 244.959 104.481 321.6 175.29 321.6H208L316.977 132.708C348.959 77.2719 308.95 8 244.95 8ZM208 321.6H351.947C415.982 321.6 456.013 390.91 424.013 446.377C409.155 472.132 381.681 488 351.947 488H271.29C200.481 488 156.216 411.359 191.601 350.026L208 321.6Z" fill="currentColor"/>
<path d="M208 321.6H16L106.462 164.8L208 321.6Z" fill="currentColor"/>
<path d="M388.923 8L208 321.6L253.6 8H388.923Z" fill="currentColor"/>
<path d="M304 488H112L202.462 331.2L304 488Z" fill="currentColor"/>
<path d="M496 321.6H208L419.399 454.4L496 321.6Z" fill="currentColor"/>
</svg>

After

Width:  |  Height:  |  Size: 771 B

@@ -8,12 +8,13 @@
ariaLabel?: string;
class?: string;
disabled?: boolean;
href?: string;
icon: Component;
iconSize?: string;
onclick: (e?: MouseEvent) => void;
onclick?: (e?: MouseEvent) => void;
size?: ButtonSize;
stopPropagationOnClick?: boolean;
tooltip: string;
tooltip?: string;
variant?: ButtonVariant;
tooltipSide?: TooltipSide;
}
@@ -22,6 +23,7 @@
icon,
tooltip,
variant = 'ghost',
href = '',
size = 'sm',
class: className = '',
disabled = false,
@@ -31,34 +33,49 @@
onclick,
ariaLabel
}: Props = $props();
let innerWidth = $state(0);
const showTooltip = $derived(!!tooltip && innerWidth > 768);
</script>
<Tooltip.Root>
<Tooltip.Trigger>
<!-- prevent another nested button element -->
{#snippet child({ props })}
<Button
{...props}
{variant}
{size}
{disabled}
onclick={(e: MouseEvent) => {
if (stopPropagationOnClick) e.stopPropagation();
{#snippet button(props = {})}
<Button
{...props}
{href}
{variant}
{size}
{disabled}
onclick={(e: MouseEvent) => {
if (stopPropagationOnClick) e.stopPropagation();
onclick?.(e);
}}
class="h-6 w-6 p-0 {className} flex hover:bg-transparent data-[state=open]:bg-transparent!"
aria-label={ariaLabel || tooltip}
>
{#if icon}
{@const IconComponent = icon}
<IconComponent class={iconSize} />
{/if}
</Button>
{/snippet}
</Tooltip.Trigger>
onclick?.(e);
}}
class="h-6 w-6 p-0 {className} flex hover:bg-transparent data-[state=open]:bg-transparent!"
aria-label={ariaLabel || tooltip}
>
{#if icon}
{@const IconComponent = icon}
<Tooltip.Content side={tooltipSide}>
<p>{tooltip}</p>
</Tooltip.Content>
</Tooltip.Root>
<IconComponent class={iconSize} />
{/if}
</Button>
{/snippet}
{#if showTooltip}
<Tooltip.Root>
<Tooltip.Trigger>
<!-- prevent another nested button element -->
{#snippet child({ props })}
{@render button(props)}
{/snippet}
</Tooltip.Trigger>
<Tooltip.Content side={tooltipSide}>
<p>{tooltip}</p>
</Tooltip.Content>
</Tooltip.Root>
{:else}
{@render button({ href })}
{/if}
<svelte:window bind:innerWidth />
@@ -494,7 +494,7 @@
/>
<div
class="{INPUT_CLASSES} overflow-hidden rounded-3xl backdrop-blur-md {disabled
class="{INPUT_CLASSES} overflow-hidden rounded-4xl md:rounded-3xl backdrop-blur-md {disabled
? 'cursor-not-allowed opacity-60'
: ''}"
data-slot="input-area"
@@ -510,7 +510,7 @@
/>
<div
class="flex-column relative min-h-[48px] items-center rounded-3xl py-2 pb-2.25 shadow-sm transition-all focus-within:shadow-md md:!py-3"
class="flex-column relative min-h-12 items-center rounded-4xl md:rounded-3xl py-2 pb-2.25 shadow-sm transition-all focus-within:shadow-md md:py-3!"
onpaste={handlePaste}
>
<ChatFormTextarea
@@ -15,7 +15,7 @@
<Tooltip.Root>
<Tooltip.Trigger class="w-full">
<Button
class="file-upload-button h-8 w-8 rounded-full p-0"
class="file-upload-button md:h-8 md:w-8 h-9 w-9 rounded-full p-0"
{disabled}
{onclick}
variant="secondary"
@@ -15,6 +15,7 @@
import { McpLogo } from '$lib/components/app';
import { PencilRuler, ChevronDown, ChevronRight } from '@lucide/svelte';
import { HealthCheckStatus } from '$lib/enums';
import { AttachmentAction } from '$lib/enums/attachment.enums';
interface Props {
class?: string;
@@ -270,14 +271,22 @@
</Collapsible.Root>
{/if}
<button type="button" class={sheetItemClass} onclick={onSystemPromptClick}>
<button
type="button"
class={sheetItemClass}
onclick={() => attachmentMenu.callbacks[AttachmentAction.SYSTEM_PROMPT_CLICK]()}
>
<MessageSquare class="h-4 w-4 shrink-0" />
<span>System Message</span>
</button>
{#if hasMcpPromptsSupport}
<button type="button" class={sheetItemClass} onclick={onMcpPromptClick}>
<button
type="button"
class={sheetItemClass}
onclick={() => attachmentMenu.callbacks[AttachmentAction.MCP_PROMPT_CLICK]()}
>
<Zap class="h-4 w-4 shrink-0" />
<span>MCP Prompt</span>
@@ -285,7 +294,11 @@
{/if}
{#if hasMcpResourcesSupport}
<button type="button" class={sheetItemClass} onclick={onMcpResourcesClick}>
<button
type="button"
class={sheetItemClass}
onclick={() => attachmentMenu.callbacks[AttachmentAction.MCP_RESOURCES_CLICK]()}
>
<FolderOpen class="h-4 w-4 shrink-0" />
<span>MCP Resources</span>
@@ -42,6 +42,7 @@
{hasMcpPromptsSupport}
{hasMcpResourcesSupport}
{onFileUpload}
{onSystemPromptClick}
{onMcpPromptClick}
{onMcpResourcesClick}
>
@@ -20,7 +20,7 @@
type="submit"
disabled={isDisabled}
class={[
'h-8 w-8 rounded-full p-0',
'md:h-8 md:w-8 h-9 w-9 rounded-full p-0',
showErrorState &&
'bg-red-400/10 text-red-400 hover:bg-red-400/20 hover:text-red-400 disabled:opacity-100'
]}
@@ -1,4 +1,5 @@
<script lang="ts">
import { isMobile } from '$lib/stores/viewport.svelte';
import { autoResizeTextarea } from '$lib/utils';
import { onMount } from 'svelte';
@@ -37,7 +38,9 @@
}
export function focus() {
textareaElement?.focus();
if (isMobile.current) return;
textareaElement?.focus({ preventScroll: true });
}
export function resetHeight() {
@@ -231,7 +231,7 @@
editedContent = message.content;
}
textareaElement?.focus();
textareaElement?.focus({ preventScroll: true });
editedExtras = message.extra ? [...message.extra] : [];
editedUploadedFiles = [];
@@ -324,7 +324,7 @@
}
</script>
<div use:fadeInView>
<div use:fadeInView class="chat-message">
{#if message.role === MessageRole.SYSTEM}
<ChatMessageSystem
bind:textareaElement
@@ -180,6 +180,9 @@
let displayedModel = $derived(message.model ?? null);
// model being switched to while it loads, so the selector bar tracks it
let pendingModel = $state<string | null>(null);
let isCurrentlyLoading = $derived(isLoading());
let isStreaming = $derived(isChatStreaming());
let hasNoContent = $derived(!message?.content?.trim());
@@ -207,6 +210,42 @@
isLastAssistantMessage
);
let assistantEl: HTMLDivElement | undefined = $state();
let lastUserMessageHeight = $state(0);
let assistantMarginTop = $state(0);
$effect(() => {
if (!assistantEl) return;
assistantMarginTop = Math.round(parseFloat(getComputedStyle(assistantEl).marginTop));
const chatMessageEl = assistantEl.closest('.chat-message');
const previousChatMessage = chatMessageEl?.previousElementSibling;
const userMessageEl = previousChatMessage?.querySelector(
'.chat-message-user'
) as HTMLElement | null;
if (!userMessageEl) {
lastUserMessageHeight = 0;
return;
}
const updateHeight = () => {
const rect = userMessageEl.getBoundingClientRect();
const marginTop = Math.round(parseFloat(getComputedStyle(userMessageEl).marginTop));
lastUserMessageHeight = Math.round(rect.height + marginTop);
};
updateHeight();
const resizeObserver = new ResizeObserver(updateHeight);
resizeObserver.observe(userMessageEl);
return () => {
resizeObserver.disconnect();
};
});
function handleCopyModel() {
void copyToClipboard(displayedModel ?? '');
}
@@ -219,12 +258,17 @@
</script>
<div
class="text-md group w-full leading-7.5 {className}"
bind:this={assistantEl}
class="chat-message-assistant text-md group w-full leading-7.5 {className}"
style:--last-user-message-height={lastUserMessageHeight > 0
? `${lastUserMessageHeight}px`
: undefined}
style:--assistant-margin-top={assistantMarginTop > 0 ? `${assistantMarginTop}px` : undefined}
role="group"
aria-label="Assistant message with actions"
>
{#if showProcessingInfoTop}
<div class="mt-6 w-full max-w-[48rem]" in:fade>
<div class="mt-6 w-full max-w-3xl" in:fade>
<div class="processing-container">
<span class="processing-text">
{modelLoadingText ??
@@ -257,7 +301,7 @@
{/if}
{#if showProcessingInfoBottom}
<div class="mt-4 w-full max-w-[48rem]" in:fade>
<div class="mt-4 w-full max-w-3xl" in:fade>
<div class="processing-container">
<span class="processing-text">
{modelLoadingText ??
@@ -277,13 +321,19 @@
>
{#if isRouter}
<ModelsSelectorDropdown
currentModel={displayedModel}
currentModel={pendingModel ?? displayedModel}
disabled={isLoading()}
onModelChange={async (modelId: string, modelName: string) => {
const status = modelsStore.getModelStatus(modelId);
if (status !== ServerModelStatus.LOADED) {
await modelsStore.loadModel(modelId);
pendingModel = modelId;
try {
await modelsStore.loadModel(modelId);
} finally {
pendingModel = null;
}
}
onRegenerate(modelName);
@@ -351,6 +401,23 @@
</div>
<style>
:global(.chat-message):last-child .chat-message-assistant {
--assistant-min-height-offset: calc(
var(--last-user-message-height, 19rem) + var(--chat-form-height, 6rem) +
var(--chat-form-bottom-position, 0.5rem) + var(--chat-form-padding-top, 6rem) +
var(--assistant-margin-top, 3rem)
);
min-height: calc(100dvh - var(--assistant-min-height-offset));
@media (width > 768px) {
--assistant-min-height-offset: calc(
var(--last-user-message-height, 18rem) + var(--chat-form-height, 6rem) +
var(--chat-form-bottom-position, 1rem) + var(--chat-form-padding-top, 6rem) +
var(--assistant-margin-top, 3rem)
);
}
}
.processing-container {
display: flex;
flex-direction: column;
@@ -48,7 +48,7 @@
<div
aria-label="User message with actions"
class="group flex flex-col items-end gap-3 md:gap-2 {className}"
class="chat-message-user group flex flex-col items-end gap-3 md:gap-2 {className}"
role="group"
>
{#if editCtx.isEditing}
@@ -19,7 +19,7 @@
renderMarkdown = false,
textColorClass = 'text-foreground',
cardBgClass = 'dark:bg-primary/15',
maxHeightStyle = 'max-height: var(--max-message-height);'
maxHeightStyle = ''
}: Props = $props();
let isMultiline = $state(false);
@@ -59,7 +59,7 @@
{#if content.trim()}
<Card
class="max-w-[80%] overflow-y-auto rounded-[1.125rem] border-none bg-primary/5 px-3.75 py-1.5 {textColorClass} backdrop-blur-md data-[multiline]:py-2.5 {cardBgClass}"
class="chat-message-user-bubble max-w-[80%] overflow-y-auto rounded-[1.125rem] border-none bg-primary/5 px-3.75 py-1.5 {textColorClass} backdrop-blur-md data-multiline:py-2.5 {cardBgClass}"
data-multiline={isMultiline ? '' : undefined}
style="{maxHeightStyle} overflow-wrap: anywhere; word-break: break-word;"
>
@@ -37,6 +37,7 @@
let allConversationMessages = $state<DatabaseMessage[]>([]);
let isVisible = $state(false);
let previousConversationId = $state<string | null>(null);
let previousRouteId = $state<string | null>(null);
const currentConfig = config();
@@ -157,8 +158,9 @@
});
});
beforeNavigate(() => {
beforeNavigate((navigation) => {
isVisible = false;
previousRouteId = navigation.from?.route.id ?? null;
});
afterNavigate(() => {
@@ -249,12 +251,13 @@
</script>
<div
class="transition-opacity delay-300 duration-500 ease-out
{isVisible ? 'opacity-100' : 'opacity-0'}"
class="transition-opacity duration-500 ease-out
{isVisible ? 'opacity-100' : 'opacity-0'}
{previousRouteId === '/(chat)/chat/[id]' ? '' : 'delay-300'}"
>
{#each displayMessages as { message, toolMessages, isLastAssistantMessage, siblingInfo } (message.id)}
<ChatMessage
class="mx-auto mt-12 w-full max-w-[48rem]"
class="mx-auto mt-12 w-full max-w-3xl"
{message}
{toolMessages}
{isLastAssistantMessage}
@@ -1,31 +1,28 @@
<script lang="ts">
import { Trash2 } from '@lucide/svelte';
import { afterNavigate } from '$app/navigation';
import { page } from '$app/state';
import {
ChatScreenForm,
ChatMessages,
ChatScreenDragOverlay,
ChatScreenProcessingInfo,
ChatScreenActionScrollDown,
DialogEmptyFileAlert,
DialogFileUploadError,
DialogChatError,
ServerLoadingSplash,
DialogConfirmation,
ChatScreenServerError
} from '$lib/components/app';
import { setProcessingInfoContext } from '$lib/contexts';
import { ErrorDialogType } from '$lib/enums';
import { createAutoScrollController } from '$lib/hooks/use-auto-scroll.svelte';
import { useChatScreenActiveModel } from '$lib/hooks/use-chat-screen-active-model.svelte';
import { useChatScreenDragAndDrop } from '$lib/hooks/use-chat-screen-drag-and-drop.svelte';
import { useChatScreenFileUpload } from '$lib/hooks/use-chat-screen-file-upload.svelte';
import { useChatScreenScroll } from '$lib/hooks/use-chat-screen-scroll.svelte';
import { useKeyboardShortcuts } from '$lib/hooks/use-keyboard-shortcuts.svelte';
import { device } from '$lib/stores/device.svelte';
import { isMobile } from '$lib/stores/viewport.svelte';
import {
chatStore,
errorDialog,
isLoading,
isChatStreaming,
isEditing,
getAddFilesHandler,
activeProcessingState
} from '$lib/stores/chat.svelte';
import {
@@ -34,138 +31,81 @@
activeConversation
} from '$lib/stores/conversations.svelte';
import { config } from '$lib/stores/settings.svelte';
import { serverLoading, serverError, isRouterMode } from '$lib/stores/server.svelte';
import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte';
import { isFileTypeSupported, filterFilesByModalities } from '$lib/utils';
import { parseFilesToMessageExtras, processFilesToChatUploaded } from '$lib/utils/browser-only';
import { onMount } from 'svelte';
import { serverLoading, serverError } from '$lib/stores/server.svelte';
import { parseFilesToMessageExtras } from '$lib/utils/browser-only';
import { onDestroy, onMount } from 'svelte';
import ChatScreenGreeting from './ChatScreenGreeting.svelte';
import ChatScreenActionScrollDown from './ChatScreenActionScrollDown.svelte';
import ChatScreenDialogsAndAlerts from './ChatScreenDialogsAndAlerts.svelte';
import { ROUTES } from '$lib/constants';
let { showCenteredEmpty = false } = $props();
const autoScroll = createAutoScrollController();
let disableAutoScroll = $derived(Boolean(config().disableAutoScroll));
let chatScrollContainer: HTMLDivElement | undefined = $state();
let dragCounter = $state(0);
let isDragOver = $state(false);
let showFileErrorDialog = $state(false);
let uploadedFiles = $state<ChatUploadedFile[]>([]);
let fileErrorData = $state<{
generallyUnsupported: File[];
modalityUnsupported: File[];
modalityReasons: Record<string, string>;
supportedTypes: string[];
}>({
generallyUnsupported: [],
modalityUnsupported: [],
modalityReasons: {},
supportedTypes: []
});
let showDeleteDialog = $state(false);
let showEmptyFileDialog = $state(false);
let processingInfoVisible = $state(false);
let emptyFileNames = $state<string[]>([]);
let initialMessage = $state('');
let isEmpty = $derived(
showCenteredEmpty && !activeConversation() && activeMessages().length === 0 && !isLoading()
);
let activeErrorDialog = $derived(errorDialog());
let isServerLoading = $derived(serverLoading());
let hasPropsError = $derived(!!serverError());
let isCurrentConversationLoading = $derived(isLoading() || isChatStreaming());
let showProcessingInfo = $derived(
isCurrentConversationLoading ||
(config().keepStatsVisible && !!page.params.id) ||
activeProcessingState() !== null
);
let isRouter = $derived(isRouterMode());
let conversationModel = $derived(
chatStore.getConversationModel(activeMessages() as DatabaseMessage[])
);
let activeModelId = $derived.by(() => {
const options = modelOptions();
if (!isRouter) {
return options.length > 0 ? options[0].model : null;
}
const selectedId = selectedModelId();
if (selectedId) {
const model = options.find((m) => m.id === selectedId);
if (model) return model.model;
}
if (conversationModel) {
const model = options.find((m) => m.model === conversationModel);
if (model) return model.model;
}
return null;
});
let modelPropsVersion = $state(0);
setProcessingInfoContext({
get showProcessingInfo() {
return showProcessingInfo;
}
});
$effect(() => {
if (activeModelId) {
const cached = modelsStore.getModelProps(activeModelId);
let disableAutoScroll = $derived(Boolean(config().disableAutoScroll) || isMobile.current);
let isMobileUserScrolledUp = $state(false);
let mobileScrollDownHint = $state(false);
let mobileScrollDownHintLockedUntil = $state(0);
let emptyFileNames = $state<string[]>([]);
let initialMessage = $state('');
let showDeleteDialog = $state(false);
let showEmptyFileDialog = $state(false);
let isEmpty = $derived(
showCenteredEmpty && !activeConversation() && activeMessages().length === 0 && !isLoading()
);
let activeErrorDialog = $derived(errorDialog());
let isServerLoading = $derived(serverLoading());
let hasPropsError = $derived(!!serverError());
let isCurrentConversationLoading = $derived(isLoading() || isChatStreaming());
let showProcessingInfo = $derived(
isCurrentConversationLoading ||
(config().keepStatsVisible && !!page.params.id) ||
activeProcessingState() !== null
);
let chatFormBottomPosition = $derived.by(() => {
if (!isMobile.current) return '1rem';
if (device.isStandalone) return '1.5rem';
if (device.isIOSSafari) return '0.25rem';
return '0.5rem';
});
if (!cached) {
modelsStore.fetchModelProps(activeModelId).then(() => {
modelPropsVersion++;
});
const autoScroll = createAutoScrollController();
const scroll = useChatScreenScroll(autoScroll);
const activeModel = useChatScreenActiveModel();
const fileUpload = useChatScreenFileUpload({
capabilities: () => ({
hasVision: activeModel.hasVisionModality,
hasAudio: activeModel.hasAudioModality,
hasVideo: activeModel.hasVideoModality
}),
activeModelId: () => activeModel.activeModelId
});
const dragAndDrop = useChatScreenDragAndDrop({
onDrop: fileUpload.handleFileUpload
});
const { handleKeydown } = useKeyboardShortcuts({
deleteActiveConversation: () => {
if (activeConversation()) {
showDeleteDialog = true;
}
}
});
let hasAudioModality = $derived.by(() => {
if (activeModelId) {
void modelPropsVersion;
function handleMobileScroll() {
if (!isMobile.current) return;
return modelsStore.modelSupportsAudio(activeModelId);
}
const container = scroll.chatScrollContainer;
if (!container) return;
return false;
});
let hasVideoModality = $derived.by(() => {
if (activeModelId) {
void modelPropsVersion;
return modelsStore.modelSupportsVideo(activeModelId);
}
return false;
});
let hasVisionModality = $derived.by(() => {
if (activeModelId) {
void modelPropsVersion;
return modelsStore.modelSupportsVision(activeModelId);
}
return false;
});
const distanceFromBottom =
container.scrollHeight - container.clientHeight - container.scrollTop;
isMobileUserScrolledUp = distanceFromBottom > 300;
}
async function handleDeleteConfirm() {
const conversation = activeConversation();
@@ -177,27 +117,69 @@
showDeleteDialog = false;
}
function handleProcessingInfoVisibility(visible: boolean) {
processingInfoVisible = visible;
}
async function handleSendMessage(message: string, files?: ChatUploadedFile[]): Promise<boolean> {
const plainFiles = files ? $state.snapshot(files) : undefined;
const result = plainFiles
? await parseFilesToMessageExtras(plainFiles, activeModel.activeModelId ?? undefined)
: undefined;
function handleDragEnter(event: DragEvent) {
event.preventDefault();
dragCounter++;
if (event.dataTransfer?.types.includes('Files')) {
isDragOver = true;
if (result?.emptyFiles && result.emptyFiles.length > 0) {
emptyFileNames = result.emptyFiles;
showEmptyFileDialog = true;
if (files) {
const emptyFileNamesSet = new Set(result.emptyFiles);
fileUpload.uploadedFiles = fileUpload.uploadedFiles.filter(
(file) => !emptyFileNamesSet.has(file.name)
);
}
return false;
}
handleSendLikeScroll();
await chatStore.sendMessage(message, result?.extras);
return true;
}
function handleDragLeave(event: DragEvent) {
event.preventDefault();
function handleSendLikeScroll() {
if (!isMobile.current) {
autoScroll.enable();
}
dragCounter--;
setTimeout(() => {
const container = scroll.chatScrollContainer;
if (!container) return;
if (dragCounter === 0) {
isDragOver = false;
const lastUserBubble = container.querySelector(
'.chat-message:nth-last-child(2) .chat-message-user .chat-message-user-bubble'
) as HTMLElement | null;
if (isMobile.current) {
// Keep the last user message bubble just above the input on mobile
const bubbleHeight = lastUserBubble?.scrollHeight ?? 0;
const baseHeight = container.scrollHeight - innerHeight;
container.scrollTo({
top: bubbleHeight > 0 ? baseHeight - bubbleHeight : baseHeight,
behavior: 'smooth'
});
} else if (lastUserBubble) {
// On desktop, place the last user message near the top of the viewport
const topPadding = 24;
const bubbleRect = lastUserBubble.getBoundingClientRect();
container.scrollTo({
top: Math.max(0, container.scrollTop + bubbleRect.top - topPadding),
behavior: 'smooth'
});
} else {
autoScroll.scrollToBottom();
}
}, 100);
if (isMobile.current) {
autoScroll.setDisabled(disableAutoScroll);
mobileScrollDownHint = true;
mobileScrollDownHintLockedUntil = Date.now() + 500;
}
}
@@ -207,273 +189,138 @@
}
}
function handleDragOver(event: DragEvent) {
event.preventDefault();
}
function handleDrop(event: DragEvent) {
event.preventDefault();
isDragOver = false;
dragCounter = 0;
if (event.dataTransfer?.files) {
const files = Array.from(event.dataTransfer.files);
if (isEditing()) {
const handler = getAddFilesHandler();
if (handler) {
handler(files);
return;
}
}
processFiles(files);
}
}
function handleFileRemove(fileId: string) {
uploadedFiles = uploadedFiles.filter((f) => f.id !== fileId);
}
function handleFileUpload(files: File[]) {
processFiles(files);
}
const { handleKeydown } = useKeyboardShortcuts({
deleteActiveConversation: () => {
if (activeConversation()) {
showDeleteDialog = true;
}
}
});
async function handleSystemPromptAdd(draft: { message: string; files: ChatUploadedFile[] }) {
if (draft.message || draft.files.length > 0) {
chatStore.savePendingDraft(draft.message, draft.files);
}
await chatStore.addSystemPrompt();
}
function handleScroll() {
autoScroll.handleScroll();
}
async function handleSendMessage(message: string, files?: ChatUploadedFile[]): Promise<boolean> {
const plainFiles = files ? $state.snapshot(files) : undefined;
const result = plainFiles
? await parseFilesToMessageExtras(plainFiles, activeModelId ?? undefined)
: undefined;
if (result?.emptyFiles && result.emptyFiles.length > 0) {
emptyFileNames = result.emptyFiles;
showEmptyFileDialog = true;
if (files) {
const emptyFileNamesSet = new Set(result.emptyFiles);
uploadedFiles = uploadedFiles.filter((file) => !emptyFileNamesSet.has(file.name));
}
return false;
}
const extras = result?.extras;
// Enable autoscroll for user-initiated message sending
autoScroll.enable();
await chatStore.sendMessage(message, extras);
autoScroll.scrollToBottom();
return true;
}
async function processFiles(files: File[]) {
const generallySupported: File[] = [];
const generallyUnsupported: File[] = [];
for (const file of files) {
if (isFileTypeSupported(file.name, file.type)) {
generallySupported.push(file);
} else {
generallyUnsupported.push(file);
}
}
// Use model-specific capabilities for file validation
const capabilities = {
hasVision: hasVisionModality,
hasAudio: hasAudioModality,
hasVideo: hasVideoModality
};
const { supportedFiles, unsupportedFiles, modalityReasons } = filterFilesByModalities(
generallySupported,
capabilities
);
const allUnsupportedFiles = [...generallyUnsupported, ...unsupportedFiles];
if (allUnsupportedFiles.length > 0) {
const supportedTypes: string[] = ['text files', 'PDFs'];
if (hasVisionModality) supportedTypes.push('images');
if (hasAudioModality) supportedTypes.push('audio files');
if (hasVideoModality) supportedTypes.push('video files');
fileErrorData = {
generallyUnsupported,
modalityUnsupported: unsupportedFiles,
modalityReasons,
supportedTypes
};
showFileErrorDialog = true;
}
if (supportedFiles.length > 0) {
const processed = await processFilesToChatUploaded(
supportedFiles,
activeModelId ?? undefined
);
uploadedFiles = [...uploadedFiles, ...processed];
}
}
afterNavigate(() => {
if (!disableAutoScroll) {
$effect(() => {
const shouldDisableAutoScroll =
config().disableAutoScroll || (isMobile.current && isCurrentConversationLoading);
autoScroll.setDisabled(shouldDisableAutoScroll);
if (!shouldDisableAutoScroll) {
autoScroll.enable();
}
});
function handleMessagesReady() {
if (disableAutoScroll) return;
if (!autoScroll.userScrolledUp) {
requestAnimationFrame(() => {
autoScroll.scrollToBottom('instant');
});
}
}
onMount(() => {
const pendingDraft = chatStore.consumePendingDraft();
if (pendingDraft) {
initialMessage = pendingDraft.message;
fileUpload.uploadedFiles = pendingDraft.files;
}
autoScroll.startObserving();
if (!disableAutoScroll) {
autoScroll.enable();
}
const pendingDraft = chatStore.consumePendingDraft();
if (pendingDraft) {
initialMessage = pendingDraft.message;
uploadedFiles = pendingDraft.files;
if (isMobile.current && isCurrentConversationLoading) {
mobileScrollDownHint = true;
mobileScrollDownHintLockedUntil = Date.now() + 500;
}
handleMobileScroll();
});
$effect(() => {
autoScroll.setContainer(chatScrollContainer);
});
$effect(() => {
autoScroll.setDisabled(disableAutoScroll);
});
onDestroy(() => autoScroll.destroy());
</script>
{#if isDragOver}
{#if dragAndDrop.isDragOver}
<ChatScreenDragOverlay />
{/if}
<svelte:window onkeydown={handleKeydown} />
<svelte:window
onkeydown={handleKeydown}
onscroll={(e) => {
scroll.handleScroll(e);
handleMobileScroll();
if (e.isTrusted && Date.now() > mobileScrollDownHintLockedUntil) {
mobileScrollDownHint = false;
}
}}
/>
{#if isServerLoading}
<ServerLoadingSplash />
{:else}
<div
bind:this={chatScrollContainer}
aria-label="Chat interface with file drop zone"
class="flex h-full flex-col overflow-y-auto px-4 md:px-6"
ondragenter={handleDragEnter}
ondragleave={handleDragLeave}
ondragover={handleDragOver}
ondrop={handleDrop}
onscroll={handleScroll}
class="chat-screen flex grow flex-col min-h-[calc(100dvh-1rem)] md:min-h-full px-4 md:py-0 pt-12 pb-48 md:pb-4"
style:--chat-form-bottom-position={chatFormBottomPosition}
ondragenter={dragAndDrop.dragHandlers.dragenter}
ondragleave={dragAndDrop.dragHandlers.dragleave}
ondragover={dragAndDrop.dragHandlers.dragover}
ondrop={dragAndDrop.dragHandlers.drop}
role="main"
>
<div class="flex grow flex-col pt-14">
{#if !isEmpty}
<ChatMessages
messages={activeMessages()}
onMessagesReady={handleMessagesReady}
onUserAction={() => {
autoScroll.enable();
if (!autoScroll.userScrolledUp) {
autoScroll.scrollToBottom();
}
}}
/>
{/if}
{#if !isEmpty}
<ChatMessages
messages={activeMessages()}
onUserAction={() => {
handleSendLikeScroll();
}}
/>
{/if}
<div
class={[
'pointer-events-none sticky right-4 left-4 mt-auto transition-all duration-200',
isEmpty ? 'bottom-[calc(50dvh-7rem)]' : 'bottom-4 pt-24 md:pt-32'
]}
>
<ChatScreenGreeting {isEmpty} />
<div
class={[
'pointer-events-none md:sticky fixed mt-auto transition-all duration-200',
device.isStandalone
? 'bottom-6 right-4 left-4'
: device.isIOSSafari
? 'bottom-1 left-2 right-2'
: 'bottom-2 right-2 left-2',
isEmpty ? 'md:bottom-[calc(50dvh-7rem)] 2xl:bottom-[calc(50dvh-4rem)]' : 'md:bottom-4'
]}
style:padding-top={!isEmpty ? 'var(--chat-form-padding-top)' : undefined}
>
<ChatScreenGreeting {isEmpty} />
<ChatScreenActionScrollDown
container={chatScrollContainer}
hasProcessingInfoVisible={processingInfoVisible}
/>
<ChatScreenServerError />
<ChatScreenProcessingInfo onVisibilityChange={handleProcessingInfoVisibility} />
<ChatScreenServerError />
<div class="conversation-chat-form pointer-events-auto rounded-t-3xl">
<ChatScreenForm
disabled={hasPropsError || isEditing()}
{initialMessage}
isLoading={isCurrentConversationLoading}
onFileRemove={handleFileRemove}
onFileUpload={handleFileUpload}
onSend={handleSendMessage}
onStop={() => chatStore.stopGeneration()}
onSystemPromptAdd={handleSystemPromptAdd}
bind:uploadedFiles
<div class="pointer-events-none flex flex-col gap-6 items-center w-full">
{#if (isMobile.current ? mobileScrollDownHint || isMobileUserScrolledUp : autoScroll.userScrolledUp) && page.url.hash.includes(ROUTES.CHAT) && page.params.id}
<ChatScreenActionScrollDown
onclick={() => {
mobileScrollDownHint = false;
scroll.chatScrollContainer?.scrollTo({
top: scroll.chatScrollContainer.scrollHeight,
behavior: 'smooth'
});
}}
/>
</div>
{/if}
{#if showProcessingInfo}
<ChatScreenProcessingInfo />
{/if}
</div>
<ChatScreenForm
class="pointer-events-auto conversation-chat-form"
disabled={hasPropsError || isEditing()}
{initialMessage}
isLoading={isCurrentConversationLoading}
onFileRemove={fileUpload.handleFileRemove}
onFileUpload={fileUpload.handleFileUpload}
onSend={handleSendMessage}
onStop={() => chatStore.stopGeneration()}
onSystemPromptAdd={handleSystemPromptAdd}
bind:uploadedFiles={fileUpload.uploadedFiles}
/>
</div>
</div>
{/if}
<DialogFileUploadError bind:open={showFileErrorDialog} {fileErrorData} />
<DialogConfirmation
bind:open={showDeleteDialog}
title="Delete Conversation"
description="Are you sure you want to delete this conversation? This action cannot be undone and will permanently remove all messages in this conversation."
confirmText="Delete"
cancelText="Cancel"
variant="destructive"
icon={Trash2}
onConfirm={handleDeleteConfirm}
onCancel={() => (showDeleteDialog = false)}
/>
<DialogEmptyFileAlert
bind:open={showEmptyFileDialog}
emptyFiles={emptyFileNames}
onOpenChange={(open) => {
if (!open) {
emptyFileNames = [];
}
}}
/>
<DialogChatError
message={activeErrorDialog?.message ?? ''}
contextInfo={activeErrorDialog?.contextInfo}
onOpenChange={handleErrorDialogOpenChange}
open={Boolean(activeErrorDialog)}
type={activeErrorDialog?.type ?? ErrorDialogType.SERVER}
<ChatScreenDialogsAndAlerts
{showDeleteDialog}
{handleDeleteConfirm}
{showEmptyFileDialog}
{emptyFileNames}
{activeErrorDialog}
{handleErrorDialogOpenChange}
{fileUpload}
/>
@@ -1,58 +1,18 @@
<script lang="ts">
import { ArrowDown } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import ActionIcon from '$lib/components/app/actions/ActionIcon.svelte';
interface Props {
container: HTMLDivElement | undefined;
hasProcessingInfoVisible: boolean;
}
let { container, hasProcessingInfoVisible }: Props = $props();
let show = $state(false);
let buttonBottom = $derived(hasProcessingInfoVisible ? '2rem' : '0');
function checkVisibility() {
if (!container) return;
const { scrollTop, scrollHeight, clientHeight } = container;
const distanceFromBottom = scrollHeight - clientHeight - scrollTop;
show = distanceFromBottom > clientHeight * 0.5;
}
function scrollToBottom() {
if (container) {
container.scrollTo({
top: container.scrollHeight,
behavior: 'smooth'
});
}
}
$effect(() => {
const c = container;
if (c) {
c.addEventListener('scroll', checkVisibility);
checkVisibility();
return () => {
c.removeEventListener('scroll', checkVisibility);
};
}
});
let { onclick }: { onclick: (e?: MouseEvent) => void } = $props();
</script>
<div class="relative z-50 mx-auto mb-4 flex max-w-[48rem] justify-center">
<Button
onclick={scrollToBottom}
variant="secondary"
size="icon"
disabled={!show}
class="pointer-events-auto absolute h-10 w-10 rounded-full bg-background/80 shadow-lg backdrop-blur-sm transition-all duration-200 hover:bg-muted/80"
style="bottom: {buttonBottom}; transform: translateY({show ? '0' : '2rem'}); opacity: {show
? 1
: 0};"
aria-label="Scroll to bottom"
>
<ArrowDown class="h-4 w-4" />
</Button>
<div class="pointer-events-auto flex justify-center relative h-0">
<ActionIcon
icon={ArrowDown}
{onclick}
ariaLabel="Scroll to bottom"
tooltip="Scroll to bottom"
size="lg"
iconSize="h-4 w-4"
class="h-9 w-9 rounded-full bg-accent text-accent-foreground absolute bottom-4 shadow-md"
/>
</div>
@@ -0,0 +1,55 @@
<script lang="ts">
import { Trash2 } from '@lucide/svelte';
import { ErrorDialogType } from '$lib/enums';
import {
DialogChatError,
DialogConfirmation,
DialogEmptyFileAlert,
DialogFileUploadError
} from '$lib/components/app';
let {
showDeleteDialog,
handleDeleteConfirm,
showEmptyFileDialog,
emptyFileNames,
activeErrorDialog,
handleErrorDialogOpenChange,
fileUpload
} = $props();
</script>
<DialogFileUploadError
bind:open={fileUpload.showFileErrorDialog}
fileErrorData={fileUpload.fileErrorData}
/>
<DialogConfirmation
bind:open={showDeleteDialog}
title="Delete Conversation"
description="Are you sure you want to delete this conversation? This action cannot be undone and will permanently remove all messages in this conversation."
confirmText="Delete"
cancelText="Cancel"
variant="destructive"
icon={Trash2}
onConfirm={handleDeleteConfirm}
onCancel={() => (showDeleteDialog = false)}
/>
<DialogEmptyFileAlert
bind:open={showEmptyFileDialog}
emptyFiles={emptyFileNames}
onOpenChange={(open) => {
if (!open) {
emptyFileNames = [];
}
}}
/>
<DialogChatError
message={activeErrorDialog?.message ?? ''}
contextInfo={activeErrorDialog?.contextInfo}
onOpenChange={handleErrorDialogOpenChange}
open={Boolean(activeErrorDialog)}
type={activeErrorDialog?.type ?? ErrorDialogType.SERVER}
/>
@@ -2,6 +2,7 @@
import { afterNavigate } from '$app/navigation';
import { page } from '$app/state';
import { ChatForm } from '$lib/components/app';
import { isMobile } from '$lib/stores/viewport.svelte';
import { onMount } from 'svelte';
import { useDraftMessages } from '$lib/hooks/use-draft-messages.svelte';
@@ -32,7 +33,30 @@
}: Props = $props();
let chatFormRef: ChatForm | undefined = $state(undefined);
let formWrapperEl: HTMLDivElement | undefined = $state();
let chatId = $derived(page.params.id as string | undefined);
$effect(() => {
if (!formWrapperEl) return;
const formEl = formWrapperEl.querySelector('form') as HTMLElement | null;
if (!formEl) return;
const updateHeight = () => {
const height = Math.round(formEl.getBoundingClientRect().height);
document.documentElement.style.setProperty('--chat-form-height', `${height}px`);
};
updateHeight();
const resizeObserver = new ResizeObserver(updateHeight);
resizeObserver.observe(formEl);
return () => {
resizeObserver.disconnect();
document.documentElement.style.removeProperty('--chat-form-height');
};
});
let hasLoadingAttachments = $derived(uploadedFiles.some((f) => f.isLoading));
let message = $derived(initialMessage);
let previousIsLoading = $derived(isLoading);
@@ -83,12 +107,14 @@
}
onMount(() => {
setTimeout(() => chatFormRef?.focus(), 10);
if (!isMobile.current) {
setTimeout(() => chatFormRef?.focus(), 100);
}
});
afterNavigate((navigation) => {
if (navigation?.from != null) {
setTimeout(() => chatFormRef?.focus(), 10);
if (navigation?.from != null && !isMobile.current) {
setTimeout(() => chatFormRef?.focus(), 100);
}
});
@@ -108,12 +134,12 @@
});
</script>
<div class="relative mx-auto max-w-[48rem]">
<div class="chat-screen-form-wrapper" bind:this={formWrapperEl}>
<ChatForm
class="mx-auto max-w-3xl {className}"
bind:this={chatFormRef}
bind:value={message}
bind:uploadedFiles
class={className}
{disabled}
{isLoading}
showMcpPromptButton
@@ -1,5 +1,4 @@
<script lang="ts">
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
import { serverStore } from '$lib/stores/server.svelte';
interface Props {
@@ -11,10 +10,9 @@
<div
class={[
'pointer-events-none mb-4 hidden px-4 text-center',
isEmpty && 'pointer-events-auto block!'
'pointer-events-none mb-4 hidden px-4 text-center text-balance',
isEmpty && 'mb-[calc(50dvh-8rem)] md:mb-6 pointer-events-auto block!'
]}
use:fadeInView={{ duration: 300 }}
>
<h1 class="mb-2 text-2xl font-semibold tracking-tight md:text-3xl">Hello there</h1>
@@ -5,13 +5,8 @@
import { chatStore, isLoading, isChatStreaming } from '$lib/stores/chat.svelte';
import { activeMessages, activeConversation } from '$lib/stores/conversations.svelte';
import { config } from '$lib/stores/settings.svelte';
import { getProcessingInfoContext } from '$lib/contexts';
import { page } from '$app/state';
const processingState = useProcessingState();
const processingInfoCtx = getProcessingInfoContext();
let showProcessingInfo = $derived(processingInfoCtx.showProcessingInfo);
let isCurrentConversationLoading = $derived(isLoading());
let isStreaming = $derived(isChatStreaming());
@@ -70,8 +65,8 @@
<div
class={[
'chat-processing-info-container pointer-events-none relative',
page.params.id && showProcessingInfo && 'visible'
'chat-processing-info-container pointer-events-none relative w-full hidden md:block',
processingVisible && 'visible'
]}
>
<div class="chat-processing-info-content absolute bottom-4 left-1/2 -translate-x-1/2">
@@ -677,13 +677,6 @@ export { default as ChatScreenForm } from './ChatScreen/ChatScreenForm.svelte';
*/
export { default as ChatScreenProcessingInfo } from './ChatScreen/ChatScreenProcessingInfo.svelte';
/**
* Scroll-to-bottom action button. Displays a floating button when the user
* has scrolled up more than half a viewport height from the bottom.
* Takes the chat container element as a prop to manage scroll state internally.
*/
export { default as ChatScreenActionScrollDown } from './ChatScreen/ChatScreenActionScrollDown.svelte';
/**
* Server error alert displayed when the server is unreachable.
* Shows the error message with a retry button.
@@ -3,6 +3,7 @@
import { Search, X } from '@lucide/svelte';
interface Props {
autofocus?: boolean;
value?: string;
placeholder?: string;
onInput?: (value: string) => void;
@@ -15,6 +16,7 @@
}
let {
autofocus,
value = $bindable(''),
placeholder = 'Search...',
onInput,
@@ -39,7 +41,7 @@
if (value) {
value = '';
onInput?.('');
ref?.focus();
ref?.focus({ preventScroll: true });
} else {
onClose?.();
}
@@ -52,6 +54,7 @@
/>
<Input
{autofocus}
{id}
bind:value
bind:ref
@@ -0,0 +1,15 @@
<script>
import logoMark from '$lib/assets/logo.svg?raw';
let { class: className = '', style = '' } = $props();
</script>
<div class={className} {style}>
{@html logoMark}
</div>
<style>
div :global(svg) {
width: var(--size, 1rem);
height: var(--size, 1rem);
}
</style>
@@ -51,3 +51,11 @@ export { default as KeyboardShortcutInfo } from './KeyboardShortcutInfo.svelte';
* Preview button is shown only for HTML code blocks.
*/
export { default as CodeBlockActions } from './CodeBlockActions.svelte';
/**
* **Logo** - Application brand mark
*
* Inline SVG of the application logo. Accepts styling via the standard
* `class` and `style` props and inherits color via `currentColor`.
*/
export { default as Logo } from './Logo.svelte';
@@ -0,0 +1,11 @@
<script lang="ts">
let { percent }: { percent: number } = $props();
</script>
<!-- thin determinate load bar pinned to the bottom edge, pulsing while it fills -->
<div class="pointer-events-none absolute inset-x-0 bottom-0 h-0.5 overflow-hidden rounded-b-sm">
<div
class="h-full animate-pulse bg-primary transition-[width] duration-200 ease-out"
style="width: {percent}%"
></div>
</div>
@@ -2,8 +2,10 @@
import { ChevronDown, Loader2, Package } from '@lucide/svelte';
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
import * as Tooltip from '$lib/components/ui/tooltip';
import { KeyboardKey } from '$lib/enums';
import { KeyboardKey, ServerModelStatus } from '$lib/enums';
import { useModelsSelector } from '$lib/hooks/use-models-selector.svelte';
import { modelsStore, routerModels } from '$lib/stores/models.svelte';
import { modelLoadFraction } from '$lib/utils';
import {
DialogModelInformation,
DropdownMenuSearchable,
@@ -11,6 +13,7 @@
ModelsSelectorList,
ModelsSelectorOption
} from '$lib/components/app';
import ModelLoadHighlight from './ModelLoadHighlight.svelte';
import type { ModelItem } from './utils';
interface Props {
@@ -113,6 +116,17 @@
{/if}
{:else}
{@const selectedOption = ms.getDisplayOption()}
{@const triggerModel = selectedOption?.model}
{@const triggerStatus = triggerModel
? routerModels().find((m) => m.id === triggerModel)?.status?.value
: undefined}
{@const triggerLoading =
!!triggerModel &&
(triggerStatus === ServerModelStatus.LOADING ||
modelsStore.isModelOperationInProgress(triggerModel))}
{@const triggerLoadPercent = triggerLoading
? Math.round(modelLoadFraction(modelsStore.getLoadProgress(triggerModel)) * 100)
: 0}
{#if ms.isRouter}
<DropdownMenu.Root bind:open={isOpen} onOpenChange={ms.handleOpenChange}>
@@ -123,7 +137,7 @@
<DropdownMenu.Trigger
{...props}
class={[
`inline-grid cursor-pointer grid-cols-[1fr_auto_1fr] items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 text-xs shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
`relative inline-grid cursor-pointer grid-cols-[1fr_auto_1fr] items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 text-xs shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
!ms.isCurrentModelInCache
? 'bg-red-400/10 !text-red-400 hover:bg-red-400/20 hover:text-red-400'
: forceForegroundText
@@ -154,6 +168,10 @@
{:else}
<ChevronDown class="h-3 w-3.5 shrink-0" />
{/if}
{#if triggerLoading}
<ModelLoadHighlight percent={triggerLoadPercent} />
{/if}
</DropdownMenu.Trigger>
{/snippet}
</Tooltip.Trigger>
@@ -10,6 +10,7 @@
RotateCw
} from '@lucide/svelte';
import { ActionIcon, ModelId } from '$lib/components/app';
import ModelLoadHighlight from './ModelLoadHighlight.svelte';
import type { ModelOption } from '$lib/types/models';
import { ServerModelStatus } from '$lib/enums';
import { modelsStore, routerModels } from '$lib/stores/models.svelte';
@@ -119,11 +120,11 @@
</div>
{#if isLoading}
<div class="flex w-4 [@media(pointer:coarse)]:w-5 items-center justify-center">
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-5">
<Loader2 class="h-4 w-4 animate-spin text-muted-foreground" />
</div>
{:else if isFailed}
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-auto">
<CircleAlert
class="h-3.5 w-3.5 text-red-500 group-hover:hidden [@media(pointer:coarse)]:hidden"
/>
@@ -140,7 +141,7 @@
</div>
</div>
{:else if isSleeping}
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-auto">
<span
class="h-2 w-2 rounded-full bg-orange-400 group-hover:hidden [@media(pointer:coarse)]:hidden"
></span>
@@ -159,7 +160,7 @@
</div>
</div>
{:else if isLoaded}
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-auto">
<span
class="h-2 w-2 rounded-full bg-green-500 group-hover:hidden [@media(pointer:coarse)]:hidden"
></span>
@@ -176,7 +177,7 @@
</div>
</div>
{:else}
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-auto">
<span
class="h-2 w-2 rounded-full bg-muted-foreground/50 group-hover:hidden [@media(pointer:coarse)]:hidden"
></span>
@@ -196,13 +197,6 @@
</div>
{#if isLoading}
<div
class="pointer-events-none absolute inset-x-0 bottom-0 h-0.5 overflow-hidden rounded-b-sm bg-muted"
>
<div
class="h-full bg-primary transition-[width] duration-200 ease-out"
style="width: {loadPercent}%"
></div>
</div>
<ModelLoadHighlight percent={loadPercent} />
{/if}
</div>
@@ -8,6 +8,10 @@
ModelsSelectorList,
SearchInput
} from '$lib/components/app';
import ModelLoadHighlight from './ModelLoadHighlight.svelte';
import { ServerModelStatus } from '$lib/enums';
import { modelsStore, routerModels } from '$lib/stores/models.svelte';
import { modelLoadFraction } from '$lib/utils';
interface Props {
class?: string;
@@ -61,12 +65,23 @@
<p class="text-xs text-muted-foreground">No models available.</p>
{:else}
{@const selectedOption = ms.getDisplayOption()}
{@const triggerModel = selectedOption?.model}
{@const triggerStatus = triggerModel
? routerModels().find((m) => m.id === triggerModel)?.status?.value
: undefined}
{@const triggerLoading =
!!triggerModel &&
(triggerStatus === ServerModelStatus.LOADING ||
modelsStore.isModelOperationInProgress(triggerModel))}
{@const triggerLoadPercent = triggerLoading
? Math.round(modelLoadFraction(modelsStore.getLoadProgress(triggerModel)) * 100)
: 0}
{#if ms.isRouter}
<button
type="button"
class={[
`inline-flex cursor-pointer items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 max-sm:px-3 max-sm:py-2 text-xs max-sm:text-sm shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
`relative inline-flex cursor-pointer items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 text-xs shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 max-sm:px-3 max-sm:py-2 max-sm:text-sm dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
!ms.isCurrentModelInCache
? 'bg-red-400/10 !text-red-400 hover:bg-red-400/20 hover:text-red-400'
: forceForegroundText
@@ -99,6 +114,10 @@
{:else}
<ChevronDown class="h-3 w-3.5 shrink-0" />
{/if}
{#if triggerLoading}
<ModelLoadHighlight percent={triggerLoadPercent} />
{/if}
</button>
<Sheet.Root bind:open={sheetOpen} onOpenChange={handleSheetOpenChange}>
@@ -1,84 +0,0 @@
<script lang="ts">
import { goto } from '$app/navigation';
import { page } from '$app/state';
import { ActionIcon } from '$lib/components/app';
import {
ICON_STRIP_TRANSITION_DURATION,
ICON_STRIP_TRANSITION_DELAY_MULTIPLIER,
SIDEBAR_ACTIONS_ITEMS
} from '$lib/constants';
import { TooltipSide } from '$lib/enums';
import { fade } from 'svelte/transition';
import { circIn } from 'svelte/easing';
import { onMount } from 'svelte';
import { useKeyboardShortcuts } from '$lib/hooks/use-keyboard-shortcuts.svelte';
interface Props {
sidebarOpen: boolean;
onSearchClick: () => void;
}
let { sidebarOpen = false, onSearchClick }: Props = $props();
const { handleKeydown } = useKeyboardShortcuts({ activateSearchMode: () => onSearchClick() });
let initialized = $state(false);
let showIcons = $derived(!sidebarOpen);
showIcons = false;
onMount(() => {
showIcons = !sidebarOpen;
setTimeout(() => {
initialized = true;
}, ICON_STRIP_TRANSITION_DELAY_MULTIPLIER * SIDEBAR_ACTIONS_ITEMS.length);
});
</script>
<svelte:window onkeydown={handleKeydown} />
<div
class="hidden shrink-0 transition-[width] duration-200 ease-linear md:block {sidebarOpen
? 'w-0'
: 'w-[calc(var(--sidebar-width-icon)+1.5rem)]'}"
></div>
<aside
class="fixed top-0 bottom-0 left-0 z-10 hidden w-[calc(var(--sidebar-width-icon)+1.5rem)] flex-col items-center justify-between py-3 transition-opacity duration-200 ease-linear md:flex {sidebarOpen
? 'pointer-events-none opacity-0'
: 'opacity-100'}"
>
<div class="mt-12 flex flex-col items-center gap-1">
{#each SIDEBAR_ACTIONS_ITEMS as item, i (item.tooltip)}
{@const onclick = item.route ? () => goto(item.route!) : onSearchClick}
{@const isActive = item.activeRouteId
? page.route.id === item.activeRouteId
: item.activeRoutePrefix
? !!page.route.id?.startsWith(item.activeRoutePrefix)
: false}
{#if showIcons}
<div
in:fade={{
duration: ICON_STRIP_TRANSITION_DURATION,
delay: !initialized
? ICON_STRIP_TRANSITION_DELAY_MULTIPLIER + i * ICON_STRIP_TRANSITION_DELAY_MULTIPLIER
: 0,
easing: circIn
}}
>
<ActionIcon
icon={item.icon}
tooltip={item.tooltip}
tooltipSide={TooltipSide.RIGHT}
size="lg"
iconSize="h-4 w-4"
class="h-9 w-9 rounded-full hover:bg-accent! {isActive
? 'bg-accent text-accent-foreground'
: ''}"
{onclick}
/>
</div>
{/if}
{/each}
</div>
</aside>
@@ -1,40 +1,67 @@
<script lang="ts">
import { goto } from '$app/navigation';
import { page } from '$app/state';
import { Trash2, Pencil, Pin, X } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import { DialogConfirmation } from '$lib/components/app';
import SidebarNavigationActions from './SidebarNavigationActions.svelte';
import SidebarNavigationConversationItem from './SidebarNavigationConversationItem.svelte';
import { Checkbox } from '$lib/components/ui/checkbox';
import Label from '$lib/components/ui/label/label.svelte';
import ScrollArea from '$lib/components/ui/scroll-area/scroll-area.svelte';
import * as Sidebar from '$lib/components/ui/sidebar';
import Input from '$lib/components/ui/input/input.svelte';
import { ROUTES } from '$lib/constants/routes';
import { RouterService } from '$lib/services/router.service';
import { PanelLeftClose, PanelLeftOpen, X } from '@lucide/svelte';
import {
conversationsStore,
conversations,
buildConversationTree
} from '$lib/stores/conversations.svelte';
import { chatStore } from '$lib/stores/chat.svelte';
import { getPreviewText } from '$lib/utils';
import { APP_NAME } from '$lib/constants';
ActionIcon,
Logo,
SidebarNavigationConversationList,
SidebarNavigationActions
} from '$lib/components/app';
import { ROUTES } from '$lib/constants';
import { fade } from 'svelte/transition';
const sidebar = Sidebar.useSidebar();
import { useKeyboardShortcuts } from '$lib/hooks/use-keyboard-shortcuts.svelte';
import { conversationsStore, conversations } from '$lib/stores/conversations.svelte';
import { chatStore } from '$lib/stores/chat.svelte';
import { RouterService } from '$lib/services/router.service';
import { isMobile } from '$lib/stores/viewport.svelte';
import { TooltipSide } from '$lib/enums';
import { device } from '$lib/stores/device.svelte';
import { circIn } from 'svelte/easing';
interface Props {
onSearchClick?: () => void;
}
let { onSearchClick = () => {} }: Props = $props();
const { handleKeydown } = useKeyboardShortcuts({ activateSearchMode: () => onSearchClick() });
let isExpandedMode = $state(false);
let hoveredTooltip = $state<string | null>(null);
let logoHovered = $state(false);
const isStripExpanded = $derived(isExpandedMode || hoveredTooltip !== null);
const isOnMobile = $derived(isMobile.current);
function toggleExpandedMode() {
isExpandedMode = !isExpandedMode;
if (!isExpandedMode) {
hoveredTooltip = null;
}
}
$effect(() => {
if (!isExpandedMode) {
isSearchModeActive = false;
searchQuery = '';
cancelMobileCollapse();
}
});
// On mobile the dedicated /search route hides the sidebar (see the aside
// render guard below). Collapse it as we enter /search so it doesn't
// reappear expanded when the user navigates back via the back button.
$effect(() => {
if (isMobile.current && page.url.hash.includes(ROUTES.SEARCH)) {
isExpandedMode = false;
}
});
let currentChatId = $derived(page.params.id);
let isSearchModeActive = $state(false);
let searchQuery = $state('');
let showDeleteDialog = $state(false);
let deleteWithForks = $state(false);
let showEditDialog = $state(false);
let selectedConversation = $state<DatabaseConversation | null>(null);
let editedName = $state('');
let selectedConversationNamePreview = $derived.by(() =>
selectedConversation ? getPreviewText(selectedConversation.name) : ''
);
let filteredConversations = $derived.by(() => {
if (isSearchModeActive) {
@@ -50,294 +77,206 @@
return conversations();
});
let conversationTree = $derived(buildConversationTree(filteredConversations));
let pinnedConversations = $derived.by(() => {
return conversationTree.filter(({ conversation }) => conversation.pinned);
});
let unpinnedConversations = $derived.by(() => {
return conversationTree.filter(({ conversation }) => !conversation.pinned);
});
let selectedConversationHasDescendants = $derived.by(() => {
if (!selectedConversation) return false;
const allConvs = conversations();
const queue = [selectedConversation.id];
while (queue.length > 0) {
const parentId = queue.pop()!;
for (const c of allConvs) {
if (c.forkedFromConversationId === parentId) return true;
}
}
return false;
});
async function handleDeleteConversation(id: string) {
const conversation = conversations().find((conv) => conv.id === id);
if (conversation) {
selectedConversation = conversation;
deleteWithForks = false;
showDeleteDialog = true;
async function selectConversation(id: string) {
if (isMobile.current) {
scheduleMobileCollapse();
}
await goto(RouterService.chat(id));
}
async function handleEditConversation(id: string) {
const conversation = conversations().find((conv) => conv.id === id);
if (conversation) {
selectedConversation = conversation;
editedName = conversation.name;
showEditDialog = true;
if (!conversation) return;
const newName = window.prompt('Rename conversation', conversation.name);
if (newName && newName.trim()) {
await conversationsStore.updateConversationName(id, newName.trim());
}
}
function handleConfirmDelete() {
if (selectedConversation) {
const convId = selectedConversation.id;
const withForks = deleteWithForks;
showDeleteDialog = false;
async function handleDeleteConversation(id: string) {
const conversation = conversations().find((conv) => conv.id === id);
if (!conversation) return;
setTimeout(() => {
conversationsStore.deleteConversation(convId, {
deleteWithForks: withForks
});
}, 100); // Wait for animation to finish
}
}
const confirmed = window.confirm(
`Delete "${conversation.name}"? This action cannot be undone.`
);
if (!confirmed) return;
function handleConfirmEdit() {
if (!editedName.trim() || !selectedConversation) return;
showEditDialog = false;
conversationsStore.updateConversationName(selectedConversation.id, editedName);
selectedConversation = null;
}
export function handleMobileSidebarItemClick() {
if (sidebar.isMobile) {
sidebar.toggle();
}
}
let chatSidebarActions: { activateSearch?: () => void } | undefined = $state();
let openedForSearch = $state(false);
export function activateSearchMode() {
if (!sidebar.open) {
openedForSearch = true;
}
chatSidebarActions?.activateSearch?.();
}
function handleSearchDeactivated() {
if (openedForSearch) {
openedForSearch = false;
sidebar.toggle();
}
}
$effect(() => {
if (!sidebar.open) {
isSearchModeActive = false;
searchQuery = '';
openedForSearch = false;
}
});
export function editActiveConversation() {
if (currentChatId) {
const activeConversation = filteredConversations.find((conv) => conv.id === currentChatId);
if (activeConversation) {
const event = new CustomEvent('edit-active-conversation', {
detail: { conversationId: currentChatId }
});
document.dispatchEvent(event);
}
}
}
async function selectConversation(id: string) {
if (isSearchModeActive) {
isSearchModeActive = false;
searchQuery = '';
}
handleMobileSidebarItemClick();
await goto(RouterService.chat(id));
await conversationsStore.deleteConversation(id, { deleteWithForks: false });
}
function handleStopGeneration(id: string) {
chatStore.stopGenerationForChat(id);
}
let innerWidth = $state(0);
let pendingCollapse = $state<ReturnType<typeof setTimeout> | null>(null);
function scheduleMobileCollapse() {
if (pendingCollapse) {
clearTimeout(pendingCollapse);
}
pendingCollapse = setTimeout(() => {
isExpandedMode = false;
pendingCollapse = null;
}, 100);
}
function cancelMobileCollapse() {
if (pendingCollapse) {
clearTimeout(pendingCollapse);
pendingCollapse = null;
}
}
</script>
<div class="flex h-full flex-col">
<ScrollArea class="h-full flex-1">
<Sidebar.Header class="gap-4 bg-sidebar/50 p-3 backdrop-blur-lg md:pt-4 md:pb-2">
<div class="flex items-center justify-between">
<a href={ROUTES.START} onclick={handleMobileSidebarItemClick}>
<h1 class="inline-flex items-center gap-1 px-2 text-xl font-semibold">
{APP_NAME}
</h1>
</a>
<svelte:window onkeydown={handleKeydown} bind:innerWidth />
<Button
class="rounded-full md:hidden"
variant="ghost"
size="icon"
onclick={() => sidebar.toggle()}
>
<X class="h-4 w-4" />
<span class="sr-only">Close sidebar</span>
</Button>
{#if innerWidth > 768 || (!page.url.hash.includes(ROUTES.SETTINGS) && !page.url.hash.includes(ROUTES.MCP_SERVERS) && !page.url.hash.includes(ROUTES.SEARCH))}
<aside
class={[
// Layout & positioning
'fixed md:sticky top-2 left-2 md:left-0 md:ml-2 md:mt-2 pt-2 z-10 w-[calc(100dvw-1rem)]',
// Dimensions & overflow
'md:h-[calc(100dvh-1.125rem)]',
isExpandedMode &&
(device.isStandalone
? 'h-[calc(100dvh-2rem)]'
: device.isIOSDevice
? 'h-[calc(100dvh-0.5rem)]'
: 'h-[calc(100dvh-1rem)]'),
// Shape & depth
'rounded-3xl md:rounded-2xl',
// Flex layout
'flex flex-col justify-between',
// Transition
'md:transition-[width,padding] duration-200 ease-out',
// Expanded state: width, surface, depth
isStripExpanded && 'md:w-72 md:bg-muted/60 md:backdrop-blur-xl border-border shadow-md',
// Collapsed state
!isStripExpanded && 'md:w-12',
// Expanded mode flag (for mobile ::before overlay)
isExpandedMode && 'is-expanded'
]}
>
<div class="px-2 flex items-center justify-between">
<div
role="button"
tabindex="0"
class="relative"
onmouseenter={() => (logoHovered = true)}
onmouseleave={() => (logoHovered = false)}
>
<ActionIcon
icon={!isExpandedMode && logoHovered && innerWidth > 768 ? PanelLeftOpen : Logo}
size="lg"
iconSize="h-4.5 w-4.5 md:h-4 md:w-4"
class="{isExpandedMode
? 'bg-muted! md:bg-foreground/5!'
: 'bg-transparent!'} md:h-9 md:w-9 h-10 w-10 rounded-full md:hover:bg-foreground/10! pointer-events-auto"
href={isExpandedMode ? ROUTES.START : undefined}
onclick={isExpandedMode ? undefined : toggleExpandedMode}
tooltip={isExpandedMode ? undefined : 'Open Sidebar'}
tooltipSide={TooltipSide.RIGHT}
ariaLabel={isExpandedMode ? 'Go to start' : 'Expand navigation'}
/>
</div>
<SidebarNavigationActions
bind:this={chatSidebarActions}
{handleMobileSidebarItemClick}
bind:isSearchModeActive
bind:searchQuery
onSearchDeactivated={handleSearchDeactivated}
/>
</Sidebar.Header>
{#if !isSearchModeActive && pinnedConversations.length > 0}
<Sidebar.Group class="p-0 px-4">
<Sidebar.GroupLabel>
<div class="flex items-center gap-1">
<Pin class="h-3.5 w-3.5" />
<span>Pinned</span>
</div>
</Sidebar.GroupLabel>
<Sidebar.GroupContent>
<Sidebar.Menu>
{#each pinnedConversations as { conversation, depth } (conversation.id)}
<Sidebar.MenuItem class="mb-1 p-0">
<SidebarNavigationConversationItem
conversation={{
id: conversation.id,
name: conversation.name,
lastModified: conversation.lastModified,
currNode: conversation.currNode,
forkedFromConversationId: conversation.forkedFromConversationId,
pinned: conversation.pinned
}}
{depth}
isActive={currentChatId === conversation.id}
onSelect={selectConversation}
onEdit={handleEditConversation}
onDelete={handleDeleteConversation}
onStop={handleStopGeneration}
/>
</Sidebar.MenuItem>
{/each}
</Sidebar.Menu>
</Sidebar.GroupContent>
</Sidebar.Group>
{/if}
<Sidebar.Group class="mt-2 h-[calc(100vh-21rem)] space-y-2 p-0 px-3">
{#if (filteredConversations.length > 0 && isSearchModeActive) || !isSearchModeActive}
<Sidebar.GroupLabel>
{isSearchModeActive ? 'Search results' : 'Recent conversations'}
</Sidebar.GroupLabel>
{#if isExpandedMode || isOnMobile}
<div
class="flex items-center transition-all duration-150 ease-out {isMobile.current &&
!isExpandedMode
? 'opacity-0 h-0!'
: ''}"
in:fade={{ duration: 150, easing: circIn, delay: 50 }}
out:fade={{ duration: 100 }}
>
<ActionIcon
icon={isMobile.current ? X : PanelLeftClose}
size="lg"
iconSize="h-4.5 w-4.5 md:h-4 md:w-4"
class="backdrop-blur-none md:h-9 md:w-9 h-10 w-10 rounded-full mr-1 hover:bg-accent!"
onclick={toggleExpandedMode}
tooltip="Close Sidebar"
tooltipSide={TooltipSide.LEFT}
ariaLabel="Collapse navigation"
/>
</div>
{/if}
<Sidebar.GroupContent>
<Sidebar.Menu>
{#each isSearchModeActive ? conversationTree : unpinnedConversations as { conversation, depth } (conversation.id)}
<Sidebar.MenuItem class="mb-1 p-0">
<SidebarNavigationConversationItem
conversation={{
id: conversation.id,
name: conversation.name,
lastModified: conversation.lastModified,
currNode: conversation.currNode,
forkedFromConversationId: conversation.forkedFromConversationId,
pinned: conversation.pinned
}}
{depth}
isActive={currentChatId === conversation.id}
onSelect={selectConversation}
onEdit={handleEditConversation}
onDelete={handleDeleteConversation}
onStop={handleStopGeneration}
/>
</Sidebar.MenuItem>
{/each}
{#if (isSearchModeActive ? conversationTree : unpinnedConversations).length === 0}
<div class="px-2 py-4 text-center">
<p class="mb-4 p-4 text-sm text-muted-foreground">
{searchQuery.length > 0
? 'No results found'
: isSearchModeActive
? 'Start typing to see results'
: 'No conversations yet'}
</p>
</div>
{/if}
</Sidebar.Menu>
</Sidebar.GroupContent>
</Sidebar.Group>
</ScrollArea>
</div>
<DialogConfirmation
bind:open={showDeleteDialog}
title="Delete Conversation"
description={selectedConversation
? `Are you sure you want to delete "${selectedConversationNamePreview}"? This action cannot be undone and will permanently remove all messages in this conversation.`
: ''}
confirmText="Delete"
cancelText="Cancel"
variant="destructive"
icon={Trash2}
onConfirm={handleConfirmDelete}
onCancel={() => {
showDeleteDialog = false;
selectedConversation = null;
}}
>
{#if selectedConversationHasDescendants}
<div class="flex items-center gap-2 py-2">
<Checkbox id="delete-with-forks" bind:checked={deleteWithForks} />
<Label for="delete-with-forks" class="text-sm">Also delete all forked conversations</Label>
</div>
{/if}
</DialogConfirmation>
<DialogConfirmation
bind:open={showEditDialog}
title="Edit Conversation Name"
description=""
confirmText="Save"
cancelText="Cancel"
icon={Pencil}
onConfirm={handleConfirmEdit}
onCancel={() => {
showEditDialog = false;
selectedConversation = null;
}}
onKeydown={(event) => {
if (event.key === 'Enter') {
event.preventDefault();
event.stopImmediatePropagation();
handleConfirmEdit();
<div class="mt-2 flex min-h-0 flex-1 flex-col gap-4 md:gap-1 overflow-y-auto">
<div
class="flex min-h-0 flex-1 flex-col gap-4 md:gap-1 {isMobile.current
? 'transition-[opacity,height] duration-200 ease-out'
: ''} {isMobile.current && !isExpandedMode ? 'opacity-0 !h-0' : ''}"
in:fade={{ duration: 200 }}
out:fade={{ duration: 200 }}
>
<SidebarNavigationActions
isExpandedMode={innerWidth > 768 ? isExpandedMode : true}
class="px-2"
bind:isSearchModeActive
bind:searchQuery
onSearchDeactivated={() => {
isSearchModeActive = false;
searchQuery = '';
}}
onSearchClick={() => {
isExpandedMode = true;
isSearchModeActive = true;
}}
onNewChat={() => {
if (isMobile.current) {
scheduleMobileCollapse();
}
}}
/>
{#if isExpandedMode || isOnMobile}
<SidebarNavigationConversationList
class="px-2"
{filteredConversations}
{currentChatId}
{isSearchModeActive}
{searchQuery}
onSelect={selectConversation}
onEdit={handleEditConversation}
onDelete={handleDeleteConversation}
onStop={handleStopGeneration}
/>
{/if}
</div>
</div>
</aside>
{/if}
<style>
aside {
@media (max-width: 768px) {
--size: 1.125rem;
}
}}
>
<Input
class="text-foreground"
placeholder="Enter a new name"
type="text"
bind:value={editedName}
/>
</DialogConfirmation>
}
@media (max-width: 768px) {
aside {
&:not(.is-expanded) {
pointer-events: none;
}
}
aside.is-expanded::before {
content: '';
position: fixed;
top: -0.5rem;
bottom: -0.25rem;
left: -0.5rem;
right: -0.5rem;
z-index: -1;
background: var(--background);
backdrop-filter: blur(1rem);
pointer-events: none;
}
}
</style>
@@ -1,39 +1,86 @@
<script lang="ts">
import { KeyboardShortcutInfo } from '$lib/components/app';
import { Button } from '$lib/components/ui/button';
import type { Component } from 'svelte';
import { SearchInput } from '$lib/components/app';
import { goto } from '$app/navigation';
import { page } from '$app/state';
import { SIDEBAR_ACTIONS_ITEMS } from '$lib/constants/ui';
import { Search } from '@lucide/svelte';
import { ActionIcon, KeyboardShortcutInfo, SearchInput } from '$lib/components/app';
import { Button } from '$lib/components/ui/button';
import {
ICON_STRIP_TRANSITION_DURATION,
ICON_STRIP_TRANSITION_DELAY_MULTIPLIER,
ROUTES,
SIDEBAR_ACTIONS_ITEMS
} from '$lib/constants';
import { isMobile } from '$lib/stores/viewport.svelte';
import { TooltipSide } from '$lib/enums';
import { fade } from 'svelte/transition';
import { circIn } from 'svelte/easing';
import { onMount } from 'svelte';
import type { Component } from 'svelte';
interface Props {
handleMobileSidebarItemClick: () => void;
class: string;
isExpandedMode: boolean;
isSearchModeActive: boolean;
searchQuery: string;
isCancelAlwaysVisible?: boolean;
onSearchDeactivated?: () => void;
onSearchClick?: () => void;
onNewChat?: () => void;
}
let {
handleMobileSidebarItemClick,
isSearchModeActive = $bindable(),
searchQuery = $bindable(),
isCancelAlwaysVisible = false,
onSearchDeactivated
class: className,
isExpandedMode = false,
isSearchModeActive = $bindable(false),
searchQuery = $bindable(''),
onSearchDeactivated,
onSearchClick,
onNewChat
}: Props = $props();
let initialized = $state(false);
let showIcons = $state(false);
let searchInputRef = $state<HTMLInputElement | null>(null);
const isOnMobile = $derived(isMobile.current);
$effect(() => {
if (isSearchModeActive && searchInputRef) {
searchInputRef.focus();
}
});
onMount(() => {
showIcons = true;
setTimeout(() => {
initialized = true;
}, ICON_STRIP_TRANSITION_DELAY_MULTIPLIER * SIDEBAR_ACTIONS_ITEMS.length);
});
function handleSearchModeDeactivate() {
isSearchModeActive = false;
searchQuery = '';
onSearchDeactivated?.();
}
export function activateSearch() {
isSearchModeActive = true;
// Focus after Svelte renders the input
queueMicrotask(() => searchInputRef?.focus());
function isItemActive(item: {
activeRouteId?: string;
activeRoutePrefix?: string;
activeUrlIncludes?: string;
}): boolean {
if (item.activeRouteId) {
return page.route.id === item.activeRouteId;
}
if (item.activeRoutePrefix) {
return !!page.route.id?.startsWith(item.activeRoutePrefix);
}
if (item.activeUrlIncludes) {
return page.url?.hash?.includes(item.activeUrlIncludes) ?? false;
}
return false;
}
</script>
@@ -41,56 +88,109 @@
<IconComponent class="h-4 w-4" />
{/snippet}
<div class="my-1 space-y-1">
{#if isSearchModeActive}
{#if isSearchModeActive}
<div class="px-4 my-2">
<SearchInput
bind:value={searchQuery}
bind:ref={searchInputRef}
onClose={handleSearchModeDeactivate}
onKeyDown={(e) => e.key === 'Escape' && handleSearchModeDeactivate()}
placeholder="Search conversations..."
{isCancelAlwaysVisible}
/>
{:else}
{#each SIDEBAR_ACTIONS_ITEMS as item (item.route)}
{#if !item.route}
<Button
class="w-full justify-between px-2 backdrop-blur-none! hover:[&>kbd]:opacity-100"
onclick={activateSearch}
variant="ghost"
>
<div class="flex items-center gap-2">
{@render itemIcon(item.icon)}
</div>
{:else if isExpandedMode || isOnMobile}
<div
class="{className} flex flex-col gap-5 md:gap-1 mt-2 md:mt-0 {!isExpandedMode && isOnMobile
? 'hidden pointer-events-none'
: ''}"
>
{#each SIDEBAR_ACTIONS_ITEMS as item, i (item.tooltip)}
{@const isActive = isItemActive(item)}
{@const isSearchOnMobile = item.icon === Search && isMobile.current}
{@const itemHref = isSearchOnMobile ? ROUTES.SEARCH : item.route}
{@const itemOnClick = item.route
? () => {
onNewChat?.();
goto(item.route!);
}
: isSearchOnMobile
? undefined
: onSearchClick}
{@const itemTransition = {
duration: ICON_STRIP_TRANSITION_DURATION,
delay: !initialized
? ICON_STRIP_TRANSITION_DELAY_MULTIPLIER + i * ICON_STRIP_TRANSITION_DELAY_MULTIPLIER
: 0,
easing: circIn
}}
{item.tooltip}
</div>
{#if showIcons}
<div transition:fade={itemTransition}>
<Button
class="w-full min-w-9 justify-between px-2 backdrop-blur-none! hover:[&>kbd]:opacity-100 {isActive
? 'bg-accent text-accent-foreground'
: ''}"
href={itemHref}
onclick={itemOnClick}
variant="ghost"
size="default"
>
<span class="flex min-w-0 items-center px-0.5 gap-2">
{@render itemIcon(item.icon)}
{#if item.keys}
<KeyboardShortcutInfo keys={item.keys} />
{/if}
</Button>
{:else}
<Button
class="w-full justify-between px-2 backdrop-blur-none! hover:[&>kbd]:opacity-100 {(item.activeRouteId &&
page.route.id === item.activeRouteId) ||
(item.activeRoutePrefix && page.route.id?.startsWith(item.activeRoutePrefix))
? 'bg-accent text-accent-foreground'
: ''}"
href={item.route}
onclick={handleMobileSidebarItemClick}
variant="ghost"
>
<div class="flex items-center gap-2">
{@render itemIcon(item.icon)}
{#if showIcons}
<span
in:fade={{ duration: 150, easing: circIn, delay: 50 }}
out:fade={{ duration: 100 }}
class="min-w-0 truncate">{item.tooltip}</span
>
{/if}
</span>
{item.tooltip}
</div>
{#if item.keys}
<KeyboardShortcutInfo keys={item.keys} />
{/if}
</Button>
{#if item.keys}
<KeyboardShortcutInfo keys={item.keys} />
{/if}
</Button>
</div>
{/if}
{/each}
{/if}
</div>
</div>
{:else}
<div class="{className} flex-col gap-1 hidden md:flex">
{#each SIDEBAR_ACTIONS_ITEMS as item, i (item.tooltip)}
{@const isActive = isItemActive(item)}
{@const isSearchOnMobile = item.icon === Search && isMobile.current}
{@const itemOnClick = item.route
? () => {
onNewChat?.();
goto(item.route!);
}
: isSearchOnMobile
? undefined
: onSearchClick}
{@const itemTransition = {
duration: ICON_STRIP_TRANSITION_DURATION,
delay: !initialized
? ICON_STRIP_TRANSITION_DELAY_MULTIPLIER + i * ICON_STRIP_TRANSITION_DELAY_MULTIPLIER
: 0,
easing: circIn
}}
{#if showIcons}
<div transition:fade={itemTransition}>
<ActionIcon
icon={item.icon}
tooltip={item.tooltip}
tooltipSide={TooltipSide.RIGHT}
size="lg"
iconSize="h-4 w-4"
class="h-9 w-9 rounded-full hover:bg-accent! {isActive
? 'bg-accent text-accent-foreground'
: ''}"
onclick={itemOnClick}
/>
</div>
{/if}
{/each}
</div>
{/if}
@@ -0,0 +1,135 @@
<script lang="ts">
import { Pin } from '@lucide/svelte';
import { buildConversationTree } from '$lib/stores/conversations.svelte';
import SidebarNavigationConversationItem from './SidebarNavigationConversationItem.svelte';
import SidebarNavigationSearchResults from './SidebarNavigationSearchResults.svelte';
interface Props {
class: string;
filteredConversations: DatabaseConversation[];
currentChatId: string | undefined;
isSearchModeActive: boolean;
searchQuery: string;
onSelect: (id: string) => void;
onEdit: (id: string) => void;
onDelete: (id: string) => void;
onStop: (id: string) => void;
}
let {
class: className,
filteredConversations,
currentChatId,
isSearchModeActive,
searchQuery,
onSelect,
onEdit,
onDelete,
onStop
}: Props = $props();
let conversationTree = $derived(buildConversationTree(filteredConversations));
let pinnedConversations = $derived(
conversationTree.filter(({ conversation }) => conversation.pinned)
);
let unpinnedConversations = $derived(
conversationTree.filter(({ conversation }) => !conversation.pinned)
);
const recentEmptyMessage = $derived(
searchQuery.length > 0 ? 'No results found' : 'No conversations yet'
);
</script>
{#if isSearchModeActive}
<SidebarNavigationSearchResults
class={className}
{searchQuery}
{filteredConversations}
{currentChatId}
{onSelect}
{onEdit}
{onDelete}
{onStop}
/>
{:else}
{#if pinnedConversations.length > 0}
<div class="py-2 flex whitespace-nowrap {className}">
<div
class="text-muted-foreground inline-flex h-8 shrink-0 items-center rounded-md px-2 text-xs font-medium gap-1"
>
<Pin class="h-3.5 w-3.5" />
<span>Pinned</span>
</div>
</div>
<ul class="flex w-full min-w-0 flex-col gap-4 md:gap-1 {className}">
{#each pinnedConversations as { conversation, depth } (conversation.id)}
<li class="group/item relative mb-1 p-0">
<SidebarNavigationConversationItem
conversation={{
id: conversation.id,
name: conversation.name,
lastModified: conversation.lastModified,
currNode: conversation.currNode,
forkedFromConversationId: conversation.forkedFromConversationId,
pinned: conversation.pinned
}}
{depth}
isActive={currentChatId === conversation.id}
{onSelect}
{onEdit}
{onDelete}
{onStop}
/>
</li>
{/each}
</ul>
{/if}
<div class="mt-2 flex min-h-0 flex-1 flex-col gap-4 md:gap-2 whitespace-nowrap {className}">
{#if filteredConversations.length > 0}
<div
class="text-muted-foreground flex h-8 shrink-0 items-center rounded-md px-2 text-xs font-medium"
>
Recent conversations
</div>
{/if}
<div class="min-h-0 flex-1 md:overflow-y-auto">
<ul class="flex w-full min-w-0 flex-col gap-4 md:gap-1">
{#each unpinnedConversations as { conversation, depth } (conversation.id)}
<li class="group/item relative mb-1 p-0">
<SidebarNavigationConversationItem
conversation={{
id: conversation.id,
name: conversation.name,
lastModified: conversation.lastModified,
currNode: conversation.currNode,
forkedFromConversationId: conversation.forkedFromConversationId,
pinned: conversation.pinned
}}
{depth}
isActive={currentChatId === conversation.id}
{onSelect}
{onEdit}
{onDelete}
{onStop}
/>
</li>
{/each}
{#if unpinnedConversations.length === 0}
<li class="px-2 py-4 text-center">
<p class="mb-4 p-4 text-sm text-muted-foreground">
{recentEmptyMessage}
</p>
</li>
{/if}
</ul>
</div>
</div>
{/if}
@@ -16,4 +16,6 @@
}: Props = $props();
</script>
<SearchInput bind:value {placeholder} {onInput} class="mb-4 {className}" />
<div class="mb-4 px-2 {className}">
<SearchInput bind:value {placeholder} {onInput} />
</div>

Some files were not shown because too many files have changed in this diff Show More