Compare commits

...

20 Commits

Author SHA1 Message Date
Al G 2da6686176 Fix stale tensor-split params for draft models (#24814)
* meta: fix tensor split metadata for GQA attention

* Tidied the code a bit to match existing style

* Revert "Tidied the code a bit to match existing style"

This reverts commit b90c6c6300.

* Reverted the ggml-backend-meta asset hack.
2026-07-05 20:39:36 +02:00
Eve 3e5036fbfb abort if we see a multi buffer (#25276) 2026-07-05 20:38:47 +02:00
liminfei-amd 4b2a0cdee1 ggml : fix tensor-parallel + -ncmoe crash on MoE models (#25028)
Tensor parallelism (-sm tensor) combined with -ncmoe (CPU-offloaded MoE
experts) aborts during warm-up on MoE models with
GGML_ASSERT(ggml_is_contiguous(tensor)) in ggml-backend-meta.cpp.

The failing tensor is the MoE router output (ffn_moe_topk): it is mirrored
(GGML_BACKEND_SPLIT_AXIS_MIRRORED, replicated across backends since routing
must be identical) and happens to be a non-contiguous view.
ggml_backend_meta_buffer_{get,set}_tensor asserted contiguity before
consulting the split state, so a mirrored non-contiguous tensor tripped the
assert even though the GGML_BACKEND_SPLIT_AXIS_MIRRORED case right below
already handles it.

Move the split-state lookup above the assert and allow the mirrored case in
both get_tensor and set_tensor.

Diagnosis credit to the reporter (@nathanmp).

Fixes #24886

Signed-off-by: liminfei-amd <91481003+liminfei-amd@users.noreply.github.com>
2026-07-05 19:56:11 +02:00
Vexxie 7a63fdede1 ggml: Update VMM Pool allocation ggml-cuda.cu - Turing P2P access fix (fixes #24489) (#24491)
* Update ggml-cuda.cu - Turing P2P access fix.

* Add original code as fallback behaviour when NCCL or P2P is not set/true.

* Update ggml/src/ggml-cuda/ggml-cuda.cu to add comment as per suggestion

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2026-07-05 19:10:09 +02:00
fairydreaming 78d2f52468 cuda : concat implementation for quantized types (#25303)
* cuda : concat implementation for quantized types

* chore : apply am17an clever suggestion to shorten the code

---------

Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
2026-07-05 23:26:24 +08:00
liminfei-amd a4107133a6 llama : add guard for K/V rotation input when buffer is unallocated (#25215)
llm_graph_input_attn_kv::set_input and llm_graph_input_attn_kv_iswa::set_input
call set_input_k_rot / set_input_v_rot whenever the rotation tensor pointer is
non-null, but the tensor's buffer can be unallocated (NULL) when a graph only
stores K/V without attending -- e.g. DFlash speculative decoding's KV-injection
pass. set_input_k_rot then calls ggml_backend_buffer_is_host() on a NULL buffer
and aborts with GGML_ASSERT(buffer).

Guard the four k_rot/v_rot inputs with the same "&& ->buffer" check that the
adjacent kq_mask inputs already use in these two functions. When the buffer is
unallocated there is no data to upload, so skipping is correct.

Fixes #25191

Signed-off-by: liminfei-amd <91481003+liminfei-amd@users.noreply.github.com>
2026-07-04 22:37:38 +02:00
Pascal 665892536d ui: add sync blocks so display/behavior settings can be set via --ui-config-file (#25132)
* ui: add sync blocks so display/behavior settings can be set via --ui-config-file

* ui: remove enable thinking setting
2026-07-04 16:12:27 +02:00
fairydreaming ef2d770117 ggml : fix broken CPU concat implementation for quantized types (#25247)
* ggml : fix broken CPU concat implementation for quantized types

* tests : concat tests for quantized types

---------

Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
2026-07-04 13:37:37 +02:00
Piotr Wilkin (ilintar) 2d973636e2 chat: trim messages sent to StepFun parser (fixes long reasoning loops) (#25238)
* chat: trim messages sent to StepFun parser (fixes long reasoning loops)

* add regression test; remove duplicate template

* chat: trim StepFun content parts before rendering

The StepFun trim workaround ran on the already-rendered messages, where
typed content parts have been concatenated into a single string, so the
per-part whitespace could no longer be reached. Move the trim ahead of
rendering and apply it to content_parts text as well as the string
content and reasoning_content. Adds a content-parts regression test.

Co-Authored-By: Piotr Wilkin <ilintar@gmail.com>
Assisted-By: Claude Fable 5 <noreply@anthropic.com>

---------

Co-authored-by: tarruda <tpadilha84@gmail.com>
2026-07-03 23:12:11 +02:00
Nick Towle d4cff114c0 ui: Improve performance when streaming (#25225)
* ui: Improve performance when streaming

* ui: build sibling info map in branching utils

Moves the node map and sibling map construction from the
.by block into buildSiblingInfoMap() in branching.ts.

The map is built once per structural change and only read
afterwards, so it does not need SvelteMap reactivity. Keeping
the construction in plain TypeScript fixes the
svelte/prefer-svelte-reactivity lint error and groups the
branching logic where it already lives.

---------

Co-authored-by: Pascal <admin@serveurperso.com>
2026-07-03 19:03:51 +02:00
Pascal f113e02d5a ui: strip path and weight extension from model id in single model mode (#25137) 2026-07-03 17:32:48 +02:00
Ruixiang Wang 152d337fad spec: support spec-draft-p-min in DFlash (#25246)
* spec: support spec-draft-p-min in DFlash

* dflash: add n_min guard

* dflash: guard both n_min and n_max
2026-07-03 15:40:06 +02:00
Piotr Wilkin (ilintar) 75a48a9055 cuda: enable topk-moe fusion for 288 experts (#25267)
* cuda: enable topk-moe fusion for 288 experts

The topk-moe fusion only accepted power-of-2 expert counts (or the
special-cased 576), so models with 288 experts (e.g. Step-3.7-Flash)
fell back to the unfused per-layer routing chain: softmax/sigmoid,
argsort, get_rows, sum_rows, div, clamp, scale. At batch size 1 that
is ~330 extra tiny graph nodes per token.

288 is a multiple of the warp size, so the existing kernel already
handles it; this adds the missing template instantiation and accepts
288 in the eligibility check.

Measured on gfx1151 with Step-3.7-Flash IQ4_XS (llama-bench,
-b 4096 -ub 4096 -fa 1 -dio 1 -ctk q8_0 -ctv q8_0; machine idle,
before/after paired so pp4096 stays matched as a load control):

  test            | before         | after
  ----------------+----------------+----------------
  pp4096          | 460.99 ± 0.45  | 462.47 ± 0.34   (unchanged)
  tg128           |  19.10 ± 0.04  |  19.56 ± 0.03   (+2.4%)
  tg128 @ d30000  |  12.68 ± 0.04  |  12.69 ± 0.03   (unchanged)

Prompt processing is unaffected (the fusion only touches decode
routing). The decode gain is ~+2.4% at shallow context and fades with
depth: by 30k tokens each step is attention-bound over the KV cache,
so removing the fixed routing overhead is no longer visible.

Assisted-By: Claude Fable 5 <noreply@anthropic.com>

* Update tests/test-backend-ops.cpp

Co-authored-by: Oliver Simons <osimons@nvidia.com>

* Add comment for case 288 in topk-moe.cu

---------

Co-authored-by: Oliver Simons <osimons@nvidia.com>
2026-07-03 15:36:55 +02:00
Pascal 067de93718 ui: align persisted config with strict server schema and enable thinking by default (#25242)
* ui: migrate legacy string-encoded booleans in persisted config

* ui: enable thinking by default

Fresh users and legacy conversations without a persisted thinking
preference now default to enabled. The per-conversation toggle and
the persisted localStorage choice keep taking precedence.

Picks up the enable_thinking default from #24876.
2026-07-03 13:14:52 +02:00
Pascal b5315e16e0 server + ui: ping silent SSE streams every 1s and kick only after 3s so slow prefill never drops healthy connections (#25241)
* server + ui: ping silent SSE streams every 1s and kick only after 3s so slow prefill never drops healthy connections

* server + ui: sse_ping_interval becomes a per-request body field

Address review from ngxson: the global default returns to 30 so API
clients see no behavior change, and the WebUI sends sse_ping_interval: 1
in the request body since it owns the 3s visibility-kick contract and
declares the cadence it needs. Positive values keep the existing > 0
gate, -1 keeps its disabled semantics.

* server: move sse_ping_interval into the request schema

Address review from ngxson: the field is now a typed field_num with
hard limits (-1, INT32_MAX) bound to task_params, seeded from the CLI
default alongside the other inherited parameters. The raw json_value
read and its redundant comment are gone, and schema evaluation brings
type and range validation for free.
2026-07-03 12:47:04 +02:00
Aleksander Grygier 94875285e4 ui: Add MCP Servers Opt-In for first time visitors (#25239)
* feat: ui: Add predefined recommended MCP servers to settings

* feat: ui: Add MCP server recommendation dialog with custom server support

* feat: Auto-focus input fields on mount and dynamic addition

* feat: Add header validation to MCP server add and edit forms

* feat: Persist recommended MCP server opt-in selections

* test: Cover MCP configuration with tests

* chore: Format & cleanup

* feat: Centralize MCP server overrides to settings config and improve recommendation UI

* fix: Capture index before mutation to prevent focus drift

* refactor: Extract MCP_CARD_VISIBLE_TOOL_LIMIT to shared constants

* refactor: Support arbitrary authorization header schemes

* refactor: Consolidate MCP recommendations dismissal into existing storage key

* fix: Use case-insensitive comparison for MCP server ID prefix check

* refactor: Centralize MCP server visibility logic and extract recommendations hook

* refactor: Cleanup
2026-07-03 12:16:29 +02:00
Gaurav Garg 5a460dea9f Remove redundant CUDA copies after gated_delta_net. (#23940)
* Remove redundant CUDA copies after gated_delta_net.

Currently, GDN writes recurrent state snapshots into its output tail, then the graph immediately copies those snapshots into ssm_states_all. With MTP draft length 3, target decode uses K=4, so that becomes 4 extra ggml_cuda_cpy calls.

The change detects that gated_delta_net -> view -> cpy pattern and makes the CUDA GDN kernel write the state snapshot(s) directly into the recurrent cache, skipping the intermediate tail writes and copy kernels when safe.

* Address review comments
2026-07-03 14:36:29 +05:30
Alessandro de Oliveira Faria (A.K.A.CABELO) c8ae9a750c vendor : update cpp-httplib to 0.49.0 (#25218) 2026-07-03 10:26:54 +02:00
Adrien Gallouët fdb1db877c llama : add llama_model_ftype_name() (#25134)
* llama : add llama_model_ftype_name()

Expose the model file type (quantization) name, e.g. "Q8_0" or
"Q4_K - Medium", through a new public C API. The returned pointer is
valid for the lifetime of the model and nullptr when the model is
invalid or the file type is unknown.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Export enum

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* s/llama_model_ftype_name/llama_ftype_name/

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Move "(guessed)" to the front in llama_ftype_name

Prepend the "(guessed)" label instead of appending it. This allows removing
the non-thread-safe static std::string, making the function allocation-free.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Add LLAMA_FTYPE_PREFIX

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Dont check for model

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

---------

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-07-02 17:26:47 +02:00
lhez 4fc4ec5541 opencl: allow loading precompiled binary kernels from library (#23042)
* opencl: allow loading binary kernel

* opencl: add libdl.h

* ggml-backend-dl is in ggml, which depends backend libs, thus
  ggml-opencl cannot depend on ggml-backend-dl
* add libdl.h to break cyclic dep

* opencl: allow loading bin kernel lib

* opencl: load `gemm_moe_mxfp4_f32_ns` from kernel lib if available

* opencl: load q8_0 gemm from kernel lib

* opencl: load q4_0 moe gemm from kernel lib

* opencl: load q4_1 moe gemm from kernel lib

* opencl: load q4_k moe gemm from kernel lib

* opencl: always declare `get_adreno_bin_kernel_func_t`

* opencl: rephrase message

* opencl: fix for rebase

* opencl: update doc
2026-07-01 10:29:22 -07:00
72 changed files with 2646 additions and 589 deletions
+27 -1
View File
@@ -2378,6 +2378,23 @@ static void func_args_not_string(json & messages) {
}
}
// Trim leading/trailing whitespace from message contents before rendering. This
// has to run on the messages (not on the rendered JSON) because templates with
// string-only content caps concatenate typed content parts into a single string
// during rendering, after which the per-part whitespace can no longer be reached.
// Both the plain string content and the text of typed content parts are trimmed.
static void trim_all_content(std::vector<common_chat_msg> & messages) {
for (auto & message : messages) {
message.content = trim_whitespace(message.content);
message.reasoning_content = trim_whitespace(message.reasoning_content);
for (auto & part : message.content_parts) {
if (part.type == "text") {
part.text = trim_whitespace(part.text);
}
}
}
}
}
// MiniCPM5 format:
@@ -2634,7 +2651,16 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
params.tools.is_array() && tmpls->template_tool_use ? *tmpls->template_tool_use : *tmpls->template_default;
const auto & src = tmpl.source();
const auto & caps = tmpl.original_caps();
params.messages = render_message_to_json(inputs.messages, tmpl.original_caps());
std::vector<common_chat_msg> trimmed_messages;
const std::vector<common_chat_msg> * messages_to_render = &inputs.messages;
if (src.find("You have access to the following functions in JSONSchema format") != std::string::npos) {
// StepFun: trim message contents (including typed content parts) before rendering,
// otherwise leftover whitespace drives the model into reasoning loops (issue #24181)
trimmed_messages = inputs.messages;
workaround::trim_all_content(trimmed_messages);
messages_to_render = &trimmed_messages;
}
params.messages = render_message_to_json(*messages_to_render, tmpl.original_caps());
params.tool_choice = inputs.tool_choice;
params.reasoning_format = inputs.reasoning_format;
params.enable_thinking = inputs.enable_thinking;
+14 -5
View File
@@ -955,10 +955,11 @@ struct common_speculative_impl_draft_dflash : public common_speculative_impl {
LOG_INF("%s: - block_size=%d, mask_token_id=%d, n_extract=%u\n", __func__, block_size, mask_token_id, target_layer_ids_n);
// DFlash input is [id_last, <mask> * (block_size-1)], so it can draft at most block_size-1 tokens per step
if (this->params.n_max > block_size - 1) {
LOG_WRN("%s: requested draft size %d exceeds the trained DFlash block size %d -- clamping to %d draft tokens per step\n",
__func__, this->params.n_max, block_size - 1, block_size - 1);
this->params.n_max = block_size - 1;
if (this->params.n_max > block_size - 1 || this->params.n_min > block_size - 1) {
LOG_WRN("%s: requested draft size (n_max=%d, n_min=%d) exceeds the trained DFlash block size %d -- clamping to %d\n",
__func__, this->params.n_max, this->params.n_min, block_size, block_size - 1);
this->params.n_max = std::min(this->params.n_max, block_size - 1);
this->params.n_min = std::min(this->params.n_min, block_size - 1);
}
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, n_seq);
@@ -968,7 +969,7 @@ struct common_speculative_impl_draft_dflash : public common_speculative_impl {
for (auto & s : smpls) {
common_params_sampling sparams;
sparams.no_perf = false;
sparams.top_k = 1;
sparams.top_k = 10;
sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
s.reset(common_sampler_init(model_dft, sparams));
}
@@ -1173,10 +1174,18 @@ struct common_speculative_impl_draft_dflash : public common_speculative_impl {
const llama_token id = cur_p->data[0].id;
if (cur_p->data[0].p < params.p_min) {
break;
}
common_sampler_accept(smpl, id, true);
result.push_back(id);
}
if (result.size() < (size_t) params.n_min) {
result.clear();
}
}
}
+51 -39
View File
@@ -1,16 +1,26 @@
# llama.cpp for OpenCL
- [Background](#background)
- [OS](#os)
- [Hardware](#hardware)
- [DataType Supports](#datatype-supports)
- [Model Preparation](#model-preparation)
- [CMake Options](#cmake-options)
- [Android](#android)
- [Windows 11 Arm64](#windows-11-arm64)
- [Linux](#Linux)
- [Known Issue](#known-issues)
- [TODO](#todo)
- [llama.cpp for OpenCL](#llamacpp-for-opencl)
- [Background](#background)
- [Llama.cpp + OpenCL](#llamacpp--opencl)
- [OS](#os)
- [Hardware](#hardware)
- [Adreno GPU](#adreno-gpu)
- [DataType Supports](#datatype-supports)
- [Model Preparation](#model-preparation)
- [Binary Kernel Library](#binary-kernel-library)
- [CMake Options](#cmake-options)
- [Android](#android)
- [I. Setup Environment](#i-setup-environment)
- [II. Build llama.cpp](#ii-build-llamacpp)
- [Windows 11 Arm64](#windows-11-arm64)
- [I. Setup Environment](#i-setup-environment-1)
- [II. Build llama.cpp](#ii-build-llamacpp-1)
- [Linux](#linux)
- [I. Setup Environment](#i-setup-environment-2)
- [II. Build llama.cpp](#ii-build-llamacpp-2)
- [Known Issues](#known-issues)
- [TODO](#todo)
## Background
@@ -34,11 +44,13 @@ The llama.cpp OpenCL backend is designed to enable llama.cpp on **Qualcomm Adren
**Verified devices**
| Adreno GPU | Status |
|:------------------------------------:|:-------:|
| Adreno 750 (Snapdragon 8 Gen 3) | Support |
| Adreno 830 (Snapdragon 8 Elite) | Support |
| Adreno X85 (Snapdragon X Elite) | Support |
| Adreno GPU | Status |
|:-------------------------------------:|:-------:|
| Adreno 750 (Snapdragon 8 Gen 3) | Support |
| Adreno 830 (Snapdragon 8 Elite) | Support |
| Adreno 840 (Snapdragon 8 Elite Gen 5) | Support |
| Adreno X1-85 (Snapdragon X Elite) | Support |
| Adreno X2-90 (Snapdragon X2 Elite) | Support |
> A6x GPUs with a recent driver and compiler are supported; they are usually found in IoT platforms.
However, A6x GPUs in phones are likely not supported due to the outdated driver and compiler.
@@ -47,42 +59,43 @@ However, A6x GPUs in phones are likely not supported due to the outdated driver
| DataType | Status |
|:----------------------:|:--------------------------:|
| Q1_0 | Support |
| Q4_0 | Support |
| Q6_K | Support, but not optimized |
| Q4_1 | Support |
| Q5_0 | Support |
| Q5_1 | Support |
| Q8_0 | Support |
| Q4_K | Support |
| Q5_K | Support |
| Q6_K | Support |
| MXFP4 | Support |
| IQ4_NL | Support |
## Model Preparation
You can refer to the general [llama-quantize tool](/tools/quantize/README.md) for steps to convert a model in Hugging Face safetensor format to GGUF with quantization.
Since common quantizations are supported now, it is recommanded to download GGUF models directly from Huggingface.
Currently we support `Q4_0` quantization and have optimized for it. To achieve best performance on Adreno GPU, add `--pure` to `llama-quantize` (i.e., make all weights in `Q4_0`). For example,
## Binary Kernel Library
```sh
./llama-quantize --pure ggml-model-qwen2.5-3b-f16.gguf ggml-model-qwen-3b-Q4_0.gguf Q4_0
```
A prebuilt binary kernel library has been introduced for Adreno GPUs.
It currently targets X2 GPUs (X2-90, X2-85 and X2-45) found in Snapdragon X2 SoC.
The library currently contains kernels for MUL_MAT_ID with Q4_0, Q4_1, Q4_K, MXFP4.
The library must be manually downloaded from https://softwarecenter.qualcomm.com/catalog/item/Adreno_Kernel_Library_GGML.
Since `Q6_K` is also supported, `Q4_0` quantization without `--pure` will also work. However, the performance will be worse compared to pure `Q4_0` quantization.
To allow using the kernel library, add `-DGGML_OPENCL_USE_ADRENO_BIN_KERNELS=ON` when configuring with CMake.
Then, extract `adreno-opencl-kernels.dll` from the zip file downloaded from the above URL and put it alongside the executables.
If kernels compatible with the current GPU are found in the library, they will be loaded and used.
### `MXFP4` MoE Models
OpenAI gpt-oss models are MoE models in `MXFP4`. The quantized model will be in `MXFP4_MOE`, a mixture of `MXFP4` and `Q8_0`.
For this quantization, there is no need to specify `--pure`.
For gpt-oss-20b model, you can directly [download](https://huggingface.co/ggml-org/gpt-oss-20b-GGUF) the quantized GGUF file in `MXFP4_MOE` from Hugging Face.
Although it is possible to quantize gpt-oss-20b model in pure `Q4_0` (all weights in `Q4_0`), it is not recommended since `MXFP4` has been optimized for MoE while `Q4_0` is not. In addition, accuracy should degrade with such pure `Q4_0` quantization.
Hence, using the default `MXFP4_MOE` quantization (see the link above) is recommended for this model.
> Note that the `Q4_0` model found [here](https://huggingface.co/unsloth/gpt-oss-20b-GGUF/blob/main/gpt-oss-20b-Q4_0.gguf) is a mixture of `Q4_0`, `Q8_0` and `MXFP4` and gives better performance than `MXFP4_MOE` quantization.
## CMake Options
The OpenCL backend has the following CMake options that control the behavior of the backend.
| CMake options | Default value | Description |
|:---------------------------------:|:--------------:|:------------------------------------------|
| `GGML_OPENCL_EMBED_KERNELS` | `ON` | Embed OpenCL kernels into the executable. |
| `GGML_OPENCL_USE_ADRENO_KERNELS` | `ON` | Use kernels optimized for Adreno. |
| CMake options | Default value | Description |
|:------------------------------------:|:--------------:|:------------------------------------------|
| `GGML_OPENCL_EMBED_KERNELS` | `ON` | Embed OpenCL kernels into the executable. |
| `GGML_OPENCL_USE_ADRENO_KERNELS` | `ON` | Use kernels optimized for Adreno. |
| `GGML_OPENCL_USE_ADRENO_BIN_KERNELS` | `OFF` | Allow using binary kernel lib for Adreno. |
## Android
@@ -277,6 +290,5 @@ ninja
## TODO
- Optimization for Q6_K
- Support and optimization for Q4_K
- Improve flash attention
- Improve OpenCL C kernels performance
+7 -4
View File
@@ -1144,6 +1144,11 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor_impl(ggml_backend_m
ggml_context * simple_ctx = stc.ctxs[j].get();
ggml_backend_buffer_t simple_buf = buf_ctx->bufs[j].get();
if ((simple_buf != nullptr) && ggml_backend_buffer_is_multi_buffer(simple_buf)) {
// see https://github.com/ggml-org/llama.cpp/issues/22197
GGML_ABORT("multi buffers are not supported by the meta backend");
}
if (split_dim >= 0 && split_dim < GGML_MAX_DIMS) {
// TODO: the following assert fails for llama-parallel even though the results are correct:
// GGML_ASSERT(ggml_is_contiguously_allocated(tensor));
@@ -1245,9 +1250,8 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer
static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(buffer);
GGML_ASSERT(ggml_is_contiguous(tensor));
const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false);
GGML_ASSERT(ggml_is_contiguous(tensor) || split_state.axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED);
if (split_state.n_segments != 1 || split_state.nr[0] != 1) {
GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS);
@@ -1360,9 +1364,8 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg
static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(buffer);
GGML_ASSERT(ggml_is_contiguous(tensor));
const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false);
GGML_ASSERT(ggml_is_contiguous(tensor) || split_state.axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED);
if (split_state.n_segments != 1 || split_state.nr[0] != 1) {
GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS);
+15 -3
View File
@@ -1913,7 +1913,11 @@ static void ggml_compute_forward_concat_any(
GGML_ASSERT(dim >= 0 && dim < 4);
int64_t o[4] = {0, 0, 0, 0};
o[dim] = src0->ne[dim];
if (dim == 0) {
o[dim] = src0->ne[dim]/ggml_blck_size(src0->type);
} else {
o[dim] = src0->ne[dim];
}
const char * x;
@@ -1921,8 +1925,8 @@ static void ggml_compute_forward_concat_any(
for (int i3 = 0; i3 < ne3; i3++) {
for (int i2 = ith; i2 < ne2; i2 += nth) {
for (int i1 = 0; i1 < ne1; i1++) {
for (int i0 = 0; i0 < ne0; i0++) {
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
for (int i0 = 0; i0 < ne0/ggml_blck_size(dst->type); i0++) {
if (i0 < ne00/ggml_blck_size(src0->type) && i1 < ne01 && i2 < ne02 && i3 < ne03) {
x = (const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03;
} else {
x = (const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13;
@@ -2071,6 +2075,14 @@ void ggml_compute_forward_concat(
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
if (ggml_is_quantized(src0->type)) {
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(src0->ne[0] % ggml_blck_size(src0->type) == 0);
GGML_ASSERT(src1->ne[0] % ggml_blck_size(src1->type) == 0);
}
switch (src0->type) {
case GGML_TYPE_F16:
+32 -20
View File
@@ -152,8 +152,8 @@ static void concat_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml
src0_d + i3*(src0->nb[3] / sizeof(T)),
src1_d + i3*(src1->nb[3] / sizeof(T)),
dst_d + i3*( dst->nb[3] / sizeof(T)),
src0->ne[0], src0->ne[1], src0->ne[2],
dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
ggml_row_size(src0->type, src0->ne[0])/sizeof(T), src0->ne[1], src0->ne[2],
ggml_row_size(dst->type, dst->ne[0])/sizeof(T), dst->ne[1], dst->ne[2], dim, stream);
}
} else {
const size_t size0 = ggml_nbytes(src0);
@@ -163,6 +163,8 @@ static void concat_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml
CUDA_CHECK(cudaMemcpyAsync((char *) dst->data + size0, src1->data, size1, cudaMemcpyDeviceToDevice, stream));
}
} else {
GGML_ASSERT(!ggml_is_quantized(src0->type));
dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
auto launch_kernel = [&](auto dim) {
concat_non_cont<T, dim><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
@@ -204,24 +206,34 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(src0->type == src1->type);
GGML_ASSERT(dst->type == src0->type);
GGML_ASSERT(!ggml_is_quantized(src0->type));
GGML_ASSERT(ggml_blck_size(src0->type) == 1);
switch (ggml_type_size(src0->type)) {
case 1:
concat_cuda<uint8_t>(src0, src1, dst, dim, stream);
break;
case 2:
concat_cuda<uint16_t>(src0, src1, dst, dim, stream);
break;
case 4:
concat_cuda<uint32_t>(src0, src1, dst, dim, stream);
break;
case 8:
concat_cuda<uint64_t>(src0, src1, dst, dim, stream);
break;
default:
GGML_ABORT("Unsupported type size: %zu", ggml_type_size(src0->type));
break;
if (ggml_is_quantized(src0->type)) {
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(src0->ne[0] % ggml_blck_size(src0->type) == 0);
GGML_ASSERT(src1->ne[0] % ggml_blck_size(src1->type) == 0);
// if tensors are contiguous and ne[0] is multiple of the block size we can concat both tensors as byte tensors
concat_cuda<uint8_t>(src0, src1, dst, dim, stream);
} else {
GGML_ASSERT(ggml_blck_size(src0->type) == 1);
switch (ggml_type_size(src0->type)) {
case 1:
concat_cuda<uint8_t>(src0, src1, dst, dim, stream);
break;
case 2:
concat_cuda<uint16_t>(src0, src1, dst, dim, stream);
break;
case 4:
concat_cuda<uint32_t>(src0, src1, dst, dim, stream);
break;
case 8:
concat_cuda<uint64_t>(src0, src1, dst, dim, stream);
break;
default:
GGML_ABORT("Unsupported type size: %zu", ggml_type_size(src0->type));
break;
}
}
}
+40 -25
View File
@@ -10,6 +10,7 @@ gated_delta_net_cuda(const float * q,
const float * beta,
const float * curr_state,
float * dst,
float * state,
int64_t H,
int64_t n_tokens,
int64_t n_seqs,
@@ -25,6 +26,7 @@ gated_delta_net_cuda(const float * q,
const uint3 neqk1_magic,
const uint3 rq3_magic,
float scale,
int64_t state_slot_stride,
int K) {
const uint32_t h_idx = blockIdx.x;
const uint32_t sequence = blockIdx.y;
@@ -35,9 +37,7 @@ gated_delta_net_cuda(const float * q,
const uint32_t iq1 = fastmodulo(h_idx, neqk1_magic);
const uint32_t iq3 = fastdiv(sequence, rq3_magic);
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
float * attn_data = dst;
float * state = dst + attn_score_elems;
// input state holds s0 only: [S_v, S_v, H, n_seqs] — seq stride is D = H * S_v * S_v.
// output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before.
@@ -145,10 +145,9 @@ gated_delta_net_cuda(const float * q,
if constexpr (keep_rs_t) {
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
// When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned.
const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output
const int target_slot = (int) n_tokens - 1 - t;
if (target_slot >= 0 && target_slot < K) {
float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset;
float * curr_state = state + target_slot * state_slot_stride;
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
@@ -171,13 +170,13 @@ template <bool KDA, bool keep_rs_t>
static void launch_gated_delta_net(
const float * q_d, const float * k_d, const float * v_d,
const float * g_d, const float * b_d, const float * s_d,
float * dst_d,
float * dst_d, float * state_d,
int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs,
int64_t sq1, int64_t sq2, int64_t sq3,
int64_t sv1, int64_t sv2, int64_t sv3,
int64_t sb1, int64_t sb2, int64_t sb3,
int64_t neqk1, int64_t rq3,
float scale, int K, cudaStream_t stream) {
float scale, int64_t state_slot_stride, int K, cudaStream_t stream) {
//TODO: Add chunked kernel for even faster pre-fill
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
const int num_warps = 4;
@@ -187,34 +186,32 @@ static void launch_gated_delta_net(
const uint3 neqk1_magic = init_fastdiv_values(neqk1);
const uint3 rq3_magic = init_fastdiv_values(rq3);
int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(grid_dims, block_dims, 0, stream);
switch (S_v) {
case 16:
ggml_cuda_kernel_launch(gated_delta_net_cuda<16, KDA, keep_rs_t>, launch_params,
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, state_slot_stride, K);
break;
case 32:
ggml_cuda_kernel_launch(gated_delta_net_cuda<32, KDA, keep_rs_t>, launch_params,
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, state_slot_stride, K);
break;
case 64: {
ggml_cuda_kernel_launch(gated_delta_net_cuda<64, KDA, keep_rs_t>, launch_params,
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, state_slot_stride, K);
break;
}
case 128: {
ggml_cuda_kernel_launch(gated_delta_net_cuda<128, KDA, keep_rs_t>, launch_params,
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, state_slot_stride, K);
break;
}
default:
@@ -223,7 +220,8 @@ static void launch_gated_delta_net(
}
}
void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
static void ggml_cuda_op_gated_delta_net_impl(
ggml_backend_cuda_context & ctx, ggml_tensor * dst, const ggml_cuda_gated_delta_net_fused_cache * cache) {
ggml_tensor * src_q = dst->src[0];
ggml_tensor * src_k = dst->src[1];
ggml_tensor * src_v = dst->src[2];
@@ -288,25 +286,42 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
const int K = ggml_get_op_params_i32(dst, 0);
const bool keep_rs = K > 1;
// recurrent state -> gdn_out tail (after attention scores), or the cache when fusing
float * state_d = dst_d + S_v * H * n_tokens * n_seqs;
int64_t state_slot_stride = S_v * S_v * H * n_seqs;
if (cache != nullptr) {
state_d = cache->data;
state_slot_stride = cache->slot_stride;
}
if (kda) {
if (keep_rs) {
launch_gated_delta_net<true, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
launch_gated_delta_net<true, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
sb1, sb2, sb3, neqk1, rq3, scale, state_slot_stride, K, stream);
} else {
launch_gated_delta_net<true, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
launch_gated_delta_net<true, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
sb1, sb2, sb3, neqk1, rq3, scale, state_slot_stride, K, stream);
}
} else {
if (keep_rs) {
launch_gated_delta_net<false, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
launch_gated_delta_net<false, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
sb1, sb2, sb3, neqk1, rq3, scale, state_slot_stride, K, stream);
} else {
launch_gated_delta_net<false, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
launch_gated_delta_net<false, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
sb1, sb2, sb3, neqk1, rq3, scale, state_slot_stride, K, stream);
}
}
}
void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_gated_delta_net_impl(ctx, dst, nullptr);
}
void ggml_cuda_op_gated_delta_net_fused_cache(
ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_cuda_gated_delta_net_fused_cache cache) {
ggml_cuda_op_gated_delta_net_impl(ctx, dst, &cache);
}
+10
View File
@@ -1,4 +1,14 @@
#include "common.cuh"
#include "ggml.h"
// fused-kernel recurrent-state output; strides in elements (per-seq stride is always D, set in-kernel)
struct ggml_cuda_gated_delta_net_fused_cache {
float * data; // rollback slot 0
int64_t slot_stride; // between rollback slots (0 when K==1)
};
void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
// same op, but writes the snapshot(s) into the cache instead of dst (see ggml_cuda_try_gdn_cache_fusion)
void ggml_cuda_op_gated_delta_net_fused_cache(ggml_backend_cuda_context & ctx, ggml_tensor * dst,
ggml_cuda_gated_delta_net_fused_cache cache);
+139 -14
View File
@@ -543,12 +543,42 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
// the memory allocation handle is no longer needed after mapping
CU_CHECK(cuMemRelease(handle));
// set access
CUmemAccessDesc access = {};
access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
access.location.id = device;
access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
CU_CHECK(cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1));
// VMM Bug fix for P2P access if GGML_CUDA_P2P is set, or if NCCL build
bool use_peer_access = getenv("GGML_CUDA_P2P") != nullptr;
#if defined(GGML_USE_NCCL)
use_peer_access = true;
#endif // defined(GGML_USE_NCCL)
if (use_peer_access) {
// NCCL implicitly enables peer access (cudaDeviceEnablePeerAccess), and
// GGML_CUDA_P2P enables it explicitly. Unlike cudaMalloc buffers, VMM
// allocations do not become peer-accessible from that alone, so access
// must be granted explicitly here.
std::vector<CUmemAccessDesc> access_descs;
const int device_count = ggml_cuda_info().device_count;
for (int id = 0; id < device_count; ++id) {
if (id != device) {
int can_access_peer = 0;
CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, device));
if (!can_access_peer) {
continue;
}
}
CUmemAccessDesc access = {};
access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
access.location.id = id;
access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
access_descs.push_back(access);
}
CU_CHECK(cuMemSetAccess(start_ptr, reserve_size, access_descs.data(), access_descs.size()));
} else {
// set access for non P2P
CUmemAccessDesc access = {};
access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
access.location.id = device;
access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
CU_CHECK(cuMemSetAccess(start_ptr, reserve_size, &access, 1));
}
// add to the pool
pool_size += reserve_size;
@@ -3251,6 +3281,11 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
GGML_UNUSED(backend);
}
static bool ggml_cuda_is_view_or_noop(const ggml_tensor * t) {
return ggml_is_empty(t) || t->op == GGML_OP_RESHAPE || t->op == GGML_OP_TRANSPOSE ||
t->op == GGML_OP_VIEW || t->op == GGML_OP_PERMUTE || t->op == GGML_OP_NONE;
}
#ifdef USE_CUDA_GRAPH
static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
@@ -3260,7 +3295,7 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
if (ggml_cuda_is_view_or_noop(node)) {
continue;
}
@@ -3403,6 +3438,70 @@ static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
return true;
}
// match gated_delta_net + the strided cpy that scatters its state snapshots into the cache
// (slot i -> rollback group i, slot 0 newest), so the kernel can write them and skip the cpy.
static int ggml_cuda_try_gdn_cache_fusion(
const ggml_cgraph * cgraph, int node_idx, ggml_cuda_gated_delta_net_fused_cache & fused_state_cpy) {
const ggml_tensor * gdn = cgraph->nodes[node_idx];
// the kernel skips the snapshot tail, so the gdn output must not be a graph output
if (gdn->op != GGML_OP_GATED_DELTA_NET || gdn->type != GGML_TYPE_F32 ||
(gdn->flags & GGML_TENSOR_FLAG_OUTPUT)) {
return 0;
}
const ggml_tensor * src_v = gdn->src[2];
const int64_t S_v = src_v->ne[0];
const int64_t H = src_v->ne[1];
const int64_t n_tokens = src_v->ne[2];
const int64_t n_seqs = src_v->ne[3];
const int64_t D = S_v * S_v * H;
const int64_t K = ggml_get_op_params_i32(gdn, 0); // snapshot slot count
const int64_t n_written = std::min<int64_t>(n_tokens, K); // newest n_written slots are written
// snapshot tail starts right after the attention scores
const size_t tail_off = ggml_row_size(GGML_TYPE_F32, S_v * H * n_tokens * n_seqs);
// snapshot cpy is the first real node after the gdn (skip views/no-ops)
const ggml_tensor * cpy = nullptr;
int skip = 0;
for (int j = node_idx + 1; j < cgraph->n_nodes && cpy == nullptr; ++j) {
const ggml_tensor * n = cgraph->nodes[j];
if (ggml_cuda_is_view_or_noop(n)) {
continue;
}
if (n->op != GGML_OP_CPY || (n->flags & GGML_TENSOR_FLAG_OUTPUT)) {
return 0;
}
cpy = n;
skip = j - node_idx;
}
if (cpy == nullptr) {
return 0;
}
const ggml_tensor * src = cpy->src[0]; // view of the gdn snapshot tail
const ggml_tensor * dst = cpy->src[1]; // cache view the kernel writes to
// src must be this gdn's snapshot tail (contiguous, at the tail offset)
if (src->op != GGML_OP_VIEW || src->view_src != gdn || src->view_offs != tail_off ||
!ggml_is_contiguous(src)) {
return 0;
}
// dst is the [D, n_seqs, n_written] cache view; require nb[1] == D (the per-seq stride the kernel
// assumes). ggml_cpy pins src to the same element count.
const std::array<int64_t, GGML_MAX_DIMS> expected_ne = { D, n_seqs, n_written, 1 };
if (dst->op != GGML_OP_VIEW || dst->type != GGML_TYPE_F32 || dst->data == nullptr ||
!std::equal(expected_ne.begin(), expected_ne.end(), dst->ne) ||
dst->nb[0] != ggml_type_size(GGML_TYPE_F32) || dst->nb[1] != (size_t) ggml_row_size(GGML_TYPE_F32, D)) {
return 0;
}
fused_state_cpy.data = (float *) dst->data; // rollback group 0 (newest)
fused_state_cpy.slot_stride = K > 1 ? (int64_t) (dst->nb[2] / sizeof(float)) : 0;
return skip;
}
static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int node_idx, ggml_cuda_topk_moe_args & args) {
args.sigmoid = false;
args.softmax = false;
@@ -3844,6 +3943,20 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph
ggml_tensor * node = cgraph->nodes[i];
// gated_delta_net -> cpy: scatter recurrent-state snapshots into the cache
if (node->op == GGML_OP_GATED_DELTA_NET) {
ggml_cuda_gated_delta_net_fused_cache fused_state_cpy;
const int nodes_to_skip = ggml_cuda_try_gdn_cache_fusion(cgraph, i, fused_state_cpy);
if (nodes_to_skip > 0) {
#ifdef GGML_CUDA_DEBUG
GGML_LOG_INFO("%s: fused gated_delta_net snapshot copies for %s (skipped %d nodes)\n",
__func__, node->name, nodes_to_skip);
#endif
ggml_cuda_op_gated_delta_net_fused_cache(*cuda_ctx, node, fused_state_cpy);
return nodes_to_skip;
}
}
//topk-moe
if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX ||
cgraph->nodes[i]->op == GGML_OP_ARGSORT) {
@@ -4372,7 +4485,7 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
#endif
prev_i = i;
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
if (ggml_cuda_is_view_or_noop(node)) {
continue;
}
@@ -5304,12 +5417,24 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
ggml_type src1_type = op->src[1]->type;
return src0_type == src1_type &&
src0_type == op->type &&
!ggml_is_quantized(src0_type) &&
ggml_blck_size(src0_type) == 1 &&
(ggml_type_size(src0_type) == 1 ||
ggml_type_size(src0_type) == 2 ||
ggml_type_size(src0_type) == 4 ||
ggml_type_size(src0_type) == 8);
(
(
ggml_is_quantized(src0_type) &&
ggml_is_contiguous(op->src[0]) &&
ggml_is_contiguous(op->src[1]) &&
op->src[0]->ne[0] % ggml_blck_size(src0_type) == 0 &&
op->src[1]->ne[0] % ggml_blck_size(src0_type) == 0
) || (
!ggml_is_quantized(src0_type) &&
ggml_blck_size(src0_type) == 1 &&
(
ggml_type_size(src0_type) == 1 ||
ggml_type_size(src0_type) == 2 ||
ggml_type_size(src0_type) == 4 ||
ggml_type_size(src0_type) == 8
)
)
);
} break;
case GGML_OP_CONV_TRANSPOSE_1D:
{
+7 -1
View File
@@ -312,6 +312,10 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
ggml_cuda_kernel_launch(topk_moe_cuda<256, has_bias>, launch_params,
logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config);
break;
case 288: // StepFun 3.7
ggml_cuda_kernel_launch(topk_moe_cuda<288, has_bias>, launch_params,
logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config);
break;
case 512:
ggml_cuda_kernel_launch(topk_moe_cuda<512, has_bias>, launch_params,
logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config);
@@ -377,8 +381,10 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
const ggml_tensor * weights,
const ggml_tensor * logits,
const ggml_tensor * ids) {
// must match an instantiation of launch_topk_moe_cuda: a power of 2 up to 512,
// or one of the non-power-of-2 expert counts of supported models
const int n_expert = ids->nb[1] / ids->nb[0];
if (((n_expert & (n_expert - 1)) != 0 || n_expert > 512) && n_expert != 576) {
if (((n_expert & (n_expert - 1)) != 0 || n_expert > 512) && n_expert != 288 && n_expert != 576) {
return false;
}
+5
View File
@@ -31,6 +31,11 @@ if (GGML_OPENCL_EMBED_KERNELS)
target_include_directories(${TARGET_NAME} PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/autogenerated")
endif ()
if (GGML_OPENCL_USE_ADRENO_BIN_KERNELS)
message(STATUS "OpenCL will use precompiled binary kernels for Adreno (improved performance on some platforms)")
add_compile_definitions(GGML_OPENCL_USE_ADRENO_BIN_KERNELS)
endif ()
function(ggml_opencl_add_kernel KNAME)
set(KERN_HDR ${CMAKE_CURRENT_BINARY_DIR}/autogenerated/${KNAME}.cl.h)
set(KERN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/kernels/${KNAME}.cl)
+313 -8
View File
@@ -13,6 +13,22 @@
#include "ggml-backend-impl.h"
#include "ggml.h"
#ifdef GGML_OPENCL_USE_ADRENO_BIN_KERNELS
#include "libdl.h"
#ifdef _WIN32
#define KERNEL_LIB_NAME "adreno-opencl-kernels.dll"
#else
#define KERNEL_LIB_NAME "libadreno-opencl-kernels.so"
#endif // _WIN32
#endif // GGML_OPENCL_USE_ADRENO_BIN_KERNELS
typedef const void * (*get_adreno_bin_kernel_func_t)(
const char * name,
const char * gpu_name,
const char * compiler_ver,
size_t * out_size
);
#include <CL/cl.h>
#include <inttypes.h>
@@ -476,6 +492,8 @@ struct ggml_backend_opencl_context {
bool adreno_has_large_buffer;
bool adreno_use_large_buffer;
bool adreno_use_bin_kernels;
get_adreno_bin_kernel_func_t get_adreno_bin_kernel_func = nullptr;
ggml_cl_compiler_version adreno_cl_compiler_version;
std::string kernel_compile_opts; // cached for lazy-compiled kernels.
@@ -718,15 +736,15 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_gated_delta_net_f32[4][2][2] = {};
cl_kernel kernel_timestep_embedding;
cl_kernel kernel_gemv_moe_q4_0_f32_ns, kernel_gemm_moe_q4_0_f32_ns;
cl_kernel kernel_gemv_moe_q4_1_f32_ns, kernel_gemm_moe_q4_1_f32_ns;
cl_kernel kernel_gemv_moe_q4_0_f32_ns, kernel_gemm_moe_q4_0_f32_ns, kernel_gemm_moe_q4_0_f32_ns_bin;
cl_kernel kernel_gemv_moe_q4_1_f32_ns, kernel_gemm_moe_q4_1_f32_ns, kernel_gemm_moe_q4_1_f32_ns_bin;
cl_kernel kernel_gemv_moe_q5_0_f32_ns, kernel_gemm_moe_q5_0_f32_ns;
cl_kernel kernel_gemv_moe_q5_1_f32_ns, kernel_gemm_moe_q5_1_f32_ns;
cl_kernel kernel_gemv_moe_q4_k_f32_ns, kernel_gemm_moe_q4_k_f32_ns;
cl_kernel kernel_gemv_moe_q4_k_f32_ns, kernel_gemm_moe_q4_k_f32_ns, kernel_gemm_moe_q4_k_f32_ns_bin;
cl_kernel kernel_gemv_moe_q5_k_f32_ns, kernel_gemm_moe_q5_k_f32_ns;
cl_kernel kernel_gemv_moe_q6_k_f32_ns, kernel_gemm_moe_q6_k_f32_ns;
cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32;
cl_kernel kernel_gemv_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns;
cl_kernel kernel_gemv_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns_bin;
cl_kernel kernel_moe_reorder_b;
cl_kernel kernel_moe_histogram, kernel_moe_scan, kernel_moe_fill, kernel_moe_scatter;
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
@@ -870,6 +888,20 @@ struct ggml_backend_opencl_context {
#endif
}
const void * get_adreno_bin_kernel(const std::string &kernel_name, size_t *bin_size) const {
if (!get_adreno_bin_kernel_func) {
return nullptr;
}
size_t sz;
const void * kernel_bin = get_adreno_bin_kernel_func(
kernel_name.c_str(), device_name.c_str(), driver_version.c_str(), &sz);
if (bin_size) {
*bin_size = sz;
}
return kernel_bin;
}
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
// Transpose kernels
cl_program program_transpose;
@@ -891,7 +923,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_gemv_noshuffle_q4_0_f32_32000_1_4096;
cl_kernel kernel_gemv_noshuffle_q4_1_f32;
cl_kernel kernel_gemm_noshuffle_q4_1_f32;
cl_kernel kernel_gemm_noshuffle_q8_0_f32;
cl_kernel kernel_gemm_noshuffle_q8_0_f32, kernel_gemm_noshuffle_q8_0_f32_bin;
cl_kernel kernel_gemv_noshuffle_q8_0_f32;
cl_kernel kernel_gemm_noshuffle_q1_0_f32;
cl_kernel kernel_gemv_noshuffle_q1_0_f32;
@@ -988,6 +1020,32 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co
return build_program_from_source_ex(ctx, dev, program_buffer, compile_opts, /*fatal=*/true);
}
static cl_program build_program_from_binary(cl_context ctx, cl_device_id dev, const char* program_buffer, const std::string &compile_opts, size_t bin_size = 0) {
cl_program p;
char *program_log;
size_t log_size;
int err;
p = clCreateProgramWithBinary(ctx, 1, &dev, &bin_size, (const unsigned char**)&program_buffer, NULL, &err);
if(err < 0) {
GGML_LOG_ERROR("OpenCL error creating program from binary");
exit(1);
}
err = clBuildProgram(p, 0, NULL, compile_opts.c_str(), NULL, NULL);
if(err < 0) {
clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);
program_log = (char*) malloc(log_size + 1);
program_log[log_size] = '\0';
clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, log_size + 1, program_log, NULL);
GGML_LOG_ERROR("ggml_opencl: kernel compile error:\n\n%s\n", program_log);
free(program_log);
exit(1);
}
return p;
}
static void load_cl_kernels_argsort(ggml_backend_opencl_context *backend_ctx) {
// compiler options for general kernels
auto opencl_c_std =
@@ -1014,6 +1072,17 @@ static void load_cl_kernels_argsort(ggml_backend_opencl_context *backend_ctx) {
}
}
static bool use_adreno_bin_kernels(ggml_backend_opencl_context * backend_ctx) {
#ifndef GGML_OPENCL_USE_ADRENO_BIN_KERNELS
return false;
#else
if (backend_ctx->gpu_family != GPU_FAMILY::ADRENO) {
return false;
}
return backend_ctx->adreno_use_bin_kernels;
#endif // GGML_OPENCL_USE_ADRENO_BIN_KERNELS
}
static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) {
if (backend_ctx->kernels_loaded) {
return;
@@ -3323,6 +3392,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) {
GGML_LOG_CONT(".");
}
// gemm_noshuffle_q8_0_f32_bin
{
size_t bin_size = 0;
backend_ctx->kernel_gemm_noshuffle_q8_0_f32_bin = nullptr;
if (use_adreno_bin_kernels(backend_ctx)) {
const char * kernel_bin = (const char *)backend_ctx->get_adreno_bin_kernel("gemm_noshuffle_q8_0_f32_ila", &bin_size);
if (kernel_bin && bin_size > 0) {
cl_program prog =
build_program_from_binary(backend_ctx->context, backend_ctx->device, kernel_bin, compile_opts, bin_size);
CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q8_0_f32_bin = clCreateKernel(prog, "kernel_gemm_noshuffle_q8_0_f32_ila", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
}
}
// gemv_noshuffle_general_q8_0_f32
{
std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std +
@@ -3424,6 +3511,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) {
GGML_LOG_CONT(".");
}
// gemm_moe_q4_1_f32_ns_bin
{
size_t bin_size = 0;
backend_ctx->kernel_gemm_moe_q4_1_f32_ns_bin = nullptr;
if (use_adreno_bin_kernels(backend_ctx)) {
const char * kernel_bin = (const char *)backend_ctx->get_adreno_bin_kernel("gemm_moe_q4_1_f32_ns_ila", &bin_size);
if (kernel_bin && bin_size > 0) {
cl_program prog =
build_program_from_binary(backend_ctx->context, backend_ctx->device, kernel_bin, CL_moe_compile_opts, bin_size);
CL_CHECK((backend_ctx->kernel_gemm_moe_q4_1_f32_ns_bin = clCreateKernel(prog, "kernel_gemm_moe_q4_1_f32_ns_ila", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
}
}
// gemv_moe_mxfp4_f32
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -3490,6 +3595,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) {
GGML_LOG_CONT(".");
}
// gemm_moe_q4_0_f32_ns_bin
{
size_t bin_size = 0;
backend_ctx->kernel_gemm_moe_q4_0_f32_ns_bin = nullptr;
if (use_adreno_bin_kernels(backend_ctx)) {
const char * kernel_bin = (const char *)backend_ctx->get_adreno_bin_kernel("gemm_moe_q4_0_f32_ns_ila", &bin_size);
if (kernel_bin && bin_size > 0) {
cl_program prog =
build_program_from_binary(backend_ctx->context, backend_ctx->device, kernel_bin, CL_moe_compile_opts, bin_size);
CL_CHECK((backend_ctx->kernel_gemm_moe_q4_0_f32_ns_bin = clCreateKernel(prog, "kernel_gemm_moe_q4_0_f32_ns_ila", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
}
}
// gemv_moe_q5_0_f32_ns
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -3592,6 +3715,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) {
GGML_LOG_CONT(".");
}
// gemm_moe_q4_k_f32_ns_bin
{
size_t bin_size = 0;
backend_ctx->kernel_gemm_moe_q4_k_f32_ns_bin = nullptr;
if (use_adreno_bin_kernels(backend_ctx)) {
const char * kernel_bin = (const char *)backend_ctx->get_adreno_bin_kernel("gemm_moe_q4_k_f32_ns_ila", &bin_size);
if (kernel_bin && bin_size > 0) {
cl_program prog =
build_program_from_binary(backend_ctx->context, backend_ctx->device, kernel_bin, CL_moe_compile_opts, bin_size);
CL_CHECK((backend_ctx->kernel_gemm_moe_q4_k_f32_ns_bin = clCreateKernel(prog, "kernel_gemm_moe_q4_k_f32_ns_ila", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
}
}
// gemv_moe_q5_k_f32_ns
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -3689,9 +3830,27 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) {
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_mxfp4_f32_ns", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_mxfp4_f32_ns", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// gemm_moe_mxfp4_f32_ns_bin
{
size_t bin_size = 0;
backend_ctx->kernel_gemm_moe_mxfp4_f32_ns_bin = nullptr;
if (use_adreno_bin_kernels(backend_ctx)) {
const char * kernel_bin = (const char *)backend_ctx->get_adreno_bin_kernel("gemm_moe_mxfp4_f32_ns_ila", &bin_size);
if (kernel_bin && bin_size > 0) {
cl_program prog =
build_program_from_binary(backend_ctx->context, backend_ctx->device, kernel_bin, CL_moe_compile_opts, bin_size);
CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32_ns_bin = clCreateKernel(prog, "kernel_gemm_moe_mxfp4_f32_ns_ila", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
}
}
// moe_reorder_b
@@ -4770,6 +4929,27 @@ static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev) {
backend_ctx->adreno_use_large_buffer = getenv("GGML_OPENCL_ADRENO_USE_LARGE_BUFFER") != nullptr &&
backend_ctx->gpu_family == GPU_FAMILY::ADRENO;
#ifdef GGML_OPENCL_USE_ADRENO_BIN_KERNELS
// try loading adreno binary kernels if enabled
// if fails to load, builtin kernels will be used
{
dl_handle * kernel_lib_handle = dl_load_library(KERNEL_LIB_NAME);
backend_ctx->adreno_use_bin_kernels = false;
if (kernel_lib_handle) {
backend_ctx->get_adreno_bin_kernel_func = (get_adreno_bin_kernel_func_t)dl_get_sym(kernel_lib_handle, "get_adreno_kernels");
if (backend_ctx->get_adreno_bin_kernel_func) {
GGML_LOG_INFO("ggml_opencl: loaded bin kernel library %s\n", KERNEL_LIB_NAME);
backend_ctx->adreno_use_bin_kernels = true;
} else {
GGML_LOG_INFO("ggml_opencl: bin kernel library %s is invalid, will use builtin kernels\n", KERNEL_LIB_NAME);
}
} else {
GGML_LOG_INFO("ggml_opencl: failed to load %s, will use builtin kernels\n", KERNEL_LIB_NAME);
}
}
#endif // GGML_OPENCL_USE_ADRENO_BIN_KERNELS
cl_int err;
// A local ref of cl_context for convenience
@@ -14972,6 +15152,99 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t
CL_CHECK(clReleaseMemObject(b_img));
CL_CHECK(clReleaseMemObject(b_sub_buf));
} else {
// use bin kernel if available
if (backend_ctx->kernel_gemm_noshuffle_q8_0_f32_bin) {
int K_pad = K;
cl_mem b_sub_buf = nullptr;
cl_mem d_sub_buf = nullptr;
cl_mem a_img = nullptr;
cl_mem s_img = nullptr;
cl_mem b_img = nullptr;
cl_mem d_img = nullptr;
// subbuffer for activations
region.origin = offset1;
region.size = K_pad * N * sizeof(float);
CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err), err));
// Create subbuffer and image1d_buffer for dst
region.origin = (extrad->offset); // + dst->view_offs;
region.size = M * N * sizeof(float);
CL_CHECK((d_sub_buf = clCreateSubBuffer((extrad->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err), err));
// create an image for A
img_fmt = { CL_R, CL_FLOAT};
memset(&img_desc, 0, sizeof(img_desc));
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
img_desc.image_width = M * K / 4; // Divide by 4 for char -> float
img_desc.buffer = extra0_q8_0->q;
CL_CHECK((a_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
// create an image for Scale
img_fmt = { CL_R, CL_HALF_FLOAT};
memset(&img_desc, 0, sizeof(img_desc));
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
img_desc.image_width = M * K / 32; // Block size is 32
img_desc.buffer = extra0_q8_0->d;
CL_CHECK((s_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
// create an image for B from sub_buffer
img_fmt = {CL_R, CL_FLOAT};
memset(&img_desc, 0, sizeof(img_desc));
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
img_desc.image_width = K_pad * N;
img_desc.buffer = b_sub_buf;
CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
// img for d
img_fmt = {CL_R, CL_FLOAT};
memset(&img_desc, 0, sizeof(img_desc));
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
img_desc.image_width = M * N;
img_desc.buffer = d_sub_buf;
CL_CHECK((d_img = clCreateImage(context, CL_MEM_WRITE_ONLY, &img_fmt, &img_desc, NULL, &err), err));
// gemm
kernel = backend_ctx->kernel_gemm_noshuffle_q8_0_f32_bin;
bool layoutA_Mfirst = true;
bool layoutS_Mfirst = true;
bool layoutB_Nfirst = false;
bool layoutC_Mfirst = true;
cl_uint lineStrideMatrixAinBytes = layoutA_Mfirst ? M * 4 : K; // int8
cl_uint lineStrideMatrixSinBytes = layoutS_Mfirst ? M * 2 : (K / 32) * 2; // fp16
cl_uint lineStrideMatrixBinBytes = layoutB_Nfirst ? N * 4 : K_pad * 4; // fp32
cl_uint lineStrideMatrixCinBytes = layoutC_Mfirst ? M * 4 : N * 4; // fp32
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &a_img));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &s_img));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &extra1->offset));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &d_img));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &extrad->offset));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &K));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &lineStrideMatrixAinBytes));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &lineStrideMatrixSinBytes));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &lineStrideMatrixBinBytes));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &lineStrideMatrixCinBytes));
size_t global_work_size[] = { 64, (size_t)CEIL_DIV(M, 64), (size_t)CEIL_DIV(N, 64)};
size_t local_work_size[] = { 64, 2, 2 };
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
CL_CHECK(clReleaseMemObject(b_sub_buf));
CL_CHECK(clReleaseMemObject(d_sub_buf));
CL_CHECK(clReleaseMemObject(a_img));
CL_CHECK(clReleaseMemObject(s_img));
CL_CHECK(clReleaseMemObject(b_img));
CL_CHECK(clReleaseMemObject(d_img));
return;
}
cl_mem b_sub_buf = nullptr;
cl_mem b_sub_buf_trans = nullptr;
cl_mem b_img = nullptr;
@@ -17825,6 +18098,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
} else { // for gemm
kernel = backend_ctx->kernel_gemm_moe_q4_0_f32_ns;
if (backend_ctx->kernel_gemm_moe_q4_0_f32_ns_bin) {
kernel = backend_ctx->kernel_gemm_moe_q4_0_f32_ns_bin;
}
// Reorder router if called from test-backend-ops or when new router is generated.
// Otherwise reuse the reordered result from previous mul_mat_id call.
@@ -17870,6 +18146,11 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
cl_image_desc image_desc_buf_src1;
image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}};
if (backend_ctx->kernel_gemm_moe_q4_0_f32_ns_bin) {
// bin kernel uses slightly different image format
image_format_buf_src1 = {CL_R, CL_FLOAT};
image_desc_buf_src1.image_width = static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size);
}
image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
CL_CHECK(status);
@@ -18042,6 +18323,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
} else { // for gemm
kernel = backend_ctx->kernel_gemm_moe_q4_1_f32_ns;
if (backend_ctx->kernel_gemm_moe_q4_1_f32_ns_bin) {
kernel = backend_ctx->kernel_gemm_moe_q4_1_f32_ns_bin;
}
// Reorder router if called from test-backend-ops or when new router is generated.
// Otherwise reuse the reordered result from previous mul_mat_id call.
@@ -18087,6 +18371,11 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
cl_image_desc image_desc_buf_src1;
image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}};
if (backend_ctx->kernel_gemm_moe_q4_1_f32_ns_bin) {
// bin kernel uses slightly different image format
image_format_buf_src1 = {CL_R, CL_FLOAT};
image_desc_buf_src1.image_width = static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size);
}
image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
CL_CHECK(status);
@@ -18648,6 +18937,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
} else { // for gemm
kernel = backend_ctx->kernel_gemm_moe_q4_k_f32_ns;
if (backend_ctx->kernel_gemm_moe_q4_k_f32_ns_bin) {
kernel = backend_ctx->kernel_gemm_moe_q4_k_f32_ns_bin;
}
// Reorder router if called from test-backend-ops or when new router is generated.
// Otherwise reuse the reordered result from previous mul_mat_id call.
@@ -18689,6 +18981,11 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
CL_CHECK(status);
cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}};
if (backend_ctx->kernel_gemm_moe_q4_k_f32_ns_bin) {
// bin kernel uses slightly different image format
image_format_buf_src1 = {CL_R, CL_FLOAT};
image_desc_buf_src1.image_width = static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size);
}
image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
CL_CHECK(status);
@@ -19172,6 +19469,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
} else { // for gemm
kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32_ns;
if (backend_ctx->kernel_gemm_moe_mxfp4_f32_ns_bin) {
kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32_ns_bin;
}
// Reorder router if called from test-backend-ops or when new router is generated.
// Otherwise reuse the reordered result from previous mul_mat_id call.
@@ -19218,6 +19518,11 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
cl_image_desc image_desc_buf_src1;
image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}};
if (backend_ctx->kernel_gemm_moe_mxfp4_f32_ns_bin) {
// bin kernel uses slightly different image format
image_format_buf_src1 = {CL_R, CL_FLOAT};
image_desc_buf_src1.image_width = static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size);
}
image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
CL_CHECK(status);
+79
View File
@@ -0,0 +1,79 @@
#pragma once
#ifdef _WIN32
# define WIN32_LEAN_AND_MEAN
# ifndef NOMINMAX
# define NOMINMAX
# endif
# include <windows.h>
# include <winevt.h>
#else
# include <dlfcn.h>
# include <unistd.h>
#endif
#include <filesystem>
namespace fs = std::filesystem;
#ifdef _WIN32
using dl_handle = std::remove_pointer_t<HMODULE>;
struct dl_handle_deleter {
void operator()(HMODULE handle) {
FreeLibrary(handle);
}
};
static inline dl_handle * dl_load_library(const fs::path & path) {
// suppress error dialogs for missing DLLs
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
HMODULE handle = LoadLibraryW(path.wstring().c_str());
SetErrorMode(old_mode);
return handle;
}
static inline void * dl_get_sym(dl_handle * handle, const char * name) {
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
void * p = (void *) GetProcAddress(handle, name);
SetErrorMode(old_mode);
return p;
}
static inline const char * dl_error() {
return "";
}
#else
using dl_handle = void;
struct dl_handle_deleter {
void operator()(void * handle) {
dlclose(handle);
}
};
static inline dl_handle * dl_load_library(const fs::path & path) {
dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
return handle;
}
static inline void * dl_get_sym(dl_handle * handle, const char * name) {
return dlsym(handle, name);
}
static inline const char * dl_error() {
const char *rslt = dlerror();
return rslt != nullptr ? rslt : "";
}
#endif
+6
View File
@@ -159,6 +159,9 @@ extern "C" {
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
};
// Get the model file type (quantization) as a string, e.g. "Q8_0" or "Q4_K - Medium"
LLAMA_API const char * llama_ftype_name(enum llama_ftype ftype);
enum llama_rope_scaling_type {
LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED = -1,
LLAMA_ROPE_SCALING_TYPE_NONE = 0,
@@ -606,6 +609,9 @@ extern "C" {
// Get a string describing the model type
LLAMA_API int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);
// Get the model file type (quantization), e.g. LLAMA_FTYPE_MOSTLY_Q8_0
LLAMA_API enum llama_ftype llama_model_ftype(const struct llama_model * model);
// Returns the total size of all the tensors in the model in bytes
LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
@@ -1,80 +0,0 @@
{% macro render_content(content) %}{% if content is none %}{{- '' }}{% elif content is string %}{{- content }}{% elif content is mapping %}{{- content['value'] if 'value' in content else content['text'] }}{% elif content is iterable %}{% for item in content %}{% if item.type == 'text' %}{{- item['value'] if 'value' in item else item['text'] }}{% elif item.type == 'image' %}<im_patch>{% endif %}{% endfor %}{% endif %}{% endmacro %}
{{bos_token}}{%- if tools %}
{{- '<|im_start|>system\n' }}
{%- if messages[0].role == 'system' %}
{{- render_content(messages[0].content) + '\n\n' }}
{%- endif %}
{{- "# Tools\n\nYou have access to the following functions in JSONSchema format:\n\n<tools>" }}
{%- for tool in tools %}
{{- "\n" }}
{{- tool | tojson(ensure_ascii=False) }}
{%- endfor %}
{{- "\n</tools>\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...>\n...\n</function> block must be nested within <tool_call>\n...\n</tool_call> XML tags\n- Required parameters MUST be specified\n</IMPORTANT><|im_end|>\n" }}
{%- else %}
{%- if messages[0].role == 'system' %}
{{- '<|im_start|>system\n' + render_content(messages[0].content) + '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
{%- for message in messages[::-1] %}
{%- set index = (messages|length - 1) - loop.index0 %}
{%- if ns.multi_step_tool and message.role == "user" and render_content(message.content) is string and not(render_content(message.content).startswith('<tool_response>') and render_content(message.content).endswith('</tool_response>')) %}
{%- set ns.multi_step_tool = false %}
{%- set ns.last_query_index = index %}
{%- endif %}
{%- endfor %}
{%- for message in messages %}
{%- set content = render_content(message.content) %}
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
{%- set role_name = 'observation' if (message.role == "system" and not loop.first and message.name == 'observation') else message.role %}
{{- '<|im_start|>' + role_name + '\n' + content + '<|im_end|>' + '\n' }}
{%- elif message.role == "assistant" %}
{%- if message.reasoning_content is string %}
{%- set reasoning_content = render_content(message.reasoning_content) %}
{%- else %}
{%- if '</think>' in content %}
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
{%- else %}
{%- set reasoning_content = '' %}
{%- endif %}
{%- endif %}
{%- if loop.index0 > ns.last_query_index %}
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n' + content }}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{%- if message.tool_calls %}
{%- for tool_call in message.tool_calls %}
{%- if tool_call.function is defined %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '<tool_call>\n<function=' + tool_call.name + '>\n' }}
{%- if tool_call.arguments is defined %}
{%- set arguments = tool_call.arguments %}
{%- for args_name, args_value in arguments|items %}
{{- '<parameter=' + args_name + '>\n' }}
{%- set args_value = args_value | tojson(ensure_ascii=False) | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
{{- args_value }}
{{- '\n</parameter>\n' }}
{%- endfor %}
{%- endif %}
{{- '</function>\n</tool_call>' }}
{%- endfor %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
{{- '<|im_start|>tool_response\n' }}
{%- endif %}
{{- '<tool_response>' }}
{{- content }}
{{- '</tool_response>' }}
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n<think>\n' }}
{%- endif %}
+1 -1
View File
@@ -5,7 +5,7 @@ import os
import sys
import subprocess
HTTPLIB_VERSION = "refs/tags/v0.48.0"
HTTPLIB_VERSION = "refs/tags/v0.49.0"
vendor = {
"https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp",
+6 -6
View File
@@ -494,11 +494,11 @@ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
if (self_k_rot) {
if (self_k_rot && self_k_rot->buffer) {
mctx->set_input_k_rot(self_k_rot);
}
if (self_v_rot) {
if (self_v_rot && self_v_rot->buffer) {
mctx->set_input_v_rot(self_v_rot);
}
}
@@ -592,19 +592,19 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
}
if (self_k_rot) {
if (self_k_rot && self_k_rot->buffer) {
mctx->get_base()->set_input_k_rot(self_k_rot);
}
if (self_v_rot) {
if (self_v_rot && self_v_rot->buffer) {
mctx->get_base()->set_input_v_rot(self_v_rot);
}
if (self_k_rot_swa) {
if (self_k_rot_swa && self_k_rot_swa->buffer) {
mctx->get_swa()->set_input_k_rot(self_k_rot_swa);
}
if (self_v_rot_swa) {
if (self_v_rot_swa && self_v_rot_swa->buffer) {
mctx->get_swa()->set_input_v_rot(self_v_rot_swa);
}
}
+46 -44
View File
@@ -27,52 +27,54 @@ const char * llama_file_version_name(llama_fver version) {
return "unknown";
}
static std::string llama_model_ftype_name(llama_ftype ftype) {
if (ftype & LLAMA_FTYPE_GUESSED) {
return llama_model_ftype_name((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) + " (guessed)";
}
#define LLAMA_FTYPE_PREFIX "(guessed) "
switch (ftype) {
case LLAMA_FTYPE_ALL_F32: return "all F32";
case LLAMA_FTYPE_MOSTLY_F16: return "F16";
case LLAMA_FTYPE_MOSTLY_BF16: return "BF16";
case LLAMA_FTYPE_MOSTLY_Q1_0: return "Q1_0";
case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0";
case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1";
case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0";
case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1";
case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0";
case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return "MXFP4 MoE";
case LLAMA_FTYPE_MOSTLY_NVFP4: return "NVFP4";
case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small";
case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small";
case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "Q3_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "Q3_K - Large";
case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "Q4_K - Small";
case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "Q4_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small";
case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K";
case LLAMA_FTYPE_MOSTLY_TQ1_0: return "TQ1_0 - 1.69 bpw ternary";
case LLAMA_FTYPE_MOSTLY_TQ2_0: return "TQ2_0 - 2.06 bpw ternary";
case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return "IQ2_XXS - 2.0625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_S: return "IQ2_S - 2.5 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_M: return "IQ2_M - 2.7 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_XXS: return "IQ3_XXS - 3.0625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ1_S: return "IQ1_S - 1.5625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ1_M: return "IQ1_M - 1.75 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw";
default: return "unknown, may not work";
const char * llama_ftype_name(llama_ftype ftype) {
static constexpr size_t guessed_prefix_len = sizeof(LLAMA_FTYPE_PREFIX) - 1;
const char * name;
switch ((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) {
case LLAMA_FTYPE_ALL_F32: name = LLAMA_FTYPE_PREFIX "all F32"; break;
case LLAMA_FTYPE_MOSTLY_F16: name = LLAMA_FTYPE_PREFIX "F16"; break;
case LLAMA_FTYPE_MOSTLY_BF16: name = LLAMA_FTYPE_PREFIX "BF16"; break;
case LLAMA_FTYPE_MOSTLY_Q1_0: name = LLAMA_FTYPE_PREFIX "Q1_0"; break;
case LLAMA_FTYPE_MOSTLY_Q4_0: name = LLAMA_FTYPE_PREFIX "Q4_0"; break;
case LLAMA_FTYPE_MOSTLY_Q4_1: name = LLAMA_FTYPE_PREFIX "Q4_1"; break;
case LLAMA_FTYPE_MOSTLY_Q5_0: name = LLAMA_FTYPE_PREFIX "Q5_0"; break;
case LLAMA_FTYPE_MOSTLY_Q5_1: name = LLAMA_FTYPE_PREFIX "Q5_1"; break;
case LLAMA_FTYPE_MOSTLY_Q8_0: name = LLAMA_FTYPE_PREFIX "Q8_0"; break;
case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: name = LLAMA_FTYPE_PREFIX "MXFP4 MoE"; break;
case LLAMA_FTYPE_MOSTLY_NVFP4: name = LLAMA_FTYPE_PREFIX "NVFP4"; break;
case LLAMA_FTYPE_MOSTLY_Q2_K: name = LLAMA_FTYPE_PREFIX "Q2_K - Medium"; break;
case LLAMA_FTYPE_MOSTLY_Q2_K_S: name = LLAMA_FTYPE_PREFIX "Q2_K - Small"; break;
case LLAMA_FTYPE_MOSTLY_Q3_K_S: name = LLAMA_FTYPE_PREFIX "Q3_K - Small"; break;
case LLAMA_FTYPE_MOSTLY_Q3_K_M: name = LLAMA_FTYPE_PREFIX "Q3_K - Medium"; break;
case LLAMA_FTYPE_MOSTLY_Q3_K_L: name = LLAMA_FTYPE_PREFIX "Q3_K - Large"; break;
case LLAMA_FTYPE_MOSTLY_Q4_K_S: name = LLAMA_FTYPE_PREFIX "Q4_K - Small"; break;
case LLAMA_FTYPE_MOSTLY_Q4_K_M: name = LLAMA_FTYPE_PREFIX "Q4_K - Medium"; break;
case LLAMA_FTYPE_MOSTLY_Q5_K_S: name = LLAMA_FTYPE_PREFIX "Q5_K - Small"; break;
case LLAMA_FTYPE_MOSTLY_Q5_K_M: name = LLAMA_FTYPE_PREFIX "Q5_K - Medium"; break;
case LLAMA_FTYPE_MOSTLY_Q6_K: name = LLAMA_FTYPE_PREFIX "Q6_K"; break;
case LLAMA_FTYPE_MOSTLY_TQ1_0: name = LLAMA_FTYPE_PREFIX "TQ1_0 - 1.69 bpw ternary"; break;
case LLAMA_FTYPE_MOSTLY_TQ2_0: name = LLAMA_FTYPE_PREFIX "TQ2_0 - 2.06 bpw ternary"; break;
case LLAMA_FTYPE_MOSTLY_IQ2_XXS: name = LLAMA_FTYPE_PREFIX "IQ2_XXS - 2.0625 bpw"; break;
case LLAMA_FTYPE_MOSTLY_IQ2_XS: name = LLAMA_FTYPE_PREFIX "IQ2_XS - 2.3125 bpw"; break;
case LLAMA_FTYPE_MOSTLY_IQ2_S: name = LLAMA_FTYPE_PREFIX "IQ2_S - 2.5 bpw"; break;
case LLAMA_FTYPE_MOSTLY_IQ2_M: name = LLAMA_FTYPE_PREFIX "IQ2_M - 2.7 bpw"; break;
case LLAMA_FTYPE_MOSTLY_IQ3_XS: name = LLAMA_FTYPE_PREFIX "IQ3_XS - 3.3 bpw"; break;
case LLAMA_FTYPE_MOSTLY_IQ3_XXS: name = LLAMA_FTYPE_PREFIX "IQ3_XXS - 3.0625 bpw"; break;
case LLAMA_FTYPE_MOSTLY_IQ1_S: name = LLAMA_FTYPE_PREFIX "IQ1_S - 1.5625 bpw"; break;
case LLAMA_FTYPE_MOSTLY_IQ1_M: name = LLAMA_FTYPE_PREFIX "IQ1_M - 1.75 bpw"; break;
case LLAMA_FTYPE_MOSTLY_IQ4_NL: name = LLAMA_FTYPE_PREFIX "IQ4_NL - 4.5 bpw"; break;
case LLAMA_FTYPE_MOSTLY_IQ4_XS: name = LLAMA_FTYPE_PREFIX "IQ4_XS - 4.25 bpw"; break;
case LLAMA_FTYPE_MOSTLY_IQ3_S: name = LLAMA_FTYPE_PREFIX "IQ3_S - 3.4375 bpw"; break;
case LLAMA_FTYPE_MOSTLY_IQ3_M: name = LLAMA_FTYPE_PREFIX "IQ3_S mix - 3.66 bpw"; break;
default: name = LLAMA_FTYPE_PREFIX "unknown, may not work"; break;
}
return (ftype & LLAMA_FTYPE_GUESSED) ? name : name + guessed_prefix_len;
}
#undef LLAMA_FTYPE_PREFIX
// return a list of splits for a given path
// for example, given "<name>-00002-of-00004.gguf", returns list of all 4 splits
static std::vector<std::string> llama_get_list_splits(const std::string & path, const int idx, const int n_split) {
@@ -1693,12 +1695,12 @@ bool llama_model_loader::load_all_data(
}
std::string llama_model_loader::ftype_name() const {
return llama_model_ftype_name(ftype);
return llama_ftype_name(ftype);
}
void llama_model_loader::print_info() const {
LLAMA_LOG_INFO("%s: file format = %s\n", __func__, llama_file_version_name(fver));
LLAMA_LOG_INFO("%s: file type = %s\n", __func__, llama_model_ftype_name(ftype).c_str());
LLAMA_LOG_INFO("%s: file type = %s\n", __func__, llama_ftype_name(ftype));
if (n_bytes < GiB) {
LLAMA_LOG_INFO("%s: file size = %.2f MiB (%.2f BPW) \n", __func__, n_bytes/1024.0/1024.0, n_bytes*8.0/n_elements);
} else {
+20
View File
@@ -987,6 +987,8 @@ struct llama_model::impl {
std::string desc_str;
llama_ftype ftype = LLAMA_FTYPE_ALL_F32;
// model memory mapped files
llama_mmaps mappings;
@@ -1010,9 +1012,17 @@ struct llama_model::impl {
std::vector<layer_dev> dev_layer;
bool has_tensor_overrides;
std::vector<float> tensor_split_owned;
};
llama_model::llama_model(const llama_model_params & params) : params(params), pimpl(std::make_unique<impl>()) {
if (params.tensor_split != nullptr) {
// llama_model_params stores tensor_split as a borrowed pointer, but the model
// may need it later for tensor-parallel KV-cache split metadata.
pimpl->tensor_split_owned.assign(params.tensor_split, params.tensor_split + llama_max_devices());
this->params.tensor_split = pimpl->tensor_split_owned.data();
}
pimpl->has_tensor_overrides = params.tensor_buft_overrides && params.tensor_buft_overrides[0].pattern;
}
@@ -1200,6 +1210,8 @@ void llama_model_base::load_hparams(llama_model_loader & ml) {
pimpl->desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name();
pimpl->ftype = ml.ftype;
if (hparams.f_max_alibi_bias > 0.0f) {
hparams.use_alibi = true;
}
@@ -1646,6 +1658,10 @@ std::string llama_model::desc() const {
return pimpl->desc_str;
}
llama_ftype llama_model::ftype() const {
return pimpl->ftype;
}
size_t llama_model::size() const {
return pimpl->n_bytes;
}
@@ -2616,6 +2632,10 @@ int32_t llama_model_desc(const llama_model * model, char * buf, size_t buf_size)
return snprintf(buf, buf_size, "%s", model->desc().c_str());
}
llama_ftype llama_model_ftype(const llama_model * model) {
return model->ftype();
}
uint64_t llama_model_size(const llama_model * model) {
return model->size();
}
+2
View File
@@ -637,6 +637,8 @@ struct llama_model {
std::string desc() const;
llama_ftype ftype() const;
size_t size() const; // file size
size_t n_tensors() const;
size_t n_devices() const;
+7
View File
@@ -8918,6 +8918,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
for (ggml_type type_a : { GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0 }) {
for (int dim : { 0, 1, 2, 3, }) {
test_cases.emplace_back(new test_concat(type_a, {128, 12, 13, 14}, dim == 0 ? 256 : 7, dim, 0));
}
}
for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) {
for (uint32_t i = 4; i <= 1024*1024; i *= 2) {
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {i-1, 1, 1, 1}));
@@ -9219,6 +9225,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({160, 4, 1, 1}, 160, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({288, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w)); // Used by StepFun 3.7
}
}
}
-1
View File
@@ -1887,7 +1887,6 @@ static void test_role_markers_all_templates(testing & t) {
{ "Qwen-Qwen3-0.6B.jinja", "<|im_start|>user", "<|im_start|>assistant" },
{ "Qwen-QwQ-32B.jinja", "<|im_start|>user", "<|im_start|>assistant" },
{ "StepFun3.5-Flash.jinja", "<|im_start|>user", "<|im_start|>assistant" },
{ "stepfun-ai-Step-3.5-Flash.jinja", "<|im_start|>user", "<|im_start|>assistant" },
// DeepSeek family
{ "deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja", "<User>", "<Assistant>" },
+53
View File
@@ -3155,6 +3155,59 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
}
}
}
{
// StepFun trimming regression test (see https://github.com/ggml-org/llama.cpp/pull/25238)
auto tmpls = read_templates("models/templates/StepFun3.5-Flash.jinja");
common_chat_msg message_chatbot = simple_assist_msg("Let me check.\n\n", "I am thinking.\n\n");
{
common_chat_templates_inputs inputs;
inputs.messages = { message_chatbot };
inputs.add_generation_prompt = true;
auto params = common_chat_templates_apply(tmpls.get(), inputs);
if (params.prompt.find("Let me check.\n\n") != std::string::npos) {
throw std::runtime_error("StepFun 3.5: content not trimmed");
}
if (params.prompt.find("I am thinking.\n\n") != std::string::npos) {
throw std::runtime_error("StepFun 3.5: reasoning_content not trimmed");
}
}
{
// Trimming must also reach typed (text) content parts, not just string content
// (see https://github.com/ggml-org/llama.cpp/pull/25238)
common_chat_msg message_parts;
message_parts.role = "user";
message_parts.content_parts = {
{ /* .type = */ "text", /* .text = */ "First part.\n\n" },
{ /* .type = */ "media_marker", /* .text = */ "<__media__>" },
{ /* .type = */ "text", /* .text = */ "Second part.\n\n" },
};
common_chat_templates_inputs inputs;
inputs.messages = { message_parts };
inputs.add_generation_prompt = true;
auto params = common_chat_templates_apply(tmpls.get(), inputs);
if (params.prompt.find("First part.\n\n") != std::string::npos ||
params.prompt.find("Second part.\n\n") != std::string::npos) {
throw std::runtime_error("StepFun 3.5: text content parts not trimmed");
}
// the trimmed text itself must still be present
if (params.prompt.find("First part.") == std::string::npos ||
params.prompt.find("Second part.") == std::string::npos) {
throw std::runtime_error("StepFun 3.5: text content parts missing after trim");
}
}
}
}
{
+3
View File
@@ -448,6 +448,9 @@ int llama_cli(int argc, char ** argv) {
console::log("%s\n", LLAMA_ASCII_LOGO);
console::log("build : %s\n", inf.build_info.c_str());
console::log("model : %s\n", inf.model_name.c_str());
if (!inf.model_ftype.empty()) {
console::log("ftype : %s\n", inf.model_ftype.c_str());
}
console::log("modalities : %s\n", modalities.c_str());
if (!params.system_prompt.empty()) {
console::log("using custom system prompt\n");
+2
View File
@@ -521,6 +521,8 @@ These words will not be included in the completion, so make sure to add them to
`return_progress`: Include prompt processing progress in `stream` mode. The progress will be contained inside `prompt_progress` with 4 values: `total`, `cache`, `processed`, and `time_ms`. The overall progress is `processed/total`, while the actual timed progress is `(processed-cache)/(total-cache)`. The `time_ms` field contains the elapsed time in milliseconds since prompt processing started. Default: `false`
`sse_ping_interval`: Interval in seconds between SSE comment pings emitted while the stream stays silent, keeping the connection observable during long prompt processing. Overrides the server `--sse-ping-interval` setting for this request, `-1` disables pings. Default: server setting
`post_sampling_probs`: Returns the probabilities of top `n_probs` tokens after applying sampling chain.
`response_fields`: A list of response fields, for example: `"response_fields": ["content", "generation_settings/n_predict"]`. If the specified field is missing, it will simply be omitted from the response without triggering an error. Note that fields with a slash will be unnested; for example, `generation_settings/n_predict` will move the field `n_predict` from the `generation_settings` object to the root of the response and give it a new name.
+10 -3
View File
@@ -3989,6 +3989,8 @@ server_context_meta server_context::get_meta() const {
auto bos_token_str = bos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_tgt, bos_id, true) : "";
auto eos_token_str = eos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_tgt, eos_id, true) : "";
const char * ftype_name = llama_ftype_name(llama_model_ftype(impl->model_tgt));
return server_context_meta {
/* build_info */ std::string(llama_build_info()),
/* model_name */ impl->model_name,
@@ -4023,6 +4025,7 @@ server_context_meta server_context::get_meta() const {
/* model_n_embd_inp */ llama_model_n_embd(impl->model_tgt),
/* model_n_params */ llama_model_n_params(impl->model_tgt),
/* model_size */ llama_model_size(impl->model_tgt),
/* model_ftype */ ftype_name,
};
}
@@ -4086,6 +4089,8 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
auto & rd = res->rd;
auto & params = this->params;
int32_t sse_ping_interval = params.sse_ping_interval;
try {
std::vector<server_task> tasks;
@@ -4136,6 +4141,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
task.params.message_spans = task.tokens.find_message_spans(delimiters);
task.id_slot = json_value(data, "id_slot", -1);
sse_ping_interval = task.params.sse_ping_interval;
// OAI-compat
task.params.res_type = res_type;
@@ -4225,7 +4231,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
}
res->status = 200;
res->content_type = "text/event-stream";
res->next = [res_this = res.get(), res_type, &req, &params](std::string & output) -> bool {
res->next = [res_this = res.get(), res_type, sse_ping_interval, &req](std::string & output) -> bool {
static auto format_error = [](task_response_type res_type, const json & res_json) {
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
return format_anthropic_sse({
@@ -4274,10 +4280,10 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
// receive subsequent results
bool timeout = false;
int64_t start_time = ggml_time_ms();
auto result = rd.next([&timeout, &start_time, &params, &effective_should_stop]() {
auto result = rd.next([&timeout, &start_time, sse_ping_interval, &effective_should_stop]() {
if (effective_should_stop()) {
return true; // should_stop condition met
} else if (params.sse_ping_interval > 0 && ggml_time_ms() - start_time > (int64_t)params.sse_ping_interval * 1000) {
} else if (sse_ping_interval > 0 && ggml_time_ms() - start_time > (int64_t)sse_ping_interval * 1000) {
timeout = true;
return true; // timeout
}
@@ -5118,6 +5124,7 @@ json server_routes::get_model_info() const {
{"n_embd", meta->model_n_embd_inp},
{"n_params", meta->model_n_params},
{"size", meta->model_size},
{"ftype", meta->model_ftype},
}},
};
}
+1
View File
@@ -50,6 +50,7 @@ struct server_context_meta {
int32_t model_n_embd_inp;
uint64_t model_n_params;
uint64_t model_size;
std::string model_ftype;
};
enum server_state {
+5
View File
@@ -37,6 +37,10 @@ std::vector<std::unique_ptr<field>> make_llama_cmpl_schema(const common_params &
add((new field_bool("return_progress", params.return_progress))
->set_desc("Include prompt processing progress events in stream mode"));
add((new field_num("sse_ping_interval", params.sse_ping_interval))
->set_hard_limits(-1, INT32_MAX)
->set_desc("Interval in seconds between SSE comment pings emitted while the stream stays silent, -1 disables pings"));
add((new field_num("n_predict", params.n_predict))
->set_hard_limits(-1, INT32_MAX)
->add_alias("max_completion_tokens")
@@ -504,6 +508,7 @@ task_params eval_llama_cmpl_schema(
params.n_cache_reuse = params_base.n_cache_reuse;
params.cache_prompt = params_base.cache_prompt;
params.antiprompt = params_base.antiprompt;
params.sse_ping_interval = params_base.sse_ping_interval;
// enabling this will output extra debug information in the HTTP responses from the server
params.verbose = params_base.verbosity > 9;
+2
View File
@@ -54,6 +54,8 @@ struct task_params {
bool return_tokens = false;
bool return_progress = false;
int32_t sse_ping_interval = 30; // seconds between SSE comment pings while the stream stays silent, -1 disables
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
int32_t n_predict = -1; // new tokens to predict
@@ -18,7 +18,7 @@
let mcpSearchQuery = $state('');
let allMcpServers = $derived(mcpStore.getServersSorted());
let mcpServers = $derived(allMcpServers.filter((s) => s.enabled));
let mcpServers = $derived(mcpStore.visibleMcpServers);
let hasMcpServers = $derived(mcpServers.length > 0);
// let hasAnyMcpServers = $derived(allMcpServers.length > 0);
let filteredMcpServers = $derived.by(() => {
@@ -74,9 +74,7 @@
const sheetItemRowClass =
'flex w-full items-center justify-between gap-2 rounded-md px-3 py-2 text-left text-sm transition-colors hover:bg-accent';
function getEnabledMcpServers() {
return mcpStore.getServersSorted().filter((s) => s.enabled);
}
let visibleMcpServers = $derived(mcpStore.visibleMcpServers);
</script>
<div class="flex items-center gap-1 {className}">
@@ -153,13 +151,13 @@
<span class="flex-1">MCP Servers</span>
<span class="text-xs text-muted-foreground">
{getEnabledMcpServers().length} server{getEnabledMcpServers().length !== 1 ? 's' : ''}
{visibleMcpServers.length} server{visibleMcpServers.length !== 1 ? 's' : ''}
</span>
</Collapsible.Trigger>
<Collapsible.Content>
<div class="flex flex-col gap-0.5 pl-4">
{#each getEnabledMcpServers() as server (server.id)}
{#each visibleMcpServers as server (server.id)}
{@const healthState = mcpStore.getHealthCheckState(server.id)}
{@const hasError = healthState.status === HealthCheckStatus.ERROR}
{@const displayName = mcpStore.getServerLabel(server)}
@@ -202,7 +200,7 @@
</button>
{/each}
{#if getEnabledMcpServers().length === 0}
{#if visibleMcpServers.length === 0}
<div class="px-3 py-2 text-center text-sm text-muted-foreground">
No MCP servers configured
</div>
@@ -43,7 +43,7 @@
assistantMessages: number;
messageTypes: string[];
} | null>(null);
let editedContent = $state(message.content);
let editedContent = $derived(message.content);
let rawEditContent = $derived.by(() => {
if (message.role !== MessageRole.ASSISTANT) return undefined;
@@ -1,8 +1,9 @@
<script lang="ts">
import { ChevronDown, ShieldQuestion } from '@lucide/svelte';
import { ChatMessageActionCard } from '$lib/components/app';
import { Button } from '$lib/components/ui/button';
import { Button, buttonVariants } from '$lib/components/ui/button';
import * as ButtonGroup from '$lib/components/ui/button-group';
import { cn } from '$lib/components/ui/utils';
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
import { ToolSource, ToolPermissionDecision } from '$lib/enums';
import { TOOL_SERVER_LABELS } from '$lib/constants';
@@ -19,25 +20,17 @@
<ChatMessageActionCard icon={ShieldQuestion}>
{#snippet message()}
Allow use of
<span class="font-semibold">{toolName}</span>
{#if serverLabel}
from <span class="font-semibold">{serverLabel}</span>
{/if}
?
Allow use of <span class="font-semibold">{toolName}</span>{#if serverLabel}
from <span class="font-semibold">{serverLabel}</span>{/if}?
{/snippet}
{#snippet actions()}
<DropdownMenu.Root>
<ButtonGroup.Root
class="overflow-hidden rounded-md bg-foreground text-white shadow-sm dark:bg-secondary dark:text-foreground"
>
<ButtonGroup.Root class="overflow-hidden rounded-md shadow-sm">
<Button
class="rounded-none! shadow-none!"
variant="secondary"
size="sm"
class="!rounded-r-none !shadow-none"
onclick={() => onDecision(ToolPermissionDecision.ONCE)}
>
Allow once
@@ -45,10 +38,14 @@
<ButtonGroup.Separator />
<DropdownMenu.Trigger>
<Button size="sm" class="rounded-none! !ps-2 shadow-none!">
<ChevronDown class="h-3.5 w-3.5" />
</Button>
<DropdownMenu.Trigger
class={cn(
buttonVariants({ variant: 'secondary', size: 'sm' }),
'inline-flex cursor-pointer items-center !rounded-l-none !shadow-none !px-2'
)}
aria-label="More allow options"
>
<ChevronDown class="h-3.5 w-3.5" />
</DropdownMenu.Trigger>
</ButtonGroup.Root>
@@ -76,12 +73,7 @@
</DropdownMenu.Content>
</DropdownMenu.Root>
<Button
variant="destructive"
size="sm"
class="text-destructive hover:text-destructive"
onclick={() => onDecision(ToolPermissionDecision.DENY)}
>
<Button variant="destructive" size="sm" onclick={() => onDecision(ToolPermissionDecision.DENY)}>
Deny
</Button>
{/snippet}
@@ -20,9 +20,9 @@
agenticInjectSteeringMessage
} from '$lib/stores/agentic.svelte';
import {
buildSiblingInfoMap,
copyToClipboard,
formatMessageForClipboard,
getMessageSiblings,
hasAgenticContent
} from '$lib/utils';
@@ -169,6 +169,8 @@
});
});
let siblingInfoByMessageId = $derived(buildSiblingInfoMap(allConversationMessages));
let displayMessages = $derived.by(() => {
if (!messages.length) {
return [];
@@ -223,18 +225,18 @@
}
}
const siblingInfo = getMessageSiblings(allConversationMessages, msg.id);
const siblingInfo = siblingInfoByMessageId.get(msg.id) ?? {
message: msg,
siblingIds: [msg.id],
currentIndex: 0,
totalSiblings: 1
};
result.push({
message: msg,
toolMessages,
isLastAssistantMessage: false,
siblingInfo: siblingInfo || {
message: msg,
siblingIds: [msg.id],
currentIndex: 0,
totalSiblings: 1
}
siblingInfo
});
}
@@ -4,7 +4,7 @@
import { McpServerForm } from '$lib/components/app/mcp';
import { mcpStore } from '$lib/stores/mcp.svelte';
import { conversationsStore } from '$lib/stores/conversations.svelte';
import { uuid } from '$lib/utils';
import { parseHeadersToArray, uuid } from '$lib/utils';
import { MCP_SERVER_ID_PREFIX } from '$lib/constants';
interface Props {
@@ -26,6 +26,10 @@
return 'Invalid URL format';
}
});
let newServerHeaderPairsValid = $derived(
parseHeadersToArray(newServerHeaders).every((p) => p.key.trim() && p.value.trim())
);
let canSave = $derived(!newServerUrlError && newServerHeaderPairsValid);
function handleOpenChange(value: boolean) {
if (!value) {
@@ -37,7 +41,7 @@
}
function saveNewServer() {
if (newServerUrlError) return;
if (!canSave) return;
const newServerId = uuid() ?? `${MCP_SERVER_ID_PREFIX}-${Date.now()}`;
@@ -52,6 +56,11 @@
handleOpenChange(false);
}
function handleSubmit(event: SubmitEvent) {
event.preventDefault();
saveNewServer();
}
</script>
<Dialog.Root {open} onOpenChange={handleOpenChange}>
@@ -60,29 +69,27 @@
<Dialog.Title>Add New Server</Dialog.Title>
</Dialog.Header>
<div class="space-y-4 py-4">
<McpServerForm
url={newServerUrl}
headers={newServerHeaders}
onUrlChange={(v) => (newServerUrl = v)}
onHeadersChange={(v) => (newServerHeaders = v)}
urlError={newServerUrl ? newServerUrlError : null}
id="new-server"
/>
</div>
<form onsubmit={handleSubmit} class="contents">
<div class="space-y-4 py-4">
<McpServerForm
url={newServerUrl}
headers={newServerHeaders}
onUrlChange={(v) => (newServerUrl = v)}
onHeadersChange={(v) => (newServerHeaders = v)}
urlError={newServerUrl ? newServerUrlError : null}
id="new-server"
/>
</div>
<Dialog.Footer>
<Button variant="secondary" size="sm" onclick={() => handleOpenChange(false)}>Cancel</Button>
<Dialog.Footer>
<Button variant="secondary" size="sm" onclick={() => handleOpenChange(false)}>
Cancel
</Button>
<Button
variant="default"
size="sm"
onclick={saveNewServer}
disabled={!!newServerUrlError}
aria-label="Save"
>
Add
</Button>
</Dialog.Footer>
<Button variant="default" size="sm" type="submit" disabled={!canSave} aria-label="Save">
Add
</Button>
</Dialog.Footer>
</form>
</Dialog.Content>
</Dialog.Root>
@@ -0,0 +1,180 @@
<script lang="ts">
import { Button } from '$lib/components/ui/button';
import * as Card from '$lib/components/ui/card';
import * as Dialog from '$lib/components/ui/dialog';
import { fly } from 'svelte/transition';
import { McpServerCardCompact, McpServerForm } from '$lib/components/app/mcp';
import { RECOMMENDED_MCP_SERVERS } from '$lib/constants';
import { conversationsStore } from '$lib/stores/conversations.svelte';
import { mcpStore } from '$lib/stores/mcp.svelte';
import { uuid } from '$lib/utils';
import { MCP_SERVERS_ADDED_TO_CHAT_LOCALSTORAGE_KEY, MCP_SERVER_ID_PREFIX } from '$lib/constants';
import type { MCPServerSettingsEntry } from '$lib/types';
import { Plus } from '@lucide/svelte';
interface Props {
open: boolean;
onOpenChange?: (open: boolean) => void;
}
let { open = $bindable(), onOpenChange }: Props = $props();
let selected = $state<Record<string, boolean>>(
Object.fromEntries(RECOMMENDED_MCP_SERVERS.map((server) => [server.id, false]))
);
let addedServers = $state<MCPServerSettingsEntry[]>([]);
let showAddForm = $state(false);
let newServerUrl = $state('');
let newServerHeaders = $state('');
let newServerUrlError = $derived.by(() => {
if (!newServerUrl.trim()) return 'URL is required';
try {
new URL(newServerUrl);
return null;
} catch {
return 'Invalid URL format';
}
});
function handleOpenChange(value: boolean) {
if (!value) {
showAddForm = false;
newServerUrl = '';
newServerHeaders = '';
addedServers = [];
localStorage.setItem(MCP_SERVERS_ADDED_TO_CHAT_LOCALSTORAGE_KEY, 'true');
}
open = value;
onOpenChange?.(value);
}
function resetAddForm() {
showAddForm = false;
newServerUrl = '';
newServerHeaders = '';
}
function enableSelected() {
localStorage.setItem(MCP_SERVERS_ADDED_TO_CHAT_LOCALSTORAGE_KEY, 'true');
for (const server of RECOMMENDED_MCP_SERVERS) {
if (selected[server.id]) {
const existing = mcpStore.getServerById(server.id);
if (existing) {
mcpStore.updateServer(server.id, { enabled: true });
} else {
mcpStore.addServer({
id: server.id,
enabled: true,
url: server.url,
name: server.name
});
}
conversationsStore.setMcpServerOverride(server.id, true);
}
}
handleOpenChange(false);
}
function saveNewServer() {
if (newServerUrlError) return;
const newServerId = uuid() ?? `${MCP_SERVER_ID_PREFIX}-${Date.now()}`;
localStorage.setItem(MCP_SERVERS_ADDED_TO_CHAT_LOCALSTORAGE_KEY, 'true');
const newServer = mcpStore.addServer({
id: newServerId,
enabled: true,
url: newServerUrl.trim(),
headers: newServerHeaders.trim() || undefined
});
conversationsStore.setMcpServerOverride(newServerId, true);
if (newServer) {
addedServers = [...addedServers, newServer];
}
resetAddForm();
}
</script>
<Dialog.Root bind:open onOpenChange={handleOpenChange}>
<Dialog.Content class="sm:max-w-lg">
<Dialog.Header>
<Dialog.Title>Do more with MCP</Dialog.Title>
<Dialog.Description>
Power-up your experience by adding tools, resources and more capabilities provided by MCP
servers.
</Dialog.Description>
</Dialog.Header>
<div class="max-h-[60vh] space-y-4 overflow-y-auto py-4" in:fly={{ y: 16, duration: 300 }}>
<h3 class="text-sm font-semibold">Quickly get started with</h3>
{#each RECOMMENDED_MCP_SERVERS as server (server.id)}
<McpServerCardCompact
{server}
enabled={selected[server.id]}
onToggle={(enabled) => (selected[server.id] = enabled)}
/>
{/each}
{#if addedServers.length > 0}
{#each addedServers as server (server.id)}
<McpServerCardCompact {server} enabled={true} />
{/each}
{/if}
{#if showAddForm}
<Card.Root class="gap-3! bg-muted/30 p-4">
<McpServerForm
url={newServerUrl}
headers={newServerHeaders}
onUrlChange={(v) => (newServerUrl = v)}
onHeadersChange={(v) => (newServerHeaders = v)}
urlError={newServerUrl ? newServerUrlError : null}
id="recommendation-new-server"
/>
<div class="flex justify-end gap-2 pt-2">
<Button variant="secondary" size="sm" onclick={resetAddForm}>Cancel</Button>
<Button
variant="default"
size="sm"
onclick={saveNewServer}
disabled={!!newServerUrlError}
aria-label="Save"
>
Add
</Button>
</div>
</Card.Root>
{:else}
<Card.Root class="gap-0 border-dashed bg-muted/30 p-0 transition-colors hover:bg-muted/50">
<button
type="button"
class="flex w-full items-center justify-center gap-2 rounded-lg p-6 text-sm text-muted-foreground transition-colors hover:text-foreground"
onclick={() => (showAddForm = true)}
aria-label="Add your own MCP server"
>
<Plus class="h-4 w-4" />
<span>Add your own server</span>
</button>
</Card.Root>
{/if}
</div>
<Dialog.Footer>
<Button variant="secondary" size="sm" onclick={() => handleOpenChange(false)}>Not now</Button>
<Button variant="default" size="sm" onclick={enableSelected}>Add selected</Button>
</Dialog.Footer>
</Dialog.Content>
</Dialog.Root>
@@ -18,6 +18,15 @@
*/
export { default as DialogMcpServerAddNew } from './DialogMcpServerAddNew.svelte';
/**
* **DialogMcpServerRecommendations** - Suggested MCP servers opt-in dialog
*
* Prompts the user to enable pre-defined recommended MCP servers on first launch.
* Shows one switch per suggested server and persists the choice as a per-chat
* override so the selected servers become available in conversations.
*/
export { default as DialogMcpServerRecommendations } from './DialogMcpServerRecommendations.svelte';
/**
* **DialogExportSettings** - Settings export dialog with sensitive data warning
*
@@ -1,4 +1,5 @@
<script lang="ts">
import { tick } from 'svelte';
import { Plus, Trash2 } from '@lucide/svelte';
import { Input } from '$lib/components/ui/input';
import {
@@ -33,8 +34,18 @@
sectionLabelOptional = true
}: Props = $props();
function addPair() {
// Pre-allocate the ref array so `bind:ref={keyInputRefs[index]}` never reads `undefined`
// for in-range indices; the $effect below keeps it in sync when `pairs` grows.
// svelte-ignore state_referenced_locally
let keyInputRefs: (HTMLInputElement | null)[] = $state(pairs.map(() => null));
async function addPair() {
// Capture the target index before mutating so deletions earlier in the
// list can't make keyInputRefs.length drift past the newly-appended row.
const newIndex = pairs.length;
onPairsChange([...pairs, { key: '', value: '' }]);
await tick();
keyInputRefs[newIndex]?.focus();
}
function removePair(index: number) {
@@ -76,6 +87,15 @@
newPairs[index] = { ...newPairs[index], value: trimmed };
onPairsChange(newPairs);
}
// Keep keyInputRefs aligned with pairs length so bind:ref never sees `undefined`.
// $effect.pre runs during traversal in tree order, before the {#each} block re-renders,
// so newly-appended items always have a defined slot when their binding is set up.
$effect.pre(() => {
while (keyInputRefs.length < pairs.length) {
keyInputRefs.push(null);
}
});
</script>
<div class={className}>
@@ -103,6 +123,7 @@
{#each pairs as pair, index (index)}
<div class="flex items-start gap-2">
<Input
bind:ref={keyInputRefs[index]}
type="text"
placeholder={keyPlaceholder}
value={pair.key}
@@ -163,7 +163,7 @@
{/if}
</div>
<div class="flex justify-between gap-4">
<div class="mt-auto flex justify-between gap-4">
{#if showSkeleton}
<Skeleton class="h-3 w-28" />
{:else if protocolVersion}
@@ -0,0 +1,156 @@
<script lang="ts">
import * as Card from '$lib/components/ui/card';
import { Badge } from '$lib/components/ui/badge';
import { Skeleton } from '$lib/components/ui/skeleton';
import { Switch } from '$lib/components/ui/switch';
import * as Tooltip from '$lib/components/ui/tooltip';
import { McpServerIdentity } from '$lib/components/app/mcp';
import { mcpStore } from '$lib/stores/mcp.svelte';
import { HealthCheckStatus } from '$lib/enums';
import type { MCPServerDisplayInfo, HealthCheckState, MCPServerSettingsEntry } from '$lib/types';
import { onMount } from 'svelte';
import { MCP_CARD_VISIBLE_TOOL_LIMIT, NEWLINE } from '$lib/constants';
interface Props {
server: MCPServerDisplayInfo & { description?: string };
enabled?: boolean;
onToggle?: (enabled: boolean) => void;
}
let { server, enabled = false, onToggle }: Props = $props();
onMount(() => {
const state = mcpStore.getHealthCheckState(server.id);
if (state.status === HealthCheckStatus.IDLE) {
mcpStore.runHealthCheck(server as MCPServerSettingsEntry).catch(() => {});
}
});
let healthState = $derived<HealthCheckState>(mcpStore.getHealthCheckState(server.id));
let displayName = $derived(mcpStore.getServerLabel(server));
let faviconUrl = $derived(mcpStore.getServerFavicon(server.id));
let isIdle = $derived(healthState.status === HealthCheckStatus.IDLE);
let isHealthChecking = $derived(healthState.status === HealthCheckStatus.CONNECTING);
let isError = $derived(healthState.status === HealthCheckStatus.ERROR);
let errorMessage = $derived(
healthState.status === HealthCheckStatus.ERROR ? healthState.message : undefined
);
let serverInfo = $derived(
healthState.status === HealthCheckStatus.SUCCESS ? healthState.serverInfo : undefined
);
let tools = $derived(healthState.status === HealthCheckStatus.SUCCESS ? healthState.tools : []);
let instructions = $derived(
healthState.status === HealthCheckStatus.SUCCESS ? healthState.instructions : undefined
);
let showSkeleton = $derived(isIdle || isHealthChecking);
// Curated descriptions get two lines; instructions fallback is one line so the
// compact card stays scannable.
let description = $derived.by(() => {
if (server.description) {
return { text: server.description, lines: 2 };
}
if (!instructions) return null;
const firstLine = instructions.split(NEWLINE).find((line: string) => line.trim().length > 0);
const trimmed = firstLine?.trim();
return trimmed ? { text: trimmed, lines: 1 } : null;
});
let visibleTools = $derived(tools.slice(0, MCP_CARD_VISIBLE_TOOL_LIMIT));
let hiddenTools = $derived(tools.slice(MCP_CARD_VISIBLE_TOOL_LIMIT));
let hiddenToolCount = $derived(hiddenTools.length);
function handleToggle(checked: boolean) {
onToggle?.(checked);
}
</script>
<Card.Root class="!gap-3 bg-muted/30 p-4">
<div class="flex items-start justify-between gap-3">
<div class="min-w-0 flex-1">
{#if showSkeleton}
<span class="flex min-w-0 items-center gap-1.5">
<Skeleton class="h-5 w-5 rounded" />
<Skeleton class="h-4 w-32" />
</span>
{:else}
<McpServerIdentity
{displayName}
{faviconUrl}
{serverInfo}
iconClass="h-5 w-5"
iconRounded="rounded"
nameClass="font-medium"
/>
{/if}
</div>
<Switch checked={enabled} disabled={isError || showSkeleton} onCheckedChange={handleToggle} />
</div>
{#if isError && errorMessage}
<p class="text-xs text-destructive">{errorMessage}</p>
{/if}
{#if showSkeleton}
<div class="space-y-1.5">
<Skeleton class="h-3 w-full max-w-md" />
</div>
<div class="flex flex-wrap items-center gap-1.5">
<Skeleton class="h-5 w-16 rounded-full" />
<Skeleton class="h-5 w-20 rounded-full" />
<Skeleton class="h-5 w-24 rounded-full" />
<Skeleton class="h-5 w-14 rounded-full" />
</div>
{:else}
{#if description}
{#if description.lines === 2}
<p class="line-clamp-2 text-xs text-muted-foreground" title={description.text}>
{description.text}
</p>
{:else}
<p class="line-clamp-1 truncate text-xs text-muted-foreground" title={description.text}>
{description.text}
</p>
{/if}
{/if}
{#if tools.length > 0}
<div class="flex flex-wrap items-center gap-1.5">
{#each visibleTools as tool (tool.name)}
<Tooltip.Root>
<Tooltip.Trigger>
<Badge variant="secondary" class="h-5 max-w-40 px-2 text-[11px]">
<span class="block min-w-0 flex-1 truncate">{tool.name}</span>
</Badge>
</Tooltip.Trigger>
<Tooltip.Content>
<p class="max-w-xs text-xs">
{tool.description ?? 'No description'}
</p>
</Tooltip.Content>
</Tooltip.Root>
{/each}
{#if hiddenToolCount > 0}
<Tooltip.Root>
<Tooltip.Trigger>
<Badge variant="secondary" class="h-5 px-2 text-[11px] text-muted-foreground">
+ {hiddenToolCount} more tools
</Badge>
</Tooltip.Trigger>
<Tooltip.Content class="max-w-md">
<p class="text-xs">
{hiddenTools.map((tool) => tool.name).join(', ')}
</p>
</Tooltip.Content>
</Tooltip.Root>
{/if}
</div>
{/if}
{/if}
</Card.Root>
@@ -1,6 +1,7 @@
<script lang="ts">
import { Button } from '$lib/components/ui/button';
import { McpServerForm } from '$lib/components/app/mcp';
import { parseHeadersToArray } from '$lib/utils';
interface Props {
serverId: string;
@@ -26,13 +27,21 @@
}
});
let canSave = $derived(!urlError);
let headerPairsValid = $derived(
parseHeadersToArray(editHeaders).every((p) => p.key.trim() && p.value.trim())
);
let canSave = $derived(!urlError && headerPairsValid);
function handleSave() {
if (!canSave) return;
onSave(editUrl.trim(), editHeaders.trim(), editUseProxy);
}
function handleSubmit(event: SubmitEvent) {
event.preventDefault();
handleSave();
}
export function setInitialValues(url: string, headers: string, useProxy: boolean) {
editUrl = url;
editHeaders = headers;
@@ -40,25 +49,27 @@
}
</script>
<div class="space-y-4">
<p class="font-medium">Configure Server</p>
<form onsubmit={handleSubmit} class="contents">
<div class="space-y-4">
<p class="font-medium">Configure Server</p>
<McpServerForm
url={editUrl}
headers={editHeaders}
useProxy={editUseProxy}
onUrlChange={(v) => (editUrl = v)}
onHeadersChange={(v) => (editHeaders = v)}
onUseProxyChange={(v) => (editUseProxy = v)}
urlError={editUrl ? urlError : null}
id={serverId}
/>
<McpServerForm
url={editUrl}
headers={editHeaders}
useProxy={editUseProxy}
onUrlChange={(v) => (editUrl = v)}
onHeadersChange={(v) => (editHeaders = v)}
onUseProxyChange={(v) => (editUseProxy = v)}
urlError={editUrl ? urlError : null}
id={serverId}
/>
<div class="flex items-center justify-end gap-2">
<Button variant="secondary" size="sm" onclick={onCancel}>Cancel</Button>
<div class="flex items-center justify-end gap-2">
<Button variant="secondary" size="sm" onclick={onCancel}>Cancel</Button>
<Button size="sm" onclick={handleSave} disabled={!canSave}>
{serverUrl.trim() ? 'Update' : 'Add'}
</Button>
<Button size="sm" type="submit" disabled={!canSave}>
{serverUrl.trim() ? 'Update' : 'Add'}
</Button>
</div>
</div>
</div>
</form>
@@ -38,14 +38,87 @@
let headerPairs = $derived<KeyValuePair[]>(parseHeadersToArray(headers));
const AUTHORIZATION_HEADER = 'Authorization';
const BEARER_PREFIX = 'Bearer ';
// Heuristic: this dedicated UI only owns Authorization headers that already
// carry a Bearer scheme. Anything else (e.g. Basic, raw tokens) stays in the
// KV section so the user can still edit those values verbatim.
const matchesAuthorizationKey = (key: string): boolean =>
key.trim().toLowerCase() === AUTHORIZATION_HEADER.toLowerCase();
const isBearerScheme = (value: string): boolean =>
value.trim().toLowerCase().startsWith(BEARER_PREFIX.toLowerCase());
const ownedByBearerUi = (p: KeyValuePair): boolean =>
matchesAuthorizationKey(p.key) && isBearerScheme(p.value);
let hasAuthorization = $derived(headerPairs.some(ownedByBearerUi));
let wantsAuthorization = $state(false);
let showAuthorization = $derived(hasAuthorization || wantsAuthorization);
let urlInput: HTMLInputElement | null = $state(null);
let bearerInput: HTMLInputElement | null = $state(null);
$effect(() => {
urlInput?.focus();
});
$effect(() => {
if (wantsAuthorization && bearerInput) {
bearerInput.focus();
}
});
let bearerToken = $derived.by(() => {
const auth = headerPairs.find(ownedByBearerUi);
if (!auth) return '';
return auth.value.trim().slice(BEARER_PREFIX.length).trim();
});
$effect(() => {
if (!headers.trim()) {
wantsAuthorization = false;
}
});
function updateHeaderPairs(newPairs: KeyValuePair[]) {
headerPairs = newPairs;
onHeadersChange(serializeHeaders(newPairs));
}
// The dedicated UI owns the Authorization slot end-to-end when the user
// engages it: any prior Authorization row (Bearer or otherwise) is replaced
// by exactly one { Authorization: "Bearer <token>" } entry. JSON's last-key
// behavior would otherwise pick one arbitrarily, so we strip first.
function updateBearerToken(token: string) {
const filtered = headerPairs.filter((p) => !matchesAuthorizationKey(p.key));
const trimmed = token.trim();
if (trimmed) {
filtered.push({ key: AUTHORIZATION_HEADER, value: `${BEARER_PREFIX}${trimmed}` });
}
updateHeaderPairs(filtered);
}
function setUseAuthorization(checked: boolean) {
wantsAuthorization = checked;
if (!checked) {
// Only drop the entry this UI owns; a non-Bearer Authorization row
// authored in the KV section must survive a toggle off untouched.
const filtered = headerPairs.filter((p) => !ownedByBearerUi(p));
updateHeaderPairs(filtered);
}
}
</script>
<div class="grid gap-3">
<div>
<div class="grid gap-2">
<div class="mb-4">
<label for="server-url-{id}" class="mb-2 block text-xs font-medium">
Server URL <span class="text-destructive">*</span>
</label>
@@ -57,50 +130,52 @@
value={url}
oninput={(e) => onUrlChange(e.currentTarget.value)}
class={urlError ? 'border-destructive' : ''}
bind:ref={urlInput}
/>
{#if urlError}
<p class="mt-1.5 text-xs text-destructive">{urlError}</p>
{/if}
{#if !isWebSocket && onUseProxyChange}
<label
class={[
'mt-3 flex items-start gap-2',
mcpStore.isProxyAvailable && 'cursor-pointer',
!mcpStore.isProxyAvailable && 'opacity-80'
]}
>
<Switch
class="mt-1"
id="use-proxy-{id}"
checked={useProxy}
disabled={!mcpStore.isProxyAvailable}
onCheckedChange={(checked) => onUseProxyChange?.(checked)}
/>
<span>
<span class="text-xs text-muted-foreground">Use llama-server proxy</span>
<br />
{#if !mcpStore.isProxyAvailable}
<span class="inline-flex gap-0.75 text-xs text-muted-foreground/60"
>(Run <pre>llama-server</pre>
with
<pre>{CLI_FLAGS.MCP_PROXY}</pre>
flag)</span
>
{/if}
</span>
</label>
{/if}
</div>
<label class="flex items-center gap-2 cursor-pointer">
<Switch
id="use-authorization-{id}"
checked={showAuthorization}
onCheckedChange={setUseAuthorization}
/>
<span class="text-xs text-muted-foreground">Authorization</span>
</label>
{#if showAuthorization}
<div class="relative mt-2">
<Input
id="bearer-token-{id}"
type="password"
autocomplete="off"
placeholder="Paste token here"
value={bearerToken}
oninput={(e) => updateBearerToken(e.currentTarget.value)}
class="pl-16"
bind:ref={bearerInput}
/>
<span
class="pointer-events-none absolute inset-y-0 left-3 flex items-center text-sm font-medium text-foreground"
>
Bearer
</span>
</div>
{/if}
<KeyValuePairs
class="mt-2"
pairs={headerPairs}
onPairsChange={updateHeaderPairs}
class="mt-3"
pairs={headerPairs.filter((p) => !ownedByBearerUi(p))}
onPairsChange={(pairs) => {
const auth = headerPairs.find(ownedByBearerUi);
updateHeaderPairs(auth ? [...pairs, auth] : pairs);
}}
keyPlaceholder="Header name"
valuePlaceholder="Value"
addButtonLabel="Add"
@@ -108,4 +183,37 @@
sectionLabel="Custom Headers"
sectionLabelOptional
/>
{#if !isWebSocket && onUseProxyChange}
<label
class={[
'mt-3 flex items-start gap-2',
mcpStore.isProxyAvailable && 'cursor-pointer',
!mcpStore.isProxyAvailable && 'opacity-80'
]}
>
<Switch
class="mt-1"
id="use-proxy-{id}"
checked={useProxy}
disabled={!mcpStore.isProxyAvailable}
onCheckedChange={(checked) => onUseProxyChange?.(checked)}
/>
<span>
<span class="text-xs text-muted-foreground">Use llama-server proxy</span>
<br />
{#if !mcpStore.isProxyAvailable}
<span class="inline-flex gap-0.75 text-xs text-muted-foreground/60"
>(Run <pre>llama-server</pre>
with
<pre>{CLI_FLAGS.MCP_PROXY}</pre>
flag)</span
>
{/if}
</span>
</label>
{/if}
</div>
@@ -1,6 +1,7 @@
<script lang="ts">
import { ExternalLink } from '@lucide/svelte';
import { Badge } from '$lib/components/ui/badge';
import { McpLogo } from '$lib/components/app/mcp';
import { TruncatedText } from '$lib/components/app/misc';
import { sanitizeExternalUrl } from '$lib/utils';
import type { MCPServerInfo } from '$lib/types';
@@ -34,20 +35,15 @@
<span class="flex min-w-0 items-center gap-1.5">
{#if faviconUrl}
<img
src={faviconUrl}
alt=""
class={['shrink-0', iconRounded, iconClass]}
onerror={(e) => {
(e.currentTarget as HTMLImageElement).style.display = 'none';
}}
/>
<img src={faviconUrl} alt="" class={['shrink-0 text-foreground', iconRounded, iconClass]} />
{:else}
<McpLogo class={['shrink-0 text-foreground', iconRounded, iconClass].join(' ')} />
{/if}
<TruncatedText text={displayName ?? ''} class={nameClass ?? ''} />
{#if showVersion && serverInfo?.version}
<Badge variant="secondary" class="h-4 min-w-0 shrink px-1 text-[10px]">
<Badge variant="secondary" class="h-4 max-w-24 min-w-0 shrink px-1 text-[10px]">
<TruncatedText text={`v${serverInfo.version}`} />
</Badge>
{/if}
@@ -180,6 +180,16 @@ export { default as McpServerCardDeleteDialog } from './McpServerCard/McpServerC
/** Skeleton loading state for server card during health checks. */
export { default as McpServerCardSkeleton } from './McpServerCardSkeleton.svelte';
/**
* **McpServerCardCompact** - Condensed MCP server card
*
* Compact alternative to McpServerCard tailored for picker-style UIs.
* Shows the server identity, status, and a flex-wrapped list of available tools.
* Tool names are rendered as badges; hovering a badge shows its description in a tooltip.
* Does not show connection logs or server instructions.
*/
export { default as McpServerCardCompact } from './McpServerCard/McpServerCardCompact.svelte';
/**
* **McpServerIdentity** - Server identity display (icon, name, version)
*
@@ -21,7 +21,7 @@
let { class: className }: Props = $props();
let servers = $derived(mcpStore.getServersSorted());
let servers = $derived(mcpStore.visibleMcpServers);
let initialLoadComplete = $state(false);
let isAddingServer = $state(false);
+1
View File
@@ -8,6 +8,7 @@ export * from './attachment-labels';
export * from './database';
export * from './reasoning-effort';
export * from './reasoning-effort-tokens';
export * from './recommended-mcp-servers';
export * from './storage';
export * from './attachment-menu';
export * from './auto-scroll';
+2
View File
@@ -1,2 +1,4 @@
export const MCP_SERVER_URL_PLACEHOLDER = 'https://mcp.example.com/sse';
export const MIN_AUTOCOMPLETE_INPUT_LENGTH = 1;
/** Number of tools shown on the compact MCP server card before collapsing to a "+ N more" badge */
export const MCP_CARD_VISIBLE_TOOL_LIMIT = 4;
+5
View File
@@ -37,3 +37,8 @@ export const MODEL_ACTIVATED_PARAMS_RE = /^[Aa]\d+(\.\d+)?[BbMmKkTt]$/;
* Container format segments to exclude from tags (every model uses these).
*/
export const MODEL_IGNORED_SEGMENTS = new Set(['GGUF', 'GGML']);
/**
* Matches a trailing weight file extension, e.g. `model.gguf` -> `model`.
*/
export const MODEL_WEIGHT_EXTENSION_RE = /\.(gguf|ggml)$/i;
@@ -0,0 +1,35 @@
import { DEFAULT_MCP_CONFIG } from './mcp';
import type { RecommendedMCPServer } from '$lib/types';
/**
* Pre-defined recommended MCP servers.
*
* Servers are enabled by default, but they are not turned on for individual
* conversations until the user explicitly enables them (so their tools are
* disabled by default).
*/
export const RECOMMENDED_MCP_SERVERS: RecommendedMCPServer[] = [
{
id: 'exa-web-search',
name: 'Exa Web Search',
description: 'Search the web and retrieve relevant content.',
url: 'https://mcp.exa.ai/mcp',
enabled: true,
requestTimeoutSeconds: DEFAULT_MCP_CONFIG.requestTimeoutSeconds
},
{
id: 'huggingface-mcp',
name: 'Hugging Face',
description:
'Browse models, datasets, spaces and machine learning papers from the Hugging Face hub.',
url: 'https://huggingface.co/mcp',
enabled: true,
requestTimeoutSeconds: DEFAULT_MCP_CONFIG.requestTimeoutSeconds
}
];
export const RECOMMENDED_MCP_SERVER_IDS = new Set(
RECOMMENDED_MCP_SERVERS.map((server) => server.id)
);
export const RECOMMENDED_MCP_SERVERS_OPTIN_DIALOG_DELAY = 1000;
+1 -1
View File
@@ -59,6 +59,7 @@ export const SETTINGS_KEYS = {
// MCP
MCP_SERVERS: 'mcpServers',
MCP_REQUEST_TIMEOUT_SECONDS: 'mcpRequestTimeoutSeconds',
MCP_DEFAULT_SERVER_OVERRIDES: 'mcpDefaultServerOverrides',
AGENTIC_MAX_TURNS: 'agenticMaxTurns',
ALWAYS_SHOW_AGENTIC_TURNS: 'alwaysShowAgenticTurns',
AGENTIC_MAX_TOOL_PREVIEW_LINES: 'agenticMaxToolPreviewLines',
@@ -68,7 +69,6 @@ export const SETTINGS_KEYS = {
// Developer
DISABLE_REASONING_PARSING: 'disableReasoningParsing',
EXCLUDE_REASONING_FROM_CONTEXT: 'excludeReasoningFromContext',
ENABLE_THINKING: 'enableThinking',
SHOW_RAW_OUTPUT_SWITCH: 'showRawOutputSwitch',
// PY_INTERPRETER_ENABLED: 'pyInterpreterEnabled',
JS_SANDBOX_ENABLED: 'jsSandboxEnabled',
+49 -17
View File
@@ -28,6 +28,7 @@ import McpLogo from '$lib/components/app/mcp/McpLogo.svelte';
import { SETTINGS_KEYS } from './settings-keys';
import { ROUTES, SETTINGS_SECTION_SLUGS } from './routes';
import { TITLE_GENERATION } from './title-generation';
import { RECOMMENDED_MCP_SERVERS } from './recommended-mcp-servers';
export const SETTINGS_SECTION_TITLES = {
GENERAL: 'General',
@@ -184,7 +185,11 @@ const SETTINGS_REGISTRY: Record<string, SettingsSectionEntry> = {
defaultValue: false,
type: SettingsFieldType.CHECKBOX,
section: SETTINGS_SECTION_SLUGS.GENERAL,
isExperimental: true
isExperimental: true,
sync: {
serverKey: SETTINGS_KEYS.TITLE_GENERATION_USE_LLM,
paramType: SyncableParameterType.BOOLEAN
}
},
{
key: SETTINGS_KEYS.TITLE_GENERATION_PROMPT,
@@ -192,7 +197,11 @@ const SETTINGS_REGISTRY: Record<string, SettingsSectionEntry> = {
help: 'Optional template for the title generation prompt. Use {{USER}} for the user message and {{ASSISTANT}} for the assistant message.',
defaultValue: TITLE_GENERATION.DEFAULT_PROMPT,
type: SettingsFieldType.TEXTAREA,
section: SETTINGS_SECTION_SLUGS.GENERAL
section: SETTINGS_SECTION_SLUGS.GENERAL,
sync: {
serverKey: SETTINGS_KEYS.TITLE_GENERATION_PROMPT,
paramType: SyncableParameterType.STRING
}
},
{
key: SETTINGS_KEYS.MAX_IMAGE_RESOLUTION,
@@ -200,7 +209,11 @@ const SETTINGS_REGISTRY: Record<string, SettingsSectionEntry> = {
help: 'Images larger than this will be resized before sending to server. Set to 0 to disable.',
defaultValue: 0,
type: SettingsFieldType.INPUT,
section: SETTINGS_SECTION_SLUGS.GENERAL
section: SETTINGS_SECTION_SLUGS.GENERAL,
sync: {
serverKey: SETTINGS_KEYS.MAX_IMAGE_RESOLUTION,
paramType: SyncableParameterType.NUMBER
}
}
]
},
@@ -384,7 +397,11 @@ const SETTINGS_REGISTRY: Record<string, SettingsSectionEntry> = {
help: 'Display the current build version in the bottom-right corner of the interface.',
defaultValue: false,
type: SettingsFieldType.CHECKBOX,
section: SETTINGS_SECTION_SLUGS.DISPLAY
section: SETTINGS_SECTION_SLUGS.DISPLAY,
sync: {
serverKey: SETTINGS_KEYS.SHOW_BUILD_VERSION,
paramType: SyncableParameterType.BOOLEAN
}
}
]
},
@@ -668,7 +685,11 @@ const SETTINGS_REGISTRY: Record<string, SettingsSectionEntry> = {
help: 'After each response, re-submit the conversation to pre-fill the server KV cache. Makes the next turn faster since the prompt is already encoded while you read the response.',
defaultValue: false,
type: SettingsFieldType.CHECKBOX,
section: SETTINGS_SECTION_SLUGS.DEVELOPER
section: SETTINGS_SECTION_SLUGS.DEVELOPER,
sync: {
serverKey: SETTINGS_KEYS.PRE_ENCODE_CONVERSATION,
paramType: SyncableParameterType.BOOLEAN
}
},
{
key: SETTINGS_KEYS.DISABLE_REASONING_PARSING,
@@ -676,7 +697,11 @@ const SETTINGS_REGISTRY: Record<string, SettingsSectionEntry> = {
help: 'Send reasoning_format=none so the server returns thinking tokens inline instead of extracting them into a separate field.',
defaultValue: false,
type: SettingsFieldType.CHECKBOX,
section: SETTINGS_SECTION_SLUGS.DEVELOPER
section: SETTINGS_SECTION_SLUGS.DEVELOPER,
sync: {
serverKey: SETTINGS_KEYS.DISABLE_REASONING_PARSING,
paramType: SyncableParameterType.BOOLEAN
}
},
{
key: SETTINGS_KEYS.EXCLUDE_REASONING_FROM_CONTEXT,
@@ -690,14 +715,6 @@ const SETTINGS_REGISTRY: Record<string, SettingsSectionEntry> = {
paramType: SyncableParameterType.BOOLEAN
}
},
{
key: SETTINGS_KEYS.ENABLE_THINKING,
label: 'Enable thinking',
help: 'Enable model thinking/reasoning for each request. When off, the model will skip the thinking phase and go straight to the response.',
defaultValue: false,
type: SettingsFieldType.CHECKBOX,
section: SETTINGS_SECTION_SLUGS.DEVELOPER
},
{
key: SETTINGS_KEYS.SHOW_RAW_OUTPUT_SWITCH,
label: 'Enable raw output toggle',
@@ -716,7 +733,11 @@ const SETTINGS_REGISTRY: Record<string, SettingsSectionEntry> = {
help: 'Expose a run_javascript tool to the model. Code runs in a Web Worker inside a sandboxed iframe with an opaque origin, isolated from the WebUI and its API, with a hard timeout.',
defaultValue: false,
type: SettingsFieldType.CHECKBOX,
section: SETTINGS_SECTION_SLUGS.DEVELOPER
section: SETTINGS_SECTION_SLUGS.DEVELOPER,
sync: {
serverKey: SETTINGS_KEYS.JS_SANDBOX_ENABLED,
paramType: SyncableParameterType.BOOLEAN
}
},
{
key: SETTINGS_KEYS.CUSTOM_JSON,
@@ -752,7 +773,11 @@ const SETTINGS_REGISTRY: Record<string, SettingsSectionEntry> = {
defaultValue: DEFAULT_MCP_CONFIG.requestTimeoutSeconds,
type: SettingsFieldType.INPUT,
section: SETTINGS_SECTION_SLUGS.MCP,
isPositiveInteger: true
isPositiveInteger: true,
sync: {
serverKey: SETTINGS_KEYS.MCP_REQUEST_TIMEOUT_SECONDS,
paramType: SyncableParameterType.NUMBER
}
}
]
}
@@ -774,9 +799,16 @@ const NON_UI_SETTINGS: SettingsEntry[] = [
key: SETTINGS_KEYS.MCP_SERVERS,
label: 'MCP servers',
help: 'Configure MCP servers as a JSON list. Use the form in the MCP Client settings section to edit.',
defaultValue: '[]',
defaultValue: JSON.stringify(RECOMMENDED_MCP_SERVERS),
type: SettingsFieldType.INPUT,
sync: { serverKey: SETTINGS_KEYS.MCP_SERVERS, paramType: SyncableParameterType.STRING }
},
{
key: SETTINGS_KEYS.MCP_DEFAULT_SERVER_OVERRIDES,
label: 'MCP default server overrides',
help: 'Per-server enable/disable defaults inherited by new chats. JSON-serialized list of {serverId, enabled} entries.',
defaultValue: '[]',
type: SettingsFieldType.INPUT
}
// {
// key: SETTINGS_KEYS.PY_INTERPRETER_ENABLED,
+2 -4
View File
@@ -21,9 +21,10 @@ export const DISABLED_TOOLS_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.disabledTool
/** Disabled tools keyed by stable selection identity, no migration from the name based key */
export const DISABLED_TOOL_KEYS_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.disabledToolKeys`;
export const FAVORITE_MODELS_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.favoriteModels`;
export const MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.mcpDefaultEnabled`;
export const THINKING_ENABLED_DEFAULT_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.thinkingEnabledDefault`;
export const REASONING_EFFORT_DEFAULT_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.reasoningEffortDefault`;
/** Set when user has interacted with the MCP server recommendations dialog (checked servers, added custom server, or dismissed) */
export const MCP_SERVERS_ADDED_TO_CHAT_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.mcpServersSetupDone`;
export const USER_OVERRIDES_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.userOverrides`;
/** Key prefix for per-conversation resumable stream state, conversationId is appended */
@@ -38,8 +39,6 @@ export const DEPRECATED_CONFIG_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME_DEPRECATED
export const DEPRECATED_DISABLED_TOOLS_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME_DEPRECATED}.disabledTools`;
/** @deprecated Use {@link FAVORITE_MODELS_LOCALSTORAGE_KEY} instead */
export const DEPRECATED_FAVORITE_MODELS_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME_DEPRECATED}.favoriteModels`;
/** @deprecated Use {@link MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY} instead */
export const DEPRECATED_MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME_DEPRECATED}.mcpDefaultEnabled`;
/** @deprecated Use {@link USER_OVERRIDES_LOCALSTORAGE_KEY} instead */
export const DEPRECATED_USER_OVERRIDES_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME_DEPRECATED}.userOverrides`;
@@ -52,6 +51,5 @@ export const NEW_TO_DEPRECATED_MAP: Record<string, string> = {
[CONFIG_LOCALSTORAGE_KEY]: DEPRECATED_CONFIG_LOCALSTORAGE_KEY,
[DISABLED_TOOLS_LOCALSTORAGE_KEY]: DEPRECATED_DISABLED_TOOLS_LOCALSTORAGE_KEY,
[FAVORITE_MODELS_LOCALSTORAGE_KEY]: DEPRECATED_FAVORITE_MODELS_LOCALSTORAGE_KEY,
[MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY]: DEPRECATED_MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY,
[USER_OVERRIDES_LOCALSTORAGE_KEY]: DEPRECATED_USER_OVERRIDES_LOCALSTORAGE_KEY
};
+1 -1
View File
@@ -1,3 +1,3 @@
// grace window after a visibilitychange before we kick a reader whose socket likely died
// while the tab was hidden. covers brief background pauses without thrashing live streams
export const STREAM_VISIBILITY_KICK_MS = 1000;
export const STREAM_VISIBILITY_KICK_MS = 3000;
@@ -0,0 +1,85 @@
import { browser } from '$app/environment';
import {
MCP_SERVERS_ADDED_TO_CHAT_LOCALSTORAGE_KEY,
RECOMMENDED_MCP_SERVER_IDS,
RECOMMENDED_MCP_SERVERS_OPTIN_DIALOG_DELAY
} from '$lib/constants';
import { mcpStore } from '$lib/stores/mcp.svelte';
/**
* First-run opt-in dialog for the recommended MCP servers.
*
* Owns the dismissed / open / trigger-timeout state and the effect that
* schedules the dialog. Reads opt-in status and the configured server list
* from `mcpStore`, so callers don't need to recompute on their side.
*/
export function useMcpRecommendations() {
let dismissed = $state(
browser && localStorage.getItem(MCP_SERVERS_ADDED_TO_CHAT_LOCALSTORAGE_KEY) === 'true'
);
let open = $state(false);
let checked = $state(false);
let triggerTimeout: ReturnType<typeof setTimeout> | null = null;
function dismiss() {
if (browser) {
localStorage.setItem(MCP_SERVERS_ADDED_TO_CHAT_LOCALSTORAGE_KEY, 'true');
}
dismissed = true;
open = false;
if (triggerTimeout) {
clearTimeout(triggerTimeout);
triggerTimeout = null;
}
}
function handleOpenChange(next: boolean) {
open = next;
if (!next) dismiss();
}
$effect(() => {
if (!browser) return;
if (open || dismissed) {
if (triggerTimeout) {
clearTimeout(triggerTimeout);
triggerTimeout = null;
}
return;
}
// Already evaluated once this session; leave any pending trigger alone so
// it can still fire later. Setting `checked = true` below re-runs this
// effect, and we must not wipe the timeout that was just scheduled.
if (checked) return;
if (mcpStore.optedInRecommendationIds.size > 0) {
checked = true;
return;
}
const hasRecommendations = mcpStore
.getServers()
.some((server) => RECOMMENDED_MCP_SERVER_IDS.has(server.id));
if (hasRecommendations) {
triggerTimeout = setTimeout(() => {
open = true;
}, RECOMMENDED_MCP_SERVERS_OPTIN_DIALOG_DELAY);
}
checked = true;
});
return {
get open() {
return open;
},
get dismissed() {
return dismissed;
},
dismiss,
handleOpenChange
};
}
@@ -255,6 +255,7 @@ export class ChatService {
}),
stream,
return_progress: stream ? true : undefined,
sse_ping_interval: stream ? 1 : undefined,
tools: tools && tools.length > 0 ? tools : undefined
};
+95 -1
View File
@@ -20,6 +20,7 @@
import Dexie from 'dexie';
import {
STORAGE_APP_NAME,
STORAGE_APP_NAME_DEPRECATED,
DB_APP_NAME_DEPRECATED,
CONFIG_LOCALSTORAGE_KEY,
IDXDB_TABLES,
@@ -494,12 +495,105 @@ const customJsonKeyMigration: Migration = {
}
};
const MCP_DEFAULT_ENABLED_MIGRATION_ID = 'mcp-default-enabled-to-config-v1';
const LEGACY_MCP_DEFAULT_ENABLED_KEY = `${STORAGE_APP_NAME}.mcpDefaultEnabled`;
const DEPRECATED_LEGACY_MCP_DEFAULT_ENABLED_KEY = `${STORAGE_APP_NAME_DEPRECATED}.mcpDefaultEnabled`;
const mcpDefaultEnabledMigration: Migration = {
id: MCP_DEFAULT_ENABLED_MIGRATION_ID,
description:
'Copy mcpDefaultEnabled localStorage key into settings config (preserves legacy keys)',
async run(): Promise<void> {
const raw =
localStorage.getItem(LEGACY_MCP_DEFAULT_ENABLED_KEY) ??
localStorage.getItem(DEPRECATED_LEGACY_MCP_DEFAULT_ENABLED_KEY);
// Legacy keys intentionally left in place so a downgrade keeps reading them.
if (raw === null) {
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
console.log('[Migration] MCP default enabled: no legacy key found, skipping');
return;
}
const configRaw = localStorage.getItem(CONFIG_LOCALSTORAGE_KEY);
const config = configRaw ? JSON.parse(configRaw) : {};
// Don't overwrite an existing config entry — current data wins.
if (SETTINGS_KEYS.MCP_DEFAULT_SERVER_OVERRIDES in config) {
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
console.log('[Migration] MCP default enabled: config already has overrides, skipping');
return;
}
try {
const parsed = JSON.parse(raw);
if (!Array.isArray(parsed)) return;
const valid = parsed.every(
(o) =>
typeof o === 'object' &&
o !== null &&
typeof (o as Record<string, unknown>).serverId === 'string' &&
typeof (o as Record<string, unknown>).enabled === 'boolean'
);
if (!valid) return;
} catch {
return;
}
config[SETTINGS_KEYS.MCP_DEFAULT_SERVER_OVERRIDES] = raw;
localStorage.setItem(CONFIG_LOCALSTORAGE_KEY, JSON.stringify(config));
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
console.log('[Migration] MCP default enabled: moved legacy key into config');
}
};
const CONFIG_TYPES_MIGRATION_ID = 'config-type-normalization-v1';
const configTypesMigration: Migration = {
id: CONFIG_TYPES_MIGRATION_ID,
description: 'Coerce legacy string-encoded booleans in persisted config to real booleans',
async run(): Promise<void> {
const configRaw = localStorage.getItem(CONFIG_LOCALSTORAGE_KEY);
if (configRaw === null) return;
const config = JSON.parse(configRaw);
let changed = false;
// Pre-schema configs persisted booleans as the strings "true"/"false", which the
// strict server schema now rejects. Coerce those back to real booleans. No config
// string field holds exactly "true"/"false", so the match is unambiguous.
for (const key of Object.keys(config)) {
if (config[key] === 'true') {
config[key] = true;
changed = true;
} else if (config[key] === 'false') {
config[key] = false;
changed = true;
}
}
if (changed) {
localStorage.setItem(CONFIG_LOCALSTORAGE_KEY, JSON.stringify(config));
}
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
console.log(`[Migration] Config types: coerced string booleans (changed=${changed})`);
}
};
const migrations: Migration[] = [
localStorageMigration,
idxdbMigration,
legacyMessageMigration,
themeMigration,
customJsonKeyMigration
customJsonKeyMigration,
mcpDefaultEnabledMigration,
configTypesMigration
];
export const MigrationService = {
+10 -5
View File
@@ -1,5 +1,5 @@
import { ServerModelStatus } from '$lib/enums';
import { apiFetch, apiPost } from '$lib/utils';
import { apiFetch, apiPost, normalizeModelName } from '$lib/utils';
import type { ParsedModelId } from '$lib/types/models';
import {
MODEL_QUANTIZATION_SEGMENT_RE,
@@ -7,6 +7,7 @@ import {
MODEL_PARAMS_RE,
MODEL_ACTIVATED_PARAMS_RE,
MODEL_IGNORED_SEGMENTS,
MODEL_WEIGHT_EXTENSION_RE,
MODEL_ID_NOT_FOUND,
MODEL_ID_ORG_SEPARATOR,
MODEL_ID_SEGMENT_SEPARATOR,
@@ -139,15 +140,19 @@ export class ModelsService {
tags: []
};
// strip directory path and weight extension so a bare `-m /path/file.gguf`
// parses like a clean repo id; the HF `org/model` form is preserved
const source = normalizeModelName(modelId).replace(MODEL_WEIGHT_EXTENSION_RE, '');
// 1. Extract colon-separated quantization (e.g. `model:Q4_K_M`)
const colonIdx = modelId.indexOf(MODEL_ID_QUANTIZATION_SEPARATOR);
const colonIdx = source.indexOf(MODEL_ID_QUANTIZATION_SEPARATOR);
let modelPath: string;
if (colonIdx !== MODEL_ID_NOT_FOUND) {
result.quantization = modelId.slice(colonIdx + 1) || null;
modelPath = modelId.slice(0, colonIdx);
result.quantization = source.slice(colonIdx + 1) || null;
modelPath = source.slice(0, colonIdx);
} else {
modelPath = modelId;
modelPath = source;
}
// 2. Extract org name (e.g. `org/model` -> org = "org")
+14 -19
View File
@@ -23,7 +23,7 @@ import { browser } from '$app/environment';
import { toast } from 'svelte-sonner';
import { DatabaseService } from '$lib/services/database.service';
import { MigrationService } from '$lib/services/migration.service';
import { config } from '$lib/stores/settings.svelte';
import { config, settingsStore } from '$lib/stores/settings.svelte';
import { filterByLeafNodeId, findLeafNode, generateConversationTitle } from '$lib/utils';
import type { McpServerOverride } from '$lib/types/database';
import { zipSync, unzipSync, strToU8, strFromU8 } from 'fflate';
@@ -46,7 +46,7 @@ import {
ISO_TIME_SEPARATOR_REPLACEMENT,
NON_ALPHANUMERIC_REGEX,
MULTIPLE_UNDERSCORE_REGEX,
MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY,
SETTINGS_KEYS,
THINKING_ENABLED_DEFAULT_LOCALSTORAGE_KEY,
REASONING_EFFORT_DEFAULT_LOCALSTORAGE_KEY
} from '$lib/constants';
@@ -90,12 +90,10 @@ class ConversationsStore {
/** Global (non-conversation-specific) reasoning effort default */
pendingReasoningEffort = $state<ReasoningEffort>(ConversationsStore.loadReasoningEffortDefault());
/** Load MCP default overrides from localStorage */
private static loadMcpDefaults(): McpServerOverride[] {
if (typeof globalThis.localStorage === 'undefined') return [];
const raw = config()[SETTINGS_KEYS.MCP_DEFAULT_SERVER_OVERRIDES];
if (typeof raw !== 'string' || raw.length === 0) return [];
try {
const raw = localStorage.getItem(MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY);
if (!raw) return [];
const parsed = JSON.parse(raw);
if (!Array.isArray(parsed)) return [];
return parsed.filter(
@@ -106,30 +104,23 @@ class ConversationsStore {
}
}
/** Persist MCP default overrides to localStorage */
private saveMcpDefaults(): void {
if (typeof globalThis.localStorage === 'undefined') return;
const plain = this.pendingMcpServerOverrides.map((o) => ({
serverId: o.serverId,
enabled: o.enabled
}));
if (plain.length > 0) {
localStorage.setItem(MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY, JSON.stringify(plain));
} else {
localStorage.removeItem(MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY);
}
settingsStore.updateConfig(SETTINGS_KEYS.MCP_DEFAULT_SERVER_OVERRIDES, JSON.stringify(plain));
}
/** Load thinking-enabled default from localStorage */
private static loadThinkingDefaults(): boolean {
if (typeof globalThis.localStorage === 'undefined') return false;
if (typeof globalThis.localStorage === 'undefined') return true;
try {
const raw = localStorage.getItem(THINKING_ENABLED_DEFAULT_LOCALSTORAGE_KEY);
if (!raw) return false;
const parsed = raw === 'true';
return typeof parsed === 'boolean' ? parsed : false;
if (!raw) return true;
return raw === 'true';
} catch {
return false;
return true;
}
}
@@ -189,6 +180,10 @@ class ConversationsStore {
try {
await MigrationService.runAllMigrations();
// Re-read defaults after migrations: a migration may have populated
// the settings config (e.g. moved legacy MCP overrides into it).
this.pendingMcpServerOverrides = ConversationsStore.loadMcpDefaults();
await this.loadConversations();
this.isInitialized = true;
} catch (error) {
@@ -337,7 +332,7 @@ class ConversationsStore {
}
this.pendingMcpServerOverrides = [];
this.pendingThinkingEnabled = false;
this.pendingThinkingEnabled = ConversationsStore.loadThinkingDefaults();
this.activeConversation = conversation;
if (conversation.currNode) {
+36 -4
View File
@@ -20,11 +20,13 @@
*/
import { browser } from '$app/environment';
import { SvelteSet } from 'svelte/reactivity';
import { SETTINGS_KEYS } from '$lib/constants';
import { MCPService } from '$lib/services/mcp.service';
import { config, settingsStore } from '$lib/stores/settings.svelte';
import { mcpResourceStore } from '$lib/stores/mcp-resources.svelte';
import { serverStore } from '$lib/stores/server.svelte';
import { conversationsStore } from '$lib/stores/conversations.svelte';
import { mode } from 'mode-watcher';
import {
parseMcpServerSettings,
@@ -48,10 +50,11 @@ import {
EXPECTED_THEMED_ICON_PAIR_COUNT,
MCP_ALLOWED_ICON_MIME_TYPES,
MCP_SERVER_ID_PREFIX,
MCP_RECONNECT_INITIAL_DELAY,
MCP_RECONNECT_BACKOFF_MULTIPLIER,
MCP_RECONNECT_INITIAL_DELAY,
MCP_RECONNECT_MAX_DELAY,
MCP_RECONNECT_ATTEMPT_TIMEOUT_MS
MCP_RECONNECT_ATTEMPT_TIMEOUT_MS,
RECOMMENDED_MCP_SERVER_IDS
} from '$lib/constants';
import type {
MCPToolCall,
@@ -70,6 +73,7 @@ import type {
Tool,
HealthCheckState,
MCPServerSettingsEntry,
MCPServerDisplayInfo,
MCPServerConfig,
MCPResourceIcon,
MCPResourceAttachment,
@@ -365,7 +369,7 @@ class MCPStore {
return this.connections;
}
getServerLabel(server: MCPServerSettingsEntry): string {
getServerLabel(server: MCPServerDisplayInfo): string {
const healthState = this.getHealthCheckState(server.id);
if (healthState?.status === HealthCheckStatus.SUCCESS)
@@ -527,7 +531,7 @@ class MCPStore {
addServer(
serverData: Omit<MCPServerSettingsEntry, 'id' | 'requestTimeoutSeconds'> & { id?: string }
): void {
): MCPServerSettingsEntry {
const servers = this.getServers();
const newServer: MCPServerSettingsEntry = {
id: serverData.id || (uuid() ?? `server-${Date.now()}`),
@@ -540,6 +544,7 @@ class MCPStore {
useProxy: serverData.useProxy
};
settingsStore.updateConfig(SETTINGS_KEYS.MCP_SERVERS, JSON.stringify([...servers, newServer]));
return newServer;
}
updateServer(id: string, updates: Partial<MCPServerSettingsEntry>): void {
@@ -576,6 +581,33 @@ class MCPStore {
});
}
/**
* Recommended MCP server IDs the user opted in to via per-chat overrides.
* Single source of truth for "which recommendations has the user accepted",
* shared by the recommendations hook and the visible-servers getter.
*/
get optedInRecommendationIds(): ReadonlySet<string> {
const ids = new SvelteSet<string>();
for (const override of conversationsStore.pendingMcpServerOverrides) {
if (RECOMMENDED_MCP_SERVER_IDS.has(override.serverId) && override.enabled) {
ids.add(override.serverId);
}
}
return ids;
}
/**
* MCP servers selectable in chat-add UIs and the settings page:
* enabled in settings and either non-recommended or explicitly opted in.
*/
get visibleMcpServers(): MCPServerSettingsEntry[] {
const optedIn = this.optedInRecommendationIds;
return this.getServersSorted().filter(
(server) =>
server.enabled && (!RECOMMENDED_MCP_SERVER_IDS.has(server.id) || optedIn.has(server.id))
);
}
async ensureInitialized(perChatOverrides?: McpServerOverride[]): Promise<boolean> {
if (!browser) {
return false;
+1
View File
@@ -265,6 +265,7 @@ export interface ApiChatCompletionRequest {
stream?: boolean;
model?: string;
return_progress?: boolean;
sse_ping_interval?: number;
tools?: ApiChatCompletionTool[];
// Reasoning parameters
reasoning_format?: string;
+2
View File
@@ -127,6 +127,8 @@ export type {
MCPServerConfig,
MCPClientConfig,
MCPServerSettingsEntry,
MCPServerDisplayInfo,
RecommendedMCPServer,
MCPToolCall,
OpenAIToolDefinition,
ServerStatus,
+18 -3
View File
@@ -209,17 +209,32 @@ export type MCPToolCall = {
};
};
export type MCPServerSettingsEntry = {
/**
* Minimum fields needed to display or identify an MCP server.
*/
export interface MCPServerDisplayInfo {
id: string;
enabled: boolean;
name?: string;
url: string;
}
export type MCPServerSettingsEntry = MCPServerDisplayInfo & {
enabled: boolean;
requestTimeoutSeconds: number;
headers?: string;
name?: string;
iconUrl?: string;
useProxy?: boolean;
};
/**
* Pre-defined recommended MCP server shown to the user in onboarding/picker UIs.
*/
export interface RecommendedMCPServer extends MCPServerDisplayInfo {
description: string;
enabled: boolean;
requestTimeoutSeconds: number;
}
export interface MCPHostManagerConfig {
servers: MCPClientConfig['servers'];
clientInfo?: Implementation;
+39 -95
View File
@@ -92,18 +92,14 @@ export function filterByLeafNodeId(
* Finds the leaf node (message with no children) for a given message branch.
* Traverses down the tree following the last child until reaching a leaf.
*
* @param messages - All messages in the conversation
* @param nodeMap - Map of messages keyed by ID
* @param messageId - Starting message ID to find leaf for
* @returns The leaf node ID, or the original messageId if no children
*/
export function findLeafNode(messages: readonly DatabaseMessage[], messageId: string): string {
const nodeMap = new Map<string, DatabaseMessage>();
// Build node map for quick lookups
for (const msg of messages) {
nodeMap.set(msg.id, msg);
}
function findLeafNodeInMap(
nodeMap: ReadonlyMap<string, DatabaseMessage>,
messageId: string
): string {
let currentNode: DatabaseMessage | undefined = nodeMap.get(messageId);
while (currentNode && currentNode.children.length > 0) {
// Follow the last child (most recent branch)
@@ -114,6 +110,22 @@ export function findLeafNode(messages: readonly DatabaseMessage[], messageId: st
return currentNode?.id ?? messageId;
}
/**
* Convenience wrapper around {@link findLeafNodeInMap} for callers that only have
* a flat message array.
*
* Finds the leaf node (message with no children) for a given message branch.
* Traverses down the tree following the last child until reaching a leaf.
*
* @param messages - All messages in the conversation
* @param messageId - Starting message ID to find leaf for
* @returns The leaf node ID, or the original messageId if no children
*/
export function findLeafNode(messages: readonly DatabaseMessage[], messageId: string): string {
const nodeMap = new Map(messages.map((msg) => [msg.id, msg] as const));
return findLeafNodeInMap(nodeMap, messageId);
}
/**
* Finds all descendant messages (children, grandchildren, etc.) of a given message.
* This is used for cascading deletion to remove all messages in a branch.
@@ -156,21 +168,14 @@ export function findDescendantMessages(
* Gets sibling information for a message, including all sibling IDs and current position.
* Siblings are messages that share the same parent.
*
* @param messages - All messages in the conversation
* @param nodeMap - Map of messages keyed by ID
* @param messageId - The message to get sibling info for
* @returns Sibling information including leaf node IDs for navigation
*/
export function getMessageSiblings(
messages: readonly DatabaseMessage[],
nodeMap: ReadonlyMap<string, DatabaseMessage>,
messageId: string
): ChatMessageSiblingInfo | null {
const nodeMap = new Map<string, DatabaseMessage>();
// Build node map for quick lookups
for (const msg of messages) {
nodeMap.set(msg.id, msg);
}
const message = nodeMap.get(messageId);
if (!message) {
return null;
@@ -203,7 +208,9 @@ export function getMessageSiblings(
// Convert sibling message IDs to their corresponding leaf node IDs
// This allows navigation between different conversation branches
const siblingLeafIds = siblingIds.map((siblingId: string) => findLeafNode(messages, siblingId));
const siblingLeafIds = siblingIds.map((siblingId: string) =>
findLeafNodeInMap(nodeMap, siblingId)
);
// Find current message's position among siblings
const currentIndex = siblingIds.indexOf(messageId);
@@ -217,85 +224,22 @@ export function getMessageSiblings(
}
/**
* Creates a display-ready list of messages with sibling information for UI rendering.
* This is the main function used by chat components to render conversation branches.
* Builds sibling information for every message in a conversation.
* A single node map is shared across all lookups for O(1) access.
*
* @param messages - All messages in the conversation
* @param leafNodeId - Current leaf node being viewed
* @returns Array of messages with sibling navigation info
* @returns Map of message ID to its sibling information
*/
export function getMessageDisplayList(
messages: readonly DatabaseMessage[],
leafNodeId: string
): ChatMessageSiblingInfo[] {
// Get the current conversation path
const currentPath = filterByLeafNodeId(messages, leafNodeId, true);
const result: ChatMessageSiblingInfo[] = [];
// Add sibling info for each message in the current path
for (const message of currentPath) {
if (message.type === 'root') {
continue; // Skip root messages in display
}
const siblingInfo = getMessageSiblings(messages, message.id);
if (siblingInfo) {
result.push(siblingInfo);
export function buildSiblingInfoMap(
messages: readonly DatabaseMessage[]
): Map<string, ChatMessageSiblingInfo> {
const nodeMap = new Map(messages.map((msg) => [msg.id, msg] as const));
const siblingMap = new Map<string, ChatMessageSiblingInfo>();
for (const msg of messages) {
const info = getMessageSiblings(nodeMap, msg.id);
if (info) {
siblingMap.set(msg.id, info);
}
}
return result;
}
/**
* Checks if a message has multiple siblings (indicating branching at that point).
*
* @param messages - All messages in the conversation
* @param messageId - The message to check
* @returns True if the message has siblings
*/
export function hasMessageSiblings(
messages: readonly DatabaseMessage[],
messageId: string
): boolean {
const siblingInfo = getMessageSiblings(messages, messageId);
return siblingInfo ? siblingInfo.totalSiblings > 1 : false;
}
/**
* Gets the next sibling message ID for navigation.
*
* @param messages - All messages in the conversation
* @param messageId - Current message ID
* @returns Next sibling's leaf node ID, or null if at the end
*/
export function getNextSibling(
messages: readonly DatabaseMessage[],
messageId: string
): string | null {
const siblingInfo = getMessageSiblings(messages, messageId);
if (!siblingInfo || siblingInfo.currentIndex >= siblingInfo.totalSiblings - 1) {
return null;
}
return siblingInfo.siblingIds[siblingInfo.currentIndex + 1];
}
/**
* Gets the previous sibling message ID for navigation.
*
* @param messages - All messages in the conversation
* @param messageId - Current message ID
* @returns Previous sibling's leaf node ID, or null if at the beginning
*/
export function getPreviousSibling(
messages: readonly DatabaseMessage[],
messageId: string
): string | null {
const siblingInfo = getMessageSiblings(messages, messageId);
if (!siblingInfo || siblingInfo.currentIndex <= 0) {
return null;
}
return siblingInfo.siblingIds[siblingInfo.currentIndex - 1];
return siblingMap;
}
+1 -4
View File
@@ -26,10 +26,7 @@ export {
findLeafNode,
findDescendantMessages,
getMessageSiblings,
getMessageDisplayList,
hasMessageSiblings,
getNextSibling,
getPreviousSibling
buildSiblingInfoMap
} from './branching';
// Code
+9
View File
@@ -8,6 +8,7 @@
import { onMount } from 'svelte';
import { SidebarNavigation, DialogConversationTitleUpdate } from '$lib/components/app';
import { DialogMcpServerRecommendations } from '$lib/components/app/dialogs';
import { PwaMetaTags, PwaRefreshAlert } from '$lib/components/pwa';
import { pwaAssetsHead } from 'virtual:pwa-assets/head';
@@ -26,6 +27,7 @@
import { FAVICON_PATHS, FAVICON_SELECTORS } from '$lib/constants/pwa';
import { useKeyboardShortcuts } from '$lib/hooks/use-keyboard-shortcuts.svelte';
import { usePwa } from '$lib/hooks/use-pwa.svelte';
import { useMcpRecommendations } from '$lib/hooks/use-mcp-recommendations.svelte';
import { conversations } from '$lib/stores/conversations.svelte';
import { isMobile } from '$lib/stores/viewport.svelte';
import { theme } from '$lib/stores/theme.svelte';
@@ -37,6 +39,8 @@
let innerHeight = $state<number | undefined>();
let innerWidth = $state(browser ? window.innerWidth : 0);
const mcpRecommendations = useMcpRecommendations();
let chatSidebar:
| {
activateSearchMode?: () => void;
@@ -321,6 +325,11 @@
onConfirm={handleTitleUpdateConfirm}
onCancel={handleTitleUpdateCancel}
/>
<DialogMcpServerRecommendations
open={mcpRecommendations.open}
onOpenChange={mcpRecommendations.handleOpenChange}
/>
</Tooltip.Provider>
<!-- PWA update prompt + version -->
@@ -0,0 +1,37 @@
<script lang="ts">
import { untrack } from 'svelte';
import McpServerForm from '$lib/components/app/mcp/McpServerForm.svelte';
interface Props {
headers?: string;
}
let { headers = '' }: Props = $props();
let headersState = $state(untrack(() => headers));
let lastCapturedHeaders = $state(untrack(() => headers));
$effect(() => {
if (headers !== lastCapturedHeaders) {
headersState = headers;
lastCapturedHeaders = headers;
}
});
</script>
<!--
Drives McpServerForm with a controlled `headers` string and exposes the
latest captured value through `data-captured-headers` so the client test
can read it back without a custom binding API.
-->
<McpServerForm
url="https://example.test/mcp"
headers={headersState}
onUrlChange={() => {}}
onHeadersChange={(value) => {
headersState = value;
}}
id="mcp-server-form-test"
/>
<div data-testid="captured-headers" data-captured-headers={headersState} hidden></div>
@@ -0,0 +1,133 @@
import { describe, expect, it } from 'vitest';
import { render } from 'vitest-browser-svelte';
import McpServerFormWrapper from './components/McpServerFormWrapper.svelte';
const AUTHORIZATION_HEADER = 'Authorization';
const BEARER_PREFIX = 'Bearer ';
const BEARER_PLACEHOLDER = 'Paste token here';
/**
* Client-side tests for the McpServerForm bearer UI.
*
* The dedicated UI only "owns" Authorization headers that already carry a
* Bearer scheme (heuristic check on the value). Other Authorization values
* stay in the KV section so the user can still edit them verbatim. Storage
* always goes through the same custom-headers slot, so a round-trip via this
* UI produces exactly one `Authorization: Bearer <token>` entry.
*
* Equivalent parser coverage lives in `tests/unit/headers.test.ts`.
*/
describe('McpServerForm - Authorization / bearer UI', () => {
function bearerInput(screen: Awaited<ReturnType<typeof render>>) {
return screen.locator.getByPlaceholder(BEARER_PLACEHOLDER);
}
function capturedHeaders(screen: Awaited<ReturnType<typeof render>>) {
return screen.getByTestId('captured-headers');
}
it('mounts with the bearer input hidden when no auth header is present', async () => {
const screen = await render(McpServerFormWrapper, { headers: '' });
await expect.element(screen.getByRole('textbox', { name: /server url/i })).toBeVisible();
await expect.element(bearerInput(screen)).not.toBeInTheDocument();
});
it('toggling Authorization shows the bearer input', async () => {
const screen = await render(McpServerFormWrapper, { headers: '' });
await screen.getByRole('switch', { name: /authorization/i }).click();
await expect.element(bearerInput(screen)).toBeVisible();
});
it('typing a token writes the Authorization row with the Bearer prefix prepended', async () => {
const screen = await render(McpServerFormWrapper, { headers: '' });
await screen.getByRole('switch', { name: /authorization/i }).click();
const token = 'super-secret';
await bearerInput(screen).fill(token);
const expected = JSON.stringify({ [AUTHORIZATION_HEADER]: `${BEARER_PREFIX}${token}` });
await expect
.element(capturedHeaders(screen))
.toHaveAttribute('data-captured-headers', expected);
});
it('pre-existing Bearer header pre-fills the bearer input with the token stripped', async () => {
const existing = JSON.stringify({
'X-Trace-Id': 'abc',
[AUTHORIZATION_HEADER]: `${BEARER_PREFIX}preexisting`
});
const screen = await render(McpServerFormWrapper, { headers: existing });
await expect.element(bearerInput(screen)).toBeVisible();
await expect.element(bearerInput(screen)).toHaveValue('preexisting');
});
it('non-Bearer Authorization is ignored by the dedicated UI and stays in the KV section', async () => {
const existing = JSON.stringify({ [AUTHORIZATION_HEADER]: 'Basic czNjcjpwYXNz' });
const screen = await render(McpServerFormWrapper, { headers: existing });
await expect.element(bearerInput(screen)).not.toBeInTheDocument();
const headerKeyInput = screen.getByPlaceholder('Header name');
await expect.element(headerKeyInput).toBeVisible();
});
it('engaging the token UI replaces a non-Bearer Authorization with the Bearer scheme', async () => {
const existing = JSON.stringify({ [AUTHORIZATION_HEADER]: 'Basic old' });
const screen = await render(McpServerFormWrapper, { headers: existing });
await screen.getByRole('switch', { name: /authorization/i }).click();
await bearerInput(screen).fill('new');
const expected = JSON.stringify({ [AUTHORIZATION_HEADER]: `${BEARER_PREFIX}new` });
await expect
.element(capturedHeaders(screen))
.toHaveAttribute('data-captured-headers', expected);
});
it('toggling Authorization off with no token drops the Bearer row but keeps non-Bearer schemes', async () => {
const existing = JSON.stringify({ [AUTHORIZATION_HEADER]: `${BEARER_PREFIX}xyz` });
const screen = await render(McpServerFormWrapper, { headers: existing });
await screen.getByRole('switch', { name: /authorization/i }).click();
await expect.element(capturedHeaders(screen)).toHaveAttribute('data-captured-headers', '');
});
it('toggling Authorization off when no Bearer row is present leaves headers untouched', async () => {
const existing = JSON.stringify({ [AUTHORIZATION_HEADER]: 'Basic czNjcjpwYXNz' });
const screen = await render(McpServerFormWrapper, { headers: existing });
await screen.getByRole('switch', { name: /authorization/i }).click();
await screen.getByRole('switch', { name: /authorization/i }).click();
await expect
.element(capturedHeaders(screen))
.toHaveAttribute('data-captured-headers', existing);
});
it('clearing the bearer input drops the Authorization row', async () => {
const existing = JSON.stringify({ [AUTHORIZATION_HEADER]: `${BEARER_PREFIX}xyz` });
const screen = await render(McpServerFormWrapper, { headers: existing });
await bearerInput(screen).fill('');
await expect.element(capturedHeaders(screen)).toHaveAttribute('data-captured-headers', '');
});
it('does not surface Bearer Authorization in the KV section even when pre-existing', async () => {
const existing = JSON.stringify({ [AUTHORIZATION_HEADER]: `${BEARER_PREFIX}xyz` });
const screen = await render(McpServerFormWrapper, { headers: existing });
const headerKeyInput = screen.getByPlaceholder('Header name');
await expect.element(headerKeyInput).not.toBeInTheDocument();
});
});
+126
View File
@@ -0,0 +1,126 @@
import { describe, expect, it } from 'vitest';
import { parseHeadersToArray, serializeHeaders } from '$lib/utils/headers';
/**
* Tests for the header serialization helpers used by the MCP server form
* (custom header rows) and the new Authorization/Bearer-token flow.
*/
describe('parseHeadersToArray', () => {
it('returns an empty array for empty or whitespace-only input', () => {
expect(parseHeadersToArray('')).toEqual([]);
expect(parseHeadersToArray(' ')).toEqual([]);
expect(parseHeadersToArray(undefined as unknown as string)).toEqual([]);
});
it('returns an empty array for invalid JSON input', () => {
expect(parseHeadersToArray('{not-json')).toEqual([]);
expect(parseHeadersToArray('[]')).toEqual([]);
expect(parseHeadersToArray('"plain-string"')).toEqual([]);
});
it('converts an object into ordered key/value pairs', () => {
expect(parseHeadersToArray('{"X-Foo":"bar","Authorization":"Bearer abc"}')).toEqual([
{ key: 'X-Foo', value: 'bar' },
{ key: 'Authorization', value: 'Bearer abc' }
]);
});
it('stringifies non-string values', () => {
expect(parseHeadersToArray('{"count":"42","flag":"true"}')).toEqual([
{ key: 'count', value: '42' },
{ key: 'flag', value: 'true' }
]);
});
});
describe('serializeHeaders', () => {
it('returns an empty string when there are no valid pairs', () => {
expect(serializeHeaders([])).toBe('');
expect(serializeHeaders([{ key: '', value: 'value' }])).toBe('');
expect(serializeHeaders([{ key: ' ', value: 'value' }])).toBe('');
});
it('returns an empty string when every pair has a blank key', () => {
expect(
serializeHeaders([
{ key: '', value: 'drop-me' },
{ key: ' ', value: 'drop-me-too' },
{ key: '\t', value: 'tab-key' }
])
).toBe('');
});
it('drops pairs with empty keys but keeps the rest', () => {
expect(
serializeHeaders([
{ key: '', value: 'drop-me' },
{ key: 'X-Keep', value: 'ok' }
])
).toBe('{"X-Keep":"ok"}');
});
it('trims keys before serializing', () => {
expect(serializeHeaders([{ key: ' X-Space ', value: 'ok' }])).toBe('{"X-Space":"ok"}');
});
it('preserves the input order of surviving pairs', () => {
const serialized = serializeHeaders([
{ key: 'X-C', value: '3' },
{ key: 'X-A', value: '1' },
{ key: 'X-B', value: '2' }
]);
// Object key order follows insertion order in modern JS engines, so
// the serialized JSON writes keys in our input order.
expect(JSON.parse(serialized)).toEqual({ 'X-C': '3', 'X-A': '1', 'X-B': '2' });
});
});
describe('parseHeadersToArray / serializeHeaders roundtrip', () => {
it('serializes back to an equal header object after a parse', () => {
const original = JSON.stringify({
'Content-Type': 'application/json',
'X-Trace-Id': 'abc-123'
});
const roundtrip = serializeHeaders(parseHeadersToArray(original));
expect(JSON.parse(roundtrip)).toEqual(JSON.parse(original));
});
it('drops rows whose keys are blank after trimming during serialization', () => {
const pairs = parseHeadersToArray('{"X-Keep":"ok","":"drop-me"}');
// parseHeadersToArray keeps raw key strings (the consumer is expected to
// filter blanks, not the parser); serialization must strip them.
expect(pairs).toEqual([
{ key: 'X-Keep', value: 'ok' },
{ key: '', value: 'drop-me' }
]);
expect(serializeHeaders(pairs)).toBe('{"X-Keep":"ok"}');
});
it('preserves upstream keys untouched (does not lowercase them)', () => {
const upperCased = '{"Authorization":"Bearer xyz"}';
const parsed = parseHeadersToArray(upperCased);
expect(parsed).toEqual([{ key: 'Authorization', value: 'Bearer xyz' }]);
});
it('bearer-token write survives a re-parse when paired with regular custom headers', () => {
// The McpServerForm bearer UI writes {Authorization: `Bearer <token>`}
// into the same headers string as the custom KV section. The round
// trip below mirrors the exact shape the form produces so a future
// refactor of either code path cannot silently change the on-disk key.
const pairs = [
{ key: 'X-Trace-Id', value: 'abc-123' },
{ key: 'Authorization', value: 'Bearer super-secret' }
];
const serialized = serializeHeaders(pairs);
expect(serialized).toBe('{"X-Trace-Id":"abc-123","Authorization":"Bearer super-secret"}');
expect(parseHeadersToArray(serialized)).toEqual(pairs);
});
});
@@ -0,0 +1,144 @@
import { describe, expect, it, vi } from 'vitest';
import { parseMcpServerSettings } from '$lib/utils/mcp';
import { DEFAULT_MCP_CONFIG, MCP_SERVER_ID_PREFIX } from '$lib/constants/mcp';
/**
* Tests for the mcpServers settings parser.
*
* The branch seeds the MCP servers setting with a default value of
* `JSON.stringify(RECOMMENDED_MCP_SERVERS)`, so the parser has to be
* resilient to anything that may live in the user's localStorage: malformed
* JSON, wrong shapes, missing fields, falsy-but-not-zero numbers, and entry
* arrays that have been mutated by the user via the settings form.
*/
describe('parseMcpServerSettings', () => {
it('returns an empty array for falsy or whitespace-only input', () => {
expect(parseMcpServerSettings(null)).toEqual([]);
expect(parseMcpServerSettings(undefined)).toEqual([]);
expect(parseMcpServerSettings('')).toEqual([]);
expect(parseMcpServerSettings(' ')).toEqual([]);
});
it('returns an empty array and logs a warning for invalid JSON strings', () => {
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {});
expect(parseMcpServerSettings('{not-json')).toEqual([]);
expect(warn).toHaveBeenCalled();
warn.mockRestore();
});
it('returns an empty array for valid JSON that is not an array', () => {
expect(parseMcpServerSettings('"plain-string"')).toEqual([]);
expect(parseMcpServerSettings('{"id":"foo"}')).toEqual([]);
expect(parseMcpServerSettings('42')).toEqual([]);
expect(parseMcpServerSettings('null')).toEqual([]);
});
it('drops entries with no parseable id and substitutes a stable fallback', () => {
const parsed = parseMcpServerSettings(
JSON.stringify([{ url: 'https://a.test', enabled: true }, { url: 'https://b.test' }])
);
expect(parsed).toHaveLength(2);
expect(parsed[0]?.id).toBe(`${MCP_SERVER_ID_PREFIX}-1`);
expect(parsed[1]?.id).toBe(`${MCP_SERVER_ID_PREFIX}-2`);
});
it('reuses the first id when it is present and falls back only for missing ones', () => {
const parsed = parseMcpServerSettings(
JSON.stringify([
{ id: 'custom-1', url: 'https://a.test' },
{ url: 'https://b.test' },
{ id: 'custom-3', url: 'https://c.test' }
])
);
expect(parsed[0]?.id).toBe('custom-1');
expect(parsed[1]?.id).toBe(`${MCP_SERVER_ID_PREFIX}-2`);
expect(parsed[2]?.id).toBe('custom-3');
});
it('falls back to the configured default requestTimeoutSeconds only for nullish values', () => {
const fallback = DEFAULT_MCP_CONFIG.requestTimeoutSeconds;
const parsed = parseMcpServerSettings(
JSON.stringify([
{ id: 'a', url: 'https://a.test' },
{ id: 'b', url: 'https://b.test', requestTimeoutSeconds: undefined },
{ id: 'c', url: 'https://c.test', requestTimeoutSeconds: 0 },
{ id: 'd', url: 'https://d.test', requestTimeoutSeconds: 45 }
])
);
// The parser uses ?? for timeout fallback, which only triggers on
// null/undefined. Explicit 0 is preserved at face value.
expect(parsed[0]?.requestTimeoutSeconds).toBe(fallback);
expect(parsed[1]?.requestTimeoutSeconds).toBe(fallback);
expect(parsed[2]?.requestTimeoutSeconds).toBe(0);
expect(parsed[3]?.requestTimeoutSeconds).toBe(45);
});
it('treats whitespace-only headers strings as undefined', () => {
const parsed = parseMcpServerSettings(
JSON.stringify([
{ id: 'a', url: 'https://a.test', headers: ' ' },
{ id: 'b', url: 'https://b.test', headers: '{"X-Foo":"bar"}' }
])
);
// The parser trims headers and coerces empty/whitespace to undefined.
expect(parsed[0]?.headers).toBeUndefined();
expect(parsed[1]?.headers).toBe('{"X-Foo":"bar"}');
});
it('defaults coercion for booleans (undefined -> false, true -> true)', () => {
const parsed = parseMcpServerSettings(
JSON.stringify([
{ id: 'a', url: 'https://a.test' },
{ id: 'b', url: 'https://b.test', enabled: true },
{ id: 'c', url: 'https://c.test', enabled: false },
{ id: 'd', url: 'https://d.test', useProxy: true }
])
);
expect(parsed[0]?.enabled).toBe(false);
expect(parsed[1]?.enabled).toBe(true);
expect(parsed[2]?.enabled).toBe(false);
expect(parsed[0]?.useProxy).toBe(false);
expect(parsed[3]?.useProxy).toBe(true);
});
it('preserves input order when mapping entries', () => {
const source = [
{ id: 'gamma', url: 'https://c.test' },
{ id: 'alpha', url: 'https://a.test' },
{ id: 'beta', url: 'https://b.test' }
];
const parsed = parseMcpServerSettings(JSON.stringify(source));
expect(parsed.map((entry) => entry.id)).toEqual(['gamma', 'alpha', 'beta']);
});
it('passes non-string raw input through the JSON-equality path', () => {
const parsed = parseMcpServerSettings([
{ id: 'a', url: 'https://a.test' },
{ id: 'b', url: 'https://b.test', enabled: true }
]);
expect(parsed).toHaveLength(2);
expect(parsed[0]?.id).toBe('a');
expect(parsed[1]?.enabled).toBe(true);
});
it('coerces non-string url values to an empty string rather than throwing', () => {
const parsed = parseMcpServerSettings(
JSON.stringify([{ id: 'a', url: 42 }, { id: 'b' }, { id: 'c', url: 'https://c.test' }])
);
expect(parsed[0]?.url).toBe('');
expect(parsed[1]?.url).toBe('');
expect(parsed[2]?.url).toBe('https://c.test');
});
});
@@ -0,0 +1,90 @@
import { describe, expect, it } from 'vitest';
import {
RECOMMENDED_MCP_SERVER_IDS,
RECOMMENDED_MCP_SERVERS
} from '$lib/constants/recommended-mcp-servers';
import { parseMcpServerSettings } from '$lib/utils/mcp';
import { DEFAULT_MCP_CONFIG, MCP_SERVER_ID_PREFIX } from '$lib/constants/mcp';
/**
* Tests for the predefined recommended MCP servers.
*
* These are surfaced to first-time users via
* DialogMcpServerRecommendations and used as the default value of the MCP
* servers setting, so a regression that breaks the round-trip through the
* settings parser would silently break onboarding for new users.
*/
describe('RECOMMENDED_MCP_SERVERS', () => {
it('lists at least one entry and uses stable, unique ids', () => {
expect(RECOMMENDED_MCP_SERVERS.length).toBeGreaterThan(0);
const ids = RECOMMENDED_MCP_SERVERS.map((server) => server.id);
expect(new Set(ids).size).toBe(ids.length);
for (const id of ids) {
expect(id).toMatch(/^[a-z0-9-]+$/);
expect(id.toLowerCase()).not.toContain(MCP_SERVER_ID_PREFIX.toLowerCase());
}
});
it('requires a name, description and url for every entry', () => {
for (const server of RECOMMENDED_MCP_SERVERS) {
expect(server.name?.trim().length ?? 0).toBeGreaterThan(0);
expect(server.description.trim().length).toBeGreaterThan(0);
expect(server.url.trim().length).toBeGreaterThan(0);
expect(() => new URL(server.url)).not.toThrow();
}
});
});
describe('RECOMMENDED_MCP_SERVER_IDS', () => {
it('matches the ids declared in RECOMMENDED_MCP_SERVERS', () => {
expect(RECOMMENDED_MCP_SERVER_IDS.size).toBe(RECOMMENDED_MCP_SERVERS.length);
for (const server of RECOMMENDED_MCP_SERVERS) {
expect(RECOMMENDED_MCP_SERVER_IDS.has(server.id)).toBe(true);
}
});
});
describe('recommended-mcp-servers default value', () => {
it('round-trips cleanly through parseMcpServerSettings', () => {
const serialized = JSON.stringify(RECOMMENDED_MCP_SERVERS);
const parsed = parseMcpServerSettings(serialized);
expect(parsed).toHaveLength(RECOMMENDED_MCP_SERVERS.length);
for (let index = 0; index < RECOMMENDED_MCP_SERVERS.length; index++) {
const source = RECOMMENDED_MCP_SERVERS[index];
const entry = parsed[index];
expect(entry).toBeDefined();
expect(entry?.id).toBe(source.id);
expect(entry?.url).toBe(source.url);
expect(entry?.enabled).toBe(source.enabled);
expect(entry?.requestTimeoutSeconds).toBe(source.requestTimeoutSeconds);
expect(entry?.name).toBe(source.name);
// Headers and useProxy are not set on recommended servers; the
// parser must fall back to the inactive defaults rather than
// surfacing undefined-boundary states.
expect(entry?.headers).toBeUndefined();
expect(entry?.useProxy).toBe(false);
}
});
it('uses the global default timeout when one is not specified on an entry', () => {
const sourceOnlyRequired = {
id: 'roundtrip-only',
name: 'Only required fields',
url: 'https://example.test/mcp',
description: 'Smoke entry for parser roundtrip with default timeout.',
enabled: true
};
const parsed = parseMcpServerSettings(JSON.stringify([sourceOnlyRequired]));
const entry = parsed[0];
expect(entry?.requestTimeoutSeconds).toBe(DEFAULT_MCP_CONFIG.requestTimeoutSeconds);
});
});
+127 -35
View File
@@ -478,7 +478,7 @@ bool set_socket_opt_time(socket_t sock, int level, int optname,
}
bool is_hex(char c, int &v) {
if (isdigit(static_cast<unsigned char>(c))) {
if (is_ascii_digit(c)) {
v = c - '0';
return true;
} else if ('A' <= c && c <= 'F') {
@@ -695,7 +695,11 @@ std::string base64_encode(const std::string &in) {
std::string out;
out.reserve(in.size());
auto val = 0;
// Unsigned: the accumulator is never masked, so with a signed int the
// `val << 8` below overflows once enough bytes are folded in (undefined
// behaviour before C++20). Only the low bits are ever emitted, so the
// wrap-around of an unsigned accumulator does not affect the output.
uint32_t val = 0;
auto valb = -6;
for (auto c : in) {
@@ -3887,8 +3891,7 @@ bool parse_range_header(const std::string &s, Ranges &ranges) {
bool parse_range_header(const std::string &s, Ranges &ranges) try {
#endif
auto is_valid = [](const std::string &str) {
return std::all_of(str.cbegin(), str.cend(),
[](unsigned char c) { return std::isdigit(c); });
return std::all_of(str.cbegin(), str.cend(), is_ascii_digit);
};
if (s.size() > 7 && s.compare(0, 6, "bytes=") == 0) {
@@ -4336,7 +4339,7 @@ bool is_multipart_boundary_chars_valid(const std::string &boundary) {
auto valid = true;
for (size_t i = 0; i < boundary.size(); i++) {
auto c = boundary[i];
if (!std::isalnum(static_cast<unsigned char>(c)) && c != '-' && c != '_') {
if (!is_ascii_alnum(c) && c != '-' && c != '_') {
valid = false;
break;
}
@@ -4344,18 +4347,47 @@ bool is_multipart_boundary_chars_valid(const std::string &boundary) {
return valid;
}
// Escape a multipart field name/filename following the WHATWG HTML standard
// ("escape a multipart form-data name"), which is what browsers send:
// '"' -> %22, CR -> %0D, LF -> %0A
// With escape_quote = false, only CR and LF are escaped; this is for header
// values outside a quoted-string (e.g. Content-Type), where '"' is legal.
std::string escape_multipart_field(const std::string &s,
bool escape_quote = true) {
std::string result;
result.reserve(s.size());
for (auto c : s) {
switch (c) {
case '"':
if (escape_quote) {
result += "%22";
} else {
result += c;
}
break;
case '\r': result += "%0D"; break;
case '\n': result += "%0A"; break;
default: result += c; break;
}
}
return result;
}
template <typename T>
std::string
serialize_multipart_formdata_item_begin(const T &item,
const std::string &boundary) {
std::string body = "--" + boundary + "\r\n";
body += "Content-Disposition: form-data; name=\"" + item.name + "\"";
body += "Content-Disposition: form-data; name=\"" +
escape_multipart_field(item.name) + "\"";
if (!item.filename.empty()) {
body += "; filename=\"" + item.filename + "\"";
body += "; filename=\"" + escape_multipart_field(item.filename) + "\"";
}
body += "\r\n";
if (!item.content_type.empty()) {
body += "Content-Type: " + item.content_type + "\r\n";
body +=
"Content-Type: " + escape_multipart_field(item.content_type, false) +
"\r\n";
}
body += "\r\n";
@@ -4821,10 +4853,9 @@ private:
namespace fields {
bool is_token_char(char c) {
return std::isalnum(static_cast<unsigned char>(c)) || c == '!' || c == '#' ||
c == '$' || c == '%' || c == '&' || c == '\'' || c == '*' ||
c == '+' || c == '-' || c == '.' || c == '^' || c == '_' || c == '`' ||
c == '|' || c == '~';
return is_ascii_alnum(c) || c == '!' || c == '#' || c == '$' || c == '%' ||
c == '&' || c == '\'' || c == '*' || c == '+' || c == '-' ||
c == '.' || c == '^' || c == '_' || c == '`' || c == '|' || c == '~';
}
bool is_token(const std::string &s) {
@@ -4873,7 +4904,8 @@ bool is_field_value(const std::string &s) { return is_field_content(s); }
} // namespace fields
bool perform_websocket_handshake(Stream &strm, const std::string &host,
int port, const std::string &path,
int port, bool is_ssl,
const std::string &path,
const Headers &headers,
std::string &selected_subprotocol) {
// Validate path and host
@@ -4899,7 +4931,7 @@ bool perform_websocket_handshake(Stream &strm, const std::string &host,
// Build upgrade request
std::string req_str = "GET " + path + " HTTP/1.1\r\n";
req_str += "Host: " + host + ":" + std::to_string(port) + "\r\n";
req_str += "Host: " + make_host_and_port_string(host, port, is_ssl) + "\r\n";
req_str += "Upgrade: websocket\r\n";
req_str += "Connection: Upgrade\r\n";
req_str += "Sec-WebSocket-Key: " + client_key + "\r\n";
@@ -5599,9 +5631,8 @@ std::string encode_uri_component(const std::string &value) {
escaped << std::hex;
for (auto c : value) {
if (std::isalnum(static_cast<uint8_t>(c)) || c == '-' || c == '_' ||
c == '.' || c == '!' || c == '~' || c == '*' || c == '\'' || c == '(' ||
c == ')') {
if (detail::is_ascii_alnum(c) || c == '-' || c == '_' || c == '.' ||
c == '!' || c == '~' || c == '*' || c == '\'' || c == '(' || c == ')') {
escaped << c;
} else {
escaped << std::uppercase;
@@ -5620,10 +5651,10 @@ std::string encode_uri(const std::string &value) {
escaped << std::hex;
for (auto c : value) {
if (std::isalnum(static_cast<uint8_t>(c)) || c == '-' || c == '_' ||
c == '.' || c == '!' || c == '~' || c == '*' || c == '\'' || c == '(' ||
c == ')' || c == ';' || c == '/' || c == '?' || c == ':' || c == '@' ||
c == '&' || c == '=' || c == '+' || c == '$' || c == ',' || c == '#') {
if (detail::is_ascii_alnum(c) || c == '-' || c == '_' || c == '.' ||
c == '!' || c == '~' || c == '*' || c == '\'' || c == '(' || c == ')' ||
c == ';' || c == '/' || c == '?' || c == ':' || c == '@' || c == '&' ||
c == '=' || c == '+' || c == '$' || c == ',' || c == '#') {
escaped << c;
} else {
escaped << std::uppercase;
@@ -5684,7 +5715,8 @@ std::string encode_path_component(const std::string &component) {
auto c = static_cast<unsigned char>(component[i]);
// Unreserved characters per RFC 3986: ALPHA / DIGIT / "-" / "." / "_" / "~"
if (std::isalnum(c) || c == '-' || c == '.' || c == '_' || c == '~') {
if (detail::is_ascii_alnum(static_cast<char>(c)) || c == '-' || c == '.' ||
c == '_' || c == '~') {
result += static_cast<char>(c);
}
// Path-safe sub-delimiters: "!" / "$" / "&" / "'" / "(" / ")" / "*" / "+" /
@@ -5757,7 +5789,8 @@ std::string encode_query_component(const std::string &component,
auto c = static_cast<unsigned char>(component[i]);
// Unreserved characters per RFC 3986
if (std::isalnum(c) || c == '-' || c == '.' || c == '_' || c == '~') {
if (detail::is_ascii_alnum(static_cast<char>(c)) || c == '-' || c == '.' ||
c == '_' || c == '~') {
result += static_cast<char>(c);
}
// Space handling
@@ -6010,6 +6043,48 @@ size_t MultipartFormData::get_file_count(const std::string &key) const {
return static_cast<size_t>(std::distance(r.first, r.second));
}
// Multipart FormData writer implementation
bool is_valid_multipart_boundary(const std::string &boundary) {
return detail::is_multipart_boundary_chars_valid(boundary);
}
MultipartFormDataWriter::MultipartFormDataWriter()
: boundary_(detail::make_multipart_data_boundary()) {}
MultipartFormDataWriter::MultipartFormDataWriter(std::string boundary)
: boundary_(std::move(boundary)) {}
const std::string &MultipartFormDataWriter::boundary() const {
return boundary_;
}
std::string MultipartFormDataWriter::content_type() const {
return detail::serialize_multipart_formdata_get_content_type(boundary_);
}
std::string
MultipartFormDataWriter::serialize(const UploadFormDataItems &items) const {
return detail::serialize_multipart_formdata(items, boundary_);
}
size_t MultipartFormDataWriter::content_length(
const UploadFormDataItems &items) const {
return detail::get_multipart_content_length(items, boundary_);
}
std::string
MultipartFormDataWriter::item_begin(const UploadFormData &item) const {
return detail::serialize_multipart_formdata_item_begin(item, boundary_);
}
std::string MultipartFormDataWriter::item_end() {
return detail::serialize_multipart_formdata_item_end();
}
std::string MultipartFormDataWriter::finish() const {
return detail::serialize_multipart_formdata_finish(boundary_);
}
// Response implementation
size_t Response::get_header_value_u64(const std::string &key, size_t def,
size_t id) const {
@@ -6229,8 +6304,10 @@ ssize_t detail::BodyReader::read(char *buf, size_t len) {
}
// ThreadPool implementation
ThreadPool::ThreadPool(size_t n, size_t max_n, size_t mqr)
: base_thread_count_(n), max_queued_requests_(mqr), idle_thread_count_(0),
ThreadPool::ThreadPool(size_t n, size_t max_n, size_t mqr,
time_t idle_timeout_sec)
: base_thread_count_(n), max_queued_requests_(mqr),
idle_timeout_sec_(idle_timeout_sec), idle_thread_count_(0),
shutdown_(false) {
#ifndef CPPHTTPLIB_NO_EXCEPTIONS
if (max_n != 0 && max_n < n) {
@@ -6340,9 +6417,9 @@ void ThreadPool::worker(bool is_dynamic) {
idle_thread_count_++;
if (is_dynamic) {
auto has_work = cond_.wait_for(
lock, std::chrono::seconds(CPPHTTPLIB_THREAD_POOL_IDLE_TIMEOUT),
[&] { return !jobs_.empty() || shutdown_; });
auto has_work =
cond_.wait_for(lock, std::chrono::seconds(idle_timeout_sec_),
[&] { return !jobs_.empty() || shutdown_; });
if (!has_work) {
// Timed out with no work - exit this dynamic thread
idle_thread_count_--;
@@ -9687,9 +9764,18 @@ bool ClientImpl::write_request(Stream &strm, Request &req,
if (!query_part.empty()) {
// Normalize the query string (decode then re-encode) while preserving
// the original parameter order.
auto normalized = detail::normalize_query_string(query_part);
if (!normalized.empty()) { path_with_query += '?' + normalized; }
// the original parameter order. When path encoding is disabled the
// caller has supplied an already-encoded target and expects the exact
// bytes to be sent on the wire, so skip normalization for the query
// too. Normalizing here would decode-then-re-encode the query and
// corrupt pre-encoded binary payloads (e.g. turning `%20` into `+`,
// which a strict RFC 3986 server decodes back as `+`, not a space).
if (path_encode_) {
auto normalized = detail::normalize_query_string(query_part);
if (!normalized.empty()) { path_with_query += '?' + normalized; }
} else {
path_with_query += '?' + query_part;
}
// Still populate req.params for handlers/users who read them.
detail::parse_query_text(query_part, req.params);
@@ -12518,7 +12604,7 @@ bool is_ipv4_address(const std::string &str) {
for (char c : str) {
if (c == '.') {
dots++;
} else if (!isdigit(static_cast<unsigned char>(c))) {
} else if (!detail::is_ascii_digit(c)) {
return false;
}
}
@@ -12535,7 +12621,7 @@ bool parse_ipv4(const std::string &str, unsigned char *out) {
}
int val = 0;
int digits = 0;
while (*p >= '0' && *p <= '9') {
while (detail::is_ascii_digit(*p)) {
val = val * 10 + (*p - '0');
if (val > 255) { return false; }
p++;
@@ -16487,9 +16573,15 @@ bool WebSocketClient::connect() {
return false;
}
#ifdef CPPHTTPLIB_SSL_ENABLED
auto is_ssl = is_ssl_;
#else
auto is_ssl = false;
#endif
std::string selected_subprotocol;
if (!detail::perform_websocket_handshake(*strm, host_, port_, path_, headers_,
selected_subprotocol)) {
if (!detail::perform_websocket_handshake(*strm, host_, port_, is_ssl, path_,
headers_, selected_subprotocol)) {
shutdown_and_close();
return false;
}
+56 -12
View File
@@ -8,8 +8,8 @@
#ifndef CPPHTTPLIB_HTTPLIB_H
#define CPPHTTPLIB_HTTPLIB_H
#define CPPHTTPLIB_VERSION "0.48.0"
#define CPPHTTPLIB_VERSION_NUM "0x003000"
#define CPPHTTPLIB_VERSION "0.49.0"
#define CPPHTTPLIB_VERSION_NUM "0x003100"
#ifdef _WIN32
#if defined(_WIN32_WINNT) && _WIN32_WINNT < 0x0A00
@@ -309,7 +309,6 @@ using socket_t = int;
#include <array>
#include <atomic>
#include <cassert>
#include <cctype>
#include <chrono>
#include <climits>
#include <condition_variable>
@@ -540,6 +539,21 @@ make_unique(std::size_t n) {
return std::unique_ptr<T>(new RT[n]);
}
// Locale-independent ASCII character classification. The <cctype>
// counterparts (std::isalnum, std::isdigit, ...) consult the global C locale,
// so e.g. std::isalnum(0xC5) can return true once an embedder calls
// setlocale(). HTTP grammars are defined over ASCII, so raw bytes must be
// classified without regard to the locale.
inline bool is_ascii_digit(char c) { return '0' <= c && c <= '9'; }
inline bool is_ascii_alpha(char c) {
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z');
}
inline bool is_ascii_alnum(char c) {
return is_ascii_digit(c) || is_ascii_alpha(c);
}
namespace case_ignore {
inline unsigned char to_lower(int c) {
@@ -661,7 +675,7 @@ inline from_chars_result<T> from_chars(const char *first, const char *last,
for (; p != last; ++p) {
char c = *p;
int digit = -1;
if ('0' <= c && c <= '9') {
if (is_ascii_digit(c)) {
digit = c - '0';
} else if ('a' <= c && c <= 'z') {
digit = c - 'a' + 10;
@@ -733,14 +747,14 @@ inline from_chars_result<double> from_chars(const char *first, const char *last,
return false;
};
for (; p != last && '0' <= *p && *p <= '9'; ++p) {
for (; p != last && is_ascii_digit(*p); ++p) {
seen_digit = true;
accumulate(*p);
}
if (p != last && *p == '.') {
++p;
for (; p != last && '0' <= *p && *p <= '9'; ++p) {
for (; p != last && is_ascii_digit(*p); ++p) {
seen_digit = true;
if (frac_digits < max_frac_digits && accumulate(*p)) { ++frac_digits; }
}
@@ -803,8 +817,8 @@ inline bool parse_url(const std::string &url, UrlComponents &uc) {
// IPv6 host must be [a-fA-F0-9:]+ only
if (uc.host.empty()) { return false; }
for (auto c : uc.host) {
if (!((c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F') ||
(c >= '0' && c <= '9') || c == ':')) {
if (!(is_ascii_digit(c) || (c >= 'a' && c <= 'f') ||
(c >= 'A' && c <= 'F') || c == ':')) {
return false;
}
}
@@ -1541,7 +1555,9 @@ public:
class ThreadPool final : public TaskQueue {
public:
explicit ThreadPool(size_t n, size_t max_n = 0, size_t mqr = 0);
explicit ThreadPool(
size_t n, size_t max_n = 0, size_t mqr = 0,
time_t idle_timeout_sec = CPPHTTPLIB_THREAD_POOL_IDLE_TIMEOUT);
ThreadPool(const ThreadPool &) = delete;
~ThreadPool() override = default;
@@ -1556,6 +1572,7 @@ private:
size_t base_thread_count_;
size_t max_thread_count_;
size_t max_queued_requests_;
time_t idle_timeout_sec_;
size_t idle_thread_count_;
bool shutdown_;
@@ -1680,6 +1697,35 @@ make_multipart_content_provider(const UploadFormDataItems &items,
} // namespace detail
bool is_valid_multipart_boundary(const std::string &boundary);
// Serializer for multipart/form-data request bodies. The boundary is owned
// by the writer so that per-part framing and the final terminator always
// agree. Field names and filenames are escaped following the WHATWG HTML
// standard ('"' -> %22, CR -> %0D, LF -> %0A); CR and LF are also escaped
// in content types.
class MultipartFormDataWriter {
public:
MultipartFormDataWriter();
// precondition: is_valid_multipart_boundary(boundary)
explicit MultipartFormDataWriter(std::string boundary);
const std::string &boundary() const;
std::string content_type() const;
// In-memory items -> whole body (known length)
std::string serialize(const UploadFormDataItems &items) const;
size_t content_length(const UploadFormDataItems &items) const;
// Per-part framing for streaming via a content provider
std::string item_begin(const UploadFormData &item) const;
static std::string item_end();
std::string finish() const;
private:
std::string boundary_;
};
class Server {
public:
using Handler = std::function<void(const Request &, Response &)>;
@@ -2897,9 +2943,7 @@ template <size_t N> inline constexpr size_t str_len(const char (&)[N]) {
}
inline bool is_numeric(const std::string &str) {
return !str.empty() &&
std::all_of(str.cbegin(), str.cend(),
[](unsigned char c) { return std::isdigit(c); });
return !str.empty() && std::all_of(str.cbegin(), str.cend(), is_ascii_digit);
}
inline size_t get_header_value_u64(const Headers &headers,