Compare commits

...

10 Commits

Author SHA1 Message Date
Aleksander Grygier 94875285e4 ui: Add MCP Servers Opt-In for first time visitors (#25239)
* feat: ui: Add predefined recommended MCP servers to settings

* feat: ui: Add MCP server recommendation dialog with custom server support

* feat: Auto-focus input fields on mount and dynamic addition

* feat: Add header validation to MCP server add and edit forms

* feat: Persist recommended MCP server opt-in selections

* test: Cover MCP configuration with tests

* chore: Format & cleanup

* feat: Centralize MCP server overrides to settings config and improve recommendation UI

* fix: Capture index before mutation to prevent focus drift

* refactor: Extract MCP_CARD_VISIBLE_TOOL_LIMIT to shared constants

* refactor: Support arbitrary authorization header schemes

* refactor: Consolidate MCP recommendations dismissal into existing storage key

* fix: Use case-insensitive comparison for MCP server ID prefix check

* refactor: Centralize MCP server visibility logic and extract recommendations hook

* refactor: Cleanup
2026-07-03 12:16:29 +02:00
Gaurav Garg 5a460dea9f Remove redundant CUDA copies after gated_delta_net. (#23940)
* Remove redundant CUDA copies after gated_delta_net.

Currently, GDN writes recurrent state snapshots into its output tail, then the graph immediately copies those snapshots into ssm_states_all. With MTP draft length 3, target decode uses K=4, so that becomes 4 extra ggml_cuda_cpy calls.

The change detects that gated_delta_net -> view -> cpy pattern and makes the CUDA GDN kernel write the state snapshot(s) directly into the recurrent cache, skipping the intermediate tail writes and copy kernels when safe.

* Address review comments
2026-07-03 14:36:29 +05:30
Alessandro de Oliveira Faria (A.K.A.CABELO) c8ae9a750c vendor : update cpp-httplib to 0.49.0 (#25218) 2026-07-03 10:26:54 +02:00
Adrien Gallouët fdb1db877c llama : add llama_model_ftype_name() (#25134)
* llama : add llama_model_ftype_name()

Expose the model file type (quantization) name, e.g. "Q8_0" or
"Q4_K - Medium", through a new public C API. The returned pointer is
valid for the lifetime of the model and nullptr when the model is
invalid or the file type is unknown.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Export enum

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* s/llama_model_ftype_name/llama_ftype_name/

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Move "(guessed)" to the front in llama_ftype_name

Prepend the "(guessed)" label instead of appending it. This allows removing
the non-thread-safe static std::string, making the function allocation-free.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Add LLAMA_FTYPE_PREFIX

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Dont check for model

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

---------

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-07-02 17:26:47 +02:00
lhez 4fc4ec5541 opencl: allow loading precompiled binary kernels from library (#23042)
* opencl: allow loading binary kernel

* opencl: add libdl.h

* ggml-backend-dl is in ggml, which depends backend libs, thus
  ggml-opencl cannot depend on ggml-backend-dl
* add libdl.h to break cyclic dep

* opencl: allow loading bin kernel lib

* opencl: load `gemm_moe_mxfp4_f32_ns` from kernel lib if available

* opencl: load q8_0 gemm from kernel lib

* opencl: load q4_0 moe gemm from kernel lib

* opencl: load q4_1 moe gemm from kernel lib

* opencl: load q4_k moe gemm from kernel lib

* opencl: always declare `get_adreno_bin_kernel_func_t`

* opencl: rephrase message

* opencl: fix for rebase

* opencl: update doc
2026-07-01 10:29:22 -07:00
Adrien Gallouët a6647b1a32 common : use hf primary split as model path (#25194)
Fixes #25181
2026-07-01 18:33:00 +02:00
Max Krasnyansky 13e673863b hexagon: flash attention rework (optimizations, accuracy improvements, etc) (#25085)
* hex-mm: fold mm quant tasks into the main matmul threads

* hex-mm: minor formatting fixes

* hex-mm: cleanup is_quant checks in dma dispatch

* hex-mm: fix dst-spad alignment

* hex-mm: move fp kernels in the hvx-mm-kernels header

* hex-mm: fuse with ADD

* hex-fa: factor out ukernels into separate headers and unify the rest

* hex-fa: move kernel-params compute into the host

* hex-fa: refactor vtcm alloc for consistency

* hex-fa: add support for FA_SELECT

* hex-fa: update tracing insrumentation to cover all functions

* hex-fa: update hvx fallback thresholds to recover t/g regressions

* hex-fa: update tracing instrumentation

* hex-fa: improved tracing with additional events

* hex-fa: optimize mask processing (fastdiv, etc)

* hex-fa: improve mask dma caching

* hmx-fa: change loop order to maximize mask cache hits

* hex-fa: remove over instrumentation

* hex-fa: breakdown QKV prep trace events

* hmx-fa: further mask proc optimizations

* hex-fa: mask broadcast is the common case, optimize for that

* hex-fa: use aligned loads where possible

* hex-fa: update loops to use uint32_t indices

* hmx-fa: fold vtcm init into q prep task

* hex-fa: update rest of the hmx funcs to use uint32_t

* hmx-fa: fold build_d into the main softmax loop

* hmx-fa: start kv dmas earlier

* hmx-fa: start mask dma a bit earlier

* hex-fa: precompute rows per task to avoid divs

* hmx-fa: specialize fa_o_store for f16 and f32

* hmx-fa: prelim support for Sinks

* hmx-fa: keep softmax accumulators in fp32

* hex-fa: add tanh_f16 and exp2_f16 and use that in FA

* hex-fa: use fp16 math in the hvx kernel

* hex-fa: avoid expensive float -> __fp16 cast for slopes and softcap

* hex-fa: replace most vec_exp_f32 with vec_exp2_f16

* hmx-fa: vectorize sinks update

* hex-fa: minor formatting

* hmx-fa: fold softcap loop into the tile load

* hmx-fa: use vectoralias to populate sinks

* hex-fa: remove redudant check

* hex-fa: fix vtcm size compute to use fp32 for accumulators

* hex-mm: fix trailing spaces

* hmx-fa: dont use -inf to init mask to avoid conversion overflows

* hex-fa: no need to explicitly guard -inf in the f16->f32 converter now

* hmx-fa: cleanup fa sinks handling

* hex-mm: fixed src2 stride handling when mm is fused with add

* hex-fa: make lto happy
2026-07-01 06:59:19 -07:00
Johannes Gäßler b820cc8e6f CUDA: consistent use of __restrict__ + PDL for FA (#25185) 2026-07-01 10:55:14 +02:00
ragz4125 6dbc1174b8 ggml-cpu: add AVX2 optimization for nvfp4 dot product and use UE4M3 LUT (#23961) 2026-07-01 15:31:20 +08:00
Aleksander Grygier 9d88e7cedd ui Prevent tool messages from incorrectly appending to other conversations (#25177)
* fix: Prevent tool messages from incorrectly appending to other conversations

* ui: prevent agentic loop from poisoning another conv's currNode

* ui: make editedContent a  so background recompute does not wipe in-progress edits

---------

Co-authored-by: Pascal <admin@serveurperso.com>
2026-07-01 09:25:18 +02:00
81 changed files with 6238 additions and 3365 deletions
+10 -8
View File
@@ -496,13 +496,15 @@ void common_models_handler_apply(common_models_handler & handler, common_params
}
// handle hf_plan tasks
auto add_tasks = [&opts, &tasks](const hf_cache::hf_files & model_files, common_params_model & model) {
auto add_tasks = [&opts, &tasks](const hf_cache::hf_files & model_files,
const hf_cache::hf_file & primary,
common_params_model & model) {
for (size_t i = 0; i < model_files.size(); ++i) {
auto & model_file = model_files[i];
bool is_first = (i == 0);
tasks.emplace_back(model_file, opts, [&, is_first]() {
if (is_first) {
// only use first part as model path
bool is_primary = (model_file.path == primary.path);
tasks.emplace_back(model_file, opts, [&, is_primary]() {
if (is_primary) {
// the primary file is the first split (00001-of), use it as model path
model.path = hf_cache::finalize_file(model_file);
} else {
hf_cache::finalize_file(model_file);
@@ -511,7 +513,7 @@ void common_models_handler_apply(common_models_handler & handler, common_params
}
};
if (!plan.model_files.empty()) {
add_tasks(plan.model_files, params.model);
add_tasks(plan.model_files, plan.primary, params.model);
}
if (!plan.mmproj.local_path.empty()) {
tasks.emplace_back(plan.mmproj, opts, [&]() {
@@ -539,12 +541,12 @@ void common_models_handler_apply(common_models_handler & handler, common_params
// handle plan_spec (e.g. --spec-draft-hf)
if (!plan_spec.model_files.empty()) {
add_tasks(plan_spec.model_files, params.speculative.draft.mparams);
add_tasks(plan_spec.model_files, plan_spec.primary, params.speculative.draft.mparams);
}
// handle vocoder plan (e.g. --hf-repo-v)
if (!plan_voc.model_files.empty()) {
add_tasks(plan_voc.model_files, params.vocoder.model);
add_tasks(plan_voc.model_files, plan_voc.primary, params.vocoder.model);
}
// run all tasks in parallel
+51 -39
View File
@@ -1,16 +1,26 @@
# llama.cpp for OpenCL
- [Background](#background)
- [OS](#os)
- [Hardware](#hardware)
- [DataType Supports](#datatype-supports)
- [Model Preparation](#model-preparation)
- [CMake Options](#cmake-options)
- [Android](#android)
- [Windows 11 Arm64](#windows-11-arm64)
- [Linux](#Linux)
- [Known Issue](#known-issues)
- [TODO](#todo)
- [llama.cpp for OpenCL](#llamacpp-for-opencl)
- [Background](#background)
- [Llama.cpp + OpenCL](#llamacpp--opencl)
- [OS](#os)
- [Hardware](#hardware)
- [Adreno GPU](#adreno-gpu)
- [DataType Supports](#datatype-supports)
- [Model Preparation](#model-preparation)
- [Binary Kernel Library](#binary-kernel-library)
- [CMake Options](#cmake-options)
- [Android](#android)
- [I. Setup Environment](#i-setup-environment)
- [II. Build llama.cpp](#ii-build-llamacpp)
- [Windows 11 Arm64](#windows-11-arm64)
- [I. Setup Environment](#i-setup-environment-1)
- [II. Build llama.cpp](#ii-build-llamacpp-1)
- [Linux](#linux)
- [I. Setup Environment](#i-setup-environment-2)
- [II. Build llama.cpp](#ii-build-llamacpp-2)
- [Known Issues](#known-issues)
- [TODO](#todo)
## Background
@@ -34,11 +44,13 @@ The llama.cpp OpenCL backend is designed to enable llama.cpp on **Qualcomm Adren
**Verified devices**
| Adreno GPU | Status |
|:------------------------------------:|:-------:|
| Adreno 750 (Snapdragon 8 Gen 3) | Support |
| Adreno 830 (Snapdragon 8 Elite) | Support |
| Adreno X85 (Snapdragon X Elite) | Support |
| Adreno GPU | Status |
|:-------------------------------------:|:-------:|
| Adreno 750 (Snapdragon 8 Gen 3) | Support |
| Adreno 830 (Snapdragon 8 Elite) | Support |
| Adreno 840 (Snapdragon 8 Elite Gen 5) | Support |
| Adreno X1-85 (Snapdragon X Elite) | Support |
| Adreno X2-90 (Snapdragon X2 Elite) | Support |
> A6x GPUs with a recent driver and compiler are supported; they are usually found in IoT platforms.
However, A6x GPUs in phones are likely not supported due to the outdated driver and compiler.
@@ -47,42 +59,43 @@ However, A6x GPUs in phones are likely not supported due to the outdated driver
| DataType | Status |
|:----------------------:|:--------------------------:|
| Q1_0 | Support |
| Q4_0 | Support |
| Q6_K | Support, but not optimized |
| Q4_1 | Support |
| Q5_0 | Support |
| Q5_1 | Support |
| Q8_0 | Support |
| Q4_K | Support |
| Q5_K | Support |
| Q6_K | Support |
| MXFP4 | Support |
| IQ4_NL | Support |
## Model Preparation
You can refer to the general [llama-quantize tool](/tools/quantize/README.md) for steps to convert a model in Hugging Face safetensor format to GGUF with quantization.
Since common quantizations are supported now, it is recommanded to download GGUF models directly from Huggingface.
Currently we support `Q4_0` quantization and have optimized for it. To achieve best performance on Adreno GPU, add `--pure` to `llama-quantize` (i.e., make all weights in `Q4_0`). For example,
## Binary Kernel Library
```sh
./llama-quantize --pure ggml-model-qwen2.5-3b-f16.gguf ggml-model-qwen-3b-Q4_0.gguf Q4_0
```
A prebuilt binary kernel library has been introduced for Adreno GPUs.
It currently targets X2 GPUs (X2-90, X2-85 and X2-45) found in Snapdragon X2 SoC.
The library currently contains kernels for MUL_MAT_ID with Q4_0, Q4_1, Q4_K, MXFP4.
The library must be manually downloaded from https://softwarecenter.qualcomm.com/catalog/item/Adreno_Kernel_Library_GGML.
Since `Q6_K` is also supported, `Q4_0` quantization without `--pure` will also work. However, the performance will be worse compared to pure `Q4_0` quantization.
To allow using the kernel library, add `-DGGML_OPENCL_USE_ADRENO_BIN_KERNELS=ON` when configuring with CMake.
Then, extract `adreno-opencl-kernels.dll` from the zip file downloaded from the above URL and put it alongside the executables.
If kernels compatible with the current GPU are found in the library, they will be loaded and used.
### `MXFP4` MoE Models
OpenAI gpt-oss models are MoE models in `MXFP4`. The quantized model will be in `MXFP4_MOE`, a mixture of `MXFP4` and `Q8_0`.
For this quantization, there is no need to specify `--pure`.
For gpt-oss-20b model, you can directly [download](https://huggingface.co/ggml-org/gpt-oss-20b-GGUF) the quantized GGUF file in `MXFP4_MOE` from Hugging Face.
Although it is possible to quantize gpt-oss-20b model in pure `Q4_0` (all weights in `Q4_0`), it is not recommended since `MXFP4` has been optimized for MoE while `Q4_0` is not. In addition, accuracy should degrade with such pure `Q4_0` quantization.
Hence, using the default `MXFP4_MOE` quantization (see the link above) is recommended for this model.
> Note that the `Q4_0` model found [here](https://huggingface.co/unsloth/gpt-oss-20b-GGUF/blob/main/gpt-oss-20b-Q4_0.gguf) is a mixture of `Q4_0`, `Q8_0` and `MXFP4` and gives better performance than `MXFP4_MOE` quantization.
## CMake Options
The OpenCL backend has the following CMake options that control the behavior of the backend.
| CMake options | Default value | Description |
|:---------------------------------:|:--------------:|:------------------------------------------|
| `GGML_OPENCL_EMBED_KERNELS` | `ON` | Embed OpenCL kernels into the executable. |
| `GGML_OPENCL_USE_ADRENO_KERNELS` | `ON` | Use kernels optimized for Adreno. |
| CMake options | Default value | Description |
|:------------------------------------:|:--------------:|:------------------------------------------|
| `GGML_OPENCL_EMBED_KERNELS` | `ON` | Embed OpenCL kernels into the executable. |
| `GGML_OPENCL_USE_ADRENO_KERNELS` | `ON` | Use kernels optimized for Adreno. |
| `GGML_OPENCL_USE_ADRENO_BIN_KERNELS` | `OFF` | Allow using binary kernel lib for Adreno. |
## Android
@@ -277,6 +290,5 @@ ninja
## TODO
- Optimization for Q6_K
- Support and optimization for Q4_K
- Improve flash attention
- Improve OpenCL C kernels performance
+3 -2
View File
@@ -1111,11 +1111,12 @@ GGML_TABLE_BEGIN(int8_t, kvalues_iq4nl, 16)
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
GGML_TABLE_END()
// e2m1 values (doubled)
// e2m1 values (doubled), shared by MXFP4 and NVFP4
// ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
GGML_TABLE_BEGIN(int8_t, kvalues_mxfp4, 16)
GGML_TABLE_BEGIN(int8_t, kvalues_fp4, 16)
0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12,
GGML_TABLE_END()
#define kvalues_mxfp4 kvalues_fp4
#define NGRID_IQ1S 2048
#define IQ1S_DELTA 0.125f
-1
View File
@@ -82,7 +82,6 @@
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
// quants.c
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
+142 -4
View File
@@ -934,7 +934,7 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
#if defined __AVX2__
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_fp4);
const __m128i m4b = _mm_set1_epi8(0x0f);
const __m256i mone = _mm256_set1_epi16(1);
@@ -963,7 +963,7 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
#elif defined __AVX__
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_fp4);
const __m128i m4b = _mm_set1_epi8(0x0f);
__m256 accum = _mm256_setzero_ps();
@@ -993,14 +993,152 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
int sumi1 = 0;
int sumi2 = 0;
for (int j = 0; j < QK_MXFP4/2; ++j) {
sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
sumi1 += y[ib].qs[j + 0] * kvalues_fp4[x[ib].qs[j] & 0xf];
sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_fp4[x[ib].qs[j] >> 4];
}
sumf += d * (sumi1 + sumi2);
}
*s = sumf;
}
void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
assert(n % QK_NVFP4 == 0);
const block_nvfp4 * GGML_RESTRICT x = vx;
const block_q8_0 * GGML_RESTRICT y = vy;
const int nb = n / QK_NVFP4;
int ib = 0;
float sumf = 0;
#if defined(__AVX2__)
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_fp4);
const __m128i m4b = _mm_set1_epi8(0x0f);
const __m256i mone = _mm256_set1_epi16(1);
__m256 accum = _mm256_setzero_ps();
for(; ib < nb; ib++){
const __m128i q4bits_01 = _mm_loadu_si128((const __m128i *)(x[ib].qs + 0));
const __m128i q4bits_23 = _mm_loadu_si128((const __m128i *)(x[ib].qs + 16));
const __m256i q8_01 = _mm256_loadu_si256((const __m256i *)y[2*ib + 0].qs);
const __m256i q8_23 = _mm256_loadu_si256((const __m256i *)y[2*ib + 1].qs);
const __m128i q4_01_lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_01, m4b));
const __m128i q4_01_hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_01, 4), m4b));
const __m128i q4_23_lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_23, m4b));
const __m128i q4_23_hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_23, 4), m4b));
//reordering
const __m256i q4_01 = MM256_SET_M128I(_mm_unpackhi_epi64(q4_01_lo,q4_01_hi), _mm_unpacklo_epi64(q4_01_lo,q4_01_hi));
const __m256i q4_23 = MM256_SET_M128I(_mm_unpackhi_epi64(q4_23_lo,q4_23_hi),_mm_unpacklo_epi64(q4_23_lo,q4_23_hi));
const __m256i p01 = mul_add_epi8(q4_01,q8_01);
const __m256i p_1 = _mm256_madd_epi16(p01, mone);
const __m256i p23 = mul_add_epi8(q4_23,q8_23);
const __m256i p_2 = _mm256_madd_epi16(p23, mone);
const float dy0 = GGML_CPU_FP16_TO_FP32(y[2*ib].d);
const float dy1 = GGML_CPU_FP16_TO_FP32(y[2*ib+1].d);
const float s0 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[0]) * dy0;
const float s1 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[1]) * dy0;
const float s2 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[2]) * dy1;
const float s3 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[3]) * dy1;
const __m256 scales01 = _mm256_set_m128(_mm_set1_ps(s1), _mm_set1_ps(s0));
const __m256 scales23 = _mm256_set_m128(_mm_set1_ps(s3), _mm_set1_ps(s2));
accum = _mm256_fmadd_ps(scales01, _mm256_cvtepi32_ps(p_1), accum);
accum = _mm256_fmadd_ps(scales23, _mm256_cvtepi32_ps(p_2), accum);
}
sumf = hsum_float_8(accum);
#elif defined(__AVX__)
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_fp4);
const __m128i m4b = _mm_set1_epi8(0x0f);
__m256 accum = _mm256_setzero_ps();
for(; ib < nb; ib++){
const __m128i q4bits_01 = _mm_loadu_si128((const __m128i *)(x[ib].qs + 0));
const __m128i q4bits_23 = _mm_loadu_si128((const __m128i *)(x[ib].qs + 16));
const __m128i q8_0 = _mm_loadu_si128((const __m128i *)(y[2*ib + 0].qs + 0));
const __m128i q8_1 = _mm_loadu_si128((const __m128i *)(y[2*ib + 0].qs + 16));
const __m128i q8_2 = _mm_loadu_si128((const __m128i *)(y[2*ib + 1].qs + 0));
const __m128i q8_3 = _mm_loadu_si128((const __m128i *)(y[2*ib + 1].qs + 16));
const __m128i q4_01_lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_01, m4b));
const __m128i q4_01_hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_01, 4), m4b));
const __m128i q4_23_lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_23, m4b));
const __m128i q4_23_hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_23, 4), m4b));
const __m128i q4_0 = _mm_unpacklo_epi64(q4_01_lo, q4_01_hi);
const __m128i q4_1 = _mm_unpackhi_epi64(q4_01_lo, q4_01_hi);
const __m128i q4_2 = _mm_unpacklo_epi64(q4_23_lo, q4_23_hi);
const __m128i q4_3 = _mm_unpackhi_epi64(q4_23_lo, q4_23_hi);
const __m128i p0_i32 = mul_sum_i8_pairs(q4_0, q8_0);
const __m128i p1_i32 = mul_sum_i8_pairs(q4_1, q8_1);
const __m128i p2_i32 = mul_sum_i8_pairs(q4_2, q8_2);
const __m128i p3_i32 = mul_sum_i8_pairs(q4_3, q8_3);
const __m128 p0 = _mm_cvtepi32_ps(p0_i32);
const __m128 p1 = _mm_cvtepi32_ps(p1_i32);
const __m128 p2 = _mm_cvtepi32_ps(p2_i32);
const __m128 p3 = _mm_cvtepi32_ps(p3_i32);
const __m256 p01 = _mm256_set_m128(p1, p0);
const __m256 p23 = _mm256_set_m128(p3, p2);
const float dy0 = GGML_CPU_FP16_TO_FP32(y[2*ib].d);
const float dy1 = GGML_CPU_FP16_TO_FP32(y[2*ib+1].d);
const float s0 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[0]) * dy0;
const float s1 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[1]) * dy0;
const float s2 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[2]) * dy1;
const float s3 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[3]) * dy1;
const __m256 scales01 = _mm256_set_m128(_mm_set1_ps(s1), _mm_set1_ps(s0));
const __m256 scales23 = _mm256_set_m128(_mm_set1_ps(s3), _mm_set1_ps(s2));
accum = _mm256_add_ps(accum, _mm256_mul_ps(p01, scales01));
accum = _mm256_add_ps(accum, _mm256_mul_ps(p23, scales23));
}
sumf = hsum_float_8(accum);
#endif
for (;ib < nb; ++ib) {
for (int s_idx = 0; s_idx < 4; ++s_idx) {
const float d = GGML_CPU_UE4M3_TO_FP32(x[ib].d[s_idx]);
const int q8_block = s_idx / 2;
const int q8_off = (s_idx % 2) * QK_NVFP4_SUB;
const float dy = GGML_CPU_FP16_TO_FP32(y[2*ib + q8_block].d);
int sumi_lo = 0, sumi_hi = 0;
for (int j = 0; j < QK_NVFP4_SUB/2; ++j) {
const uint8_t qv = x[ib].qs[s_idx*(QK_NVFP4_SUB/2) + j];
sumi_lo += y[2*ib + q8_block].qs[q8_off + j + 0] * kvalues_fp4[qv & 0xf];
sumi_hi += y[2*ib + q8_block].qs[q8_off + j + QK_NVFP4_SUB/2] * kvalues_fp4[qv >> 4];
}
sumf += dy * d * (sumi_lo + sumi_hi);
}
}
*s = sumf;
}
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
+8
View File
@@ -82,6 +82,9 @@ float ggml_table_f32_f16[1 << 16];
// precomputed f32 table for e8m0 half (1 KB) (simd-mappings.h)
float ggml_table_f32_e8m0_half[1 << 8];
// precomputed f32 table for ue4m3 (1 KB) (simd-mappings.h)
float ggml_table_f32_ue4m3[1 << 8];
#if defined(__ARM_ARCH)
struct ggml_arm_arch_features_type {
int sve_cnt;
@@ -3798,6 +3801,11 @@ void ggml_cpu_init(void) {
ggml_table_f32_e8m0_half[i] = GGML_E8M0_TO_FP32_HALF(i);
}
// initialize UE4M3 table (256 entries)
for (int i = 0; i < (1 << 8); ++i) {
ggml_table_f32_ue4m3[i] = ggml_ue4m3_to_fp32(i);
}
const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0);
+11
View File
@@ -120,6 +120,10 @@ extern float ggml_table_f32_f16[1 << 16];
// defined in ggml-cpu.c, initialized in ggml_cpu_init()
extern float ggml_table_f32_e8m0_half[1 << 8];
// precomputed f32 table for ue4m3 (1 KB)
// defined in ggml-cpu.c, initialized in ggml_cpu_init()
extern float ggml_table_f32_ue4m3[1 << 8];
// Use lookup table for E8M0 on x86 (faster than bit manipulation)
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
#define GGML_CPU_E8M0_TO_FP32_HALF(x) ggml_table_f32_e8m0_half[(uint8_t)(x)]
@@ -127,6 +131,13 @@ extern float ggml_table_f32_e8m0_half[1 << 8];
#define GGML_CPU_E8M0_TO_FP32_HALF(x) GGML_E8M0_TO_FP32_HALF(x)
#endif
// Use lookup table for UE4M3 on x86 (faster than bit manipulation)
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
#define GGML_CPU_UE4M3_TO_FP32(x) ggml_table_f32_ue4m3[(uint8_t)(x)]
#else
#define GGML_CPU_UE4M3_TO_FP32(x) ggml_ue4m3_to_fp32(x)
#endif
// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
// so we define GGML_CPU_FP16_TO_FP32 and GGML_CPU_FP32_TO_FP16 elsewhere for NEON.
// This is also true for POWER9.
+7 -3
View File
@@ -664,7 +664,10 @@ constexpr __device__ dequantize_V_t get_dequantize_V() {
template <int ncols1>
__launch_bounds__(FATTN_KQ_STRIDE/2, 1)
static __global__ void flash_attn_mask_to_KV_max(
const half2 * __restrict__ mask, int * __restrict__ KV_max, const int ne30, const int64_t s31, const int64_t s33) {
const half2 * mask_ptr, int * KV_max_ptr, const int ne30, const int64_t s31, const int64_t s33) {
const half2 * GGML_CUDA_RESTRICT mask = mask_ptr;
int * GGML_CUDA_RESTRICT KV_max = KV_max_ptr;
const int ne31 = gridDim.x;
const int tid = threadIdx.x;
const int sequence = blockIdx.y;
@@ -1099,8 +1102,9 @@ void launch_fattn(
const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
KV_max.alloc(ne_KV_max);
flash_attn_mask_to_KV_max<ncols1><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
((const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33);
ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num_KV_max, block_dim_KV_max, 0, main_stream);
ggml_cuda_kernel_launch(flash_attn_mask_to_KV_max<ncols1>, launch_params,
(const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33);
CUDA_CHECK(cudaGetLastError());
}
+40 -25
View File
@@ -10,6 +10,7 @@ gated_delta_net_cuda(const float * q,
const float * beta,
const float * curr_state,
float * dst,
float * state,
int64_t H,
int64_t n_tokens,
int64_t n_seqs,
@@ -25,6 +26,7 @@ gated_delta_net_cuda(const float * q,
const uint3 neqk1_magic,
const uint3 rq3_magic,
float scale,
int64_t state_slot_stride,
int K) {
const uint32_t h_idx = blockIdx.x;
const uint32_t sequence = blockIdx.y;
@@ -35,9 +37,7 @@ gated_delta_net_cuda(const float * q,
const uint32_t iq1 = fastmodulo(h_idx, neqk1_magic);
const uint32_t iq3 = fastdiv(sequence, rq3_magic);
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
float * attn_data = dst;
float * state = dst + attn_score_elems;
// input state holds s0 only: [S_v, S_v, H, n_seqs] — seq stride is D = H * S_v * S_v.
// output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before.
@@ -145,10 +145,9 @@ gated_delta_net_cuda(const float * q,
if constexpr (keep_rs_t) {
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
// When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned.
const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output
const int target_slot = (int) n_tokens - 1 - t;
if (target_slot >= 0 && target_slot < K) {
float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset;
float * curr_state = state + target_slot * state_slot_stride;
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
@@ -171,13 +170,13 @@ template <bool KDA, bool keep_rs_t>
static void launch_gated_delta_net(
const float * q_d, const float * k_d, const float * v_d,
const float * g_d, const float * b_d, const float * s_d,
float * dst_d,
float * dst_d, float * state_d,
int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs,
int64_t sq1, int64_t sq2, int64_t sq3,
int64_t sv1, int64_t sv2, int64_t sv3,
int64_t sb1, int64_t sb2, int64_t sb3,
int64_t neqk1, int64_t rq3,
float scale, int K, cudaStream_t stream) {
float scale, int64_t state_slot_stride, int K, cudaStream_t stream) {
//TODO: Add chunked kernel for even faster pre-fill
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
const int num_warps = 4;
@@ -187,34 +186,32 @@ static void launch_gated_delta_net(
const uint3 neqk1_magic = init_fastdiv_values(neqk1);
const uint3 rq3_magic = init_fastdiv_values(rq3);
int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(grid_dims, block_dims, 0, stream);
switch (S_v) {
case 16:
ggml_cuda_kernel_launch(gated_delta_net_cuda<16, KDA, keep_rs_t>, launch_params,
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, state_slot_stride, K);
break;
case 32:
ggml_cuda_kernel_launch(gated_delta_net_cuda<32, KDA, keep_rs_t>, launch_params,
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, state_slot_stride, K);
break;
case 64: {
ggml_cuda_kernel_launch(gated_delta_net_cuda<64, KDA, keep_rs_t>, launch_params,
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, state_slot_stride, K);
break;
}
case 128: {
ggml_cuda_kernel_launch(gated_delta_net_cuda<128, KDA, keep_rs_t>, launch_params,
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, state_slot_stride, K);
break;
}
default:
@@ -223,7 +220,8 @@ static void launch_gated_delta_net(
}
}
void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
static void ggml_cuda_op_gated_delta_net_impl(
ggml_backend_cuda_context & ctx, ggml_tensor * dst, const ggml_cuda_gated_delta_net_fused_cache * cache) {
ggml_tensor * src_q = dst->src[0];
ggml_tensor * src_k = dst->src[1];
ggml_tensor * src_v = dst->src[2];
@@ -288,25 +286,42 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
const int K = ggml_get_op_params_i32(dst, 0);
const bool keep_rs = K > 1;
// recurrent state -> gdn_out tail (after attention scores), or the cache when fusing
float * state_d = dst_d + S_v * H * n_tokens * n_seqs;
int64_t state_slot_stride = S_v * S_v * H * n_seqs;
if (cache != nullptr) {
state_d = cache->data;
state_slot_stride = cache->slot_stride;
}
if (kda) {
if (keep_rs) {
launch_gated_delta_net<true, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
launch_gated_delta_net<true, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
sb1, sb2, sb3, neqk1, rq3, scale, state_slot_stride, K, stream);
} else {
launch_gated_delta_net<true, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
launch_gated_delta_net<true, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
sb1, sb2, sb3, neqk1, rq3, scale, state_slot_stride, K, stream);
}
} else {
if (keep_rs) {
launch_gated_delta_net<false, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
launch_gated_delta_net<false, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
sb1, sb2, sb3, neqk1, rq3, scale, state_slot_stride, K, stream);
} else {
launch_gated_delta_net<false, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
launch_gated_delta_net<false, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
sb1, sb2, sb3, neqk1, rq3, scale, state_slot_stride, K, stream);
}
}
}
void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_gated_delta_net_impl(ctx, dst, nullptr);
}
void ggml_cuda_op_gated_delta_net_fused_cache(
ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_cuda_gated_delta_net_fused_cache cache) {
ggml_cuda_op_gated_delta_net_impl(ctx, dst, &cache);
}
+10
View File
@@ -1,4 +1,14 @@
#include "common.cuh"
#include "ggml.h"
// fused-kernel recurrent-state output; strides in elements (per-seq stride is always D, set in-kernel)
struct ggml_cuda_gated_delta_net_fused_cache {
float * data; // rollback slot 0
int64_t slot_stride; // between rollback slots (0 when K==1)
};
void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
// same op, but writes the snapshot(s) into the cache instead of dst (see ggml_cuda_try_gdn_cache_fusion)
void ggml_cuda_op_gated_delta_net_fused_cache(ggml_backend_cuda_context & ctx, ggml_tensor * dst,
ggml_cuda_gated_delta_net_fused_cache cache);
+85 -2
View File
@@ -3251,6 +3251,11 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
GGML_UNUSED(backend);
}
static bool ggml_cuda_is_view_or_noop(const ggml_tensor * t) {
return ggml_is_empty(t) || t->op == GGML_OP_RESHAPE || t->op == GGML_OP_TRANSPOSE ||
t->op == GGML_OP_VIEW || t->op == GGML_OP_PERMUTE || t->op == GGML_OP_NONE;
}
#ifdef USE_CUDA_GRAPH
static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
@@ -3260,7 +3265,7 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
if (ggml_cuda_is_view_or_noop(node)) {
continue;
}
@@ -3403,6 +3408,70 @@ static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
return true;
}
// match gated_delta_net + the strided cpy that scatters its state snapshots into the cache
// (slot i -> rollback group i, slot 0 newest), so the kernel can write them and skip the cpy.
static int ggml_cuda_try_gdn_cache_fusion(
const ggml_cgraph * cgraph, int node_idx, ggml_cuda_gated_delta_net_fused_cache & fused_state_cpy) {
const ggml_tensor * gdn = cgraph->nodes[node_idx];
// the kernel skips the snapshot tail, so the gdn output must not be a graph output
if (gdn->op != GGML_OP_GATED_DELTA_NET || gdn->type != GGML_TYPE_F32 ||
(gdn->flags & GGML_TENSOR_FLAG_OUTPUT)) {
return 0;
}
const ggml_tensor * src_v = gdn->src[2];
const int64_t S_v = src_v->ne[0];
const int64_t H = src_v->ne[1];
const int64_t n_tokens = src_v->ne[2];
const int64_t n_seqs = src_v->ne[3];
const int64_t D = S_v * S_v * H;
const int64_t K = ggml_get_op_params_i32(gdn, 0); // snapshot slot count
const int64_t n_written = std::min<int64_t>(n_tokens, K); // newest n_written slots are written
// snapshot tail starts right after the attention scores
const size_t tail_off = ggml_row_size(GGML_TYPE_F32, S_v * H * n_tokens * n_seqs);
// snapshot cpy is the first real node after the gdn (skip views/no-ops)
const ggml_tensor * cpy = nullptr;
int skip = 0;
for (int j = node_idx + 1; j < cgraph->n_nodes && cpy == nullptr; ++j) {
const ggml_tensor * n = cgraph->nodes[j];
if (ggml_cuda_is_view_or_noop(n)) {
continue;
}
if (n->op != GGML_OP_CPY || (n->flags & GGML_TENSOR_FLAG_OUTPUT)) {
return 0;
}
cpy = n;
skip = j - node_idx;
}
if (cpy == nullptr) {
return 0;
}
const ggml_tensor * src = cpy->src[0]; // view of the gdn snapshot tail
const ggml_tensor * dst = cpy->src[1]; // cache view the kernel writes to
// src must be this gdn's snapshot tail (contiguous, at the tail offset)
if (src->op != GGML_OP_VIEW || src->view_src != gdn || src->view_offs != tail_off ||
!ggml_is_contiguous(src)) {
return 0;
}
// dst is the [D, n_seqs, n_written] cache view; require nb[1] == D (the per-seq stride the kernel
// assumes). ggml_cpy pins src to the same element count.
const std::array<int64_t, GGML_MAX_DIMS> expected_ne = { D, n_seqs, n_written, 1 };
if (dst->op != GGML_OP_VIEW || dst->type != GGML_TYPE_F32 || dst->data == nullptr ||
!std::equal(expected_ne.begin(), expected_ne.end(), dst->ne) ||
dst->nb[0] != ggml_type_size(GGML_TYPE_F32) || dst->nb[1] != (size_t) ggml_row_size(GGML_TYPE_F32, D)) {
return 0;
}
fused_state_cpy.data = (float *) dst->data; // rollback group 0 (newest)
fused_state_cpy.slot_stride = K > 1 ? (int64_t) (dst->nb[2] / sizeof(float)) : 0;
return skip;
}
static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int node_idx, ggml_cuda_topk_moe_args & args) {
args.sigmoid = false;
args.softmax = false;
@@ -3844,6 +3913,20 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph
ggml_tensor * node = cgraph->nodes[i];
// gated_delta_net -> cpy: scatter recurrent-state snapshots into the cache
if (node->op == GGML_OP_GATED_DELTA_NET) {
ggml_cuda_gated_delta_net_fused_cache fused_state_cpy;
const int nodes_to_skip = ggml_cuda_try_gdn_cache_fusion(cgraph, i, fused_state_cpy);
if (nodes_to_skip > 0) {
#ifdef GGML_CUDA_DEBUG
GGML_LOG_INFO("%s: fused gated_delta_net snapshot copies for %s (skipped %d nodes)\n",
__func__, node->name, nodes_to_skip);
#endif
ggml_cuda_op_gated_delta_net_fused_cache(*cuda_ctx, node, fused_state_cpy);
return nodes_to_skip;
}
}
//topk-moe
if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX ||
cgraph->nodes[i]->op == GGML_OP_ARGSORT) {
@@ -4372,7 +4455,7 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
#endif
prev_i = i;
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
if (ggml_cuda_is_view_or_noop(node)) {
continue;
}
-1
View File
@@ -23,7 +23,6 @@ include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
include(ExternalProject)
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
option(GGML_HEXAGON_FA_EXP2_HF "ggml-hexagon: use FP16 exp2 polynomial in FA softmax instead of F32 exp round-trip" OFF)
set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate")
add_library(htp_iface OBJECT
+221 -26
View File
@@ -43,6 +43,7 @@
#include "htp-opnode.h"
#include "htp-ops.h"
#include "htp/matmul-ops.h"
#include "htp/flash-attn-ops.h"
#include "htp_iface.h"
#include "htp-drv.h"
@@ -62,6 +63,7 @@ static int opt_profile = 0; // profiling mode (0-disabled, 1-basic, 2-pmu)
static int opt_hostbuf = 1; // hostbuf ON by default
static int opt_mm_select = 3; // 3 = HMX -> Tiled -> Flat -> CPU, 2 = Tiled -> Flat -> CPU, 1 = Flat -> CPU
static int opt_fa_select = 2; // 2 = HMX -> HVX -> CPU, 1 = HVX -> CPU, 0 = CPU (unsupported)
// Default PMU events, if profiling with PMU (mode=2) is enabled
// See https://docs.qualcomm.com/doc/80-N2040-60/topic/pmu-events.html
@@ -125,6 +127,11 @@ static const char * htp_event_name(uint16_t id) {
case HTP_TRACE_EVT_HVX_W_DEQUANT: return "HVX_W_DEQUANT";
case HTP_TRACE_EVT_HVX_W_PREP: return "HVX_W_PREP";
case HTP_TRACE_EVT_HVX_O_PROC: return "HVX_O_PROC";
case HTP_TRACE_EVT_HVX_FA_QK: return "HVX_QK_FA";
case HTP_TRACE_EVT_HVX_FA_SFM: return "HVX_SFM_FA";
case HTP_TRACE_EVT_HVX_FA_Q_PREP: return "HVX_Q_PREP";
case HTP_TRACE_EVT_HVX_FA_K_PREP: return "HVX_K_PREP";
case HTP_TRACE_EVT_HVX_FA_V_PREP: return "HVX_V_PREP";
case HTP_TRACE_EVT_HMX_COMP: return "HMX_COMP";
default: return "UNKNOWN";
}
@@ -1879,6 +1886,162 @@ ggml_hexagon_session::~ggml_hexagon_session() noexcept(true) {
// ** backend interface
static bool ggml_hexagon_flash_attn_is_hmx_eligible(
const struct ggml_hexagon_session * sess,
const struct ggml_tensor * q,
const struct ggml_tensor * k,
const struct ggml_tensor * v,
const struct ggml_tensor * sinks
) {
if (sess->n_hmx == 0) {
return false;
}
if (opt_fa_select < 2) {
return false;
}
if (k->type != GGML_TYPE_F16 || v->type != GGML_TYPE_F16) {
return false;
}
const uint32_t DK = q->ne[0];
const uint32_t DV = v->ne[0];
if (DK % 64 != 0 || DV % 64 != 0) {
return false;
}
// Fall back to HVX for small token counts if head dimension is small (DK <= 128)
const uint32_t neq1 = q->ne[1];
if (DK <= 128 && neq1 < 5) {
return false;
}
return true;
}
static bool ggml_hexagon_precompute_flash_attn_params(
const struct ggml_hexagon_session * sess,
const struct ggml_tensor * op,
struct htp_fa_kernel_params * kparams
) {
if (opt_fa_select < 1) {
return false;
}
memset(kparams, 0, sizeof(*kparams));
const struct ggml_tensor * q = op->src[0];
const struct ggml_tensor * k = op->src[1];
const struct ggml_tensor * v = op->src[2];
const struct ggml_tensor * mask = op->src[3];
const struct ggml_tensor * dst = op;
const uint32_t neq0 = q->ne[0]; // head_dim (DK)
const uint32_t neq1 = q->ne[1]; // n_tokens
const uint32_t neq2 = q->ne[2]; // n_heads
const uint32_t nek1 = k->ne[1]; // kv_len
const uint32_t nev0 = v->ne[0]; // head_dim (DV)
const uint32_t DK = neq0;
const uint32_t DV = nev0;
const uint32_t n_kv_heads = k->ne[2];
const uint32_t G = neq2 / n_kv_heads;
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
memcpy(&scale, &op->op_params[0], sizeof(float));
memcpy(&max_bias, &op->op_params[1], sizeof(float));
memcpy(&logit_softcap, &op->op_params[2], sizeof(float));
if (logit_softcap != 0.0f) {
scale /= logit_softcap;
}
kparams->scale = scale;
kparams->max_bias = max_bias;
kparams->logit_softcap = logit_softcap;
kparams->is_q_fp32 = (q->type == GGML_TYPE_F32) ? 1 : 0;
kparams->is_dst_fp32 = (dst->type == GGML_TYPE_F32) ? 1 : 0;
kparams->G = G;
const uint32_t n_head = q->ne[2];
kparams->n_head_log2 = 1u << (uint32_t) std::floor(std::log2(n_head));
kparams->m0 = std::pow(2.0f, -(max_bias) / kparams->n_head_log2);
kparams->m1 = std::pow(2.0f, -(max_bias / 2.0f) / kparams->n_head_log2);
// Check HMX eligibility
const struct ggml_tensor * sinks = op->src[4];
if (ggml_hexagon_flash_attn_is_hmx_eligible(sess, q, k, v, sinks)) {
size_t Br = 0, Bc = 0;
int ret = hmx_fa_find_chunk_size(&Br, &Bc, G, DK, DV, neq1, nek1, sess->vtcm_size, sess->n_threads);
if (ret == 0) {
kparams->kernel_type = HTP_FA_KERNEL_HMX;
kparams->Br = Br;
kparams->Bc = Bc;
kparams->n_kv_blocks = (nek1 + Bc - 1) / Bc;
kparams->n_threads = (kparams->n_kv_blocks >= 3 && sess->n_threads >= 2) ? sess->n_threads : 1;
kparams->u.hmx.g_br = hex_align_up(G * Br, 32);
kparams->u.hmx.pipeline = (kparams->n_kv_blocks >= 3 && sess->n_threads >= 2) ? 1 : 0;
kparams->vtcm_size = hmx_fa_compute_vtcm_usage(G, DK, DV, Br, Bc, kparams->n_threads, kparams->u.hmx.pipeline != 0);
const size_t row_vec_bytes = hex_align_up(Bc * sizeof(uint16_t), 256);
kparams->u.hmx.row_buf_stride = row_vec_bytes / 128; // HVX vector is 128 bytes
const size_t m_line_bytes = hex_align_up(Bc * sizeof(uint16_t), 128);
kparams->u.hmx.mask_buf_row_stride = m_line_bytes / sizeof(uint16_t);
kparams->u.hmx.mask_broadcast = (mask != nullptr && mask->ne[2] == 1) ? 1 : 0;
kparams->u.hmx.div_G = init_fastdiv_values(G);
if (mask) {
kparams->src3_div2 = init_fastdiv_values(mask->ne[2]);
kparams->src3_div3 = init_fastdiv_values(mask->ne[3]);
}
kparams->qrows = 0;
kparams->qrows_per_thread = 0;
return true;
}
}
// Fallback to HVX
kparams->kernel_type = HTP_FA_KERNEL_HVX;
kparams->Br = 1;
kparams->Bc = 64; // FLASH_ATTN_BLOCK_SIZE
kparams->n_kv_blocks = (k->ne[1] + 64 - 1) / 64;
kparams->n_threads = sess->n_threads;
const size_t size_q_row_padded = hex_round_up(q->ne[0] * (kparams->is_q_fp32 ? 4 : 2), 128);
const size_t size_k_row_padded = hex_round_up(k->ne[0] * 2, 128);
const size_t size_v_row_padded = hex_round_up(v->ne[0] * 2, 128);
kparams->vtcm_size = hvx_fa_compute_vtcm_usage(DK, DV, kparams->is_q_fp32 != 0, mask != nullptr, sess->n_threads);
kparams->u.hvx.size_q_row_padded = size_q_row_padded;
kparams->u.hvx.size_k_row_padded = size_k_row_padded;
kparams->u.hvx.size_v_row_padded = size_v_row_padded;
kparams->u.hvx.src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
kparams->u.hvx.src0_div1 = init_fastdiv_values(q->ne[1]);
kparams->u.hvx.broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
kparams->u.hvx.broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
kparams->u.hvx.broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
kparams->u.hvx.broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
if (mask) {
kparams->src3_div2 = init_fastdiv_values(mask->ne[2]);
kparams->src3_div3 = init_fastdiv_values(mask->ne[3]);
}
kparams->qrows = q->ne[1] * q->ne[2] * q->ne[3];
kparams->qrows_per_thread = (kparams->qrows + sess->n_threads - 1) / sess->n_threads;
return true;
}
static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const struct ggml_tensor * src0 = op->src[0];
@@ -1912,6 +2075,17 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess
return false;
}
struct htp_fa_kernel_params kparams;
if (!ggml_hexagon_precompute_flash_attn_params(sess, op, &kparams)) {
return false;
}
if ((size_t) kparams.vtcm_size > sess->vtcm_size) {
HEX_VERBOSE("ggml-hex: skip flash_attn_ext because VTCM needed (%d) > budget (%zu)\n",
kparams.vtcm_size, sess->vtcm_size);
return false;
}
return true;
}
@@ -2211,14 +2385,14 @@ static void ggml_hexagon_precompute_hvx_mm_params(
kparams->kernel_type = (src1_nrows < (int) sess->n_threads) ? HTP_MM_KERNEL_HVX_QUANT_BLOCK : HTP_MM_KERNEL_HVX_QUANT_ROW;
kparams->src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10);
size_t vtcm_src0_size = 0, vtcm_src1_size = 0;
size_t vtcm_src0_size = 0, vtcm_src1_size = 0, vtcm_dst_size = 0;
uint32_t max_prefetch = (src1_nrows > HTP_MM_HMX_MIN_NROWS) ? 2 : 16;
uint32_t best_n_prefetch = 2;
size_t total_size = 0;
for (uint32_t d = max_prefetch; d >= 2; d /= 2) {
total_size = htp_mm_hvx_id_get_vtcm_sizes(
wtype, ne10, src1_nrows, sess->n_threads, src0->nb[1], d,
&vtcm_src0_size, &vtcm_src1_size
&vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size
);
if (total_size <= vtcm_budget) {
best_n_prefetch = d;
@@ -2228,14 +2402,14 @@ static void ggml_hexagon_precompute_hvx_mm_params(
if (best_n_prefetch == 2 && total_size > vtcm_budget) {
total_size = htp_mm_hvx_id_get_vtcm_sizes(
wtype, ne10, src1_nrows, sess->n_threads, src0->nb[1], 2,
&vtcm_src0_size, &vtcm_src1_size
&vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size
);
}
kparams->n_prefetch = best_n_prefetch;
kparams->vtcm_size = total_size;
kparams->vtcm_src0_size = vtcm_src0_size;
kparams->vtcm_src1_size = vtcm_src1_size;
kparams->vtcm_dst_size = 0;
kparams->vtcm_dst_size = vtcm_dst_size;
} else {
bool try_tiled = (k_align && opt_mm_select >= 2);
if (try_tiled) {
@@ -2441,11 +2615,12 @@ static void ggml_hexagon_precompute_fused_qkv_params(
size_t src3_sz_per_thread = 0;
uint32_t best_n_prefetch = 16;
size_t quant_scratch_size = hex_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float)) * sess->n_threads;
if (is_repack) {
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
uint32_t n_k_tiles = hex_round_up(ne10, 32) / 32;
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float));
size_t src1_sz_per_thread = hex_round_up(src1_row_size * src1_nrows, 128);
size_t src1_sz = src1_sz_per_thread;
@@ -2453,13 +2628,10 @@ static void ggml_hexagon_precompute_fused_qkv_params(
best_n_prefetch = 2;
for (uint32_t d = max_prefetch; d >= 2; d /= 2) {
size_t repacked_vtcm_size = hex_round_up(d * tile_row_size, 128);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
size_t src0_sz = repacked_vtcm_size * sess->n_threads;
size_t src2_sz = hex_round_up(d * tile_row_size, 128) * sess->n_threads;
size_t src3_sz = hex_round_up(d * tile_row_size, 128) * sess->n_threads;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz + quant_scratch_size;
if (tiled_vtcm_size <= sess->vtcm_size) {
best_n_prefetch = d;
@@ -2471,9 +2643,6 @@ static void ggml_hexagon_precompute_fused_qkv_params(
}
if (best_n_prefetch == 2 && src0_sz_per_thread == 0) {
size_t repacked_vtcm_size = hex_round_up(2 * tile_row_size, 128);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
src0_sz_per_thread = repacked_vtcm_size;
src2_sz_per_thread = hex_round_up(2 * tile_row_size, 128);
src3_sz_per_thread = hex_round_up(2 * tile_row_size, 128);
@@ -2492,7 +2661,7 @@ static void ggml_hexagon_precompute_fused_qkv_params(
size_t src2_sz = src2_sz_per_thread * sess->n_threads;
size_t src3_sz = src3_sz_per_thread * sess->n_threads;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz + quant_scratch_size;
bool try_tiled = (opt_mm_select >= 2);
if (try_tiled && tiled_vtcm_size <= sess->vtcm_size) {
kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW;
@@ -2500,6 +2669,7 @@ static void ggml_hexagon_precompute_fused_qkv_params(
kparams->vtcm_src1_size = src1_sz;
kparams->vtcm_src2_size = src2_sz;
kparams->vtcm_src3_size = src3_sz;
kparams->vtcm_dst_size = quant_scratch_size;
kparams->vtcm_size = tiled_vtcm_size;
kparams->n_prefetch = best_n_prefetch;
} else {
@@ -2510,7 +2680,8 @@ static void ggml_hexagon_precompute_fused_qkv_params(
kparams->vtcm_src1_size = flat_src1_sz;
kparams->vtcm_src2_size = src2_sz;
kparams->vtcm_src3_size = src3_sz;
kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz + src3_sz;
kparams->vtcm_dst_size = quant_scratch_size;
kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz + src3_sz + quant_scratch_size;
kparams->n_prefetch = best_n_prefetch;
}
}
@@ -2536,11 +2707,12 @@ static void ggml_hexagon_precompute_fused_ffn_params(
size_t src2_sz_per_thread = 0;
uint32_t best_n_prefetch = 16;
size_t quant_scratch_size = hex_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float)) * sess->n_threads;
if (is_repack) {
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
uint32_t n_k_tiles = hex_round_up(ne10, 32) / 32;
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float));
size_t src1_sz_per_thread = hex_round_up(src1_row_size * src1_nrows, 128);
size_t src1_sz = src1_sz_per_thread;
@@ -2548,12 +2720,9 @@ static void ggml_hexagon_precompute_fused_ffn_params(
best_n_prefetch = 2;
for (uint32_t d = max_prefetch; d >= 2; d /= 2) {
size_t repacked_vtcm_size = hex_round_up(d * tile_row_size, 128);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
size_t src0_sz = repacked_vtcm_size * sess->n_threads;
size_t src2_sz = hex_round_up(d * tile_row_size, 128) * sess->n_threads;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + quant_scratch_size;
if (tiled_vtcm_size <= sess->vtcm_size) {
best_n_prefetch = d;
@@ -2564,9 +2733,6 @@ static void ggml_hexagon_precompute_fused_ffn_params(
}
if (best_n_prefetch == 2 && src0_sz_per_thread == 0) {
size_t repacked_vtcm_size = hex_round_up(2 * tile_row_size, 128);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
src0_sz_per_thread = repacked_vtcm_size;
src2_sz_per_thread = hex_round_up(2 * tile_row_size, 128);
}
@@ -2582,13 +2748,14 @@ static void ggml_hexagon_precompute_fused_ffn_params(
size_t src1_sz = src1_sz_per_thread;
size_t src2_sz = src2_sz_per_thread * sess->n_threads;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + quant_scratch_size;
bool try_tiled = (opt_mm_select >= 2);
if (try_tiled && tiled_vtcm_size <= sess->vtcm_size) {
kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW;
kparams->vtcm_src0_size = src0_sz;
kparams->vtcm_src1_size = src1_sz;
kparams->vtcm_src2_size = src2_sz;
kparams->vtcm_dst_size = quant_scratch_size;
kparams->vtcm_size = tiled_vtcm_size;
kparams->n_prefetch = best_n_prefetch;
} else {
@@ -2598,7 +2765,8 @@ static void ggml_hexagon_precompute_fused_ffn_params(
kparams->vtcm_src0_size = src0_sz;
kparams->vtcm_src1_size = flat_src1_sz;
kparams->vtcm_src2_size = src2_sz;
kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz;
kparams->vtcm_dst_size = quant_scratch_size;
kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz + quant_scratch_size;
kparams->n_prefetch = best_n_prefetch;
}
}
@@ -3243,7 +3411,7 @@ static inline bool op_is_compute(ggml_tensor *node)
return !ggml_op_is_empty(node->op) && !ggml_is_empty(node) && (node->flags & GGML_TENSOR_FLAG_COMPUTE);
}
static bool is_hmx_eligible(const ggml_tensor * t) {
static bool mm_is_hmx_eligible(const ggml_tensor * t) {
if (opt_nhmx == 0) { return false; }
const ggml_tensor * src0 = t->src[0];
@@ -3262,7 +3430,7 @@ static bool is_hmx_eligible(const ggml_tensor * t) {
static bool is_mergeable_mul_mat(const ggml_tensor * t) {
if (!t || t->op != GGML_OP_MUL_MAT) return false;
if (t->src[1]->type != GGML_TYPE_F32) return false;
return ggml_is_quantized(t->src[0]->type) && !is_hmx_eligible(t);
return ggml_is_quantized(t->src[0]->type) && !mm_is_hmx_eligible(t);
}
static bool is_mergeable_mul_mat_pair(const ggml_tensor * n1, const ggml_tensor * n2) {
@@ -3357,6 +3525,26 @@ static bool try_fuse_node(const ggml_hexagon_session * sess, const ggml_cgraph *
}
}
if (n->op == GGML_OP_MUL_MAT && next_node) {
if (next_node->op == GGML_OP_ADD && op_is_compute(next_node) && ggml_can_fuse(graph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) {
if (next_node->src[0] == n || next_node->src[1] == n) {
struct htp_mm_kernel_params kparams;
ggml_hexagon_precompute_matmul_params(sess, n->src[0], n->src[1], next_node, &kparams);
if ((size_t)kparams.vtcm_size <= sess->vtcm_size) {
htp_opnode node(n, {}, HTP_OP_MUL_MAT_ADD);
node.add_fused(next_node);
memcpy(node.kernel_params, &kparams, sizeof(kparams));
nodes.push_back(std::move(node));
i += 1;
return true;
} else {
HEX_VERBOSE("ggml-hex: skip MUL_MAT_ADD fusion because VTCM needed (%d) > budget (%zu)\n",
kparams.vtcm_size, sess->vtcm_size);
}
}
}
}
return false;
}
@@ -3393,6 +3581,11 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
node.node->src[0], node.node->src[1], node.node,
(struct htp_mm_kernel_params *)node.kernel_params
);
} else if (node.opcode == HTP_OP_FLASH_ATTN_EXT) {
ggml_hexagon_precompute_flash_attn_params(sess,
node.node,
(struct htp_fa_kernel_params *)node.kernel_params
);
}
computed_nodes.push_back(std::move(node));
}
@@ -4079,6 +4272,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX");
const char * str_nhmx = getenv("GGML_HEXAGON_NHMX");
const char * str_mm_select = getenv("GGML_HEXAGON_MM_SELECT");
const char * str_fa_select = getenv("GGML_HEXAGON_FA_SELECT");
const char * str_ndev = getenv("GGML_HEXAGON_NDEV");
const char * str_arch = getenv("GGML_HEXAGON_ARCH");
const char * str_vmem = getenv("GGML_HEXAGON_VMEM");
@@ -4120,6 +4314,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;
opt_nhmx = str_nhmx ? atoi(str_nhmx) : (str_use_hmx ? atoi(str_use_hmx) : opt_nhmx);
opt_mm_select = str_mm_select ? atoi(str_mm_select) : opt_mm_select;
opt_fa_select = str_fa_select ? atoi(str_fa_select) : opt_fa_select;
opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev;
opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf;
opt_mbuf = str_mbuf ? strtoul(str_mbuf, NULL, 0) * MiB : opt_mbuf;
+13 -1
View File
@@ -11,6 +11,7 @@
#include <stdio.h>
#include "htp-ops.h"
#include "htp/matmul-ops.h"
#include "htp/flash-attn-ops.h"
struct htp_opnode {
ggml_tensor * node = nullptr;
@@ -335,7 +336,8 @@ struct htp_opformat {
}
void format_kernel_params(char * str, size_t max_size, const htp_opnode & node) {
if (node.opcode == HTP_OP_MUL_MAT || node.opcode == HTP_OP_MUL_MAT_ID ||
node.opcode == HTP_OP_MUL_MAT_QKV || node.opcode == HTP_OP_MUL_MAT_FFN) {
node.opcode == HTP_OP_MUL_MAT_QKV || node.opcode == HTP_OP_MUL_MAT_FFN ||
node.opcode == HTP_OP_MUL_MAT_ADD) {
const auto * kparams = (const struct htp_mm_kernel_params *) node.kernel_params;
const char * path = "unknown";
int32_t type = kparams->kernel_type;
@@ -350,6 +352,16 @@ struct htp_opformat {
path = "hvx-flat";
}
snprintf(str, max_size, "%s vtcm %d", path, (int) kparams->vtcm_size);
} else if (node.opcode == HTP_OP_FLASH_ATTN_EXT) {
const auto * kparams = (const struct htp_fa_kernel_params *) node.kernel_params;
const char * path = "unknown";
int32_t type = kparams->kernel_type;
if (type == HTP_FA_KERNEL_HMX) {
path = kparams->u.hmx.pipeline ? "hmx-pipe" : "hmx-seq";
} else if (type == HTP_FA_KERNEL_HVX) {
path = "hvx";
}
snprintf(str, max_size, "%s vtcm %d", path, (int) kparams->vtcm_size);
} else {
snprintf(str, max_size, "----");
}
+2 -7
View File
@@ -20,9 +20,6 @@ add_library(${HTP_LIB} SHARED
worker-pool.c
hex-dma.c
hmx-queue.c
flash-attn-ops.c
hmx-flash-attn-ops.c
matmul-ops.c
binary-ops.c
unary-ops.c
sum-rows-ops.c
@@ -42,16 +39,14 @@ add_library(${HTP_LIB} SHARED
solve-tri-ops.c
gated-delta-net-ops.c
pad-ops.c
matmul-ops.c
flash-attn-ops.c
)
target_compile_definitions(${HTP_LIB} PRIVATE
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>)
if (GGML_HEXAGON_FA_EXP2_HF)
message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)")
target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1)
endif()
build_idl(htp_iface.idl ${HTP_LIB})
File diff suppressed because it is too large Load Diff
+253
View File
@@ -0,0 +1,253 @@
#ifndef HTP_FLASH_ATTN_OPS_H
#define HTP_FLASH_ATTN_OPS_H
#include <stdint.h>
#include <stddef.h>
#include <stdbool.h>
#include "hex-fastdiv.h"
#include "hex-common.h"
#ifdef __cplusplus
extern "C" {
#endif
// Tile constants (mirrored from hmx-utils.h for use on host side if needed)
#define HMX_FP16_TILE_N_ROWS 32
#define HMX_FP16_TILE_N_COLS 32
#define HMX_FP16_TILE_N_ELMS 1024
#define HMX_FP16_TILE_SIZE 2048
#define HVX_FA_DMA_CACHE_SIZE 128
#define HMX_FA_DMA_CACHE_SIZE 4
#define HTP_FA_M_INITIAL_VAL -10000.0f
enum htp_fa_kernel_type {
HTP_FA_KERNEL_UNSUPPORTED = 0,
HTP_FA_KERNEL_HVX,
HTP_FA_KERNEL_HMX
};
struct htp_fa_kernel_params {
uint8_t kernel_type; // enum htp_fa_kernel_type
uint8_t is_q_fp32; // 1 = Q type is F32, 0 = F16
uint8_t is_dst_fp32; // 1 = dst type is F32, 0 = F16
uint8_t n_threads; // Number of threads to run
// Common parameters
uint16_t Br;
uint16_t Bc;
uint16_t n_kv_blocks; // also HVX's n_blocks
uint16_t G; // GQA factor (n_heads / n_kv_heads)
float scale;
float max_bias;
float logit_softcap;
uint32_t vtcm_size;
uint32_t qrows;
uint32_t qrows_per_thread;
float m0;
float m1;
uint32_t n_head_log2;
struct fastdiv_values src3_div2;
struct fastdiv_values src3_div3;
union {
struct {
uint32_t g_br;
uint32_t row_buf_stride;
uint32_t mask_buf_row_stride;
int32_t mask_broadcast;
int32_t pipeline;
struct fastdiv_values div_G;
} hmx;
struct {
uint32_t size_q_row_padded;
uint32_t size_k_row_padded;
uint32_t size_v_row_padded;
struct fastdiv_values src0_div21;
struct fastdiv_values src0_div1;
struct fastdiv_values broadcast_rk2;
struct fastdiv_values broadcast_rk3;
struct fastdiv_values broadcast_rv2;
struct fastdiv_values broadcast_rv3;
} hvx;
} u;
};
#if defined(__cplusplus)
static_assert(sizeof(struct htp_fa_kernel_params) <= 128, "htp_fa_kernel_params is too large for kernel_params blob");
#endif
// Exact VTCM usage for a given (gqa_factor, DK, DV, Br, Bc) configuration.
// g_br = hex_align_up(gqa_factor * Br, 32) replaces Br for all Q/O/S/P/D dimensions.
// Layout: Q + O_ping + O_pong + K_dma*2 + V_dma*2 + K_tile + V_tile + S + P + D + vectors + scales
// Mask is DMA'd into a VTCM buffer (Br rows per KV block) to avoid DDR reads in softmax.
static inline size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool pipeline) {
const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS);
const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK]
const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong
const size_t k_dma_size = hex_align_up(Bc * hex_round_up(DK * sizeof(__fp16), 128), 4096); // K DMA: [Bc, DK] x2 double-buf
const size_t v_dma_size = hex_align_up(Bc * hex_round_up(DV * sizeof(__fp16), 128), 4096); // V DMA: [Bc, DV] x2 double-buf
const size_t k_tile_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); // K tiles: [Bc, DK] interleaved
const size_t v_tile_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); // V tiles: [Bc, DV] interleaved
const size_t s_tile_size = hex_align_up(g_br * Bc * sizeof(__fp16), 4096); // S/P:[g_br, Bc]
const size_t d_tile_size = hex_align_up(g_br * g_br * sizeof(__fp16), 4096); // D: [g_br, g_br]
const size_t col_vec_size = hex_align_up(g_br * sizeof(float), 256); // m, l, etc.
const size_t row_vec_size = hex_align_up(Bc * sizeof(__fp16), 256);
const size_t m_line_size = hex_align_up(Bc * sizeof(__fp16), 128);
const size_t m_buf_size = hex_align_up(Br * m_line_size, 4096) * HMX_FA_DMA_CACHE_SIZE;
const size_t slopes_size = hex_align_up(g_br * sizeof(__fp16), 128);
return q_tile_size * 1 // Q tiles
+ o_tile_size * 2 // O ping-pong
+ k_dma_size * 2 // K DMA x2
+ v_dma_size * 2 // V DMA x2
+ k_tile_size * 1 // K tiles
+ v_tile_size * (pipeline ? 2 : 1) // V tiles (double-buffered if pipelining)
+ s_tile_size * 2 // S + P
+ d_tile_size * 1 // D (diagonal matrix)
+ col_vec_size * 4 // m_vec, l_vec, s_rowmax, p_rowsum
+ row_vec_size * 2 * n_threads // per-thread softmax row scratch
+ m_buf_size * 1 // mask VTCM buffer [Br rows]
+ slopes_size // Slopes
+ 256 * 2; // HMX scales (id + qk)
}
#define FA_HVX_BLOCK_SIZE 64
static inline size_t hvx_fa_compute_vtcm_usage(size_t DK, size_t DV, bool is_q_fp32, bool has_mask, size_t n_threads) {
const size_t size_q_row_padded = hex_round_up(DK * (is_q_fp32 ? 4 : 2), 128);
const size_t size_k_row_padded = hex_round_up(DK * sizeof(__fp16), 128);
const size_t size_v_row_padded = hex_round_up(DV * sizeof(__fp16), 128);
const size_t size_q_block = size_q_row_padded * 1;
const size_t size_k_block = size_k_row_padded * FA_HVX_BLOCK_SIZE;
const size_t size_v_block = size_v_row_padded * FA_HVX_BLOCK_SIZE;
const size_t size_m_block = hex_round_up(FA_HVX_BLOCK_SIZE * sizeof(__fp16), 128);
const size_t size_vkq_acc = hex_round_up(DV * sizeof(float), 128);
const size_t size_per_thread = size_q_block * 1
+ size_k_block * 2
+ size_v_block * 2
+ (has_mask ? size_m_block * HVX_FA_DMA_CACHE_SIZE : 0)
+ size_vkq_acc;
return size_per_thread * n_threads;
}
#define FA_MIN_KV_BLOCKS 3
// Cost-based (Br, Bc) search for flash attention with pipeline constraint.
static inline int hmx_fa_find_chunk_size(size_t * Br_out,
size_t * Bc_out,
size_t gqa_factor,
size_t DK,
size_t DV,
size_t qo_len,
size_t kv_len,
size_t vtcm_budget,
size_t n_threads) {
const size_t T = HMX_FP16_TILE_N_ROWS; // 32
const size_t br_unit = hmx_ceil_div(T, gqa_factor);
const size_t bc_unit = HMX_FP16_TILE_N_COLS * 2; // 64
const size_t fp16 = sizeof(__fp16);
const bool can_pipeline = (kv_len >= FA_MIN_KV_BLOCKS * bc_unit && n_threads >= 2);
// Approximate per-unit VTCM costs (without per-buffer alignment padding).
const size_t per_gbr = (DK + 2 * DV) * fp16 + 4 * sizeof(float); // Q + O*2 + 4 col vectors
const size_t per_gbr2 = fp16; // D diagonal matrix
const size_t per_bc =
3 * DK * fp16 + (can_pipeline ? 4 : 3) * DV * fp16 + 2 * n_threads * fp16; // K/V DMA x2 + tiles + row bufs
const size_t per_gbr_bc = 2 * fp16; // S + P
const size_t overhead = 256 * 2 + 13 * 4096;
if (vtcm_budget <= overhead) {
return -1;
}
const size_t usable = vtcm_budget - overhead;
// Br_max: largest Br aligned to br_unit that does not exceed qo_len.
const size_t Br_max = qo_len >= br_unit ? hex_align_down(qo_len, br_unit) : br_unit;
// Pipeline constraint: cap Bc so n_kv_blocks >= FA_MIN_KV_BLOCKS.
// Only relax when kv_len is too short to form enough blocks.
const size_t Bc_limit = can_pipeline ? hex_align_down(kv_len / FA_MIN_KV_BLOCKS, bc_unit) :
(kv_len >= bc_unit ? hex_align_down(kv_len, bc_unit) : bc_unit);
// Cost coefficients calibrated from profiling
const size_t c_q_fixed = 1400; // per-Q-block: q_load + epilogue o_update + o_norm + o_store
const size_t c_iter_fixed = 200; // per-KV-iter: HMX queue push/pop + DMA pop + barriers
size_t best_cost = SIZE_MAX, best_mn = 0;
size_t best_Br = 0, best_Bc = 0;
for (size_t Br = Br_max; Br >= br_unit; Br -= br_unit) {
const size_t g_br = hex_align_up(gqa_factor * Br, T);
// g_br-dependent VTCM cost: g_br * per_gbr + g_br*g_br * per_gbr2
const size_t gbr_cost = g_br * per_gbr + g_br * g_br * per_gbr2;
if (gbr_cost >= usable) {
if (Br == br_unit) {
break;
}
continue;
}
// Analytically solve for max Bc:
// remain >= Bc * (per_bc + g_br * per_gbr_bc + Br * fp16 * HMX_FA_DMA_CACHE_SIZE)
// The Br * fp16 term accounts for the VTCM mask buffer [Br * Bc].
const size_t remain = usable - gbr_cost;
const size_t bc_denom = per_bc + g_br * per_gbr_bc + Br * fp16 * HMX_FA_DMA_CACHE_SIZE;
size_t Bc = hex_smin(hex_align_down(remain / bc_denom, bc_unit), Bc_limit);
if (Bc < bc_unit) {
if (Br == br_unit) {
break;
}
continue;
}
// Exact VTCM verification (alignment padding may push over budget)
while (Bc >= bc_unit && hmx_fa_compute_vtcm_usage(gqa_factor, DK, DV, Br, Bc, n_threads, can_pipeline) > vtcm_budget) {
Bc -= bc_unit;
}
if (Bc < bc_unit) {
if (Br == br_unit) {
break;
}
continue;
}
const size_t q_blocks = (qo_len + Br - 1) / Br;
const size_t kv_blocks = (kv_len + Bc - 1) / Bc;
const size_t cost = q_blocks * (c_q_fixed + kv_blocks * c_iter_fixed);
const size_t mn = Br * Bc;
if (cost < best_cost || (cost == best_cost && mn > best_mn)) {
best_cost = cost;
best_mn = mn;
best_Br = Br;
best_Bc = Bc;
}
if (Br == br_unit) {
break;
}
}
if (best_Br == 0) {
return -1;
}
*Br_out = best_Br;
*Bc_out = best_Bc;
return 0;
}
#ifdef __cplusplus
}
#endif
#endif /* HTP_FLASH_ATTN_OPS_H */
+15 -15
View File
@@ -138,27 +138,28 @@ static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t
}
dma_descriptor_1d * desc = (dma_descriptor_1d *) &q->desc[q->push_idx];
desc->next = NULL;
desc->desc_size = 0; // 1D mode
desc->src_bypass = dma_src_l2_bypass_on;
desc->dst_bypass = dma_dst_l2_bypass_on;
desc->order = 0;
desc->done = 0;
desc->src = (void *) dptr.src;
desc->dst = (void *) dptr.dst;
desc->size = size;
desc->src = (void *) dptr.src;
desc->dst = (void *) dptr.dst;
desc->size = size;
q->dptr[q->push_idx] = dptr;
if (size) {
desc->next = NULL;
desc->desc_size = 0; // 1D mode
desc->src_bypass = dma_src_l2_bypass_on;
desc->dst_bypass = dma_dst_l2_bypass_on;
desc->order = 0;
desc->done = 0;
htp_trace_event_start(q->trace, HTP_TRACE_EVT_DMA, q->push_idx);
dmlink(q->tail, desc);
q->tail = (dma_descriptor_2d *) desc;
} else {
desc->done = 1;
desc->desc_size = 0;
desc->done = 1;
}
// FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src);
q->push_idx = (q->push_idx + 1) & q->idx_mask;
return true;
}
@@ -320,7 +321,7 @@ static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q, dma_ptr dptr, size_
return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows);
}
#define DMA_CACHE_MAX_SIZE 64U
#define DMA_CACHE_MAX_SIZE 256U
typedef struct {
uint8_t *base;
@@ -352,20 +353,19 @@ static inline bool dma_cache_push(dma_queue *q, dma_cache *c, const uint8_t * sr
if (c->src[i] == (uint32_t) src) {
c->age[i] = 0;
dst = c->base + (i * c->line_size); nrows = 0; // dummy dma
// FARF(ERROR, "dma-cache: found %p", src);
} else {
c->age[i]++;
if (c->age[i] > o_age) { o_age = c->age[i]; o_idx = i; }
}
}
if (!dst) {
// FARF(ERROR, "dma-cache: replacing #%u : age %u %p -> %p", o_idx, c->age[o_idx], (void *) c->src[o_idx], src);
c->age[o_idx] = 0;
c->src[o_idx] = (uint32_t) src;
dst = c->base + o_idx * c->line_size; // normal nrows dma
return dma_queue_push(q, dma_make_ptr(dst, src), dst_stride, src_stride, row_size, nrows);
}
return dma_queue_push(q, dma_make_ptr(dst, src), dst_stride, src_stride, row_size, nrows);
return dma_queue_push_single_1d(q, dma_make_ptr(dst, src), 0);
}
#ifdef __cplusplus
@@ -0,0 +1,96 @@
#ifndef HMX_FA_KERNELS_H
#define HMX_FA_KERNELS_H
#include <stdint.h>
#include <stddef.h>
#include <stdbool.h>
#include "hvx-utils.h"
#include "hmx-utils.h"
// HMX-specific parameters, offsets and inner kernels for Flash Attention
// Scatter offsets for diagonal tile: entry[2i] = i*136, entry[2i+1] = i*136+6
// 136 = 4 * 32 + 8 = byte offset to diagonal in a 32x32 fp16 interleaved tile
static const int16_t d_tile_scatter_offsets[64] __attribute__((aligned(128))) = {
0 * 136, 0 * 136 + 6,
1 * 136, 1 * 136 + 6,
2 * 136, 2 * 136 + 6,
3 * 136, 3 * 136 + 6,
4 * 136, 4 * 136 + 6,
5 * 136, 5 * 136 + 6,
6 * 136, 6 * 136 + 6,
7 * 136, 7 * 136 + 6,
8 * 136, 8 * 136 + 6,
9 * 136, 9 * 136 + 6,
10 * 136, 10 * 136 + 6,
11 * 136, 11 * 136 + 6,
12 * 136, 12 * 136 + 6,
13 * 136, 13 * 136 + 6,
14 * 136, 14 * 136 + 6,
15 * 136, 15 * 136 + 6,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
};
// Inner HMX tile computation kernels
static inline void hmx_fa_qk_dot_tile(
const __fp16 * row_tiles,
const __fp16 * col_tiles,
__fp16 * out_tile,
size_t n_dot_tiles
) {
for (size_t k = 0; k < n_dot_tiles; ++k) {
Q6_activation_hf_mxmem_RR((unsigned int) row_tiles, 2047);
Q6_weight_hf_mxmem_RR((unsigned int) col_tiles, 2047);
row_tiles += HMX_FP16_TILE_N_ELMS;
col_tiles += HMX_FP16_TILE_N_ELMS;
}
Q6_mxmem_AR_after_hf(out_tile, 0);
}
static inline void hmx_fa_o_update_tile(
const __fp16 * d_diag,
const __fp16 * o_rc,
const __fp16 * p_tile_in,
const __fp16 * v_tile_in,
__fp16 * o_tile_out,
size_t n_col_tiles
) {
Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047);
Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047);
for (size_t k = 0; k < n_col_tiles; ++k) {
Q6_activation_hf_mxmem_RR((unsigned int) p_tile_in, 2047);
Q6_weight_hf_mxmem_RR((unsigned int) v_tile_in, 2047);
p_tile_in += HMX_FP16_TILE_N_ELMS;
v_tile_in += HMX_FP16_TILE_N_ELMS;
}
Q6_mxmem_AR_after_hf(o_tile_out, 0);
}
static inline void hmx_fa_o_norm_tile(
const __fp16 * d_diag,
const __fp16 * o_rc,
__fp16 * o_out
) {
Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047);
Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047);
Q6_mxmem_AR_after_hf(o_out, 0);
}
#endif /* HMX_FA_KERNELS_H */
File diff suppressed because it is too large Load Diff
@@ -712,7 +712,17 @@ static inline void hmx_matmul_job_init(hmx_matmul_job_t * job,
// output : fp16 -> f32p
static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, uint32_t start_row, uint32_t n_rows, uint32_t n_cols, uint32_t dst_stride, uint32_t dst_cols) {
static void transfer_output_chunk_fp16_to_fp32(
float *restrict dst,
const float *restrict src2,
const __fp16 *restrict vtcm_src,
uint32_t start_row,
uint32_t n_rows,
uint32_t n_cols,
uint32_t dst_stride,
uint32_t src2_stride,
uint32_t dst_cols
) {
assert(n_cols % HTP_MM_HMX_TILE_N_COLS == 0);
const size_t tile_row_stride = (n_cols / HTP_MM_HMX_TILE_N_COLS) * HTP_MM_HMX_TILE_N_ELMS;
@@ -727,6 +737,7 @@ static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16
const size_t r1 = (r_idx0 % HTP_MM_HMX_TILE_N_ROWS) / 2; // index of the row pair within the tile
const __fp16 *row_base = vtcm_src + r0 * tile_row_stride;
float *output_row_base = dst + r * dst_stride; // global memory row base for row r (and r+1)
const float *src2_row_base = src2 ? (src2 + r * src2_stride) : NULL;
#pragma unroll(4)
for (size_t c = 0; c < limit_c_aligned; c += HTP_MM_HMX_TILE_N_COLS) {
@@ -738,9 +749,20 @@ static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16
HVX_Vector *pv_out0 = (HVX_Vector *) (output_row_base + c + 0);
HVX_Vector *pv_out1 = (HVX_Vector *) (output_row_base + c + dst_stride);
*pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp));
HVX_Vector v_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp));
if (src2_row_base) {
HVX_Vector v_src2_0 = hvx_vmemu(src2_row_base + c + 0);
v_out0 = hvx_vec_add_f32_f32(v_out0, v_src2_0);
}
*pv_out0 = v_out0;
if (r + 1 < n_rows) {
*pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp));
HVX_Vector v_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp));
if (src2_row_base) {
HVX_Vector v_src2_1 = hvx_vmemu(src2_row_base + c + src2_stride);
v_out1 = hvx_vec_add_f32_f32(v_out1, v_src2_1);
}
*pv_out1 = v_out1;
}
}
@@ -752,9 +774,20 @@ static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16
HVX_Vector v = ((const HVX_Vector *) tile)[r1];
HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one);
hvx_vec_store_u(output_row_base + c, valid_c * sizeof(float), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)));
HVX_Vector v_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp));
if (src2_row_base) {
HVX_Vector v_src2_0 = hvx_vmemu(src2_row_base + c + 0);
v_out0 = hvx_vec_add_f32_f32(v_out0, v_src2_0);
}
hvx_vec_store_u(output_row_base + c, valid_c * sizeof(float), v_out0);
if (r + 1 < n_rows) {
hvx_vec_store_u(output_row_base + c + dst_stride, valid_c * sizeof(float), Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp)));
HVX_Vector v_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp));
if (src2_row_base) {
HVX_Vector v_src2_1 = hvx_vmemu(src2_row_base + c + src2_stride);
v_out1 = hvx_vec_add_f32_f32(v_out1, v_src2_1);
}
hvx_vec_store_u(output_row_base + c + dst_stride, valid_c * sizeof(float), v_out1);
}
}
}
@@ -763,11 +796,13 @@ static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16
typedef struct {
const __fp16 *vtcm_src;
float *dst;
const float *src2;
uint32_t n_tasks;
uint32_t n_tot_chunks;
uint32_t n_chunks_per_task;
uint32_t n_cols;
uint32_t dst_stride; // DDR row stride
uint32_t src2_stride; // DDR row stride for residual
uint32_t dst_cols; // Actual output columns
struct htp_thread_trace * traces;
} output_transfer_task_state_t;
+35 -35
View File
@@ -42,14 +42,14 @@ static const int32_t hmx_transpose_scatter_offsets[32] __attribute__((aligned(VL
// Full range: start_row=0, end_row=n_cols.
static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
const __fp16 * restrict vtcm_src,
int n_cols,
int k,
int src_stride,
int start_row,
int end_row) {
uint32_t n_cols,
uint32_t k,
size_t src_stride,
uint32_t start_row,
uint32_t end_row) {
assert(k % HMX_FP16_TILE_N_COLS == 0);
const int n_k_tiles = k / HMX_FP16_TILE_N_COLS;
const uint32_t n_k_tiles = k / HMX_FP16_TILE_N_COLS;
const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets);
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4);
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64);
@@ -65,14 +65,14 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
if (pair_scatter) {
// Step c by 64 fp16 (two K-tiles per scatter), advance dst by 2 tiles per iter.
const int c_step = 2 * HMX_FP16_TILE_N_COLS;
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
const size_t dst_step = 2 * (size_t) HMX_FP16_TILE_N_ELMS;
const int n_c_iters = k / c_step;
const uint32_t c_step = 2 * HMX_FP16_TILE_N_COLS;
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
const size_t dst_step = 2 * (size_t) HMX_FP16_TILE_N_ELMS;
const uint32_t n_c_iters = k / c_step;
for (int r = start_row; r < end_row; r += 2) {
const int ct = r / HMX_FP16_TILE_N_ROWS;
const int local_r = r % HMX_FP16_TILE_N_ROWS;
for (uint32_t r = start_row; r < end_row; r += 2) {
const uint32_t ct = r / HMX_FP16_TILE_N_ROWS;
const uint32_t local_r = r % HMX_FP16_TILE_N_ROWS;
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols;
const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
@@ -86,7 +86,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
assert(c_byte_step % 128 == 0);
if (p1) {
for (int i = 0; i < n_c_iters; ++i) {
for (uint32_t i = 0; i < n_c_iters; ++i) {
HVX_Vector v0 = hvx_vmem(p0); p0 += c_byte_step;
HVX_Vector v1 = hvx_vmem(p1); p1 += c_byte_step;
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0);
@@ -95,7 +95,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
}
} else {
const HVX_Vector vzero = Q6_V_vzero();
for (int i = 0; i < n_c_iters; ++i) {
for (uint32_t i = 0; i < n_c_iters; ++i) {
HVX_Vector v0 = hvx_vmem(p0); p0 += c_byte_step;
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0);
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, vzero);
@@ -105,14 +105,14 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
}
} else {
// Fallback: scatter one K-tile per call (region 2047, masked).
const int c_step = HMX_FP16_TILE_N_COLS;
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
const size_t dst_step = (size_t) HMX_FP16_TILE_N_ELMS;
const int n_c_iters = k / c_step;
const uint32_t c_step = HMX_FP16_TILE_N_COLS;
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
const size_t dst_step = (size_t) HMX_FP16_TILE_N_ELMS;
const uint32_t n_c_iters = k / c_step;
for (int r = start_row; r < end_row; r += 2) {
const int ct = r / HMX_FP16_TILE_N_ROWS;
const int local_r = r % HMX_FP16_TILE_N_ROWS;
for (uint32_t r = start_row; r < end_row; r += 2) {
const uint32_t ct = r / HMX_FP16_TILE_N_ROWS;
const uint32_t local_r = r % HMX_FP16_TILE_N_ROWS;
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols;
const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
@@ -122,7 +122,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL;
if (p1) {
for (int i = 0; i < n_c_iters; ++i) {
for (uint32_t i = 0; i < n_c_iters; ++i) {
HVX_Vector v0 = hvx_vmemu(p0); p0 += c_byte_step;
HVX_Vector v1 = hvx_vmemu(p1); p1 += c_byte_step;
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0);
@@ -131,7 +131,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
}
} else {
const HVX_Vector vzero = Q6_V_vzero();
for (int i = 0; i < n_c_iters; ++i) {
for (uint32_t i = 0; i < n_c_iters; ++i) {
HVX_Vector v0 = hvx_vmemu(p0); p0 += c_byte_step;
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0);
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, vzero);
@@ -148,24 +148,24 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
// Full range: start_row=0, end_row=n_rows.
static inline void hmx_interleave_cols_to_tiles(__fp16 * restrict tiles_out,
const __fp16 * restrict src,
int n_rows,
int head_dim,
int src_stride,
int n_row_tiles,
int start_row,
int end_row) {
uint32_t n_rows,
uint32_t head_dim,
size_t src_stride,
uint32_t n_row_tiles,
uint32_t start_row,
uint32_t end_row) {
__builtin_assume(head_dim > 0);
const size_t tile_stride_elms = (size_t) n_row_tiles * HMX_FP16_TILE_N_ELMS;
for (int r = start_row; r < end_row; r += 2) {
for (uint32_t r = start_row; r < end_row; r += 2) {
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_rows;
const HVX_Vector * pv_in0 = (const HVX_Vector *) (src + r * src_stride);
const HVX_Vector * pv_in1 = next_row_valid ? (const HVX_Vector *) (src + (r + 1) * src_stride) : NULL;
// Row-pair invariants hoisted out of the c loop.
const int r0 = r / HMX_FP16_TILE_N_ROWS;
const int r1_half = (r % HMX_FP16_TILE_N_ROWS) / 2;
const uint32_t r0 = r / HMX_FP16_TILE_N_ROWS;
const uint32_t r1_half = (r % HMX_FP16_TILE_N_ROWS) / 2;
// tb0 starts at tile (c0=0, r0); tb1 at the adjacent dim-tile (c0=1, r0).
// Each c step (+= 64) advances both by 2 dim-tiles worth of fp16.
@@ -174,7 +174,7 @@ static inline void hmx_interleave_cols_to_tiles(__fp16 * restrict tiles_out,
const size_t tb_step = 2 * tile_stride_elms;
if (pv_in1) {
for (int c = 0; c < head_dim; c += 64) {
for (uint32_t c = 0; c < head_dim; c += 64) {
HVX_Vector v0 = *pv_in0++;
HVX_Vector v1 = *pv_in1++;
HVX_VectorPair vp = Q6_W_vshuff_VVR(v1, v0, -2);
@@ -185,7 +185,7 @@ static inline void hmx_interleave_cols_to_tiles(__fp16 * restrict tiles_out,
}
} else {
const HVX_Vector vzero = Q6_V_vzero();
for (int c = 0; c < head_dim; c += 64) {
for (uint32_t c = 0; c < head_dim; c += 64) {
HVX_Vector v0 = *pv_in0++;
HVX_VectorPair vp = Q6_W_vshuff_VVR(vzero, v0, -2);
((HVX_Vector *) tb0)[r1_half] = Q6_V_lo_W(vp);
+6
View File
@@ -60,6 +60,7 @@ enum htp_op_code {
HTP_OP_MUL_MAT_ID,
HTP_OP_MUL_MAT_QKV,
HTP_OP_MUL_MAT_FFN,
HTP_OP_MUL_MAT_ADD,
HTP_OP_RMS_NORM,
HTP_OP_RMS_NORM_MUL,
HTP_OP_UNARY_SILU,
@@ -175,6 +176,11 @@ enum htp_trace_event_id {
HTP_TRACE_EVT_HVX_W_DEQUANT = 23,
HTP_TRACE_EVT_HVX_W_PREP = 24,
HTP_TRACE_EVT_HVX_O_PROC = 25,
HTP_TRACE_EVT_HVX_FA_QK = 26,
HTP_TRACE_EVT_HVX_FA_SFM = 27,
HTP_TRACE_EVT_HVX_FA_Q_PREP = 28,
HTP_TRACE_EVT_HVX_FA_K_PREP = 29,
HTP_TRACE_EVT_HVX_FA_V_PREP = 30,
HTP_TRACE_EVT_HMX_COMP = 40,
};
+1 -12
View File
@@ -134,16 +134,7 @@ static inline HVX_Vector hvx_vec_f32_to_f16_shuff(HVX_Vector v0, HVX_Vector v1)
}
static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) {
HVX_Vector v = Q6_Vh_vdeal_Vh(hvx_vec_f32_to_f16_shuff(v0, v1));
#if __HVX_ARCH__ < 79
// replace NaNs with -INF, older arches produce NaNs for (-INF + 0.0)
const HVX_Vector neg_inf = hvx_vec_splat_f16(-INFINITY);
HVX_VectorPred nan = hvx_vec_is_nan_f16(v);
v = Q6_V_vmux_QVV(nan, neg_inf, v);
#endif
return v;
return Q6_Vh_vdeal_Vh(hvx_vec_f32_to_f16_shuff(v0, v1));
}
#if __HVX_ARCH__ >= 79
@@ -170,8 +161,6 @@ static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) {
}
#endif
static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) {
// This looks complicated.
// Ideally should just be Q6_Vh_equals_Vhf(vin)
+39
View File
@@ -16,6 +16,7 @@
#define EXP_COEFF_0 (0x3F000000) // 0.5 = 1/(2!)
#define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805
#define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408
#define EXP_LOG2E_F 1.44269504f
#define EXP_ONE (0x3f800000) // 1.0
#define EXP_RANGE_R (0x42B17218) // ln(FLT_MAX) approx = 88.7228
#define EXP_RANGE_L (0xC2B00000) // -88.0 (approx log(FLT_MIN))
@@ -213,4 +214,42 @@ static inline void hvx_exp_f32(uint8_t * restrict dst, const uint8_t * restrict
}
}
static inline HVX_Vector hvx_vec_exp2_f16(HVX_Vector x_v) {
const HVX_Vector zero_v = Q6_V_vzero();
const HVX_Vector half_hf_v = Q6_Vh_vsplat_R(0x3800); // fp16 0.5
// Clamp input to prevent integer underflow in FP16-to-INT16 conversion
const HVX_Vector v_clamp_min = hvx_vec_splat_f16(-24.0f);
x_v = Q6_Vhf_vmax_VhfVhf(v_clamp_min, x_v);
// k = round_toward_neg_inf(x); f = (float)k; frac = x - f
HVX_Vector x_minus_half = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vsub_VhfVhf(x_v, half_hf_v));
HVX_Vector k_v = Q6_Vh_equals_Vhf(x_minus_half); // truncate to int16
HVX_Vector f_v = Q6_Vhf_equals_Vh(k_v); // back to fp16
HVX_Vector x_qf16 = Q6_Vqf16_vsub_VhfVhf(x_v, f_v); // fractional part in qf16
// Horner: y = ((((E5*x + E4)*x + E3)*x + E2)*x + E1)*x + E0
HVX_Vector y = Q6_Vqf16_vmpy_Vqf16Vqf16(Q6_Vh_vsplat_R(0x5082), x_qf16); // E5*x
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x157d)); // + E4
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16);
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x20ed)); // + E3
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16);
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x2b1b)); // + E2
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16);
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x33b0)); // + E1
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16);
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x398c)); // + E0
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); // y = y * x
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x3c00)); // + 1.0
// Combine polynomial (mantissa) with integer part (exponent): result = y * 2^k
y = Q6_Vhf_equals_Vqf16(y);
HVX_Vector y_exp = Q6_Vuh_vlsr_VuhR(Q6_Vh_vasl_VhR(y, 1), 11);
y_exp = Q6_Vh_vadd_VhVh(k_v, y_exp);
HVX_VectorPred q_underflow = Q6_Q_vcmp_gt_VhVh(zero_v, y_exp);
y = Q6_Vh_vaslacc_VhVhR(y, k_v, 10);
return Q6_V_vmux_QVV(q_underflow, zero_v, y);
}
#endif /* HVX_EXP_H */
+232
View File
@@ -0,0 +1,232 @@
#ifndef HVX_FA_KERNELS_H
#define HVX_FA_KERNELS_H
#include <assert.h>
#include <math.h>
#include "hvx-utils.h"
// Little inner kernels for HVX
#if __HVX_ARCH__ < 79
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
#define HVX_OP_SUB_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b))
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
#else
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
#define HVX_OP_SUB_F32(a, b) Q6_Vsf_vsub_VsfVsf(a, b)
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#endif
// This is a bit of a hack because the compiler is struggling to properly inline
// the default hvx_vec_f32_to_f16 with output into the local array.
static __attribute__((unused)) __attribute__((noinline)) void hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1)
{
*(HVX_Vector *) ptr = hvx_vec_f32_to_f16(v0, v1);
}
// Dot product of two F16 vectors, accumulating to float
static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_VectorPair rsum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, vx[i], vy[i]);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf);
}
HVX_Vector rsum = HVX_OP_ADD_F32(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p));
rsum = HVX_OP_MUL_F32(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
hvx_vec_store_u(r, 4, rsum);
}
static inline HVX_Vector hvx_dot_f16_f16_aa_rx4(const void * restrict y,
const uint8_t * restrict x,
const size_t stride_x,
const size_t nvec,
const size_t nloe) {
const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x; // fp16
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) (x + stride_x); // fp16
const HVX_Vector * restrict vx2 = (const HVX_Vector * restrict) (x + stride_x * 2); // fp16
const HVX_Vector * restrict vx3 = (const HVX_Vector * restrict) (x + stride_x * 3); // fp16
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
HVX_VectorPair rsum0_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
HVX_VectorPair rsum1_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
HVX_VectorPair rsum2_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
HVX_VectorPair rsum3_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
uint32_t i = 0;
for (i = 0; i < nvec; i++) {
HVX_Vector y_hf = vy[i];
HVX_Vector x0_hf = vx0[i];
HVX_Vector x1_hf = vx1[i];
HVX_Vector x2_hf = vx2[i];
HVX_Vector x3_hf = vx3[i];
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf);
rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf);
}
if (nloe) {
// Load x (fp16) and zero-out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
HVX_Vector x2_hf = Q6_V_vand_QV(bmask, vx2[i]);
HVX_Vector x3_hf = Q6_V_vand_QV(bmask, vx3[i]);
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf);
rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf);
}
HVX_Vector rsum0 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p));
HVX_Vector rsum1 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p));
HVX_Vector rsum2 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p));
HVX_Vector rsum3 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p));
HVX_Vector_x4 rsum0123 = { .v = { rsum0, rsum1, rsum2, rsum3 } };
return hvx_vec_reduce_sum_f32x4(rsum0123);
}
static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y,
const uint8_t * restrict x,
const size_t stride_x,
const size_t n,
float s) {
const size_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
const size_t nloe = n % VLEN_FP16; // leftover elements
HVX_Vector sums = Q6_V_vzero();
const size_t stride_x_4 = stride_x * 4;
for (uint32_t j = 0; j < VLEN_FP32; j += 4) {
HVX_Vector sums_x4 = hvx_dot_f16_f16_aa_rx4(y, x, stride_x, nvec, nloe);
HVX_VectorPred pred = Q6_Q_vsetq_R(j * SIZEOF_FP32);
sums = Q6_V_vmux_QVV(pred, sums, sums_x4);
x += stride_x_4;
}
return HVX_OP_MUL_F32(hvx_vec_splat_f32(s), sums);
}
// MAD: y (F32) += x (F16) * s (F16)
static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, const __fp16 * restrict s, uint32_t n) {
const HVX_Vector * restrict vx0 = (const HVX_Vector *) x;
HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y;
HVX_Vector * restrict vy = (HVX_Vector *) y;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_Vector S0 = hvx_vec_splat_f16(*s);
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; ++i) {
vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0);
}
if (nloe) {
HVX_VectorPair xy_p = vy_p[i];
xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0);
HVX_Vector xy = Q6_V_lo_W(xy_p);
i = 2 * i; // index for vy
if (nloe >= VLEN_FP32) {
vy[i] = xy;
nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p);
}
if (nloe) {
hvx_vec_store_a(&vy[i], nloe * 4, xy);
}
}
}
// MAD: y (F32) += x0 (F16) * s0 (F16) + x1 (F16) * s1 (F16)
static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y, const void * restrict x0, const void * restrict x1,
const __fp16 * restrict s0, const __fp16 * restrict s1, uint32_t n) {
const HVX_Vector * restrict vx0 = (const HVX_Vector *) x0;
const HVX_Vector * restrict vx1 = (const HVX_Vector *) x1;
HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y;
HVX_Vector * restrict vy = (HVX_Vector *) y;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_Vector S0 = hvx_vec_splat_f16(*s0);
HVX_Vector S1 = hvx_vec_splat_f16(*s1);
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; ++i) {
vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0);
vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx1[i]), S1);
}
if (nloe) {
HVX_VectorPair xy_p = vy_p[i];
xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0);
xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx1[i]), S1);
HVX_Vector xy = Q6_V_lo_W(xy_p);
i = 2 * i; // index for vy
if (nloe >= VLEN_FP32) {
vy[i] = xy;
nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p);
}
if (nloe) {
hvx_vec_store_a(&vy[i], nloe * 4, xy);
}
}
}
static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t n, HVX_Vector vs) {
assert((size_t) dst % 128 == 0);
assert((size_t) src % 128 == 0);
const HVX_Vector * restrict vsrc = (const HVX_Vector * restrict) src;
HVX_Vector * restrict vdst = (HVX_Vector * restrict) dst;
const uint32_t nvec = n / VLEN_FP32;
const uint32_t nloe = n % VLEN_FP32;
uint32_t i = 0;
#pragma unroll(4)
for (; i < nvec; ++i) {
vdst[i] = HVX_OP_MUL_F32(vsrc[i], vs);
}
if (nloe) {
hvx_vec_store_a(&vdst[i], nloe * sizeof(float), HVX_OP_MUL_F32(vsrc[i], vs));
}
}
#endif /* HVX_FA_KERNELS_H */
+512 -25
View File
@@ -256,7 +256,7 @@ static inline void quantize_f16_f16_flat_kernel(
// Dot kernels that consume flat (non-tiled) activations
static void flat_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void flat_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -312,10 +312,14 @@ static void flat_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const v
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void flat_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void flat_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -397,11 +401,19 @@ static void flat_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void flat_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void flat_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -464,10 +476,14 @@ static void flat_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const v
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void flat_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void flat_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -561,11 +577,19 @@ static void flat_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void flat_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void flat_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -620,10 +644,14 @@ static void flat_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const v
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void flat_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void flat_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -704,11 +732,19 @@ static void flat_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void flat_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void flat_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -765,10 +801,14 @@ static void flat_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void flat_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void flat_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -851,11 +891,19 @@ static void flat_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void flat_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void flat_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -921,10 +969,14 @@ static void flat_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const
v_sum_float = hvx_vec_mul_f32_f32(v_sum_float, hvx_vec_splat_f32(0.5f));
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void flat_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void flat_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -1019,6 +1071,441 @@ static void flat_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c0 = hvx_vec_mul_f32_f32(v_sum_float_c0, hvx_vec_splat_f32(0.5f));
v_sum_float_c1 = hvx_vec_mul_f32_f32(v_sum_float_c1, hvx_vec_splat_f32(0.5f));
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
#if __HVX_ARCH__ < 79
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
#else
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#endif
static inline void vec_dot_f32_f32_aa_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const HVX_Vector * restrict x = (const HVX_Vector *) vx;
const HVX_Vector * restrict y = (const HVX_Vector *) vy;
uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors
uint32_t nloe = n % VLEN_FP32; // leftover elements
HVX_Vector rsum = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
HVX_Vector prod = HVX_OP_MUL_F32(x[i], y[i]);
rsum = HVX_OP_ADD_F32(rsum, prod);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
HVX_Vector x_sf = Q6_V_vand_QV(bmask, x[i]);
HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]);
HVX_Vector prod = HVX_OP_MUL_F32(x_sf, y_sf);
rsum = HVX_OP_ADD_F32(rsum, prod);
}
*s = hvx_vec_get_f32(hvx_vec_reduce_sum_f32(rsum));
}
static inline void vec_dot_f32_f32_aa_2x1(const uint32_t n, float * restrict s0,
const void * restrict vx0, const void * restrict vx1,
const void * restrict vy0) {
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
const HVX_Vector * restrict y = (const HVX_Vector *) vy0;
uint32_t nvec = n / VLEN_FP32;
uint32_t nloe = n % VLEN_FP32;
HVX_Vector rsum0 = Q6_V_vzero();
HVX_Vector rsum1 = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector y_sf = y[i];
HVX_Vector prod0 = HVX_OP_MUL_F32(x0[i], y_sf);
HVX_Vector prod1 = HVX_OP_MUL_F32(x1[i], y_sf);
rsum0 = HVX_OP_ADD_F32(rsum0, prod0);
rsum1 = HVX_OP_ADD_F32(rsum1, prod1);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]);
HVX_Vector x0_sf = Q6_V_vand_QV(bmask, x0[i]);
HVX_Vector x1_sf = Q6_V_vand_QV(bmask, x1[i]);
HVX_Vector prod0 = HVX_OP_MUL_F32(x0_sf, y_sf);
HVX_Vector prod1 = HVX_OP_MUL_F32(x1_sf, y_sf);
rsum0 = HVX_OP_ADD_F32(rsum0, prod0);
rsum1 = HVX_OP_ADD_F32(rsum1, prod1);
}
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1);
hvx_vec_store_u(s0, 8, rsum);
}
static inline void vec_dot_f32_f32_aa_2x2(const uint32_t n, float * restrict s0, float * restrict s1,
const void * restrict vx0, const void * restrict vx1,
const void * restrict vy0, const void * restrict vy1) {
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0;
const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1;
uint32_t nvec = n / VLEN_FP32;
uint32_t nloe = n % VLEN_FP32;
HVX_Vector r0_c0_sum = Q6_V_vzero();
HVX_Vector r0_c1_sum = Q6_V_vzero();
HVX_Vector r1_c0_sum = Q6_V_vzero();
HVX_Vector r1_c1_sum = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector r0_sf = x0[i];
HVX_Vector r1_sf = x1[i];
HVX_Vector c0_sf = y0[i];
HVX_Vector c1_sf = y1[i];
r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf));
r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf));
r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf));
r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf));
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
HVX_Vector r0_sf = Q6_V_vand_QV(bmask, x0[i]);
HVX_Vector r1_sf = Q6_V_vand_QV(bmask, x1[i]);
HVX_Vector c0_sf = Q6_V_vand_QV(bmask, y0[i]);
HVX_Vector c1_sf = Q6_V_vand_QV(bmask, y1[i]);
r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf));
r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf));
r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf));
r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf));
}
// Reduce and store results
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
hvx_vec_store_u(s0, 8, r0_r1_c0_sum);
hvx_vec_store_u(s1, 8, r0_r1_c1_sum);
}
static inline void vec_dot_f32_f32_uu_1x1(const uint32_t n, float * restrict s, const void * restrict x, const void * restrict y) {
const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors
uint32_t nloe = n % VLEN_FP32; // leftover elements
HVX_Vector rsum = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector x_sf = vx[i];
HVX_Vector y_sf = vy[i];
rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf));
}
if (nloe) {
HVX_Vector x_sf = vx[i];
HVX_Vector y_sf = vy[i];
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
x_sf = Q6_V_vand_QV(bmask, x_sf);
y_sf = Q6_V_vand_QV(bmask, y_sf);
rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf));
}
rsum = hvx_vec_reduce_sum_f32(rsum);
hvx_vec_store_u(&s[0], 4, rsum);
}
#undef HVX_OP_ADD_F32
#undef HVX_OP_MUL_F32
static inline void vec_dot_f16_f16_aa_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const HVX_Vector * restrict x = (const HVX_Vector *) vx;
const HVX_Vector * restrict y = (const HVX_Vector *) vy;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_VectorPair rsum_p = Q6_W_vzero();
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x[i], y[i]);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf);
}
HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p)));
hvx_vec_store_u(s, 4, hvx_vec_reduce_sum_f32(rsum));
}
static inline void vec_dot_f16_f16_aa_2x1(const uint32_t n, float * restrict s0,
const void * restrict vx0, const void * restrict vx1,
const void * restrict vy0) {
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
const HVX_Vector * restrict y = (const HVX_Vector *) vy0;
uint32_t nvec = n / VLEN_FP16;
uint32_t nloe = n % VLEN_FP16;
HVX_VectorPair rsum0_p = Q6_W_vzero();
HVX_VectorPair rsum1_p = Q6_W_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector y_hf = y[i];
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0[i], y_hf);
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1[i], y_hf);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, x0[i]);
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, x1[i]);
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
}
HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p)));
HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p)));
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1);
hvx_vec_store_u(s0, 8, rsum);
}
static inline void vec_dot_f16_f16_aa_2x2(const uint32_t n, float * restrict s0, float * restrict s1,
const void * restrict vx0, const void * restrict vx1,
const void * restrict vy0, const void * restrict vy1) {
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0;
const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1;
uint32_t nvec = n / VLEN_FP16;
uint32_t nloe = n % VLEN_FP16;
// Row sums (sf) - 4 accumulators for 2x2 tile
HVX_VectorPair r0_c0_sum_p = Q6_W_vzero();
HVX_VectorPair r0_c1_sum_p = Q6_W_vzero();
HVX_VectorPair r1_c0_sum_p = Q6_W_vzero();
HVX_VectorPair r1_c1_sum_p = Q6_W_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector r0_hf = x0[i];
HVX_Vector r1_hf = x1[i];
HVX_Vector c0_hf = y0[i];
HVX_Vector c1_hf = y1[i];
// Compute 4 dot products: r0xc0, r0xc1, r1xc0, r1xc1
r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf);
r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf);
r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf);
r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector r0_hf = Q6_V_vand_QV(bmask, x0[i]);
HVX_Vector r1_hf = Q6_V_vand_QV(bmask, x1[i]);
HVX_Vector c0_hf = Q6_V_vand_QV(bmask, y0[i]);
HVX_Vector c1_hf = Q6_V_vand_QV(bmask, y1[i]);
r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf);
r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf);
r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf);
r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf);
}
HVX_Vector r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c0_sum_p), Q6_V_hi_W(r0_c0_sum_p)));
HVX_Vector r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c1_sum_p), Q6_V_hi_W(r0_c1_sum_p)));
HVX_Vector r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c0_sum_p), Q6_V_hi_W(r1_c0_sum_p)));
HVX_Vector r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c1_sum_p), Q6_V_hi_W(r1_c1_sum_p)));
// Reduce and store results
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0
hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
}
static inline void vec_dot_f16_f16_uu_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const HVX_UVector * restrict x = (const HVX_UVector *) vx;
const HVX_UVector * restrict y = (const HVX_UVector *) vy;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_Vector rsum = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
hvx_vec_store_u(&s[0], 4, rsum);
}
static inline void vec_dot_f16_f32_uu_1x1(const uint32_t n, float * restrict s, const void * restrict x, const void * restrict y) {
const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vzero();
HVX_Vector rsum = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
// Load y (fp32) and convert into fp16
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
// Load x (fp16)
HVX_Vector x_hf = vx[i];
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
if (nloe) {
// Load y (fp32) and convert into fp16
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
// Load x (fp16)
HVX_Vector x_hf = vx[i];
// Zero-out unused elements
// Note that we need to clear both x and y because they may contain NANs
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
x_hf = Q6_V_vand_QV(bmask, x_hf);
y_hf = Q6_V_vand_QV(bmask, y_hf);
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
// Convert into fp32 and reduce
rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
hvx_vec_store_u(&s[0], 4, rsum);
}
static inline void hvx_tensor_add_f32_grid(
const struct htp_tensor * restrict dst,
const struct htp_tensor * restrict src2,
uint32_t start_row,
uint32_t end_row,
uint32_t start_col,
uint32_t end_col,
const struct fastdiv_values * div_ne11_12,
const struct fastdiv_values * div_ne11
) {
if (start_row >= end_row || start_col >= end_col) return;
const uint32_t nb1 = dst->nb[1]; // row stride in bytes
const uint32_t ne11 = dst->ne[1];
const uint32_t ne12 = dst->ne[2];
const uint32_t ne11_12 = ne11 * ne12;
const bool is_broadcast1 = (src2->ne[1] == 1);
const bool is_broadcast2 = (src2->ne[2] == 1);
const bool is_broadcast3 = (src2->ne[3] == 1);
for (uint32_t r = start_row; r < end_row; r++) {
float * dst_row = (float *) ((uint8_t *) dst->data + r * nb1);
uint32_t i13 = fastdiv(r, div_ne11_12);
uint32_t i12 = fastdiv(r - i13 * ne11_12, div_ne11);
uint32_t i11 = r - i13 * ne11_12 - i12 * ne11;
uint32_t i23 = is_broadcast3 ? 0 : i13;
uint32_t i22 = is_broadcast2 ? 0 : i12;
uint32_t i21 = is_broadcast1 ? 0 : i11;
const float * src2_row = (const float *) ((const uint8_t *) src2->data +
i21 * src2->nb[1] + i22 * src2->nb[2] + i23 * src2->nb[3]);
float * dst_ptr = &dst_row[start_col];
const float * src2_ptr = &src2_row[start_col];
int remaining = end_col - start_col;
while (remaining >= 32) {
HVX_Vector v_out = hvx_vmemu(dst_ptr);
HVX_Vector v_z = hvx_vmemu(src2_ptr);
hvx_vmemu(dst_ptr) = hvx_vec_add_f32_f32(v_out, v_z);
dst_ptr += 32;
src2_ptr += 32;
remaining -= 32;
}
if (remaining > 0) {
HVX_Vector v_out = hvx_vmemu(dst_ptr);
HVX_Vector v_z = hvx_vmemu(src2_ptr);
hvx_vec_store_u(dst_ptr, remaining * sizeof(float), hvx_vec_add_f32_f32(v_out, v_z));
}
}
}
@@ -378,7 +378,7 @@ static inline HVX_VectorPair accum_q8_0_32x2(
return Q6_W_vcombine_VV(v_sum1, v_sum0);
}
static void tiled_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void tiled_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -401,10 +401,14 @@ static void tiled_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void tiled_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void tiled_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -484,11 +488,19 @@ static void tiled_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void tiled_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void tiled_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -519,10 +531,14 @@ static void tiled_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void tiled_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void tiled_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -637,11 +653,19 @@ static void tiled_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void tiled_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void tiled_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -663,10 +687,14 @@ static void tiled_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void tiled_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void tiled_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -745,11 +773,19 @@ static void tiled_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void tiled_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void tiled_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -773,10 +809,14 @@ static void tiled_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void tiled_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void tiled_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -857,11 +897,19 @@ static void tiled_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, floa
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void tiled_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void tiled_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -896,10 +944,14 @@ static void tiled_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const
v_sum_float = hvx_vec_mul_f32_f32(v_sum_float, hvx_vec_splat_f32(0.5f));
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void tiled_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void tiled_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -1013,8 +1065,16 @@ static void tiled_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, floa
v_sum_float_c0 = hvx_vec_mul_f32_f32(v_sum_float_c0, hvx_vec_splat_f32(0.5f));
v_sum_float_c1 = hvx_vec_mul_f32_f32(v_sum_float_c1, hvx_vec_splat_f32(0.5f));
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static inline void quantize_f32_q8_0_tiled_kernel(
+39
View File
@@ -3,6 +3,7 @@
#include "hvx-base.h"
#include "hvx-inverse.h"
#include "hvx-exp.h"
#define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022
#define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777
@@ -139,4 +140,42 @@ static inline void hvx_tanh_f32_aa(uint8_t * restrict dst, const uint8_t * restr
hvx_tanh_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
static inline HVX_Vector hvx_vec_fast_sigmoid_f16(HVX_Vector x_v) {
const HVX_Vector v_one = hvx_vec_splat_f16(1.0f);
const HVX_Vector v_neg_log2e = hvx_vec_splat_f16(-EXP_LOG2E_F);
const HVX_Vector em_mask = Q6_Vh_vsplat_R(0x7FFF);
// Compute absolute value of x_v
HVX_Vector abs_x = Q6_V_vand_VV(x_v, em_mask);
// Compute u = -abs_x * log2(e) <= 0.
HVX_Vector u = hvx_vec_mul_f16_f16(abs_x, v_neg_log2e);
// Clamp input to prevent underflow in exp2
const HVX_Vector v_clamp_min = hvx_vec_splat_f16(-24.0f);
u = Q6_Vhf_vmax_VhfVhf(v_clamp_min, u);
HVX_Vector exp_val = hvx_vec_exp2_f16(u);
HVX_Vector denom = hvx_vec_add_f16_f16(v_one, exp_val);
HVX_Vector sig_abs = hvx_vec_inverse_f16(denom);
// check if x_v < 0 (using integer comparison on absolute value)
HVX_VectorPred is_neg = Q6_Q_vcmp_gt_VhVh(abs_x, x_v);
// If x_v < 0, return 1.0f - sig_abs
HVX_Vector sig_neg = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vsub_VhfVhf(v_one, sig_abs));
return Q6_V_vmux_QVV(is_neg, sig_neg, sig_abs);
}
static inline HVX_Vector hvx_vec_tanh_f16(HVX_Vector x) {
// tanh(x) = 2 * sigmoid(2x) - 1
const HVX_Vector v_two = hvx_vec_splat_f16(2.0f);
HVX_Vector x2 = hvx_vec_mul_f16_f16(x, v_two);
HVX_Vector sig2x = hvx_vec_fast_sigmoid_f16(x2);
const HVX_Vector v_neg_one = hvx_vec_splat_f16(-1.0f);
return hvx_vec_add_f16_f16(hvx_vec_mul_f16_f16(sig2x, v_two), v_neg_one);
}
#endif /* HVX_SIGMOID_H */
+1
View File
@@ -575,6 +575,7 @@ static inline void profile_stop(uint32_t mode, struct profile_data * d) {
static int execute_op(struct htp_ops_context * octx) {
switch (octx->op) {
case HTP_OP_MUL_MAT:
case HTP_OP_MUL_MAT_ADD:
return op_matmul(octx);
case HTP_OP_MUL_MAT_ID:
File diff suppressed because it is too large Load Diff
+19 -32
View File
@@ -392,56 +392,49 @@ static inline size_t htp_mm_hvx_get_vtcm_sizes(
case HTP_MM_KERNEL_HVX_QUANT_ROW: {
size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10);
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256);
// src0 spad is also used in dynamic quantizer to store padded src1 rows
size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, QK_Q8_0_TILED * sizeof(float));
if (vtcm_src0_size < src1_row_size_padded) {
vtcm_src0_size = src1_row_size_padded;
}
vtcm_src0_size = vtcm_src0_size * n_threads;
vtcm_dst_size = vtcm_dst_size * n_threads;
if (is_repack) {
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
uint32_t n_k_tiles = ne10 / 32;
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
vtcm_src0_size = repacked_vtcm_size * n_threads;
}
size_t quant_scratch_size_per_thread = htp_mm_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float));
size_t dst_size_per_thread = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
if (dst_size_per_thread < quant_scratch_size_per_thread) {
dst_size_per_thread = quant_scratch_size_per_thread;
}
vtcm_dst_size = dst_size_per_thread * n_threads;
break;
}
case HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT: {
size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(ne10) : htp_mm_q8_0_flat_row_size(ne10);
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256);
size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, 256);
if (vtcm_src0_size < src1_row_size_padded) {
vtcm_src0_size = src1_row_size_padded;
}
vtcm_src0_size = vtcm_src0_size * n_threads;
vtcm_dst_size = vtcm_dst_size * n_threads;
if (is_repack) {
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
uint32_t n_k_tiles = ne10 / 32;
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
vtcm_src0_size = repacked_vtcm_size * n_threads;
}
size_t quant_scratch_size_per_thread = htp_mm_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float));
size_t dst_size_per_thread = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
if (dst_size_per_thread < quant_scratch_size_per_thread) {
dst_size_per_thread = quant_scratch_size_per_thread;
}
vtcm_dst_size = dst_size_per_thread * n_threads;
break;
}
default:
@@ -463,7 +456,8 @@ static inline size_t htp_mm_hvx_id_get_vtcm_sizes(
size_t src0_row_size, // nb01
uint32_t n_prefetch,
size_t * vtcm_src0_size_out,
size_t * vtcm_src1_size_out
size_t * vtcm_src1_size_out,
size_t * vtcm_dst_size_out
) {
const bool is_repack = (wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 ||
wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL ||
@@ -476,29 +470,22 @@ static inline size_t htp_mm_hvx_id_get_vtcm_sizes(
size_t src0_sz_per_thread = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
size_t src1_sz = htp_mm_round_up(src1_row_size * src1_nrows, 256);
// src0 spad also holds temporary transposed src1 columns during dynamic quantization.
const size_t src1_row_size_padded = htp_mm_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float));
if (src0_sz_per_thread < src1_row_size_padded) {
src0_sz_per_thread = src1_row_size_padded;
}
if (is_repack) {
const uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
const uint32_t n_k_tiles = ne10 / 32;
const uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
src0_sz_per_thread = repacked_vtcm_size;
}
const size_t vtcm_src0_size = src0_sz_per_thread * n_threads;
const size_t vtcm_dst_size = htp_mm_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float)) * n_threads;
*vtcm_src0_size_out = vtcm_src0_size;
*vtcm_src1_size_out = src1_sz;
*vtcm_dst_size_out = vtcm_dst_size;
return vtcm_src0_size + src1_sz;
return vtcm_src0_size + src1_sz + vtcm_dst_size;
}
#ifdef __cplusplus
+5
View File
@@ -31,6 +31,11 @@ if (GGML_OPENCL_EMBED_KERNELS)
target_include_directories(${TARGET_NAME} PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/autogenerated")
endif ()
if (GGML_OPENCL_USE_ADRENO_BIN_KERNELS)
message(STATUS "OpenCL will use precompiled binary kernels for Adreno (improved performance on some platforms)")
add_compile_definitions(GGML_OPENCL_USE_ADRENO_BIN_KERNELS)
endif ()
function(ggml_opencl_add_kernel KNAME)
set(KERN_HDR ${CMAKE_CURRENT_BINARY_DIR}/autogenerated/${KNAME}.cl.h)
set(KERN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/kernels/${KNAME}.cl)
+313 -8
View File
@@ -13,6 +13,22 @@
#include "ggml-backend-impl.h"
#include "ggml.h"
#ifdef GGML_OPENCL_USE_ADRENO_BIN_KERNELS
#include "libdl.h"
#ifdef _WIN32
#define KERNEL_LIB_NAME "adreno-opencl-kernels.dll"
#else
#define KERNEL_LIB_NAME "libadreno-opencl-kernels.so"
#endif // _WIN32
#endif // GGML_OPENCL_USE_ADRENO_BIN_KERNELS
typedef const void * (*get_adreno_bin_kernel_func_t)(
const char * name,
const char * gpu_name,
const char * compiler_ver,
size_t * out_size
);
#include <CL/cl.h>
#include <inttypes.h>
@@ -476,6 +492,8 @@ struct ggml_backend_opencl_context {
bool adreno_has_large_buffer;
bool adreno_use_large_buffer;
bool adreno_use_bin_kernels;
get_adreno_bin_kernel_func_t get_adreno_bin_kernel_func = nullptr;
ggml_cl_compiler_version adreno_cl_compiler_version;
std::string kernel_compile_opts; // cached for lazy-compiled kernels.
@@ -718,15 +736,15 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_gated_delta_net_f32[4][2][2] = {};
cl_kernel kernel_timestep_embedding;
cl_kernel kernel_gemv_moe_q4_0_f32_ns, kernel_gemm_moe_q4_0_f32_ns;
cl_kernel kernel_gemv_moe_q4_1_f32_ns, kernel_gemm_moe_q4_1_f32_ns;
cl_kernel kernel_gemv_moe_q4_0_f32_ns, kernel_gemm_moe_q4_0_f32_ns, kernel_gemm_moe_q4_0_f32_ns_bin;
cl_kernel kernel_gemv_moe_q4_1_f32_ns, kernel_gemm_moe_q4_1_f32_ns, kernel_gemm_moe_q4_1_f32_ns_bin;
cl_kernel kernel_gemv_moe_q5_0_f32_ns, kernel_gemm_moe_q5_0_f32_ns;
cl_kernel kernel_gemv_moe_q5_1_f32_ns, kernel_gemm_moe_q5_1_f32_ns;
cl_kernel kernel_gemv_moe_q4_k_f32_ns, kernel_gemm_moe_q4_k_f32_ns;
cl_kernel kernel_gemv_moe_q4_k_f32_ns, kernel_gemm_moe_q4_k_f32_ns, kernel_gemm_moe_q4_k_f32_ns_bin;
cl_kernel kernel_gemv_moe_q5_k_f32_ns, kernel_gemm_moe_q5_k_f32_ns;
cl_kernel kernel_gemv_moe_q6_k_f32_ns, kernel_gemm_moe_q6_k_f32_ns;
cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32;
cl_kernel kernel_gemv_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns;
cl_kernel kernel_gemv_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns_bin;
cl_kernel kernel_moe_reorder_b;
cl_kernel kernel_moe_histogram, kernel_moe_scan, kernel_moe_fill, kernel_moe_scatter;
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
@@ -870,6 +888,20 @@ struct ggml_backend_opencl_context {
#endif
}
const void * get_adreno_bin_kernel(const std::string &kernel_name, size_t *bin_size) const {
if (!get_adreno_bin_kernel_func) {
return nullptr;
}
size_t sz;
const void * kernel_bin = get_adreno_bin_kernel_func(
kernel_name.c_str(), device_name.c_str(), driver_version.c_str(), &sz);
if (bin_size) {
*bin_size = sz;
}
return kernel_bin;
}
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
// Transpose kernels
cl_program program_transpose;
@@ -891,7 +923,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_gemv_noshuffle_q4_0_f32_32000_1_4096;
cl_kernel kernel_gemv_noshuffle_q4_1_f32;
cl_kernel kernel_gemm_noshuffle_q4_1_f32;
cl_kernel kernel_gemm_noshuffle_q8_0_f32;
cl_kernel kernel_gemm_noshuffle_q8_0_f32, kernel_gemm_noshuffle_q8_0_f32_bin;
cl_kernel kernel_gemv_noshuffle_q8_0_f32;
cl_kernel kernel_gemm_noshuffle_q1_0_f32;
cl_kernel kernel_gemv_noshuffle_q1_0_f32;
@@ -988,6 +1020,32 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co
return build_program_from_source_ex(ctx, dev, program_buffer, compile_opts, /*fatal=*/true);
}
static cl_program build_program_from_binary(cl_context ctx, cl_device_id dev, const char* program_buffer, const std::string &compile_opts, size_t bin_size = 0) {
cl_program p;
char *program_log;
size_t log_size;
int err;
p = clCreateProgramWithBinary(ctx, 1, &dev, &bin_size, (const unsigned char**)&program_buffer, NULL, &err);
if(err < 0) {
GGML_LOG_ERROR("OpenCL error creating program from binary");
exit(1);
}
err = clBuildProgram(p, 0, NULL, compile_opts.c_str(), NULL, NULL);
if(err < 0) {
clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);
program_log = (char*) malloc(log_size + 1);
program_log[log_size] = '\0';
clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, log_size + 1, program_log, NULL);
GGML_LOG_ERROR("ggml_opencl: kernel compile error:\n\n%s\n", program_log);
free(program_log);
exit(1);
}
return p;
}
static void load_cl_kernels_argsort(ggml_backend_opencl_context *backend_ctx) {
// compiler options for general kernels
auto opencl_c_std =
@@ -1014,6 +1072,17 @@ static void load_cl_kernels_argsort(ggml_backend_opencl_context *backend_ctx) {
}
}
static bool use_adreno_bin_kernels(ggml_backend_opencl_context * backend_ctx) {
#ifndef GGML_OPENCL_USE_ADRENO_BIN_KERNELS
return false;
#else
if (backend_ctx->gpu_family != GPU_FAMILY::ADRENO) {
return false;
}
return backend_ctx->adreno_use_bin_kernels;
#endif // GGML_OPENCL_USE_ADRENO_BIN_KERNELS
}
static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) {
if (backend_ctx->kernels_loaded) {
return;
@@ -3323,6 +3392,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) {
GGML_LOG_CONT(".");
}
// gemm_noshuffle_q8_0_f32_bin
{
size_t bin_size = 0;
backend_ctx->kernel_gemm_noshuffle_q8_0_f32_bin = nullptr;
if (use_adreno_bin_kernels(backend_ctx)) {
const char * kernel_bin = (const char *)backend_ctx->get_adreno_bin_kernel("gemm_noshuffle_q8_0_f32_ila", &bin_size);
if (kernel_bin && bin_size > 0) {
cl_program prog =
build_program_from_binary(backend_ctx->context, backend_ctx->device, kernel_bin, compile_opts, bin_size);
CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q8_0_f32_bin = clCreateKernel(prog, "kernel_gemm_noshuffle_q8_0_f32_ila", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
}
}
// gemv_noshuffle_general_q8_0_f32
{
std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std +
@@ -3424,6 +3511,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) {
GGML_LOG_CONT(".");
}
// gemm_moe_q4_1_f32_ns_bin
{
size_t bin_size = 0;
backend_ctx->kernel_gemm_moe_q4_1_f32_ns_bin = nullptr;
if (use_adreno_bin_kernels(backend_ctx)) {
const char * kernel_bin = (const char *)backend_ctx->get_adreno_bin_kernel("gemm_moe_q4_1_f32_ns_ila", &bin_size);
if (kernel_bin && bin_size > 0) {
cl_program prog =
build_program_from_binary(backend_ctx->context, backend_ctx->device, kernel_bin, CL_moe_compile_opts, bin_size);
CL_CHECK((backend_ctx->kernel_gemm_moe_q4_1_f32_ns_bin = clCreateKernel(prog, "kernel_gemm_moe_q4_1_f32_ns_ila", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
}
}
// gemv_moe_mxfp4_f32
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -3490,6 +3595,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) {
GGML_LOG_CONT(".");
}
// gemm_moe_q4_0_f32_ns_bin
{
size_t bin_size = 0;
backend_ctx->kernel_gemm_moe_q4_0_f32_ns_bin = nullptr;
if (use_adreno_bin_kernels(backend_ctx)) {
const char * kernel_bin = (const char *)backend_ctx->get_adreno_bin_kernel("gemm_moe_q4_0_f32_ns_ila", &bin_size);
if (kernel_bin && bin_size > 0) {
cl_program prog =
build_program_from_binary(backend_ctx->context, backend_ctx->device, kernel_bin, CL_moe_compile_opts, bin_size);
CL_CHECK((backend_ctx->kernel_gemm_moe_q4_0_f32_ns_bin = clCreateKernel(prog, "kernel_gemm_moe_q4_0_f32_ns_ila", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
}
}
// gemv_moe_q5_0_f32_ns
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -3592,6 +3715,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) {
GGML_LOG_CONT(".");
}
// gemm_moe_q4_k_f32_ns_bin
{
size_t bin_size = 0;
backend_ctx->kernel_gemm_moe_q4_k_f32_ns_bin = nullptr;
if (use_adreno_bin_kernels(backend_ctx)) {
const char * kernel_bin = (const char *)backend_ctx->get_adreno_bin_kernel("gemm_moe_q4_k_f32_ns_ila", &bin_size);
if (kernel_bin && bin_size > 0) {
cl_program prog =
build_program_from_binary(backend_ctx->context, backend_ctx->device, kernel_bin, CL_moe_compile_opts, bin_size);
CL_CHECK((backend_ctx->kernel_gemm_moe_q4_k_f32_ns_bin = clCreateKernel(prog, "kernel_gemm_moe_q4_k_f32_ns_ila", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
}
}
// gemv_moe_q5_k_f32_ns
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -3689,9 +3830,27 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) {
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_mxfp4_f32_ns", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_mxfp4_f32_ns", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// gemm_moe_mxfp4_f32_ns_bin
{
size_t bin_size = 0;
backend_ctx->kernel_gemm_moe_mxfp4_f32_ns_bin = nullptr;
if (use_adreno_bin_kernels(backend_ctx)) {
const char * kernel_bin = (const char *)backend_ctx->get_adreno_bin_kernel("gemm_moe_mxfp4_f32_ns_ila", &bin_size);
if (kernel_bin && bin_size > 0) {
cl_program prog =
build_program_from_binary(backend_ctx->context, backend_ctx->device, kernel_bin, CL_moe_compile_opts, bin_size);
CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32_ns_bin = clCreateKernel(prog, "kernel_gemm_moe_mxfp4_f32_ns_ila", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
}
}
// moe_reorder_b
@@ -4770,6 +4929,27 @@ static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev) {
backend_ctx->adreno_use_large_buffer = getenv("GGML_OPENCL_ADRENO_USE_LARGE_BUFFER") != nullptr &&
backend_ctx->gpu_family == GPU_FAMILY::ADRENO;
#ifdef GGML_OPENCL_USE_ADRENO_BIN_KERNELS
// try loading adreno binary kernels if enabled
// if fails to load, builtin kernels will be used
{
dl_handle * kernel_lib_handle = dl_load_library(KERNEL_LIB_NAME);
backend_ctx->adreno_use_bin_kernels = false;
if (kernel_lib_handle) {
backend_ctx->get_adreno_bin_kernel_func = (get_adreno_bin_kernel_func_t)dl_get_sym(kernel_lib_handle, "get_adreno_kernels");
if (backend_ctx->get_adreno_bin_kernel_func) {
GGML_LOG_INFO("ggml_opencl: loaded bin kernel library %s\n", KERNEL_LIB_NAME);
backend_ctx->adreno_use_bin_kernels = true;
} else {
GGML_LOG_INFO("ggml_opencl: bin kernel library %s is invalid, will use builtin kernels\n", KERNEL_LIB_NAME);
}
} else {
GGML_LOG_INFO("ggml_opencl: failed to load %s, will use builtin kernels\n", KERNEL_LIB_NAME);
}
}
#endif // GGML_OPENCL_USE_ADRENO_BIN_KERNELS
cl_int err;
// A local ref of cl_context for convenience
@@ -14972,6 +15152,99 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t
CL_CHECK(clReleaseMemObject(b_img));
CL_CHECK(clReleaseMemObject(b_sub_buf));
} else {
// use bin kernel if available
if (backend_ctx->kernel_gemm_noshuffle_q8_0_f32_bin) {
int K_pad = K;
cl_mem b_sub_buf = nullptr;
cl_mem d_sub_buf = nullptr;
cl_mem a_img = nullptr;
cl_mem s_img = nullptr;
cl_mem b_img = nullptr;
cl_mem d_img = nullptr;
// subbuffer for activations
region.origin = offset1;
region.size = K_pad * N * sizeof(float);
CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err), err));
// Create subbuffer and image1d_buffer for dst
region.origin = (extrad->offset); // + dst->view_offs;
region.size = M * N * sizeof(float);
CL_CHECK((d_sub_buf = clCreateSubBuffer((extrad->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err), err));
// create an image for A
img_fmt = { CL_R, CL_FLOAT};
memset(&img_desc, 0, sizeof(img_desc));
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
img_desc.image_width = M * K / 4; // Divide by 4 for char -> float
img_desc.buffer = extra0_q8_0->q;
CL_CHECK((a_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
// create an image for Scale
img_fmt = { CL_R, CL_HALF_FLOAT};
memset(&img_desc, 0, sizeof(img_desc));
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
img_desc.image_width = M * K / 32; // Block size is 32
img_desc.buffer = extra0_q8_0->d;
CL_CHECK((s_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
// create an image for B from sub_buffer
img_fmt = {CL_R, CL_FLOAT};
memset(&img_desc, 0, sizeof(img_desc));
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
img_desc.image_width = K_pad * N;
img_desc.buffer = b_sub_buf;
CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
// img for d
img_fmt = {CL_R, CL_FLOAT};
memset(&img_desc, 0, sizeof(img_desc));
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
img_desc.image_width = M * N;
img_desc.buffer = d_sub_buf;
CL_CHECK((d_img = clCreateImage(context, CL_MEM_WRITE_ONLY, &img_fmt, &img_desc, NULL, &err), err));
// gemm
kernel = backend_ctx->kernel_gemm_noshuffle_q8_0_f32_bin;
bool layoutA_Mfirst = true;
bool layoutS_Mfirst = true;
bool layoutB_Nfirst = false;
bool layoutC_Mfirst = true;
cl_uint lineStrideMatrixAinBytes = layoutA_Mfirst ? M * 4 : K; // int8
cl_uint lineStrideMatrixSinBytes = layoutS_Mfirst ? M * 2 : (K / 32) * 2; // fp16
cl_uint lineStrideMatrixBinBytes = layoutB_Nfirst ? N * 4 : K_pad * 4; // fp32
cl_uint lineStrideMatrixCinBytes = layoutC_Mfirst ? M * 4 : N * 4; // fp32
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &a_img));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &s_img));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &extra1->offset));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &d_img));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &extrad->offset));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &K));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &lineStrideMatrixAinBytes));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &lineStrideMatrixSinBytes));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &lineStrideMatrixBinBytes));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &lineStrideMatrixCinBytes));
size_t global_work_size[] = { 64, (size_t)CEIL_DIV(M, 64), (size_t)CEIL_DIV(N, 64)};
size_t local_work_size[] = { 64, 2, 2 };
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
CL_CHECK(clReleaseMemObject(b_sub_buf));
CL_CHECK(clReleaseMemObject(d_sub_buf));
CL_CHECK(clReleaseMemObject(a_img));
CL_CHECK(clReleaseMemObject(s_img));
CL_CHECK(clReleaseMemObject(b_img));
CL_CHECK(clReleaseMemObject(d_img));
return;
}
cl_mem b_sub_buf = nullptr;
cl_mem b_sub_buf_trans = nullptr;
cl_mem b_img = nullptr;
@@ -17825,6 +18098,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
} else { // for gemm
kernel = backend_ctx->kernel_gemm_moe_q4_0_f32_ns;
if (backend_ctx->kernel_gemm_moe_q4_0_f32_ns_bin) {
kernel = backend_ctx->kernel_gemm_moe_q4_0_f32_ns_bin;
}
// Reorder router if called from test-backend-ops or when new router is generated.
// Otherwise reuse the reordered result from previous mul_mat_id call.
@@ -17870,6 +18146,11 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
cl_image_desc image_desc_buf_src1;
image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}};
if (backend_ctx->kernel_gemm_moe_q4_0_f32_ns_bin) {
// bin kernel uses slightly different image format
image_format_buf_src1 = {CL_R, CL_FLOAT};
image_desc_buf_src1.image_width = static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size);
}
image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
CL_CHECK(status);
@@ -18042,6 +18323,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
} else { // for gemm
kernel = backend_ctx->kernel_gemm_moe_q4_1_f32_ns;
if (backend_ctx->kernel_gemm_moe_q4_1_f32_ns_bin) {
kernel = backend_ctx->kernel_gemm_moe_q4_1_f32_ns_bin;
}
// Reorder router if called from test-backend-ops or when new router is generated.
// Otherwise reuse the reordered result from previous mul_mat_id call.
@@ -18087,6 +18371,11 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
cl_image_desc image_desc_buf_src1;
image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}};
if (backend_ctx->kernel_gemm_moe_q4_1_f32_ns_bin) {
// bin kernel uses slightly different image format
image_format_buf_src1 = {CL_R, CL_FLOAT};
image_desc_buf_src1.image_width = static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size);
}
image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
CL_CHECK(status);
@@ -18648,6 +18937,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
} else { // for gemm
kernel = backend_ctx->kernel_gemm_moe_q4_k_f32_ns;
if (backend_ctx->kernel_gemm_moe_q4_k_f32_ns_bin) {
kernel = backend_ctx->kernel_gemm_moe_q4_k_f32_ns_bin;
}
// Reorder router if called from test-backend-ops or when new router is generated.
// Otherwise reuse the reordered result from previous mul_mat_id call.
@@ -18689,6 +18981,11 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
CL_CHECK(status);
cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}};
if (backend_ctx->kernel_gemm_moe_q4_k_f32_ns_bin) {
// bin kernel uses slightly different image format
image_format_buf_src1 = {CL_R, CL_FLOAT};
image_desc_buf_src1.image_width = static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size);
}
image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
CL_CHECK(status);
@@ -19172,6 +19469,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
} else { // for gemm
kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32_ns;
if (backend_ctx->kernel_gemm_moe_mxfp4_f32_ns_bin) {
kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32_ns_bin;
}
// Reorder router if called from test-backend-ops or when new router is generated.
// Otherwise reuse the reordered result from previous mul_mat_id call.
@@ -19218,6 +19518,11 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
cl_image_desc image_desc_buf_src1;
image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}};
if (backend_ctx->kernel_gemm_moe_mxfp4_f32_ns_bin) {
// bin kernel uses slightly different image format
image_format_buf_src1 = {CL_R, CL_FLOAT};
image_desc_buf_src1.image_width = static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size);
}
image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
CL_CHECK(status);
+79
View File
@@ -0,0 +1,79 @@
#pragma once
#ifdef _WIN32
# define WIN32_LEAN_AND_MEAN
# ifndef NOMINMAX
# define NOMINMAX
# endif
# include <windows.h>
# include <winevt.h>
#else
# include <dlfcn.h>
# include <unistd.h>
#endif
#include <filesystem>
namespace fs = std::filesystem;
#ifdef _WIN32
using dl_handle = std::remove_pointer_t<HMODULE>;
struct dl_handle_deleter {
void operator()(HMODULE handle) {
FreeLibrary(handle);
}
};
static inline dl_handle * dl_load_library(const fs::path & path) {
// suppress error dialogs for missing DLLs
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
HMODULE handle = LoadLibraryW(path.wstring().c_str());
SetErrorMode(old_mode);
return handle;
}
static inline void * dl_get_sym(dl_handle * handle, const char * name) {
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
void * p = (void *) GetProcAddress(handle, name);
SetErrorMode(old_mode);
return p;
}
static inline const char * dl_error() {
return "";
}
#else
using dl_handle = void;
struct dl_handle_deleter {
void operator()(void * handle) {
dlclose(handle);
}
};
static inline dl_handle * dl_load_library(const fs::path & path) {
dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
return handle;
}
static inline void * dl_get_sym(dl_handle * handle, const char * name) {
return dlsym(handle, name);
}
static inline const char * dl_error() {
const char *rslt = dlerror();
return rslt != nullptr ? rslt : "";
}
#endif
+6
View File
@@ -159,6 +159,9 @@ extern "C" {
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
};
// Get the model file type (quantization) as a string, e.g. "Q8_0" or "Q4_K - Medium"
LLAMA_API const char * llama_ftype_name(enum llama_ftype ftype);
enum llama_rope_scaling_type {
LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED = -1,
LLAMA_ROPE_SCALING_TYPE_NONE = 0,
@@ -606,6 +609,9 @@ extern "C" {
// Get a string describing the model type
LLAMA_API int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);
// Get the model file type (quantization), e.g. LLAMA_FTYPE_MOSTLY_Q8_0
LLAMA_API enum llama_ftype llama_model_ftype(const struct llama_model * model);
// Returns the total size of all the tensors in the model in bytes
LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
+4 -1
View File
@@ -69,13 +69,16 @@ mbuf=
mmsel=
[ "$MM" != "" ] && mmsel="GGML_HEXAGON_MM_SELECT=$MM"
fasel=
[ "$FA" != "" ] && fasel="GGML_HEXAGON_FA_SELECT=$FA"
set -x
adb $adbserial $adbhost shell " \
cd $basedir; ulimit -c unlimited; \
LD_LIBRARY_PATH=$basedir/$branch/lib \
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opflt $opfuse $vmem $mbuf $mmsel \
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opflt $opfuse $vmem $mbuf $mmsel $fasel \
./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
--ctx-size 8192 --ubatch-size 1024 -fa on \
+4 -1
View File
@@ -57,6 +57,9 @@ opfuse=
mmsel=
[ "$MM" != "" ] && mmsel="GGML_HEXAGON_MM_SELECT=$MM"
fasel=
[ "$FA" != "" ] && fasel="GGML_HEXAGON_FA_SELECT=$FA"
set -x
tool=$1; shift
@@ -65,5 +68,5 @@ adb $adbserial $adbhost shell " \
cd $basedir; ulimit -c unlimited; \
LD_LIBRARY_PATH=$basedir/$branch/lib \
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opfuse $mmsel ./$branch/bin/$tool $@ \
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opfuse $mmsel $fasel ./$branch/bin/$tool $@ \
"
@@ -230,6 +230,12 @@ def print_ascii_timeline(op_name, dims, types, usec, cycles, events, evt_val=Non
char = 'Q'
elif norm_evt == 'A-PREP':
char = 'A'
elif norm_evt == 'Q-PREP':
char = 'q'
elif norm_evt == 'K-PREP':
char = 'k'
elif norm_evt == 'V-PREP':
char = 'v'
elif norm_evt == 'W-DEQUANT':
char = 'D'
elif norm_evt == 'O-PROC':
+1 -1
View File
@@ -5,7 +5,7 @@ import os
import sys
import subprocess
HTTPLIB_VERSION = "refs/tags/v0.48.0"
HTTPLIB_VERSION = "refs/tags/v0.49.0"
vendor = {
"https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp",
+46 -44
View File
@@ -27,52 +27,54 @@ const char * llama_file_version_name(llama_fver version) {
return "unknown";
}
static std::string llama_model_ftype_name(llama_ftype ftype) {
if (ftype & LLAMA_FTYPE_GUESSED) {
return llama_model_ftype_name((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) + " (guessed)";
}
#define LLAMA_FTYPE_PREFIX "(guessed) "
switch (ftype) {
case LLAMA_FTYPE_ALL_F32: return "all F32";
case LLAMA_FTYPE_MOSTLY_F16: return "F16";
case LLAMA_FTYPE_MOSTLY_BF16: return "BF16";
case LLAMA_FTYPE_MOSTLY_Q1_0: return "Q1_0";
case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0";
case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1";
case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0";
case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1";
case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0";
case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return "MXFP4 MoE";
case LLAMA_FTYPE_MOSTLY_NVFP4: return "NVFP4";
case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small";
case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small";
case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "Q3_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "Q3_K - Large";
case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "Q4_K - Small";
case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "Q4_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small";
case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K";
case LLAMA_FTYPE_MOSTLY_TQ1_0: return "TQ1_0 - 1.69 bpw ternary";
case LLAMA_FTYPE_MOSTLY_TQ2_0: return "TQ2_0 - 2.06 bpw ternary";
case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return "IQ2_XXS - 2.0625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_S: return "IQ2_S - 2.5 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_M: return "IQ2_M - 2.7 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_XXS: return "IQ3_XXS - 3.0625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ1_S: return "IQ1_S - 1.5625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ1_M: return "IQ1_M - 1.75 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw";
default: return "unknown, may not work";
const char * llama_ftype_name(llama_ftype ftype) {
static constexpr size_t guessed_prefix_len = sizeof(LLAMA_FTYPE_PREFIX) - 1;
const char * name;
switch ((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) {
case LLAMA_FTYPE_ALL_F32: name = LLAMA_FTYPE_PREFIX "all F32"; break;
case LLAMA_FTYPE_MOSTLY_F16: name = LLAMA_FTYPE_PREFIX "F16"; break;
case LLAMA_FTYPE_MOSTLY_BF16: name = LLAMA_FTYPE_PREFIX "BF16"; break;
case LLAMA_FTYPE_MOSTLY_Q1_0: name = LLAMA_FTYPE_PREFIX "Q1_0"; break;
case LLAMA_FTYPE_MOSTLY_Q4_0: name = LLAMA_FTYPE_PREFIX "Q4_0"; break;
case LLAMA_FTYPE_MOSTLY_Q4_1: name = LLAMA_FTYPE_PREFIX "Q4_1"; break;
case LLAMA_FTYPE_MOSTLY_Q5_0: name = LLAMA_FTYPE_PREFIX "Q5_0"; break;
case LLAMA_FTYPE_MOSTLY_Q5_1: name = LLAMA_FTYPE_PREFIX "Q5_1"; break;
case LLAMA_FTYPE_MOSTLY_Q8_0: name = LLAMA_FTYPE_PREFIX "Q8_0"; break;
case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: name = LLAMA_FTYPE_PREFIX "MXFP4 MoE"; break;
case LLAMA_FTYPE_MOSTLY_NVFP4: name = LLAMA_FTYPE_PREFIX "NVFP4"; break;
case LLAMA_FTYPE_MOSTLY_Q2_K: name = LLAMA_FTYPE_PREFIX "Q2_K - Medium"; break;
case LLAMA_FTYPE_MOSTLY_Q2_K_S: name = LLAMA_FTYPE_PREFIX "Q2_K - Small"; break;
case LLAMA_FTYPE_MOSTLY_Q3_K_S: name = LLAMA_FTYPE_PREFIX "Q3_K - Small"; break;
case LLAMA_FTYPE_MOSTLY_Q3_K_M: name = LLAMA_FTYPE_PREFIX "Q3_K - Medium"; break;
case LLAMA_FTYPE_MOSTLY_Q3_K_L: name = LLAMA_FTYPE_PREFIX "Q3_K - Large"; break;
case LLAMA_FTYPE_MOSTLY_Q4_K_S: name = LLAMA_FTYPE_PREFIX "Q4_K - Small"; break;
case LLAMA_FTYPE_MOSTLY_Q4_K_M: name = LLAMA_FTYPE_PREFIX "Q4_K - Medium"; break;
case LLAMA_FTYPE_MOSTLY_Q5_K_S: name = LLAMA_FTYPE_PREFIX "Q5_K - Small"; break;
case LLAMA_FTYPE_MOSTLY_Q5_K_M: name = LLAMA_FTYPE_PREFIX "Q5_K - Medium"; break;
case LLAMA_FTYPE_MOSTLY_Q6_K: name = LLAMA_FTYPE_PREFIX "Q6_K"; break;
case LLAMA_FTYPE_MOSTLY_TQ1_0: name = LLAMA_FTYPE_PREFIX "TQ1_0 - 1.69 bpw ternary"; break;
case LLAMA_FTYPE_MOSTLY_TQ2_0: name = LLAMA_FTYPE_PREFIX "TQ2_0 - 2.06 bpw ternary"; break;
case LLAMA_FTYPE_MOSTLY_IQ2_XXS: name = LLAMA_FTYPE_PREFIX "IQ2_XXS - 2.0625 bpw"; break;
case LLAMA_FTYPE_MOSTLY_IQ2_XS: name = LLAMA_FTYPE_PREFIX "IQ2_XS - 2.3125 bpw"; break;
case LLAMA_FTYPE_MOSTLY_IQ2_S: name = LLAMA_FTYPE_PREFIX "IQ2_S - 2.5 bpw"; break;
case LLAMA_FTYPE_MOSTLY_IQ2_M: name = LLAMA_FTYPE_PREFIX "IQ2_M - 2.7 bpw"; break;
case LLAMA_FTYPE_MOSTLY_IQ3_XS: name = LLAMA_FTYPE_PREFIX "IQ3_XS - 3.3 bpw"; break;
case LLAMA_FTYPE_MOSTLY_IQ3_XXS: name = LLAMA_FTYPE_PREFIX "IQ3_XXS - 3.0625 bpw"; break;
case LLAMA_FTYPE_MOSTLY_IQ1_S: name = LLAMA_FTYPE_PREFIX "IQ1_S - 1.5625 bpw"; break;
case LLAMA_FTYPE_MOSTLY_IQ1_M: name = LLAMA_FTYPE_PREFIX "IQ1_M - 1.75 bpw"; break;
case LLAMA_FTYPE_MOSTLY_IQ4_NL: name = LLAMA_FTYPE_PREFIX "IQ4_NL - 4.5 bpw"; break;
case LLAMA_FTYPE_MOSTLY_IQ4_XS: name = LLAMA_FTYPE_PREFIX "IQ4_XS - 4.25 bpw"; break;
case LLAMA_FTYPE_MOSTLY_IQ3_S: name = LLAMA_FTYPE_PREFIX "IQ3_S - 3.4375 bpw"; break;
case LLAMA_FTYPE_MOSTLY_IQ3_M: name = LLAMA_FTYPE_PREFIX "IQ3_S mix - 3.66 bpw"; break;
default: name = LLAMA_FTYPE_PREFIX "unknown, may not work"; break;
}
return (ftype & LLAMA_FTYPE_GUESSED) ? name : name + guessed_prefix_len;
}
#undef LLAMA_FTYPE_PREFIX
// return a list of splits for a given path
// for example, given "<name>-00002-of-00004.gguf", returns list of all 4 splits
static std::vector<std::string> llama_get_list_splits(const std::string & path, const int idx, const int n_split) {
@@ -1693,12 +1695,12 @@ bool llama_model_loader::load_all_data(
}
std::string llama_model_loader::ftype_name() const {
return llama_model_ftype_name(ftype);
return llama_ftype_name(ftype);
}
void llama_model_loader::print_info() const {
LLAMA_LOG_INFO("%s: file format = %s\n", __func__, llama_file_version_name(fver));
LLAMA_LOG_INFO("%s: file type = %s\n", __func__, llama_model_ftype_name(ftype).c_str());
LLAMA_LOG_INFO("%s: file type = %s\n", __func__, llama_ftype_name(ftype));
if (n_bytes < GiB) {
LLAMA_LOG_INFO("%s: file size = %.2f MiB (%.2f BPW) \n", __func__, n_bytes/1024.0/1024.0, n_bytes*8.0/n_elements);
} else {
+12
View File
@@ -987,6 +987,8 @@ struct llama_model::impl {
std::string desc_str;
llama_ftype ftype = LLAMA_FTYPE_ALL_F32;
// model memory mapped files
llama_mmaps mappings;
@@ -1200,6 +1202,8 @@ void llama_model_base::load_hparams(llama_model_loader & ml) {
pimpl->desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name();
pimpl->ftype = ml.ftype;
if (hparams.f_max_alibi_bias > 0.0f) {
hparams.use_alibi = true;
}
@@ -1646,6 +1650,10 @@ std::string llama_model::desc() const {
return pimpl->desc_str;
}
llama_ftype llama_model::ftype() const {
return pimpl->ftype;
}
size_t llama_model::size() const {
return pimpl->n_bytes;
}
@@ -2616,6 +2624,10 @@ int32_t llama_model_desc(const llama_model * model, char * buf, size_t buf_size)
return snprintf(buf, buf_size, "%s", model->desc().c_str());
}
llama_ftype llama_model_ftype(const llama_model * model) {
return model->ftype();
}
uint64_t llama_model_size(const llama_model * model) {
return model->size();
}
+2
View File
@@ -637,6 +637,8 @@ struct llama_model {
std::string desc() const;
llama_ftype ftype() const;
size_t size() const; // file size
size_t n_tensors() const;
size_t n_devices() const;
+3
View File
@@ -448,6 +448,9 @@ int llama_cli(int argc, char ** argv) {
console::log("%s\n", LLAMA_ASCII_LOGO);
console::log("build : %s\n", inf.build_info.c_str());
console::log("model : %s\n", inf.model_name.c_str());
if (!inf.model_ftype.empty()) {
console::log("ftype : %s\n", inf.model_ftype.c_str());
}
console::log("modalities : %s\n", modalities.c_str());
if (!params.system_prompt.empty()) {
console::log("using custom system prompt\n");
+4
View File
@@ -3989,6 +3989,8 @@ server_context_meta server_context::get_meta() const {
auto bos_token_str = bos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_tgt, bos_id, true) : "";
auto eos_token_str = eos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_tgt, eos_id, true) : "";
const char * ftype_name = llama_ftype_name(llama_model_ftype(impl->model_tgt));
return server_context_meta {
/* build_info */ std::string(llama_build_info()),
/* model_name */ impl->model_name,
@@ -4023,6 +4025,7 @@ server_context_meta server_context::get_meta() const {
/* model_n_embd_inp */ llama_model_n_embd(impl->model_tgt),
/* model_n_params */ llama_model_n_params(impl->model_tgt),
/* model_size */ llama_model_size(impl->model_tgt),
/* model_ftype */ ftype_name,
};
}
@@ -5118,6 +5121,7 @@ json server_routes::get_model_info() const {
{"n_embd", meta->model_n_embd_inp},
{"n_params", meta->model_n_params},
{"size", meta->model_size},
{"ftype", meta->model_ftype},
}},
};
}
+1
View File
@@ -50,6 +50,7 @@ struct server_context_meta {
int32_t model_n_embd_inp;
uint64_t model_n_params;
uint64_t model_size;
std::string model_ftype;
};
enum server_state {
@@ -18,7 +18,7 @@
let mcpSearchQuery = $state('');
let allMcpServers = $derived(mcpStore.getServersSorted());
let mcpServers = $derived(allMcpServers.filter((s) => s.enabled));
let mcpServers = $derived(mcpStore.visibleMcpServers);
let hasMcpServers = $derived(mcpServers.length > 0);
// let hasAnyMcpServers = $derived(allMcpServers.length > 0);
let filteredMcpServers = $derived.by(() => {
@@ -74,9 +74,7 @@
const sheetItemRowClass =
'flex w-full items-center justify-between gap-2 rounded-md px-3 py-2 text-left text-sm transition-colors hover:bg-accent';
function getEnabledMcpServers() {
return mcpStore.getServersSorted().filter((s) => s.enabled);
}
let visibleMcpServers = $derived(mcpStore.visibleMcpServers);
</script>
<div class="flex items-center gap-1 {className}">
@@ -153,13 +151,13 @@
<span class="flex-1">MCP Servers</span>
<span class="text-xs text-muted-foreground">
{getEnabledMcpServers().length} server{getEnabledMcpServers().length !== 1 ? 's' : ''}
{visibleMcpServers.length} server{visibleMcpServers.length !== 1 ? 's' : ''}
</span>
</Collapsible.Trigger>
<Collapsible.Content>
<div class="flex flex-col gap-0.5 pl-4">
{#each getEnabledMcpServers() as server (server.id)}
{#each visibleMcpServers as server (server.id)}
{@const healthState = mcpStore.getHealthCheckState(server.id)}
{@const hasError = healthState.status === HealthCheckStatus.ERROR}
{@const displayName = mcpStore.getServerLabel(server)}
@@ -202,7 +200,7 @@
</button>
{/each}
{#if getEnabledMcpServers().length === 0}
{#if visibleMcpServers.length === 0}
<div class="px-3 py-2 text-center text-sm text-muted-foreground">
No MCP servers configured
</div>
@@ -1,8 +1,9 @@
<script lang="ts">
import { ChevronDown, ShieldQuestion } from '@lucide/svelte';
import { ChatMessageActionCard } from '$lib/components/app';
import { Button } from '$lib/components/ui/button';
import { Button, buttonVariants } from '$lib/components/ui/button';
import * as ButtonGroup from '$lib/components/ui/button-group';
import { cn } from '$lib/components/ui/utils';
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
import { ToolSource, ToolPermissionDecision } from '$lib/enums';
import { TOOL_SERVER_LABELS } from '$lib/constants';
@@ -19,25 +20,17 @@
<ChatMessageActionCard icon={ShieldQuestion}>
{#snippet message()}
Allow use of
<span class="font-semibold">{toolName}</span>
{#if serverLabel}
from <span class="font-semibold">{serverLabel}</span>
{/if}
?
Allow use of <span class="font-semibold">{toolName}</span>{#if serverLabel}
from <span class="font-semibold">{serverLabel}</span>{/if}?
{/snippet}
{#snippet actions()}
<DropdownMenu.Root>
<ButtonGroup.Root
class="overflow-hidden rounded-md bg-foreground text-white shadow-sm dark:bg-secondary dark:text-foreground"
>
<ButtonGroup.Root class="overflow-hidden rounded-md shadow-sm">
<Button
class="rounded-none! shadow-none!"
variant="secondary"
size="sm"
class="!rounded-r-none !shadow-none"
onclick={() => onDecision(ToolPermissionDecision.ONCE)}
>
Allow once
@@ -45,10 +38,14 @@
<ButtonGroup.Separator />
<DropdownMenu.Trigger>
<Button size="sm" class="rounded-none! !ps-2 shadow-none!">
<ChevronDown class="h-3.5 w-3.5" />
</Button>
<DropdownMenu.Trigger
class={cn(
buttonVariants({ variant: 'secondary', size: 'sm' }),
'inline-flex cursor-pointer items-center !rounded-l-none !shadow-none !px-2'
)}
aria-label="More allow options"
>
<ChevronDown class="h-3.5 w-3.5" />
</DropdownMenu.Trigger>
</ButtonGroup.Root>
@@ -76,12 +73,7 @@
</DropdownMenu.Content>
</DropdownMenu.Root>
<Button
variant="destructive"
size="sm"
class="text-destructive hover:text-destructive"
onclick={() => onDecision(ToolPermissionDecision.DENY)}
>
<Button variant="destructive" size="sm" onclick={() => onDecision(ToolPermissionDecision.DENY)}>
Deny
</Button>
{/snippet}
@@ -4,7 +4,7 @@
import { McpServerForm } from '$lib/components/app/mcp';
import { mcpStore } from '$lib/stores/mcp.svelte';
import { conversationsStore } from '$lib/stores/conversations.svelte';
import { uuid } from '$lib/utils';
import { parseHeadersToArray, uuid } from '$lib/utils';
import { MCP_SERVER_ID_PREFIX } from '$lib/constants';
interface Props {
@@ -26,6 +26,10 @@
return 'Invalid URL format';
}
});
let newServerHeaderPairsValid = $derived(
parseHeadersToArray(newServerHeaders).every((p) => p.key.trim() && p.value.trim())
);
let canSave = $derived(!newServerUrlError && newServerHeaderPairsValid);
function handleOpenChange(value: boolean) {
if (!value) {
@@ -37,7 +41,7 @@
}
function saveNewServer() {
if (newServerUrlError) return;
if (!canSave) return;
const newServerId = uuid() ?? `${MCP_SERVER_ID_PREFIX}-${Date.now()}`;
@@ -52,6 +56,11 @@
handleOpenChange(false);
}
function handleSubmit(event: SubmitEvent) {
event.preventDefault();
saveNewServer();
}
</script>
<Dialog.Root {open} onOpenChange={handleOpenChange}>
@@ -60,29 +69,27 @@
<Dialog.Title>Add New Server</Dialog.Title>
</Dialog.Header>
<div class="space-y-4 py-4">
<McpServerForm
url={newServerUrl}
headers={newServerHeaders}
onUrlChange={(v) => (newServerUrl = v)}
onHeadersChange={(v) => (newServerHeaders = v)}
urlError={newServerUrl ? newServerUrlError : null}
id="new-server"
/>
</div>
<form onsubmit={handleSubmit} class="contents">
<div class="space-y-4 py-4">
<McpServerForm
url={newServerUrl}
headers={newServerHeaders}
onUrlChange={(v) => (newServerUrl = v)}
onHeadersChange={(v) => (newServerHeaders = v)}
urlError={newServerUrl ? newServerUrlError : null}
id="new-server"
/>
</div>
<Dialog.Footer>
<Button variant="secondary" size="sm" onclick={() => handleOpenChange(false)}>Cancel</Button>
<Dialog.Footer>
<Button variant="secondary" size="sm" onclick={() => handleOpenChange(false)}>
Cancel
</Button>
<Button
variant="default"
size="sm"
onclick={saveNewServer}
disabled={!!newServerUrlError}
aria-label="Save"
>
Add
</Button>
</Dialog.Footer>
<Button variant="default" size="sm" type="submit" disabled={!canSave} aria-label="Save">
Add
</Button>
</Dialog.Footer>
</form>
</Dialog.Content>
</Dialog.Root>
@@ -0,0 +1,180 @@
<script lang="ts">
import { Button } from '$lib/components/ui/button';
import * as Card from '$lib/components/ui/card';
import * as Dialog from '$lib/components/ui/dialog';
import { fly } from 'svelte/transition';
import { McpServerCardCompact, McpServerForm } from '$lib/components/app/mcp';
import { RECOMMENDED_MCP_SERVERS } from '$lib/constants';
import { conversationsStore } from '$lib/stores/conversations.svelte';
import { mcpStore } from '$lib/stores/mcp.svelte';
import { uuid } from '$lib/utils';
import { MCP_SERVERS_ADDED_TO_CHAT_LOCALSTORAGE_KEY, MCP_SERVER_ID_PREFIX } from '$lib/constants';
import type { MCPServerSettingsEntry } from '$lib/types';
import { Plus } from '@lucide/svelte';
interface Props {
open: boolean;
onOpenChange?: (open: boolean) => void;
}
let { open = $bindable(), onOpenChange }: Props = $props();
let selected = $state<Record<string, boolean>>(
Object.fromEntries(RECOMMENDED_MCP_SERVERS.map((server) => [server.id, false]))
);
let addedServers = $state<MCPServerSettingsEntry[]>([]);
let showAddForm = $state(false);
let newServerUrl = $state('');
let newServerHeaders = $state('');
let newServerUrlError = $derived.by(() => {
if (!newServerUrl.trim()) return 'URL is required';
try {
new URL(newServerUrl);
return null;
} catch {
return 'Invalid URL format';
}
});
function handleOpenChange(value: boolean) {
if (!value) {
showAddForm = false;
newServerUrl = '';
newServerHeaders = '';
addedServers = [];
localStorage.setItem(MCP_SERVERS_ADDED_TO_CHAT_LOCALSTORAGE_KEY, 'true');
}
open = value;
onOpenChange?.(value);
}
function resetAddForm() {
showAddForm = false;
newServerUrl = '';
newServerHeaders = '';
}
function enableSelected() {
localStorage.setItem(MCP_SERVERS_ADDED_TO_CHAT_LOCALSTORAGE_KEY, 'true');
for (const server of RECOMMENDED_MCP_SERVERS) {
if (selected[server.id]) {
const existing = mcpStore.getServerById(server.id);
if (existing) {
mcpStore.updateServer(server.id, { enabled: true });
} else {
mcpStore.addServer({
id: server.id,
enabled: true,
url: server.url,
name: server.name
});
}
conversationsStore.setMcpServerOverride(server.id, true);
}
}
handleOpenChange(false);
}
function saveNewServer() {
if (newServerUrlError) return;
const newServerId = uuid() ?? `${MCP_SERVER_ID_PREFIX}-${Date.now()}`;
localStorage.setItem(MCP_SERVERS_ADDED_TO_CHAT_LOCALSTORAGE_KEY, 'true');
const newServer = mcpStore.addServer({
id: newServerId,
enabled: true,
url: newServerUrl.trim(),
headers: newServerHeaders.trim() || undefined
});
conversationsStore.setMcpServerOverride(newServerId, true);
if (newServer) {
addedServers = [...addedServers, newServer];
}
resetAddForm();
}
</script>
<Dialog.Root bind:open onOpenChange={handleOpenChange}>
<Dialog.Content class="sm:max-w-lg">
<Dialog.Header>
<Dialog.Title>Do more with MCP</Dialog.Title>
<Dialog.Description>
Power-up your experience by adding tools, resources and more capabilities provided by MCP
servers.
</Dialog.Description>
</Dialog.Header>
<div class="max-h-[60vh] space-y-4 overflow-y-auto py-4" in:fly={{ y: 16, duration: 300 }}>
<h3 class="text-sm font-semibold">Quickly get started with</h3>
{#each RECOMMENDED_MCP_SERVERS as server (server.id)}
<McpServerCardCompact
{server}
enabled={selected[server.id]}
onToggle={(enabled) => (selected[server.id] = enabled)}
/>
{/each}
{#if addedServers.length > 0}
{#each addedServers as server (server.id)}
<McpServerCardCompact {server} enabled={true} />
{/each}
{/if}
{#if showAddForm}
<Card.Root class="gap-3! bg-muted/30 p-4">
<McpServerForm
url={newServerUrl}
headers={newServerHeaders}
onUrlChange={(v) => (newServerUrl = v)}
onHeadersChange={(v) => (newServerHeaders = v)}
urlError={newServerUrl ? newServerUrlError : null}
id="recommendation-new-server"
/>
<div class="flex justify-end gap-2 pt-2">
<Button variant="secondary" size="sm" onclick={resetAddForm}>Cancel</Button>
<Button
variant="default"
size="sm"
onclick={saveNewServer}
disabled={!!newServerUrlError}
aria-label="Save"
>
Add
</Button>
</div>
</Card.Root>
{:else}
<Card.Root class="gap-0 border-dashed bg-muted/30 p-0 transition-colors hover:bg-muted/50">
<button
type="button"
class="flex w-full items-center justify-center gap-2 rounded-lg p-6 text-sm text-muted-foreground transition-colors hover:text-foreground"
onclick={() => (showAddForm = true)}
aria-label="Add your own MCP server"
>
<Plus class="h-4 w-4" />
<span>Add your own server</span>
</button>
</Card.Root>
{/if}
</div>
<Dialog.Footer>
<Button variant="secondary" size="sm" onclick={() => handleOpenChange(false)}>Not now</Button>
<Button variant="default" size="sm" onclick={enableSelected}>Add selected</Button>
</Dialog.Footer>
</Dialog.Content>
</Dialog.Root>
@@ -18,6 +18,15 @@
*/
export { default as DialogMcpServerAddNew } from './DialogMcpServerAddNew.svelte';
/**
* **DialogMcpServerRecommendations** - Suggested MCP servers opt-in dialog
*
* Prompts the user to enable pre-defined recommended MCP servers on first launch.
* Shows one switch per suggested server and persists the choice as a per-chat
* override so the selected servers become available in conversations.
*/
export { default as DialogMcpServerRecommendations } from './DialogMcpServerRecommendations.svelte';
/**
* **DialogExportSettings** - Settings export dialog with sensitive data warning
*
@@ -1,4 +1,5 @@
<script lang="ts">
import { tick } from 'svelte';
import { Plus, Trash2 } from '@lucide/svelte';
import { Input } from '$lib/components/ui/input';
import {
@@ -33,8 +34,18 @@
sectionLabelOptional = true
}: Props = $props();
function addPair() {
// Pre-allocate the ref array so `bind:ref={keyInputRefs[index]}` never reads `undefined`
// for in-range indices; the $effect below keeps it in sync when `pairs` grows.
// svelte-ignore state_referenced_locally
let keyInputRefs: (HTMLInputElement | null)[] = $state(pairs.map(() => null));
async function addPair() {
// Capture the target index before mutating so deletions earlier in the
// list can't make keyInputRefs.length drift past the newly-appended row.
const newIndex = pairs.length;
onPairsChange([...pairs, { key: '', value: '' }]);
await tick();
keyInputRefs[newIndex]?.focus();
}
function removePair(index: number) {
@@ -76,6 +87,15 @@
newPairs[index] = { ...newPairs[index], value: trimmed };
onPairsChange(newPairs);
}
// Keep keyInputRefs aligned with pairs length so bind:ref never sees `undefined`.
// $effect.pre runs during traversal in tree order, before the {#each} block re-renders,
// so newly-appended items always have a defined slot when their binding is set up.
$effect.pre(() => {
while (keyInputRefs.length < pairs.length) {
keyInputRefs.push(null);
}
});
</script>
<div class={className}>
@@ -103,6 +123,7 @@
{#each pairs as pair, index (index)}
<div class="flex items-start gap-2">
<Input
bind:ref={keyInputRefs[index]}
type="text"
placeholder={keyPlaceholder}
value={pair.key}
@@ -163,7 +163,7 @@
{/if}
</div>
<div class="flex justify-between gap-4">
<div class="mt-auto flex justify-between gap-4">
{#if showSkeleton}
<Skeleton class="h-3 w-28" />
{:else if protocolVersion}
@@ -0,0 +1,156 @@
<script lang="ts">
import * as Card from '$lib/components/ui/card';
import { Badge } from '$lib/components/ui/badge';
import { Skeleton } from '$lib/components/ui/skeleton';
import { Switch } from '$lib/components/ui/switch';
import * as Tooltip from '$lib/components/ui/tooltip';
import { McpServerIdentity } from '$lib/components/app/mcp';
import { mcpStore } from '$lib/stores/mcp.svelte';
import { HealthCheckStatus } from '$lib/enums';
import type { MCPServerDisplayInfo, HealthCheckState, MCPServerSettingsEntry } from '$lib/types';
import { onMount } from 'svelte';
import { MCP_CARD_VISIBLE_TOOL_LIMIT, NEWLINE } from '$lib/constants';
interface Props {
server: MCPServerDisplayInfo & { description?: string };
enabled?: boolean;
onToggle?: (enabled: boolean) => void;
}
let { server, enabled = false, onToggle }: Props = $props();
onMount(() => {
const state = mcpStore.getHealthCheckState(server.id);
if (state.status === HealthCheckStatus.IDLE) {
mcpStore.runHealthCheck(server as MCPServerSettingsEntry).catch(() => {});
}
});
let healthState = $derived<HealthCheckState>(mcpStore.getHealthCheckState(server.id));
let displayName = $derived(mcpStore.getServerLabel(server));
let faviconUrl = $derived(mcpStore.getServerFavicon(server.id));
let isIdle = $derived(healthState.status === HealthCheckStatus.IDLE);
let isHealthChecking = $derived(healthState.status === HealthCheckStatus.CONNECTING);
let isError = $derived(healthState.status === HealthCheckStatus.ERROR);
let errorMessage = $derived(
healthState.status === HealthCheckStatus.ERROR ? healthState.message : undefined
);
let serverInfo = $derived(
healthState.status === HealthCheckStatus.SUCCESS ? healthState.serverInfo : undefined
);
let tools = $derived(healthState.status === HealthCheckStatus.SUCCESS ? healthState.tools : []);
let instructions = $derived(
healthState.status === HealthCheckStatus.SUCCESS ? healthState.instructions : undefined
);
let showSkeleton = $derived(isIdle || isHealthChecking);
// Curated descriptions get two lines; instructions fallback is one line so the
// compact card stays scannable.
let description = $derived.by(() => {
if (server.description) {
return { text: server.description, lines: 2 };
}
if (!instructions) return null;
const firstLine = instructions.split(NEWLINE).find((line: string) => line.trim().length > 0);
const trimmed = firstLine?.trim();
return trimmed ? { text: trimmed, lines: 1 } : null;
});
let visibleTools = $derived(tools.slice(0, MCP_CARD_VISIBLE_TOOL_LIMIT));
let hiddenTools = $derived(tools.slice(MCP_CARD_VISIBLE_TOOL_LIMIT));
let hiddenToolCount = $derived(hiddenTools.length);
function handleToggle(checked: boolean) {
onToggle?.(checked);
}
</script>
<Card.Root class="!gap-3 bg-muted/30 p-4">
<div class="flex items-start justify-between gap-3">
<div class="min-w-0 flex-1">
{#if showSkeleton}
<span class="flex min-w-0 items-center gap-1.5">
<Skeleton class="h-5 w-5 rounded" />
<Skeleton class="h-4 w-32" />
</span>
{:else}
<McpServerIdentity
{displayName}
{faviconUrl}
{serverInfo}
iconClass="h-5 w-5"
iconRounded="rounded"
nameClass="font-medium"
/>
{/if}
</div>
<Switch checked={enabled} disabled={isError || showSkeleton} onCheckedChange={handleToggle} />
</div>
{#if isError && errorMessage}
<p class="text-xs text-destructive">{errorMessage}</p>
{/if}
{#if showSkeleton}
<div class="space-y-1.5">
<Skeleton class="h-3 w-full max-w-md" />
</div>
<div class="flex flex-wrap items-center gap-1.5">
<Skeleton class="h-5 w-16 rounded-full" />
<Skeleton class="h-5 w-20 rounded-full" />
<Skeleton class="h-5 w-24 rounded-full" />
<Skeleton class="h-5 w-14 rounded-full" />
</div>
{:else}
{#if description}
{#if description.lines === 2}
<p class="line-clamp-2 text-xs text-muted-foreground" title={description.text}>
{description.text}
</p>
{:else}
<p class="line-clamp-1 truncate text-xs text-muted-foreground" title={description.text}>
{description.text}
</p>
{/if}
{/if}
{#if tools.length > 0}
<div class="flex flex-wrap items-center gap-1.5">
{#each visibleTools as tool (tool.name)}
<Tooltip.Root>
<Tooltip.Trigger>
<Badge variant="secondary" class="h-5 max-w-40 px-2 text-[11px]">
<span class="block min-w-0 flex-1 truncate">{tool.name}</span>
</Badge>
</Tooltip.Trigger>
<Tooltip.Content>
<p class="max-w-xs text-xs">
{tool.description ?? 'No description'}
</p>
</Tooltip.Content>
</Tooltip.Root>
{/each}
{#if hiddenToolCount > 0}
<Tooltip.Root>
<Tooltip.Trigger>
<Badge variant="secondary" class="h-5 px-2 text-[11px] text-muted-foreground">
+ {hiddenToolCount} more tools
</Badge>
</Tooltip.Trigger>
<Tooltip.Content class="max-w-md">
<p class="text-xs">
{hiddenTools.map((tool) => tool.name).join(', ')}
</p>
</Tooltip.Content>
</Tooltip.Root>
{/if}
</div>
{/if}
{/if}
</Card.Root>
@@ -1,6 +1,7 @@
<script lang="ts">
import { Button } from '$lib/components/ui/button';
import { McpServerForm } from '$lib/components/app/mcp';
import { parseHeadersToArray } from '$lib/utils';
interface Props {
serverId: string;
@@ -26,13 +27,21 @@
}
});
let canSave = $derived(!urlError);
let headerPairsValid = $derived(
parseHeadersToArray(editHeaders).every((p) => p.key.trim() && p.value.trim())
);
let canSave = $derived(!urlError && headerPairsValid);
function handleSave() {
if (!canSave) return;
onSave(editUrl.trim(), editHeaders.trim(), editUseProxy);
}
function handleSubmit(event: SubmitEvent) {
event.preventDefault();
handleSave();
}
export function setInitialValues(url: string, headers: string, useProxy: boolean) {
editUrl = url;
editHeaders = headers;
@@ -40,25 +49,27 @@
}
</script>
<div class="space-y-4">
<p class="font-medium">Configure Server</p>
<form onsubmit={handleSubmit} class="contents">
<div class="space-y-4">
<p class="font-medium">Configure Server</p>
<McpServerForm
url={editUrl}
headers={editHeaders}
useProxy={editUseProxy}
onUrlChange={(v) => (editUrl = v)}
onHeadersChange={(v) => (editHeaders = v)}
onUseProxyChange={(v) => (editUseProxy = v)}
urlError={editUrl ? urlError : null}
id={serverId}
/>
<McpServerForm
url={editUrl}
headers={editHeaders}
useProxy={editUseProxy}
onUrlChange={(v) => (editUrl = v)}
onHeadersChange={(v) => (editHeaders = v)}
onUseProxyChange={(v) => (editUseProxy = v)}
urlError={editUrl ? urlError : null}
id={serverId}
/>
<div class="flex items-center justify-end gap-2">
<Button variant="secondary" size="sm" onclick={onCancel}>Cancel</Button>
<div class="flex items-center justify-end gap-2">
<Button variant="secondary" size="sm" onclick={onCancel}>Cancel</Button>
<Button size="sm" onclick={handleSave} disabled={!canSave}>
{serverUrl.trim() ? 'Update' : 'Add'}
</Button>
<Button size="sm" type="submit" disabled={!canSave}>
{serverUrl.trim() ? 'Update' : 'Add'}
</Button>
</div>
</div>
</div>
</form>
@@ -38,14 +38,87 @@
let headerPairs = $derived<KeyValuePair[]>(parseHeadersToArray(headers));
const AUTHORIZATION_HEADER = 'Authorization';
const BEARER_PREFIX = 'Bearer ';
// Heuristic: this dedicated UI only owns Authorization headers that already
// carry a Bearer scheme. Anything else (e.g. Basic, raw tokens) stays in the
// KV section so the user can still edit those values verbatim.
const matchesAuthorizationKey = (key: string): boolean =>
key.trim().toLowerCase() === AUTHORIZATION_HEADER.toLowerCase();
const isBearerScheme = (value: string): boolean =>
value.trim().toLowerCase().startsWith(BEARER_PREFIX.toLowerCase());
const ownedByBearerUi = (p: KeyValuePair): boolean =>
matchesAuthorizationKey(p.key) && isBearerScheme(p.value);
let hasAuthorization = $derived(headerPairs.some(ownedByBearerUi));
let wantsAuthorization = $state(false);
let showAuthorization = $derived(hasAuthorization || wantsAuthorization);
let urlInput: HTMLInputElement | null = $state(null);
let bearerInput: HTMLInputElement | null = $state(null);
$effect(() => {
urlInput?.focus();
});
$effect(() => {
if (wantsAuthorization && bearerInput) {
bearerInput.focus();
}
});
let bearerToken = $derived.by(() => {
const auth = headerPairs.find(ownedByBearerUi);
if (!auth) return '';
return auth.value.trim().slice(BEARER_PREFIX.length).trim();
});
$effect(() => {
if (!headers.trim()) {
wantsAuthorization = false;
}
});
function updateHeaderPairs(newPairs: KeyValuePair[]) {
headerPairs = newPairs;
onHeadersChange(serializeHeaders(newPairs));
}
// The dedicated UI owns the Authorization slot end-to-end when the user
// engages it: any prior Authorization row (Bearer or otherwise) is replaced
// by exactly one { Authorization: "Bearer <token>" } entry. JSON's last-key
// behavior would otherwise pick one arbitrarily, so we strip first.
function updateBearerToken(token: string) {
const filtered = headerPairs.filter((p) => !matchesAuthorizationKey(p.key));
const trimmed = token.trim();
if (trimmed) {
filtered.push({ key: AUTHORIZATION_HEADER, value: `${BEARER_PREFIX}${trimmed}` });
}
updateHeaderPairs(filtered);
}
function setUseAuthorization(checked: boolean) {
wantsAuthorization = checked;
if (!checked) {
// Only drop the entry this UI owns; a non-Bearer Authorization row
// authored in the KV section must survive a toggle off untouched.
const filtered = headerPairs.filter((p) => !ownedByBearerUi(p));
updateHeaderPairs(filtered);
}
}
</script>
<div class="grid gap-3">
<div>
<div class="grid gap-2">
<div class="mb-4">
<label for="server-url-{id}" class="mb-2 block text-xs font-medium">
Server URL <span class="text-destructive">*</span>
</label>
@@ -57,50 +130,52 @@
value={url}
oninput={(e) => onUrlChange(e.currentTarget.value)}
class={urlError ? 'border-destructive' : ''}
bind:ref={urlInput}
/>
{#if urlError}
<p class="mt-1.5 text-xs text-destructive">{urlError}</p>
{/if}
{#if !isWebSocket && onUseProxyChange}
<label
class={[
'mt-3 flex items-start gap-2',
mcpStore.isProxyAvailable && 'cursor-pointer',
!mcpStore.isProxyAvailable && 'opacity-80'
]}
>
<Switch
class="mt-1"
id="use-proxy-{id}"
checked={useProxy}
disabled={!mcpStore.isProxyAvailable}
onCheckedChange={(checked) => onUseProxyChange?.(checked)}
/>
<span>
<span class="text-xs text-muted-foreground">Use llama-server proxy</span>
<br />
{#if !mcpStore.isProxyAvailable}
<span class="inline-flex gap-0.75 text-xs text-muted-foreground/60"
>(Run <pre>llama-server</pre>
with
<pre>{CLI_FLAGS.MCP_PROXY}</pre>
flag)</span
>
{/if}
</span>
</label>
{/if}
</div>
<label class="flex items-center gap-2 cursor-pointer">
<Switch
id="use-authorization-{id}"
checked={showAuthorization}
onCheckedChange={setUseAuthorization}
/>
<span class="text-xs text-muted-foreground">Authorization</span>
</label>
{#if showAuthorization}
<div class="relative mt-2">
<Input
id="bearer-token-{id}"
type="password"
autocomplete="off"
placeholder="Paste token here"
value={bearerToken}
oninput={(e) => updateBearerToken(e.currentTarget.value)}
class="pl-16"
bind:ref={bearerInput}
/>
<span
class="pointer-events-none absolute inset-y-0 left-3 flex items-center text-sm font-medium text-foreground"
>
Bearer
</span>
</div>
{/if}
<KeyValuePairs
class="mt-2"
pairs={headerPairs}
onPairsChange={updateHeaderPairs}
class="mt-3"
pairs={headerPairs.filter((p) => !ownedByBearerUi(p))}
onPairsChange={(pairs) => {
const auth = headerPairs.find(ownedByBearerUi);
updateHeaderPairs(auth ? [...pairs, auth] : pairs);
}}
keyPlaceholder="Header name"
valuePlaceholder="Value"
addButtonLabel="Add"
@@ -108,4 +183,37 @@
sectionLabel="Custom Headers"
sectionLabelOptional
/>
{#if !isWebSocket && onUseProxyChange}
<label
class={[
'mt-3 flex items-start gap-2',
mcpStore.isProxyAvailable && 'cursor-pointer',
!mcpStore.isProxyAvailable && 'opacity-80'
]}
>
<Switch
class="mt-1"
id="use-proxy-{id}"
checked={useProxy}
disabled={!mcpStore.isProxyAvailable}
onCheckedChange={(checked) => onUseProxyChange?.(checked)}
/>
<span>
<span class="text-xs text-muted-foreground">Use llama-server proxy</span>
<br />
{#if !mcpStore.isProxyAvailable}
<span class="inline-flex gap-0.75 text-xs text-muted-foreground/60"
>(Run <pre>llama-server</pre>
with
<pre>{CLI_FLAGS.MCP_PROXY}</pre>
flag)</span
>
{/if}
</span>
</label>
{/if}
</div>
@@ -1,6 +1,7 @@
<script lang="ts">
import { ExternalLink } from '@lucide/svelte';
import { Badge } from '$lib/components/ui/badge';
import { McpLogo } from '$lib/components/app/mcp';
import { TruncatedText } from '$lib/components/app/misc';
import { sanitizeExternalUrl } from '$lib/utils';
import type { MCPServerInfo } from '$lib/types';
@@ -34,20 +35,15 @@
<span class="flex min-w-0 items-center gap-1.5">
{#if faviconUrl}
<img
src={faviconUrl}
alt=""
class={['shrink-0', iconRounded, iconClass]}
onerror={(e) => {
(e.currentTarget as HTMLImageElement).style.display = 'none';
}}
/>
<img src={faviconUrl} alt="" class={['shrink-0 text-foreground', iconRounded, iconClass]} />
{:else}
<McpLogo class={['shrink-0 text-foreground', iconRounded, iconClass].join(' ')} />
{/if}
<TruncatedText text={displayName ?? ''} class={nameClass ?? ''} />
{#if showVersion && serverInfo?.version}
<Badge variant="secondary" class="h-4 min-w-0 shrink px-1 text-[10px]">
<Badge variant="secondary" class="h-4 max-w-24 min-w-0 shrink px-1 text-[10px]">
<TruncatedText text={`v${serverInfo.version}`} />
</Badge>
{/if}
@@ -180,6 +180,16 @@ export { default as McpServerCardDeleteDialog } from './McpServerCard/McpServerC
/** Skeleton loading state for server card during health checks. */
export { default as McpServerCardSkeleton } from './McpServerCardSkeleton.svelte';
/**
* **McpServerCardCompact** - Condensed MCP server card
*
* Compact alternative to McpServerCard tailored for picker-style UIs.
* Shows the server identity, status, and a flex-wrapped list of available tools.
* Tool names are rendered as badges; hovering a badge shows its description in a tooltip.
* Does not show connection logs or server instructions.
*/
export { default as McpServerCardCompact } from './McpServerCard/McpServerCardCompact.svelte';
/**
* **McpServerIdentity** - Server identity display (icon, name, version)
*
@@ -21,7 +21,7 @@
let { class: className }: Props = $props();
let servers = $derived(mcpStore.getServersSorted());
let servers = $derived(mcpStore.visibleMcpServers);
let initialLoadComplete = $state(false);
let isAddingServer = $state(false);
+1
View File
@@ -8,6 +8,7 @@ export * from './attachment-labels';
export * from './database';
export * from './reasoning-effort';
export * from './reasoning-effort-tokens';
export * from './recommended-mcp-servers';
export * from './storage';
export * from './attachment-menu';
export * from './auto-scroll';
+2
View File
@@ -1,2 +1,4 @@
export const MCP_SERVER_URL_PLACEHOLDER = 'https://mcp.example.com/sse';
export const MIN_AUTOCOMPLETE_INPUT_LENGTH = 1;
/** Number of tools shown on the compact MCP server card before collapsing to a "+ N more" badge */
export const MCP_CARD_VISIBLE_TOOL_LIMIT = 4;
@@ -0,0 +1,35 @@
import { DEFAULT_MCP_CONFIG } from './mcp';
import type { RecommendedMCPServer } from '$lib/types';
/**
* Pre-defined recommended MCP servers.
*
* Servers are enabled by default, but they are not turned on for individual
* conversations until the user explicitly enables them (so their tools are
* disabled by default).
*/
export const RECOMMENDED_MCP_SERVERS: RecommendedMCPServer[] = [
{
id: 'exa-web-search',
name: 'Exa Web Search',
description: 'Search the web and retrieve relevant content.',
url: 'https://mcp.exa.ai/mcp',
enabled: true,
requestTimeoutSeconds: DEFAULT_MCP_CONFIG.requestTimeoutSeconds
},
{
id: 'huggingface-mcp',
name: 'Hugging Face',
description:
'Browse models, datasets, spaces and machine learning papers from the Hugging Face hub.',
url: 'https://huggingface.co/mcp',
enabled: true,
requestTimeoutSeconds: DEFAULT_MCP_CONFIG.requestTimeoutSeconds
}
];
export const RECOMMENDED_MCP_SERVER_IDS = new Set(
RECOMMENDED_MCP_SERVERS.map((server) => server.id)
);
export const RECOMMENDED_MCP_SERVERS_OPTIN_DIALOG_DELAY = 1000;
@@ -59,6 +59,7 @@ export const SETTINGS_KEYS = {
// MCP
MCP_SERVERS: 'mcpServers',
MCP_REQUEST_TIMEOUT_SECONDS: 'mcpRequestTimeoutSeconds',
MCP_DEFAULT_SERVER_OVERRIDES: 'mcpDefaultServerOverrides',
AGENTIC_MAX_TURNS: 'agenticMaxTurns',
ALWAYS_SHOW_AGENTIC_TURNS: 'alwaysShowAgenticTurns',
AGENTIC_MAX_TOOL_PREVIEW_LINES: 'agenticMaxToolPreviewLines',
@@ -28,6 +28,7 @@ import McpLogo from '$lib/components/app/mcp/McpLogo.svelte';
import { SETTINGS_KEYS } from './settings-keys';
import { ROUTES, SETTINGS_SECTION_SLUGS } from './routes';
import { TITLE_GENERATION } from './title-generation';
import { RECOMMENDED_MCP_SERVERS } from './recommended-mcp-servers';
export const SETTINGS_SECTION_TITLES = {
GENERAL: 'General',
@@ -774,9 +775,16 @@ const NON_UI_SETTINGS: SettingsEntry[] = [
key: SETTINGS_KEYS.MCP_SERVERS,
label: 'MCP servers',
help: 'Configure MCP servers as a JSON list. Use the form in the MCP Client settings section to edit.',
defaultValue: '[]',
defaultValue: JSON.stringify(RECOMMENDED_MCP_SERVERS),
type: SettingsFieldType.INPUT,
sync: { serverKey: SETTINGS_KEYS.MCP_SERVERS, paramType: SyncableParameterType.STRING }
},
{
key: SETTINGS_KEYS.MCP_DEFAULT_SERVER_OVERRIDES,
label: 'MCP default server overrides',
help: 'Per-server enable/disable defaults inherited by new chats. JSON-serialized list of {serverId, enabled} entries.',
defaultValue: '[]',
type: SettingsFieldType.INPUT
}
// {
// key: SETTINGS_KEYS.PY_INTERPRETER_ENABLED,
+2 -4
View File
@@ -21,9 +21,10 @@ export const DISABLED_TOOLS_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.disabledTool
/** Disabled tools keyed by stable selection identity, no migration from the name based key */
export const DISABLED_TOOL_KEYS_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.disabledToolKeys`;
export const FAVORITE_MODELS_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.favoriteModels`;
export const MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.mcpDefaultEnabled`;
export const THINKING_ENABLED_DEFAULT_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.thinkingEnabledDefault`;
export const REASONING_EFFORT_DEFAULT_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.reasoningEffortDefault`;
/** Set when user has interacted with the MCP server recommendations dialog (checked servers, added custom server, or dismissed) */
export const MCP_SERVERS_ADDED_TO_CHAT_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.mcpServersSetupDone`;
export const USER_OVERRIDES_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.userOverrides`;
/** Key prefix for per-conversation resumable stream state, conversationId is appended */
@@ -38,8 +39,6 @@ export const DEPRECATED_CONFIG_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME_DEPRECATED
export const DEPRECATED_DISABLED_TOOLS_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME_DEPRECATED}.disabledTools`;
/** @deprecated Use {@link FAVORITE_MODELS_LOCALSTORAGE_KEY} instead */
export const DEPRECATED_FAVORITE_MODELS_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME_DEPRECATED}.favoriteModels`;
/** @deprecated Use {@link MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY} instead */
export const DEPRECATED_MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME_DEPRECATED}.mcpDefaultEnabled`;
/** @deprecated Use {@link USER_OVERRIDES_LOCALSTORAGE_KEY} instead */
export const DEPRECATED_USER_OVERRIDES_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME_DEPRECATED}.userOverrides`;
@@ -52,6 +51,5 @@ export const NEW_TO_DEPRECATED_MAP: Record<string, string> = {
[CONFIG_LOCALSTORAGE_KEY]: DEPRECATED_CONFIG_LOCALSTORAGE_KEY,
[DISABLED_TOOLS_LOCALSTORAGE_KEY]: DEPRECATED_DISABLED_TOOLS_LOCALSTORAGE_KEY,
[FAVORITE_MODELS_LOCALSTORAGE_KEY]: DEPRECATED_FAVORITE_MODELS_LOCALSTORAGE_KEY,
[MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY]: DEPRECATED_MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY,
[USER_OVERRIDES_LOCALSTORAGE_KEY]: DEPRECATED_USER_OVERRIDES_LOCALSTORAGE_KEY
};
@@ -0,0 +1,85 @@
import { browser } from '$app/environment';
import {
MCP_SERVERS_ADDED_TO_CHAT_LOCALSTORAGE_KEY,
RECOMMENDED_MCP_SERVER_IDS,
RECOMMENDED_MCP_SERVERS_OPTIN_DIALOG_DELAY
} from '$lib/constants';
import { mcpStore } from '$lib/stores/mcp.svelte';
/**
* First-run opt-in dialog for the recommended MCP servers.
*
* Owns the dismissed / open / trigger-timeout state and the effect that
* schedules the dialog. Reads opt-in status and the configured server list
* from `mcpStore`, so callers don't need to recompute on their side.
*/
export function useMcpRecommendations() {
let dismissed = $state(
browser && localStorage.getItem(MCP_SERVERS_ADDED_TO_CHAT_LOCALSTORAGE_KEY) === 'true'
);
let open = $state(false);
let checked = $state(false);
let triggerTimeout: ReturnType<typeof setTimeout> | null = null;
function dismiss() {
if (browser) {
localStorage.setItem(MCP_SERVERS_ADDED_TO_CHAT_LOCALSTORAGE_KEY, 'true');
}
dismissed = true;
open = false;
if (triggerTimeout) {
clearTimeout(triggerTimeout);
triggerTimeout = null;
}
}
function handleOpenChange(next: boolean) {
open = next;
if (!next) dismiss();
}
$effect(() => {
if (!browser) return;
if (open || dismissed) {
if (triggerTimeout) {
clearTimeout(triggerTimeout);
triggerTimeout = null;
}
return;
}
// Already evaluated once this session; leave any pending trigger alone so
// it can still fire later. Setting `checked = true` below re-runs this
// effect, and we must not wipe the timeout that was just scheduled.
if (checked) return;
if (mcpStore.optedInRecommendationIds.size > 0) {
checked = true;
return;
}
const hasRecommendations = mcpStore
.getServers()
.some((server) => RECOMMENDED_MCP_SERVER_IDS.has(server.id));
if (hasRecommendations) {
triggerTimeout = setTimeout(() => {
open = true;
}, RECOMMENDED_MCP_SERVERS_OPTIN_DIALOG_DELAY);
}
checked = true;
});
return {
get open() {
return open;
},
get dismissed() {
return dismissed;
},
dismiss,
handleOpenChange
};
}
+59 -1
View File
@@ -20,6 +20,7 @@
import Dexie from 'dexie';
import {
STORAGE_APP_NAME,
STORAGE_APP_NAME_DEPRECATED,
DB_APP_NAME_DEPRECATED,
CONFIG_LOCALSTORAGE_KEY,
IDXDB_TABLES,
@@ -494,12 +495,69 @@ const customJsonKeyMigration: Migration = {
}
};
const MCP_DEFAULT_ENABLED_MIGRATION_ID = 'mcp-default-enabled-to-config-v1';
const LEGACY_MCP_DEFAULT_ENABLED_KEY = `${STORAGE_APP_NAME}.mcpDefaultEnabled`;
const DEPRECATED_LEGACY_MCP_DEFAULT_ENABLED_KEY = `${STORAGE_APP_NAME_DEPRECATED}.mcpDefaultEnabled`;
const mcpDefaultEnabledMigration: Migration = {
id: MCP_DEFAULT_ENABLED_MIGRATION_ID,
description:
'Copy mcpDefaultEnabled localStorage key into settings config (preserves legacy keys)',
async run(): Promise<void> {
const raw =
localStorage.getItem(LEGACY_MCP_DEFAULT_ENABLED_KEY) ??
localStorage.getItem(DEPRECATED_LEGACY_MCP_DEFAULT_ENABLED_KEY);
// Legacy keys intentionally left in place so a downgrade keeps reading them.
if (raw === null) {
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
console.log('[Migration] MCP default enabled: no legacy key found, skipping');
return;
}
const configRaw = localStorage.getItem(CONFIG_LOCALSTORAGE_KEY);
const config = configRaw ? JSON.parse(configRaw) : {};
// Don't overwrite an existing config entry — current data wins.
if (SETTINGS_KEYS.MCP_DEFAULT_SERVER_OVERRIDES in config) {
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
console.log('[Migration] MCP default enabled: config already has overrides, skipping');
return;
}
try {
const parsed = JSON.parse(raw);
if (!Array.isArray(parsed)) return;
const valid = parsed.every(
(o) =>
typeof o === 'object' &&
o !== null &&
typeof (o as Record<string, unknown>).serverId === 'string' &&
typeof (o as Record<string, unknown>).enabled === 'boolean'
);
if (!valid) return;
} catch {
return;
}
config[SETTINGS_KEYS.MCP_DEFAULT_SERVER_OVERRIDES] = raw;
localStorage.setItem(CONFIG_LOCALSTORAGE_KEY, JSON.stringify(config));
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
console.log('[Migration] MCP default enabled: moved legacy key into config');
}
};
const migrations: Migration[] = [
localStorageMigration,
idxdbMigration,
legacyMessageMigration,
themeMigration,
customJsonKeyMigration
customJsonKeyMigration,
mcpDefaultEnabledMigration
];
export const MigrationService = {
+29 -8
View File
@@ -1083,6 +1083,11 @@ class ChatStore {
let resolvedModel: string | null = null;
let modelPersisted = false;
const convId = assistantMessage.convId;
// Tracks the last message created in this flow. Used as the parent for the next
// turn's assistant message so createAssistantMessage does not have to read
// conversationsStore.activeMessages, which may belong to a different conversation
// after the user navigates while the loop is still running.
let lastCreatedInFlow = currentMessageId;
// freeze the POST identity from t0 so a stop cancels with the exact session key,
// never a stale or empty model resolved later
this.setChatStreaming(convId, streamedContent, currentMessageId, effectiveModel);
@@ -1208,8 +1213,15 @@ class ChatStore {
};
if (timings) uiUpdate.timings = timings;
if (resolvedModel) uiUpdate.model = resolvedModel;
conversationsStore.updateMessageAtIndex(idx, uiUpdate);
await conversationsStore.updateCurrentNode(currentMessageId);
// touch the active ui array and node pointer only when this conversation
// is displayed; otherwise persist the node move straight to the db so a
// foreign conv's currNode stays untouched
if (conversationsStore.activeConversation?.id === convId) {
conversationsStore.updateMessageAtIndex(idx, uiUpdate);
await conversationsStore.updateCurrentNode(currentMessageId);
} else {
await DatabaseService.updateCurrentNode(convId, currentMessageId);
}
},
createToolResultMessage: async (
toolCallId: string,
@@ -1230,8 +1242,16 @@ class ChatStore {
},
currentMessageId
);
conversationsStore.addMessageToActive(msg);
await conversationsStore.updateCurrentNode(msg.id);
// mirror into the active store and move the node pointer only when this
// conversation is displayed; otherwise persist the node move straight to
// the db for the owning conv so a foreign conv's currNode stays untouched
if (conversationsStore.activeConversation?.id === convId) {
conversationsStore.addMessageToActive(msg);
await conversationsStore.updateCurrentNode(msg.id);
} else {
await DatabaseService.updateCurrentNode(convId, msg.id);
}
lastCreatedInFlow = msg.id;
return msg;
},
createAssistantMessage: async () => {
@@ -1239,8 +1259,6 @@ class ChatStore {
streamedContent = '';
streamedReasoningContent = '';
const lastMsg =
conversationsStore.activeMessages[conversationsStore.activeMessages.length - 1];
const msg = await DatabaseService.createMessageBranch(
{
convId,
@@ -1252,10 +1270,13 @@ class ChatStore {
children: [],
model: resolvedModel
},
lastMsg.id
lastCreatedInFlow
);
conversationsStore.addMessageToActive(msg);
if (conversationsStore.activeConversation?.id === convId) {
conversationsStore.addMessageToActive(msg);
}
currentMessageId = msg.id;
lastCreatedInFlow = msg.id;
return msg;
},
onFlowComplete: (finalTimings?: ChatMessageTimings) => {
@@ -23,7 +23,7 @@ import { browser } from '$app/environment';
import { toast } from 'svelte-sonner';
import { DatabaseService } from '$lib/services/database.service';
import { MigrationService } from '$lib/services/migration.service';
import { config } from '$lib/stores/settings.svelte';
import { config, settingsStore } from '$lib/stores/settings.svelte';
import { filterByLeafNodeId, findLeafNode, generateConversationTitle } from '$lib/utils';
import type { McpServerOverride } from '$lib/types/database';
import { zipSync, unzipSync, strToU8, strFromU8 } from 'fflate';
@@ -46,7 +46,7 @@ import {
ISO_TIME_SEPARATOR_REPLACEMENT,
NON_ALPHANUMERIC_REGEX,
MULTIPLE_UNDERSCORE_REGEX,
MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY,
SETTINGS_KEYS,
THINKING_ENABLED_DEFAULT_LOCALSTORAGE_KEY,
REASONING_EFFORT_DEFAULT_LOCALSTORAGE_KEY
} from '$lib/constants';
@@ -90,12 +90,10 @@ class ConversationsStore {
/** Global (non-conversation-specific) reasoning effort default */
pendingReasoningEffort = $state<ReasoningEffort>(ConversationsStore.loadReasoningEffortDefault());
/** Load MCP default overrides from localStorage */
private static loadMcpDefaults(): McpServerOverride[] {
if (typeof globalThis.localStorage === 'undefined') return [];
const raw = config()[SETTINGS_KEYS.MCP_DEFAULT_SERVER_OVERRIDES];
if (typeof raw !== 'string' || raw.length === 0) return [];
try {
const raw = localStorage.getItem(MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY);
if (!raw) return [];
const parsed = JSON.parse(raw);
if (!Array.isArray(parsed)) return [];
return parsed.filter(
@@ -106,18 +104,12 @@ class ConversationsStore {
}
}
/** Persist MCP default overrides to localStorage */
private saveMcpDefaults(): void {
if (typeof globalThis.localStorage === 'undefined') return;
const plain = this.pendingMcpServerOverrides.map((o) => ({
serverId: o.serverId,
enabled: o.enabled
}));
if (plain.length > 0) {
localStorage.setItem(MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY, JSON.stringify(plain));
} else {
localStorage.removeItem(MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY);
}
settingsStore.updateConfig(SETTINGS_KEYS.MCP_DEFAULT_SERVER_OVERRIDES, JSON.stringify(plain));
}
/** Load thinking-enabled default from localStorage */
@@ -189,6 +181,10 @@ class ConversationsStore {
try {
await MigrationService.runAllMigrations();
// Re-read defaults after migrations: a migration may have populated
// the settings config (e.g. moved legacy MCP overrides into it).
this.pendingMcpServerOverrides = ConversationsStore.loadMcpDefaults();
await this.loadConversations();
this.isInitialized = true;
} catch (error) {
+36 -4
View File
@@ -20,11 +20,13 @@
*/
import { browser } from '$app/environment';
import { SvelteSet } from 'svelte/reactivity';
import { SETTINGS_KEYS } from '$lib/constants';
import { MCPService } from '$lib/services/mcp.service';
import { config, settingsStore } from '$lib/stores/settings.svelte';
import { mcpResourceStore } from '$lib/stores/mcp-resources.svelte';
import { serverStore } from '$lib/stores/server.svelte';
import { conversationsStore } from '$lib/stores/conversations.svelte';
import { mode } from 'mode-watcher';
import {
parseMcpServerSettings,
@@ -48,10 +50,11 @@ import {
EXPECTED_THEMED_ICON_PAIR_COUNT,
MCP_ALLOWED_ICON_MIME_TYPES,
MCP_SERVER_ID_PREFIX,
MCP_RECONNECT_INITIAL_DELAY,
MCP_RECONNECT_BACKOFF_MULTIPLIER,
MCP_RECONNECT_INITIAL_DELAY,
MCP_RECONNECT_MAX_DELAY,
MCP_RECONNECT_ATTEMPT_TIMEOUT_MS
MCP_RECONNECT_ATTEMPT_TIMEOUT_MS,
RECOMMENDED_MCP_SERVER_IDS
} from '$lib/constants';
import type {
MCPToolCall,
@@ -70,6 +73,7 @@ import type {
Tool,
HealthCheckState,
MCPServerSettingsEntry,
MCPServerDisplayInfo,
MCPServerConfig,
MCPResourceIcon,
MCPResourceAttachment,
@@ -365,7 +369,7 @@ class MCPStore {
return this.connections;
}
getServerLabel(server: MCPServerSettingsEntry): string {
getServerLabel(server: MCPServerDisplayInfo): string {
const healthState = this.getHealthCheckState(server.id);
if (healthState?.status === HealthCheckStatus.SUCCESS)
@@ -527,7 +531,7 @@ class MCPStore {
addServer(
serverData: Omit<MCPServerSettingsEntry, 'id' | 'requestTimeoutSeconds'> & { id?: string }
): void {
): MCPServerSettingsEntry {
const servers = this.getServers();
const newServer: MCPServerSettingsEntry = {
id: serverData.id || (uuid() ?? `server-${Date.now()}`),
@@ -540,6 +544,7 @@ class MCPStore {
useProxy: serverData.useProxy
};
settingsStore.updateConfig(SETTINGS_KEYS.MCP_SERVERS, JSON.stringify([...servers, newServer]));
return newServer;
}
updateServer(id: string, updates: Partial<MCPServerSettingsEntry>): void {
@@ -576,6 +581,33 @@ class MCPStore {
});
}
/**
* Recommended MCP server IDs the user opted in to via per-chat overrides.
* Single source of truth for "which recommendations has the user accepted",
* shared by the recommendations hook and the visible-servers getter.
*/
get optedInRecommendationIds(): ReadonlySet<string> {
const ids = new SvelteSet<string>();
for (const override of conversationsStore.pendingMcpServerOverrides) {
if (RECOMMENDED_MCP_SERVER_IDS.has(override.serverId) && override.enabled) {
ids.add(override.serverId);
}
}
return ids;
}
/**
* MCP servers selectable in chat-add UIs and the settings page:
* enabled in settings and either non-recommended or explicitly opted in.
*/
get visibleMcpServers(): MCPServerSettingsEntry[] {
const optedIn = this.optedInRecommendationIds;
return this.getServersSorted().filter(
(server) =>
server.enabled && (!RECOMMENDED_MCP_SERVER_IDS.has(server.id) || optedIn.has(server.id))
);
}
async ensureInitialized(perChatOverrides?: McpServerOverride[]): Promise<boolean> {
if (!browser) {
return false;
+2
View File
@@ -127,6 +127,8 @@ export type {
MCPServerConfig,
MCPClientConfig,
MCPServerSettingsEntry,
MCPServerDisplayInfo,
RecommendedMCPServer,
MCPToolCall,
OpenAIToolDefinition,
ServerStatus,
+18 -3
View File
@@ -209,17 +209,32 @@ export type MCPToolCall = {
};
};
export type MCPServerSettingsEntry = {
/**
* Minimum fields needed to display or identify an MCP server.
*/
export interface MCPServerDisplayInfo {
id: string;
enabled: boolean;
name?: string;
url: string;
}
export type MCPServerSettingsEntry = MCPServerDisplayInfo & {
enabled: boolean;
requestTimeoutSeconds: number;
headers?: string;
name?: string;
iconUrl?: string;
useProxy?: boolean;
};
/**
* Pre-defined recommended MCP server shown to the user in onboarding/picker UIs.
*/
export interface RecommendedMCPServer extends MCPServerDisplayInfo {
description: string;
enabled: boolean;
requestTimeoutSeconds: number;
}
export interface MCPHostManagerConfig {
servers: MCPClientConfig['servers'];
clientInfo?: Implementation;
+9
View File
@@ -8,6 +8,7 @@
import { onMount } from 'svelte';
import { SidebarNavigation, DialogConversationTitleUpdate } from '$lib/components/app';
import { DialogMcpServerRecommendations } from '$lib/components/app/dialogs';
import { PwaMetaTags, PwaRefreshAlert } from '$lib/components/pwa';
import { pwaAssetsHead } from 'virtual:pwa-assets/head';
@@ -26,6 +27,7 @@
import { FAVICON_PATHS, FAVICON_SELECTORS } from '$lib/constants/pwa';
import { useKeyboardShortcuts } from '$lib/hooks/use-keyboard-shortcuts.svelte';
import { usePwa } from '$lib/hooks/use-pwa.svelte';
import { useMcpRecommendations } from '$lib/hooks/use-mcp-recommendations.svelte';
import { conversations } from '$lib/stores/conversations.svelte';
import { isMobile } from '$lib/stores/viewport.svelte';
import { theme } from '$lib/stores/theme.svelte';
@@ -37,6 +39,8 @@
let innerHeight = $state<number | undefined>();
let innerWidth = $state(browser ? window.innerWidth : 0);
const mcpRecommendations = useMcpRecommendations();
let chatSidebar:
| {
activateSearchMode?: () => void;
@@ -321,6 +325,11 @@
onConfirm={handleTitleUpdateConfirm}
onCancel={handleTitleUpdateCancel}
/>
<DialogMcpServerRecommendations
open={mcpRecommendations.open}
onOpenChange={mcpRecommendations.handleOpenChange}
/>
</Tooltip.Provider>
<!-- PWA update prompt + version -->
@@ -0,0 +1,37 @@
<script lang="ts">
import { untrack } from 'svelte';
import McpServerForm from '$lib/components/app/mcp/McpServerForm.svelte';
interface Props {
headers?: string;
}
let { headers = '' }: Props = $props();
let headersState = $state(untrack(() => headers));
let lastCapturedHeaders = $state(untrack(() => headers));
$effect(() => {
if (headers !== lastCapturedHeaders) {
headersState = headers;
lastCapturedHeaders = headers;
}
});
</script>
<!--
Drives McpServerForm with a controlled `headers` string and exposes the
latest captured value through `data-captured-headers` so the client test
can read it back without a custom binding API.
-->
<McpServerForm
url="https://example.test/mcp"
headers={headersState}
onUrlChange={() => {}}
onHeadersChange={(value) => {
headersState = value;
}}
id="mcp-server-form-test"
/>
<div data-testid="captured-headers" data-captured-headers={headersState} hidden></div>
@@ -0,0 +1,133 @@
import { describe, expect, it } from 'vitest';
import { render } from 'vitest-browser-svelte';
import McpServerFormWrapper from './components/McpServerFormWrapper.svelte';
const AUTHORIZATION_HEADER = 'Authorization';
const BEARER_PREFIX = 'Bearer ';
const BEARER_PLACEHOLDER = 'Paste token here';
/**
* Client-side tests for the McpServerForm bearer UI.
*
* The dedicated UI only "owns" Authorization headers that already carry a
* Bearer scheme (heuristic check on the value). Other Authorization values
* stay in the KV section so the user can still edit them verbatim. Storage
* always goes through the same custom-headers slot, so a round-trip via this
* UI produces exactly one `Authorization: Bearer <token>` entry.
*
* Equivalent parser coverage lives in `tests/unit/headers.test.ts`.
*/
describe('McpServerForm - Authorization / bearer UI', () => {
function bearerInput(screen: Awaited<ReturnType<typeof render>>) {
return screen.locator.getByPlaceholder(BEARER_PLACEHOLDER);
}
function capturedHeaders(screen: Awaited<ReturnType<typeof render>>) {
return screen.getByTestId('captured-headers');
}
it('mounts with the bearer input hidden when no auth header is present', async () => {
const screen = await render(McpServerFormWrapper, { headers: '' });
await expect.element(screen.getByRole('textbox', { name: /server url/i })).toBeVisible();
await expect.element(bearerInput(screen)).not.toBeInTheDocument();
});
it('toggling Authorization shows the bearer input', async () => {
const screen = await render(McpServerFormWrapper, { headers: '' });
await screen.getByRole('switch', { name: /authorization/i }).click();
await expect.element(bearerInput(screen)).toBeVisible();
});
it('typing a token writes the Authorization row with the Bearer prefix prepended', async () => {
const screen = await render(McpServerFormWrapper, { headers: '' });
await screen.getByRole('switch', { name: /authorization/i }).click();
const token = 'super-secret';
await bearerInput(screen).fill(token);
const expected = JSON.stringify({ [AUTHORIZATION_HEADER]: `${BEARER_PREFIX}${token}` });
await expect
.element(capturedHeaders(screen))
.toHaveAttribute('data-captured-headers', expected);
});
it('pre-existing Bearer header pre-fills the bearer input with the token stripped', async () => {
const existing = JSON.stringify({
'X-Trace-Id': 'abc',
[AUTHORIZATION_HEADER]: `${BEARER_PREFIX}preexisting`
});
const screen = await render(McpServerFormWrapper, { headers: existing });
await expect.element(bearerInput(screen)).toBeVisible();
await expect.element(bearerInput(screen)).toHaveValue('preexisting');
});
it('non-Bearer Authorization is ignored by the dedicated UI and stays in the KV section', async () => {
const existing = JSON.stringify({ [AUTHORIZATION_HEADER]: 'Basic czNjcjpwYXNz' });
const screen = await render(McpServerFormWrapper, { headers: existing });
await expect.element(bearerInput(screen)).not.toBeInTheDocument();
const headerKeyInput = screen.getByPlaceholder('Header name');
await expect.element(headerKeyInput).toBeVisible();
});
it('engaging the token UI replaces a non-Bearer Authorization with the Bearer scheme', async () => {
const existing = JSON.stringify({ [AUTHORIZATION_HEADER]: 'Basic old' });
const screen = await render(McpServerFormWrapper, { headers: existing });
await screen.getByRole('switch', { name: /authorization/i }).click();
await bearerInput(screen).fill('new');
const expected = JSON.stringify({ [AUTHORIZATION_HEADER]: `${BEARER_PREFIX}new` });
await expect
.element(capturedHeaders(screen))
.toHaveAttribute('data-captured-headers', expected);
});
it('toggling Authorization off with no token drops the Bearer row but keeps non-Bearer schemes', async () => {
const existing = JSON.stringify({ [AUTHORIZATION_HEADER]: `${BEARER_PREFIX}xyz` });
const screen = await render(McpServerFormWrapper, { headers: existing });
await screen.getByRole('switch', { name: /authorization/i }).click();
await expect.element(capturedHeaders(screen)).toHaveAttribute('data-captured-headers', '');
});
it('toggling Authorization off when no Bearer row is present leaves headers untouched', async () => {
const existing = JSON.stringify({ [AUTHORIZATION_HEADER]: 'Basic czNjcjpwYXNz' });
const screen = await render(McpServerFormWrapper, { headers: existing });
await screen.getByRole('switch', { name: /authorization/i }).click();
await screen.getByRole('switch', { name: /authorization/i }).click();
await expect
.element(capturedHeaders(screen))
.toHaveAttribute('data-captured-headers', existing);
});
it('clearing the bearer input drops the Authorization row', async () => {
const existing = JSON.stringify({ [AUTHORIZATION_HEADER]: `${BEARER_PREFIX}xyz` });
const screen = await render(McpServerFormWrapper, { headers: existing });
await bearerInput(screen).fill('');
await expect.element(capturedHeaders(screen)).toHaveAttribute('data-captured-headers', '');
});
it('does not surface Bearer Authorization in the KV section even when pre-existing', async () => {
const existing = JSON.stringify({ [AUTHORIZATION_HEADER]: `${BEARER_PREFIX}xyz` });
const screen = await render(McpServerFormWrapper, { headers: existing });
const headerKeyInput = screen.getByPlaceholder('Header name');
await expect.element(headerKeyInput).not.toBeInTheDocument();
});
});
+126
View File
@@ -0,0 +1,126 @@
import { describe, expect, it } from 'vitest';
import { parseHeadersToArray, serializeHeaders } from '$lib/utils/headers';
/**
* Tests for the header serialization helpers used by the MCP server form
* (custom header rows) and the new Authorization/Bearer-token flow.
*/
describe('parseHeadersToArray', () => {
it('returns an empty array for empty or whitespace-only input', () => {
expect(parseHeadersToArray('')).toEqual([]);
expect(parseHeadersToArray(' ')).toEqual([]);
expect(parseHeadersToArray(undefined as unknown as string)).toEqual([]);
});
it('returns an empty array for invalid JSON input', () => {
expect(parseHeadersToArray('{not-json')).toEqual([]);
expect(parseHeadersToArray('[]')).toEqual([]);
expect(parseHeadersToArray('"plain-string"')).toEqual([]);
});
it('converts an object into ordered key/value pairs', () => {
expect(parseHeadersToArray('{"X-Foo":"bar","Authorization":"Bearer abc"}')).toEqual([
{ key: 'X-Foo', value: 'bar' },
{ key: 'Authorization', value: 'Bearer abc' }
]);
});
it('stringifies non-string values', () => {
expect(parseHeadersToArray('{"count":"42","flag":"true"}')).toEqual([
{ key: 'count', value: '42' },
{ key: 'flag', value: 'true' }
]);
});
});
describe('serializeHeaders', () => {
it('returns an empty string when there are no valid pairs', () => {
expect(serializeHeaders([])).toBe('');
expect(serializeHeaders([{ key: '', value: 'value' }])).toBe('');
expect(serializeHeaders([{ key: ' ', value: 'value' }])).toBe('');
});
it('returns an empty string when every pair has a blank key', () => {
expect(
serializeHeaders([
{ key: '', value: 'drop-me' },
{ key: ' ', value: 'drop-me-too' },
{ key: '\t', value: 'tab-key' }
])
).toBe('');
});
it('drops pairs with empty keys but keeps the rest', () => {
expect(
serializeHeaders([
{ key: '', value: 'drop-me' },
{ key: 'X-Keep', value: 'ok' }
])
).toBe('{"X-Keep":"ok"}');
});
it('trims keys before serializing', () => {
expect(serializeHeaders([{ key: ' X-Space ', value: 'ok' }])).toBe('{"X-Space":"ok"}');
});
it('preserves the input order of surviving pairs', () => {
const serialized = serializeHeaders([
{ key: 'X-C', value: '3' },
{ key: 'X-A', value: '1' },
{ key: 'X-B', value: '2' }
]);
// Object key order follows insertion order in modern JS engines, so
// the serialized JSON writes keys in our input order.
expect(JSON.parse(serialized)).toEqual({ 'X-C': '3', 'X-A': '1', 'X-B': '2' });
});
});
describe('parseHeadersToArray / serializeHeaders roundtrip', () => {
it('serializes back to an equal header object after a parse', () => {
const original = JSON.stringify({
'Content-Type': 'application/json',
'X-Trace-Id': 'abc-123'
});
const roundtrip = serializeHeaders(parseHeadersToArray(original));
expect(JSON.parse(roundtrip)).toEqual(JSON.parse(original));
});
it('drops rows whose keys are blank after trimming during serialization', () => {
const pairs = parseHeadersToArray('{"X-Keep":"ok","":"drop-me"}');
// parseHeadersToArray keeps raw key strings (the consumer is expected to
// filter blanks, not the parser); serialization must strip them.
expect(pairs).toEqual([
{ key: 'X-Keep', value: 'ok' },
{ key: '', value: 'drop-me' }
]);
expect(serializeHeaders(pairs)).toBe('{"X-Keep":"ok"}');
});
it('preserves upstream keys untouched (does not lowercase them)', () => {
const upperCased = '{"Authorization":"Bearer xyz"}';
const parsed = parseHeadersToArray(upperCased);
expect(parsed).toEqual([{ key: 'Authorization', value: 'Bearer xyz' }]);
});
it('bearer-token write survives a re-parse when paired with regular custom headers', () => {
// The McpServerForm bearer UI writes {Authorization: `Bearer <token>`}
// into the same headers string as the custom KV section. The round
// trip below mirrors the exact shape the form produces so a future
// refactor of either code path cannot silently change the on-disk key.
const pairs = [
{ key: 'X-Trace-Id', value: 'abc-123' },
{ key: 'Authorization', value: 'Bearer super-secret' }
];
const serialized = serializeHeaders(pairs);
expect(serialized).toBe('{"X-Trace-Id":"abc-123","Authorization":"Bearer super-secret"}');
expect(parseHeadersToArray(serialized)).toEqual(pairs);
});
});
@@ -0,0 +1,144 @@
import { describe, expect, it, vi } from 'vitest';
import { parseMcpServerSettings } from '$lib/utils/mcp';
import { DEFAULT_MCP_CONFIG, MCP_SERVER_ID_PREFIX } from '$lib/constants/mcp';
/**
* Tests for the mcpServers settings parser.
*
* The branch seeds the MCP servers setting with a default value of
* `JSON.stringify(RECOMMENDED_MCP_SERVERS)`, so the parser has to be
* resilient to anything that may live in the user's localStorage: malformed
* JSON, wrong shapes, missing fields, falsy-but-not-zero numbers, and entry
* arrays that have been mutated by the user via the settings form.
*/
describe('parseMcpServerSettings', () => {
it('returns an empty array for falsy or whitespace-only input', () => {
expect(parseMcpServerSettings(null)).toEqual([]);
expect(parseMcpServerSettings(undefined)).toEqual([]);
expect(parseMcpServerSettings('')).toEqual([]);
expect(parseMcpServerSettings(' ')).toEqual([]);
});
it('returns an empty array and logs a warning for invalid JSON strings', () => {
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {});
expect(parseMcpServerSettings('{not-json')).toEqual([]);
expect(warn).toHaveBeenCalled();
warn.mockRestore();
});
it('returns an empty array for valid JSON that is not an array', () => {
expect(parseMcpServerSettings('"plain-string"')).toEqual([]);
expect(parseMcpServerSettings('{"id":"foo"}')).toEqual([]);
expect(parseMcpServerSettings('42')).toEqual([]);
expect(parseMcpServerSettings('null')).toEqual([]);
});
it('drops entries with no parseable id and substitutes a stable fallback', () => {
const parsed = parseMcpServerSettings(
JSON.stringify([{ url: 'https://a.test', enabled: true }, { url: 'https://b.test' }])
);
expect(parsed).toHaveLength(2);
expect(parsed[0]?.id).toBe(`${MCP_SERVER_ID_PREFIX}-1`);
expect(parsed[1]?.id).toBe(`${MCP_SERVER_ID_PREFIX}-2`);
});
it('reuses the first id when it is present and falls back only for missing ones', () => {
const parsed = parseMcpServerSettings(
JSON.stringify([
{ id: 'custom-1', url: 'https://a.test' },
{ url: 'https://b.test' },
{ id: 'custom-3', url: 'https://c.test' }
])
);
expect(parsed[0]?.id).toBe('custom-1');
expect(parsed[1]?.id).toBe(`${MCP_SERVER_ID_PREFIX}-2`);
expect(parsed[2]?.id).toBe('custom-3');
});
it('falls back to the configured default requestTimeoutSeconds only for nullish values', () => {
const fallback = DEFAULT_MCP_CONFIG.requestTimeoutSeconds;
const parsed = parseMcpServerSettings(
JSON.stringify([
{ id: 'a', url: 'https://a.test' },
{ id: 'b', url: 'https://b.test', requestTimeoutSeconds: undefined },
{ id: 'c', url: 'https://c.test', requestTimeoutSeconds: 0 },
{ id: 'd', url: 'https://d.test', requestTimeoutSeconds: 45 }
])
);
// The parser uses ?? for timeout fallback, which only triggers on
// null/undefined. Explicit 0 is preserved at face value.
expect(parsed[0]?.requestTimeoutSeconds).toBe(fallback);
expect(parsed[1]?.requestTimeoutSeconds).toBe(fallback);
expect(parsed[2]?.requestTimeoutSeconds).toBe(0);
expect(parsed[3]?.requestTimeoutSeconds).toBe(45);
});
it('treats whitespace-only headers strings as undefined', () => {
const parsed = parseMcpServerSettings(
JSON.stringify([
{ id: 'a', url: 'https://a.test', headers: ' ' },
{ id: 'b', url: 'https://b.test', headers: '{"X-Foo":"bar"}' }
])
);
// The parser trims headers and coerces empty/whitespace to undefined.
expect(parsed[0]?.headers).toBeUndefined();
expect(parsed[1]?.headers).toBe('{"X-Foo":"bar"}');
});
it('defaults coercion for booleans (undefined -> false, true -> true)', () => {
const parsed = parseMcpServerSettings(
JSON.stringify([
{ id: 'a', url: 'https://a.test' },
{ id: 'b', url: 'https://b.test', enabled: true },
{ id: 'c', url: 'https://c.test', enabled: false },
{ id: 'd', url: 'https://d.test', useProxy: true }
])
);
expect(parsed[0]?.enabled).toBe(false);
expect(parsed[1]?.enabled).toBe(true);
expect(parsed[2]?.enabled).toBe(false);
expect(parsed[0]?.useProxy).toBe(false);
expect(parsed[3]?.useProxy).toBe(true);
});
it('preserves input order when mapping entries', () => {
const source = [
{ id: 'gamma', url: 'https://c.test' },
{ id: 'alpha', url: 'https://a.test' },
{ id: 'beta', url: 'https://b.test' }
];
const parsed = parseMcpServerSettings(JSON.stringify(source));
expect(parsed.map((entry) => entry.id)).toEqual(['gamma', 'alpha', 'beta']);
});
it('passes non-string raw input through the JSON-equality path', () => {
const parsed = parseMcpServerSettings([
{ id: 'a', url: 'https://a.test' },
{ id: 'b', url: 'https://b.test', enabled: true }
]);
expect(parsed).toHaveLength(2);
expect(parsed[0]?.id).toBe('a');
expect(parsed[1]?.enabled).toBe(true);
});
it('coerces non-string url values to an empty string rather than throwing', () => {
const parsed = parseMcpServerSettings(
JSON.stringify([{ id: 'a', url: 42 }, { id: 'b' }, { id: 'c', url: 'https://c.test' }])
);
expect(parsed[0]?.url).toBe('');
expect(parsed[1]?.url).toBe('');
expect(parsed[2]?.url).toBe('https://c.test');
});
});
@@ -0,0 +1,90 @@
import { describe, expect, it } from 'vitest';
import {
RECOMMENDED_MCP_SERVER_IDS,
RECOMMENDED_MCP_SERVERS
} from '$lib/constants/recommended-mcp-servers';
import { parseMcpServerSettings } from '$lib/utils/mcp';
import { DEFAULT_MCP_CONFIG, MCP_SERVER_ID_PREFIX } from '$lib/constants/mcp';
/**
* Tests for the predefined recommended MCP servers.
*
* These are surfaced to first-time users via
* DialogMcpServerRecommendations and used as the default value of the MCP
* servers setting, so a regression that breaks the round-trip through the
* settings parser would silently break onboarding for new users.
*/
describe('RECOMMENDED_MCP_SERVERS', () => {
it('lists at least one entry and uses stable, unique ids', () => {
expect(RECOMMENDED_MCP_SERVERS.length).toBeGreaterThan(0);
const ids = RECOMMENDED_MCP_SERVERS.map((server) => server.id);
expect(new Set(ids).size).toBe(ids.length);
for (const id of ids) {
expect(id).toMatch(/^[a-z0-9-]+$/);
expect(id.toLowerCase()).not.toContain(MCP_SERVER_ID_PREFIX.toLowerCase());
}
});
it('requires a name, description and url for every entry', () => {
for (const server of RECOMMENDED_MCP_SERVERS) {
expect(server.name?.trim().length ?? 0).toBeGreaterThan(0);
expect(server.description.trim().length).toBeGreaterThan(0);
expect(server.url.trim().length).toBeGreaterThan(0);
expect(() => new URL(server.url)).not.toThrow();
}
});
});
describe('RECOMMENDED_MCP_SERVER_IDS', () => {
it('matches the ids declared in RECOMMENDED_MCP_SERVERS', () => {
expect(RECOMMENDED_MCP_SERVER_IDS.size).toBe(RECOMMENDED_MCP_SERVERS.length);
for (const server of RECOMMENDED_MCP_SERVERS) {
expect(RECOMMENDED_MCP_SERVER_IDS.has(server.id)).toBe(true);
}
});
});
describe('recommended-mcp-servers default value', () => {
it('round-trips cleanly through parseMcpServerSettings', () => {
const serialized = JSON.stringify(RECOMMENDED_MCP_SERVERS);
const parsed = parseMcpServerSettings(serialized);
expect(parsed).toHaveLength(RECOMMENDED_MCP_SERVERS.length);
for (let index = 0; index < RECOMMENDED_MCP_SERVERS.length; index++) {
const source = RECOMMENDED_MCP_SERVERS[index];
const entry = parsed[index];
expect(entry).toBeDefined();
expect(entry?.id).toBe(source.id);
expect(entry?.url).toBe(source.url);
expect(entry?.enabled).toBe(source.enabled);
expect(entry?.requestTimeoutSeconds).toBe(source.requestTimeoutSeconds);
expect(entry?.name).toBe(source.name);
// Headers and useProxy are not set on recommended servers; the
// parser must fall back to the inactive defaults rather than
// surfacing undefined-boundary states.
expect(entry?.headers).toBeUndefined();
expect(entry?.useProxy).toBe(false);
}
});
it('uses the global default timeout when one is not specified on an entry', () => {
const sourceOnlyRequired = {
id: 'roundtrip-only',
name: 'Only required fields',
url: 'https://example.test/mcp',
description: 'Smoke entry for parser roundtrip with default timeout.',
enabled: true
};
const parsed = parseMcpServerSettings(JSON.stringify([sourceOnlyRequired]));
const entry = parsed[0];
expect(entry?.requestTimeoutSeconds).toBe(DEFAULT_MCP_CONFIG.requestTimeoutSeconds);
});
});
+127 -35
View File
@@ -478,7 +478,7 @@ bool set_socket_opt_time(socket_t sock, int level, int optname,
}
bool is_hex(char c, int &v) {
if (isdigit(static_cast<unsigned char>(c))) {
if (is_ascii_digit(c)) {
v = c - '0';
return true;
} else if ('A' <= c && c <= 'F') {
@@ -695,7 +695,11 @@ std::string base64_encode(const std::string &in) {
std::string out;
out.reserve(in.size());
auto val = 0;
// Unsigned: the accumulator is never masked, so with a signed int the
// `val << 8` below overflows once enough bytes are folded in (undefined
// behaviour before C++20). Only the low bits are ever emitted, so the
// wrap-around of an unsigned accumulator does not affect the output.
uint32_t val = 0;
auto valb = -6;
for (auto c : in) {
@@ -3887,8 +3891,7 @@ bool parse_range_header(const std::string &s, Ranges &ranges) {
bool parse_range_header(const std::string &s, Ranges &ranges) try {
#endif
auto is_valid = [](const std::string &str) {
return std::all_of(str.cbegin(), str.cend(),
[](unsigned char c) { return std::isdigit(c); });
return std::all_of(str.cbegin(), str.cend(), is_ascii_digit);
};
if (s.size() > 7 && s.compare(0, 6, "bytes=") == 0) {
@@ -4336,7 +4339,7 @@ bool is_multipart_boundary_chars_valid(const std::string &boundary) {
auto valid = true;
for (size_t i = 0; i < boundary.size(); i++) {
auto c = boundary[i];
if (!std::isalnum(static_cast<unsigned char>(c)) && c != '-' && c != '_') {
if (!is_ascii_alnum(c) && c != '-' && c != '_') {
valid = false;
break;
}
@@ -4344,18 +4347,47 @@ bool is_multipart_boundary_chars_valid(const std::string &boundary) {
return valid;
}
// Escape a multipart field name/filename following the WHATWG HTML standard
// ("escape a multipart form-data name"), which is what browsers send:
// '"' -> %22, CR -> %0D, LF -> %0A
// With escape_quote = false, only CR and LF are escaped; this is for header
// values outside a quoted-string (e.g. Content-Type), where '"' is legal.
std::string escape_multipart_field(const std::string &s,
bool escape_quote = true) {
std::string result;
result.reserve(s.size());
for (auto c : s) {
switch (c) {
case '"':
if (escape_quote) {
result += "%22";
} else {
result += c;
}
break;
case '\r': result += "%0D"; break;
case '\n': result += "%0A"; break;
default: result += c; break;
}
}
return result;
}
template <typename T>
std::string
serialize_multipart_formdata_item_begin(const T &item,
const std::string &boundary) {
std::string body = "--" + boundary + "\r\n";
body += "Content-Disposition: form-data; name=\"" + item.name + "\"";
body += "Content-Disposition: form-data; name=\"" +
escape_multipart_field(item.name) + "\"";
if (!item.filename.empty()) {
body += "; filename=\"" + item.filename + "\"";
body += "; filename=\"" + escape_multipart_field(item.filename) + "\"";
}
body += "\r\n";
if (!item.content_type.empty()) {
body += "Content-Type: " + item.content_type + "\r\n";
body +=
"Content-Type: " + escape_multipart_field(item.content_type, false) +
"\r\n";
}
body += "\r\n";
@@ -4821,10 +4853,9 @@ private:
namespace fields {
bool is_token_char(char c) {
return std::isalnum(static_cast<unsigned char>(c)) || c == '!' || c == '#' ||
c == '$' || c == '%' || c == '&' || c == '\'' || c == '*' ||
c == '+' || c == '-' || c == '.' || c == '^' || c == '_' || c == '`' ||
c == '|' || c == '~';
return is_ascii_alnum(c) || c == '!' || c == '#' || c == '$' || c == '%' ||
c == '&' || c == '\'' || c == '*' || c == '+' || c == '-' ||
c == '.' || c == '^' || c == '_' || c == '`' || c == '|' || c == '~';
}
bool is_token(const std::string &s) {
@@ -4873,7 +4904,8 @@ bool is_field_value(const std::string &s) { return is_field_content(s); }
} // namespace fields
bool perform_websocket_handshake(Stream &strm, const std::string &host,
int port, const std::string &path,
int port, bool is_ssl,
const std::string &path,
const Headers &headers,
std::string &selected_subprotocol) {
// Validate path and host
@@ -4899,7 +4931,7 @@ bool perform_websocket_handshake(Stream &strm, const std::string &host,
// Build upgrade request
std::string req_str = "GET " + path + " HTTP/1.1\r\n";
req_str += "Host: " + host + ":" + std::to_string(port) + "\r\n";
req_str += "Host: " + make_host_and_port_string(host, port, is_ssl) + "\r\n";
req_str += "Upgrade: websocket\r\n";
req_str += "Connection: Upgrade\r\n";
req_str += "Sec-WebSocket-Key: " + client_key + "\r\n";
@@ -5599,9 +5631,8 @@ std::string encode_uri_component(const std::string &value) {
escaped << std::hex;
for (auto c : value) {
if (std::isalnum(static_cast<uint8_t>(c)) || c == '-' || c == '_' ||
c == '.' || c == '!' || c == '~' || c == '*' || c == '\'' || c == '(' ||
c == ')') {
if (detail::is_ascii_alnum(c) || c == '-' || c == '_' || c == '.' ||
c == '!' || c == '~' || c == '*' || c == '\'' || c == '(' || c == ')') {
escaped << c;
} else {
escaped << std::uppercase;
@@ -5620,10 +5651,10 @@ std::string encode_uri(const std::string &value) {
escaped << std::hex;
for (auto c : value) {
if (std::isalnum(static_cast<uint8_t>(c)) || c == '-' || c == '_' ||
c == '.' || c == '!' || c == '~' || c == '*' || c == '\'' || c == '(' ||
c == ')' || c == ';' || c == '/' || c == '?' || c == ':' || c == '@' ||
c == '&' || c == '=' || c == '+' || c == '$' || c == ',' || c == '#') {
if (detail::is_ascii_alnum(c) || c == '-' || c == '_' || c == '.' ||
c == '!' || c == '~' || c == '*' || c == '\'' || c == '(' || c == ')' ||
c == ';' || c == '/' || c == '?' || c == ':' || c == '@' || c == '&' ||
c == '=' || c == '+' || c == '$' || c == ',' || c == '#') {
escaped << c;
} else {
escaped << std::uppercase;
@@ -5684,7 +5715,8 @@ std::string encode_path_component(const std::string &component) {
auto c = static_cast<unsigned char>(component[i]);
// Unreserved characters per RFC 3986: ALPHA / DIGIT / "-" / "." / "_" / "~"
if (std::isalnum(c) || c == '-' || c == '.' || c == '_' || c == '~') {
if (detail::is_ascii_alnum(static_cast<char>(c)) || c == '-' || c == '.' ||
c == '_' || c == '~') {
result += static_cast<char>(c);
}
// Path-safe sub-delimiters: "!" / "$" / "&" / "'" / "(" / ")" / "*" / "+" /
@@ -5757,7 +5789,8 @@ std::string encode_query_component(const std::string &component,
auto c = static_cast<unsigned char>(component[i]);
// Unreserved characters per RFC 3986
if (std::isalnum(c) || c == '-' || c == '.' || c == '_' || c == '~') {
if (detail::is_ascii_alnum(static_cast<char>(c)) || c == '-' || c == '.' ||
c == '_' || c == '~') {
result += static_cast<char>(c);
}
// Space handling
@@ -6010,6 +6043,48 @@ size_t MultipartFormData::get_file_count(const std::string &key) const {
return static_cast<size_t>(std::distance(r.first, r.second));
}
// Multipart FormData writer implementation
bool is_valid_multipart_boundary(const std::string &boundary) {
return detail::is_multipart_boundary_chars_valid(boundary);
}
MultipartFormDataWriter::MultipartFormDataWriter()
: boundary_(detail::make_multipart_data_boundary()) {}
MultipartFormDataWriter::MultipartFormDataWriter(std::string boundary)
: boundary_(std::move(boundary)) {}
const std::string &MultipartFormDataWriter::boundary() const {
return boundary_;
}
std::string MultipartFormDataWriter::content_type() const {
return detail::serialize_multipart_formdata_get_content_type(boundary_);
}
std::string
MultipartFormDataWriter::serialize(const UploadFormDataItems &items) const {
return detail::serialize_multipart_formdata(items, boundary_);
}
size_t MultipartFormDataWriter::content_length(
const UploadFormDataItems &items) const {
return detail::get_multipart_content_length(items, boundary_);
}
std::string
MultipartFormDataWriter::item_begin(const UploadFormData &item) const {
return detail::serialize_multipart_formdata_item_begin(item, boundary_);
}
std::string MultipartFormDataWriter::item_end() {
return detail::serialize_multipart_formdata_item_end();
}
std::string MultipartFormDataWriter::finish() const {
return detail::serialize_multipart_formdata_finish(boundary_);
}
// Response implementation
size_t Response::get_header_value_u64(const std::string &key, size_t def,
size_t id) const {
@@ -6229,8 +6304,10 @@ ssize_t detail::BodyReader::read(char *buf, size_t len) {
}
// ThreadPool implementation
ThreadPool::ThreadPool(size_t n, size_t max_n, size_t mqr)
: base_thread_count_(n), max_queued_requests_(mqr), idle_thread_count_(0),
ThreadPool::ThreadPool(size_t n, size_t max_n, size_t mqr,
time_t idle_timeout_sec)
: base_thread_count_(n), max_queued_requests_(mqr),
idle_timeout_sec_(idle_timeout_sec), idle_thread_count_(0),
shutdown_(false) {
#ifndef CPPHTTPLIB_NO_EXCEPTIONS
if (max_n != 0 && max_n < n) {
@@ -6340,9 +6417,9 @@ void ThreadPool::worker(bool is_dynamic) {
idle_thread_count_++;
if (is_dynamic) {
auto has_work = cond_.wait_for(
lock, std::chrono::seconds(CPPHTTPLIB_THREAD_POOL_IDLE_TIMEOUT),
[&] { return !jobs_.empty() || shutdown_; });
auto has_work =
cond_.wait_for(lock, std::chrono::seconds(idle_timeout_sec_),
[&] { return !jobs_.empty() || shutdown_; });
if (!has_work) {
// Timed out with no work - exit this dynamic thread
idle_thread_count_--;
@@ -9687,9 +9764,18 @@ bool ClientImpl::write_request(Stream &strm, Request &req,
if (!query_part.empty()) {
// Normalize the query string (decode then re-encode) while preserving
// the original parameter order.
auto normalized = detail::normalize_query_string(query_part);
if (!normalized.empty()) { path_with_query += '?' + normalized; }
// the original parameter order. When path encoding is disabled the
// caller has supplied an already-encoded target and expects the exact
// bytes to be sent on the wire, so skip normalization for the query
// too. Normalizing here would decode-then-re-encode the query and
// corrupt pre-encoded binary payloads (e.g. turning `%20` into `+`,
// which a strict RFC 3986 server decodes back as `+`, not a space).
if (path_encode_) {
auto normalized = detail::normalize_query_string(query_part);
if (!normalized.empty()) { path_with_query += '?' + normalized; }
} else {
path_with_query += '?' + query_part;
}
// Still populate req.params for handlers/users who read them.
detail::parse_query_text(query_part, req.params);
@@ -12518,7 +12604,7 @@ bool is_ipv4_address(const std::string &str) {
for (char c : str) {
if (c == '.') {
dots++;
} else if (!isdigit(static_cast<unsigned char>(c))) {
} else if (!detail::is_ascii_digit(c)) {
return false;
}
}
@@ -12535,7 +12621,7 @@ bool parse_ipv4(const std::string &str, unsigned char *out) {
}
int val = 0;
int digits = 0;
while (*p >= '0' && *p <= '9') {
while (detail::is_ascii_digit(*p)) {
val = val * 10 + (*p - '0');
if (val > 255) { return false; }
p++;
@@ -16487,9 +16573,15 @@ bool WebSocketClient::connect() {
return false;
}
#ifdef CPPHTTPLIB_SSL_ENABLED
auto is_ssl = is_ssl_;
#else
auto is_ssl = false;
#endif
std::string selected_subprotocol;
if (!detail::perform_websocket_handshake(*strm, host_, port_, path_, headers_,
selected_subprotocol)) {
if (!detail::perform_websocket_handshake(*strm, host_, port_, is_ssl, path_,
headers_, selected_subprotocol)) {
shutdown_and_close();
return false;
}
+56 -12
View File
@@ -8,8 +8,8 @@
#ifndef CPPHTTPLIB_HTTPLIB_H
#define CPPHTTPLIB_HTTPLIB_H
#define CPPHTTPLIB_VERSION "0.48.0"
#define CPPHTTPLIB_VERSION_NUM "0x003000"
#define CPPHTTPLIB_VERSION "0.49.0"
#define CPPHTTPLIB_VERSION_NUM "0x003100"
#ifdef _WIN32
#if defined(_WIN32_WINNT) && _WIN32_WINNT < 0x0A00
@@ -309,7 +309,6 @@ using socket_t = int;
#include <array>
#include <atomic>
#include <cassert>
#include <cctype>
#include <chrono>
#include <climits>
#include <condition_variable>
@@ -540,6 +539,21 @@ make_unique(std::size_t n) {
return std::unique_ptr<T>(new RT[n]);
}
// Locale-independent ASCII character classification. The <cctype>
// counterparts (std::isalnum, std::isdigit, ...) consult the global C locale,
// so e.g. std::isalnum(0xC5) can return true once an embedder calls
// setlocale(). HTTP grammars are defined over ASCII, so raw bytes must be
// classified without regard to the locale.
inline bool is_ascii_digit(char c) { return '0' <= c && c <= '9'; }
inline bool is_ascii_alpha(char c) {
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z');
}
inline bool is_ascii_alnum(char c) {
return is_ascii_digit(c) || is_ascii_alpha(c);
}
namespace case_ignore {
inline unsigned char to_lower(int c) {
@@ -661,7 +675,7 @@ inline from_chars_result<T> from_chars(const char *first, const char *last,
for (; p != last; ++p) {
char c = *p;
int digit = -1;
if ('0' <= c && c <= '9') {
if (is_ascii_digit(c)) {
digit = c - '0';
} else if ('a' <= c && c <= 'z') {
digit = c - 'a' + 10;
@@ -733,14 +747,14 @@ inline from_chars_result<double> from_chars(const char *first, const char *last,
return false;
};
for (; p != last && '0' <= *p && *p <= '9'; ++p) {
for (; p != last && is_ascii_digit(*p); ++p) {
seen_digit = true;
accumulate(*p);
}
if (p != last && *p == '.') {
++p;
for (; p != last && '0' <= *p && *p <= '9'; ++p) {
for (; p != last && is_ascii_digit(*p); ++p) {
seen_digit = true;
if (frac_digits < max_frac_digits && accumulate(*p)) { ++frac_digits; }
}
@@ -803,8 +817,8 @@ inline bool parse_url(const std::string &url, UrlComponents &uc) {
// IPv6 host must be [a-fA-F0-9:]+ only
if (uc.host.empty()) { return false; }
for (auto c : uc.host) {
if (!((c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F') ||
(c >= '0' && c <= '9') || c == ':')) {
if (!(is_ascii_digit(c) || (c >= 'a' && c <= 'f') ||
(c >= 'A' && c <= 'F') || c == ':')) {
return false;
}
}
@@ -1541,7 +1555,9 @@ public:
class ThreadPool final : public TaskQueue {
public:
explicit ThreadPool(size_t n, size_t max_n = 0, size_t mqr = 0);
explicit ThreadPool(
size_t n, size_t max_n = 0, size_t mqr = 0,
time_t idle_timeout_sec = CPPHTTPLIB_THREAD_POOL_IDLE_TIMEOUT);
ThreadPool(const ThreadPool &) = delete;
~ThreadPool() override = default;
@@ -1556,6 +1572,7 @@ private:
size_t base_thread_count_;
size_t max_thread_count_;
size_t max_queued_requests_;
time_t idle_timeout_sec_;
size_t idle_thread_count_;
bool shutdown_;
@@ -1680,6 +1697,35 @@ make_multipart_content_provider(const UploadFormDataItems &items,
} // namespace detail
bool is_valid_multipart_boundary(const std::string &boundary);
// Serializer for multipart/form-data request bodies. The boundary is owned
// by the writer so that per-part framing and the final terminator always
// agree. Field names and filenames are escaped following the WHATWG HTML
// standard ('"' -> %22, CR -> %0D, LF -> %0A); CR and LF are also escaped
// in content types.
class MultipartFormDataWriter {
public:
MultipartFormDataWriter();
// precondition: is_valid_multipart_boundary(boundary)
explicit MultipartFormDataWriter(std::string boundary);
const std::string &boundary() const;
std::string content_type() const;
// In-memory items -> whole body (known length)
std::string serialize(const UploadFormDataItems &items) const;
size_t content_length(const UploadFormDataItems &items) const;
// Per-part framing for streaming via a content provider
std::string item_begin(const UploadFormData &item) const;
static std::string item_end();
std::string finish() const;
private:
std::string boundary_;
};
class Server {
public:
using Handler = std::function<void(const Request &, Response &)>;
@@ -2897,9 +2943,7 @@ template <size_t N> inline constexpr size_t str_len(const char (&)[N]) {
}
inline bool is_numeric(const std::string &str) {
return !str.empty() &&
std::all_of(str.cbegin(), str.cend(),
[](unsigned char c) { return std::isdigit(c); });
return !str.empty() && std::all_of(str.cbegin(), str.cend(), is_ascii_digit);
}
inline size_t get_header_value_u64(const Headers &headers,