mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-07-03 12:45:45 +02:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 94875285e4 | |||
| 5a460dea9f | |||
| c8ae9a750c | |||
| fdb1db877c | |||
| 4fc4ec5541 | |||
| a6647b1a32 | |||
| 13e673863b | |||
| b820cc8e6f | |||
| 6dbc1174b8 | |||
| 9d88e7cedd |
+10
-8
@@ -496,13 +496,15 @@ void common_models_handler_apply(common_models_handler & handler, common_params
|
||||
}
|
||||
|
||||
// handle hf_plan tasks
|
||||
auto add_tasks = [&opts, &tasks](const hf_cache::hf_files & model_files, common_params_model & model) {
|
||||
auto add_tasks = [&opts, &tasks](const hf_cache::hf_files & model_files,
|
||||
const hf_cache::hf_file & primary,
|
||||
common_params_model & model) {
|
||||
for (size_t i = 0; i < model_files.size(); ++i) {
|
||||
auto & model_file = model_files[i];
|
||||
bool is_first = (i == 0);
|
||||
tasks.emplace_back(model_file, opts, [&, is_first]() {
|
||||
if (is_first) {
|
||||
// only use first part as model path
|
||||
bool is_primary = (model_file.path == primary.path);
|
||||
tasks.emplace_back(model_file, opts, [&, is_primary]() {
|
||||
if (is_primary) {
|
||||
// the primary file is the first split (00001-of), use it as model path
|
||||
model.path = hf_cache::finalize_file(model_file);
|
||||
} else {
|
||||
hf_cache::finalize_file(model_file);
|
||||
@@ -511,7 +513,7 @@ void common_models_handler_apply(common_models_handler & handler, common_params
|
||||
}
|
||||
};
|
||||
if (!plan.model_files.empty()) {
|
||||
add_tasks(plan.model_files, params.model);
|
||||
add_tasks(plan.model_files, plan.primary, params.model);
|
||||
}
|
||||
if (!plan.mmproj.local_path.empty()) {
|
||||
tasks.emplace_back(plan.mmproj, opts, [&]() {
|
||||
@@ -539,12 +541,12 @@ void common_models_handler_apply(common_models_handler & handler, common_params
|
||||
|
||||
// handle plan_spec (e.g. --spec-draft-hf)
|
||||
if (!plan_spec.model_files.empty()) {
|
||||
add_tasks(plan_spec.model_files, params.speculative.draft.mparams);
|
||||
add_tasks(plan_spec.model_files, plan_spec.primary, params.speculative.draft.mparams);
|
||||
}
|
||||
|
||||
// handle vocoder plan (e.g. --hf-repo-v)
|
||||
if (!plan_voc.model_files.empty()) {
|
||||
add_tasks(plan_voc.model_files, params.vocoder.model);
|
||||
add_tasks(plan_voc.model_files, plan_voc.primary, params.vocoder.model);
|
||||
}
|
||||
|
||||
// run all tasks in parallel
|
||||
|
||||
+51
-39
@@ -1,16 +1,26 @@
|
||||
# llama.cpp for OpenCL
|
||||
|
||||
- [Background](#background)
|
||||
- [OS](#os)
|
||||
- [Hardware](#hardware)
|
||||
- [DataType Supports](#datatype-supports)
|
||||
- [Model Preparation](#model-preparation)
|
||||
- [CMake Options](#cmake-options)
|
||||
- [Android](#android)
|
||||
- [Windows 11 Arm64](#windows-11-arm64)
|
||||
- [Linux](#Linux)
|
||||
- [Known Issue](#known-issues)
|
||||
- [TODO](#todo)
|
||||
- [llama.cpp for OpenCL](#llamacpp-for-opencl)
|
||||
- [Background](#background)
|
||||
- [Llama.cpp + OpenCL](#llamacpp--opencl)
|
||||
- [OS](#os)
|
||||
- [Hardware](#hardware)
|
||||
- [Adreno GPU](#adreno-gpu)
|
||||
- [DataType Supports](#datatype-supports)
|
||||
- [Model Preparation](#model-preparation)
|
||||
- [Binary Kernel Library](#binary-kernel-library)
|
||||
- [CMake Options](#cmake-options)
|
||||
- [Android](#android)
|
||||
- [I. Setup Environment](#i-setup-environment)
|
||||
- [II. Build llama.cpp](#ii-build-llamacpp)
|
||||
- [Windows 11 Arm64](#windows-11-arm64)
|
||||
- [I. Setup Environment](#i-setup-environment-1)
|
||||
- [II. Build llama.cpp](#ii-build-llamacpp-1)
|
||||
- [Linux](#linux)
|
||||
- [I. Setup Environment](#i-setup-environment-2)
|
||||
- [II. Build llama.cpp](#ii-build-llamacpp-2)
|
||||
- [Known Issues](#known-issues)
|
||||
- [TODO](#todo)
|
||||
|
||||
## Background
|
||||
|
||||
@@ -34,11 +44,13 @@ The llama.cpp OpenCL backend is designed to enable llama.cpp on **Qualcomm Adren
|
||||
|
||||
**Verified devices**
|
||||
|
||||
| Adreno GPU | Status |
|
||||
|:------------------------------------:|:-------:|
|
||||
| Adreno 750 (Snapdragon 8 Gen 3) | Support |
|
||||
| Adreno 830 (Snapdragon 8 Elite) | Support |
|
||||
| Adreno X85 (Snapdragon X Elite) | Support |
|
||||
| Adreno GPU | Status |
|
||||
|:-------------------------------------:|:-------:|
|
||||
| Adreno 750 (Snapdragon 8 Gen 3) | Support |
|
||||
| Adreno 830 (Snapdragon 8 Elite) | Support |
|
||||
| Adreno 840 (Snapdragon 8 Elite Gen 5) | Support |
|
||||
| Adreno X1-85 (Snapdragon X Elite) | Support |
|
||||
| Adreno X2-90 (Snapdragon X2 Elite) | Support |
|
||||
|
||||
> A6x GPUs with a recent driver and compiler are supported; they are usually found in IoT platforms.
|
||||
However, A6x GPUs in phones are likely not supported due to the outdated driver and compiler.
|
||||
@@ -47,42 +59,43 @@ However, A6x GPUs in phones are likely not supported due to the outdated driver
|
||||
|
||||
| DataType | Status |
|
||||
|:----------------------:|:--------------------------:|
|
||||
| Q1_0 | Support |
|
||||
| Q4_0 | Support |
|
||||
| Q6_K | Support, but not optimized |
|
||||
| Q4_1 | Support |
|
||||
| Q5_0 | Support |
|
||||
| Q5_1 | Support |
|
||||
| Q8_0 | Support |
|
||||
| Q4_K | Support |
|
||||
| Q5_K | Support |
|
||||
| Q6_K | Support |
|
||||
| MXFP4 | Support |
|
||||
| IQ4_NL | Support |
|
||||
|
||||
## Model Preparation
|
||||
|
||||
You can refer to the general [llama-quantize tool](/tools/quantize/README.md) for steps to convert a model in Hugging Face safetensor format to GGUF with quantization.
|
||||
Since common quantizations are supported now, it is recommanded to download GGUF models directly from Huggingface.
|
||||
|
||||
Currently we support `Q4_0` quantization and have optimized for it. To achieve best performance on Adreno GPU, add `--pure` to `llama-quantize` (i.e., make all weights in `Q4_0`). For example,
|
||||
## Binary Kernel Library
|
||||
|
||||
```sh
|
||||
./llama-quantize --pure ggml-model-qwen2.5-3b-f16.gguf ggml-model-qwen-3b-Q4_0.gguf Q4_0
|
||||
```
|
||||
A prebuilt binary kernel library has been introduced for Adreno GPUs.
|
||||
It currently targets X2 GPUs (X2-90, X2-85 and X2-45) found in Snapdragon X2 SoC.
|
||||
The library currently contains kernels for MUL_MAT_ID with Q4_0, Q4_1, Q4_K, MXFP4.
|
||||
The library must be manually downloaded from https://softwarecenter.qualcomm.com/catalog/item/Adreno_Kernel_Library_GGML.
|
||||
|
||||
Since `Q6_K` is also supported, `Q4_0` quantization without `--pure` will also work. However, the performance will be worse compared to pure `Q4_0` quantization.
|
||||
To allow using the kernel library, add `-DGGML_OPENCL_USE_ADRENO_BIN_KERNELS=ON` when configuring with CMake.
|
||||
Then, extract `adreno-opencl-kernels.dll` from the zip file downloaded from the above URL and put it alongside the executables.
|
||||
If kernels compatible with the current GPU are found in the library, they will be loaded and used.
|
||||
|
||||
### `MXFP4` MoE Models
|
||||
|
||||
OpenAI gpt-oss models are MoE models in `MXFP4`. The quantized model will be in `MXFP4_MOE`, a mixture of `MXFP4` and `Q8_0`.
|
||||
For this quantization, there is no need to specify `--pure`.
|
||||
For gpt-oss-20b model, you can directly [download](https://huggingface.co/ggml-org/gpt-oss-20b-GGUF) the quantized GGUF file in `MXFP4_MOE` from Hugging Face.
|
||||
|
||||
Although it is possible to quantize gpt-oss-20b model in pure `Q4_0` (all weights in `Q4_0`), it is not recommended since `MXFP4` has been optimized for MoE while `Q4_0` is not. In addition, accuracy should degrade with such pure `Q4_0` quantization.
|
||||
Hence, using the default `MXFP4_MOE` quantization (see the link above) is recommended for this model.
|
||||
|
||||
> Note that the `Q4_0` model found [here](https://huggingface.co/unsloth/gpt-oss-20b-GGUF/blob/main/gpt-oss-20b-Q4_0.gguf) is a mixture of `Q4_0`, `Q8_0` and `MXFP4` and gives better performance than `MXFP4_MOE` quantization.
|
||||
|
||||
## CMake Options
|
||||
|
||||
The OpenCL backend has the following CMake options that control the behavior of the backend.
|
||||
|
||||
| CMake options | Default value | Description |
|
||||
|:---------------------------------:|:--------------:|:------------------------------------------|
|
||||
| `GGML_OPENCL_EMBED_KERNELS` | `ON` | Embed OpenCL kernels into the executable. |
|
||||
| `GGML_OPENCL_USE_ADRENO_KERNELS` | `ON` | Use kernels optimized for Adreno. |
|
||||
| CMake options | Default value | Description |
|
||||
|:------------------------------------:|:--------------:|:------------------------------------------|
|
||||
| `GGML_OPENCL_EMBED_KERNELS` | `ON` | Embed OpenCL kernels into the executable. |
|
||||
| `GGML_OPENCL_USE_ADRENO_KERNELS` | `ON` | Use kernels optimized for Adreno. |
|
||||
| `GGML_OPENCL_USE_ADRENO_BIN_KERNELS` | `OFF` | Allow using binary kernel lib for Adreno. |
|
||||
|
||||
## Android
|
||||
|
||||
@@ -277,6 +290,5 @@ ninja
|
||||
|
||||
## TODO
|
||||
|
||||
- Optimization for Q6_K
|
||||
- Support and optimization for Q4_K
|
||||
- Improve flash attention
|
||||
- Improve OpenCL C kernels performance
|
||||
|
||||
@@ -1111,11 +1111,12 @@ GGML_TABLE_BEGIN(int8_t, kvalues_iq4nl, 16)
|
||||
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
|
||||
GGML_TABLE_END()
|
||||
|
||||
// e2m1 values (doubled)
|
||||
// e2m1 values (doubled), shared by MXFP4 and NVFP4
|
||||
// ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
||||
GGML_TABLE_BEGIN(int8_t, kvalues_mxfp4, 16)
|
||||
GGML_TABLE_BEGIN(int8_t, kvalues_fp4, 16)
|
||||
0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12,
|
||||
GGML_TABLE_END()
|
||||
#define kvalues_mxfp4 kvalues_fp4
|
||||
|
||||
#define NGRID_IQ1S 2048
|
||||
#define IQ1S_DELTA 0.125f
|
||||
|
||||
@@ -82,7 +82,6 @@
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
|
||||
// quants.c
|
||||
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||
|
||||
@@ -934,7 +934,7 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
|
||||
#if defined __AVX2__
|
||||
|
||||
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
|
||||
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_fp4);
|
||||
const __m128i m4b = _mm_set1_epi8(0x0f);
|
||||
const __m256i mone = _mm256_set1_epi16(1);
|
||||
|
||||
@@ -963,7 +963,7 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
|
||||
|
||||
#elif defined __AVX__
|
||||
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
|
||||
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_fp4);
|
||||
const __m128i m4b = _mm_set1_epi8(0x0f);
|
||||
|
||||
__m256 accum = _mm256_setzero_ps();
|
||||
@@ -993,14 +993,152 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
int sumi1 = 0;
|
||||
int sumi2 = 0;
|
||||
for (int j = 0; j < QK_MXFP4/2; ++j) {
|
||||
sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
|
||||
sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
|
||||
sumi1 += y[ib].qs[j + 0] * kvalues_fp4[x[ib].qs[j] & 0xf];
|
||||
sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_fp4[x[ib].qs[j] >> 4];
|
||||
}
|
||||
sumf += d * (sumi1 + sumi2);
|
||||
}
|
||||
*s = sumf;
|
||||
}
|
||||
|
||||
void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
assert(n % QK_NVFP4 == 0);
|
||||
|
||||
const block_nvfp4 * GGML_RESTRICT x = vx;
|
||||
const block_q8_0 * GGML_RESTRICT y = vy;
|
||||
|
||||
const int nb = n / QK_NVFP4;
|
||||
int ib = 0;
|
||||
float sumf = 0;
|
||||
|
||||
#if defined(__AVX2__)
|
||||
|
||||
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_fp4);
|
||||
const __m128i m4b = _mm_set1_epi8(0x0f);
|
||||
const __m256i mone = _mm256_set1_epi16(1);
|
||||
|
||||
__m256 accum = _mm256_setzero_ps();
|
||||
for(; ib < nb; ib++){
|
||||
|
||||
const __m128i q4bits_01 = _mm_loadu_si128((const __m128i *)(x[ib].qs + 0));
|
||||
const __m128i q4bits_23 = _mm_loadu_si128((const __m128i *)(x[ib].qs + 16));
|
||||
|
||||
const __m256i q8_01 = _mm256_loadu_si256((const __m256i *)y[2*ib + 0].qs);
|
||||
const __m256i q8_23 = _mm256_loadu_si256((const __m256i *)y[2*ib + 1].qs);
|
||||
|
||||
const __m128i q4_01_lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_01, m4b));
|
||||
const __m128i q4_01_hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_01, 4), m4b));
|
||||
const __m128i q4_23_lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_23, m4b));
|
||||
const __m128i q4_23_hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_23, 4), m4b));
|
||||
|
||||
//reordering
|
||||
const __m256i q4_01 = MM256_SET_M128I(_mm_unpackhi_epi64(q4_01_lo,q4_01_hi), _mm_unpacklo_epi64(q4_01_lo,q4_01_hi));
|
||||
const __m256i q4_23 = MM256_SET_M128I(_mm_unpackhi_epi64(q4_23_lo,q4_23_hi),_mm_unpacklo_epi64(q4_23_lo,q4_23_hi));
|
||||
|
||||
const __m256i p01 = mul_add_epi8(q4_01,q8_01);
|
||||
const __m256i p_1 = _mm256_madd_epi16(p01, mone);
|
||||
|
||||
const __m256i p23 = mul_add_epi8(q4_23,q8_23);
|
||||
const __m256i p_2 = _mm256_madd_epi16(p23, mone);
|
||||
|
||||
const float dy0 = GGML_CPU_FP16_TO_FP32(y[2*ib].d);
|
||||
const float dy1 = GGML_CPU_FP16_TO_FP32(y[2*ib+1].d);
|
||||
|
||||
const float s0 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[0]) * dy0;
|
||||
const float s1 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[1]) * dy0;
|
||||
const float s2 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[2]) * dy1;
|
||||
const float s3 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[3]) * dy1;
|
||||
|
||||
const __m256 scales01 = _mm256_set_m128(_mm_set1_ps(s1), _mm_set1_ps(s0));
|
||||
const __m256 scales23 = _mm256_set_m128(_mm_set1_ps(s3), _mm_set1_ps(s2));
|
||||
|
||||
accum = _mm256_fmadd_ps(scales01, _mm256_cvtepi32_ps(p_1), accum);
|
||||
accum = _mm256_fmadd_ps(scales23, _mm256_cvtepi32_ps(p_2), accum);
|
||||
}
|
||||
sumf = hsum_float_8(accum);
|
||||
|
||||
#elif defined(__AVX__)
|
||||
|
||||
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_fp4);
|
||||
const __m128i m4b = _mm_set1_epi8(0x0f);
|
||||
|
||||
__m256 accum = _mm256_setzero_ps();
|
||||
for(; ib < nb; ib++){
|
||||
|
||||
const __m128i q4bits_01 = _mm_loadu_si128((const __m128i *)(x[ib].qs + 0));
|
||||
const __m128i q4bits_23 = _mm_loadu_si128((const __m128i *)(x[ib].qs + 16));
|
||||
|
||||
const __m128i q8_0 = _mm_loadu_si128((const __m128i *)(y[2*ib + 0].qs + 0));
|
||||
const __m128i q8_1 = _mm_loadu_si128((const __m128i *)(y[2*ib + 0].qs + 16));
|
||||
const __m128i q8_2 = _mm_loadu_si128((const __m128i *)(y[2*ib + 1].qs + 0));
|
||||
const __m128i q8_3 = _mm_loadu_si128((const __m128i *)(y[2*ib + 1].qs + 16));
|
||||
|
||||
const __m128i q4_01_lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_01, m4b));
|
||||
const __m128i q4_01_hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_01, 4), m4b));
|
||||
const __m128i q4_23_lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_23, m4b));
|
||||
const __m128i q4_23_hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_23, 4), m4b));
|
||||
|
||||
const __m128i q4_0 = _mm_unpacklo_epi64(q4_01_lo, q4_01_hi);
|
||||
const __m128i q4_1 = _mm_unpackhi_epi64(q4_01_lo, q4_01_hi);
|
||||
const __m128i q4_2 = _mm_unpacklo_epi64(q4_23_lo, q4_23_hi);
|
||||
const __m128i q4_3 = _mm_unpackhi_epi64(q4_23_lo, q4_23_hi);
|
||||
|
||||
const __m128i p0_i32 = mul_sum_i8_pairs(q4_0, q8_0);
|
||||
const __m128i p1_i32 = mul_sum_i8_pairs(q4_1, q8_1);
|
||||
const __m128i p2_i32 = mul_sum_i8_pairs(q4_2, q8_2);
|
||||
const __m128i p3_i32 = mul_sum_i8_pairs(q4_3, q8_3);
|
||||
|
||||
const __m128 p0 = _mm_cvtepi32_ps(p0_i32);
|
||||
const __m128 p1 = _mm_cvtepi32_ps(p1_i32);
|
||||
const __m128 p2 = _mm_cvtepi32_ps(p2_i32);
|
||||
const __m128 p3 = _mm_cvtepi32_ps(p3_i32);
|
||||
|
||||
const __m256 p01 = _mm256_set_m128(p1, p0);
|
||||
const __m256 p23 = _mm256_set_m128(p3, p2);
|
||||
|
||||
const float dy0 = GGML_CPU_FP16_TO_FP32(y[2*ib].d);
|
||||
const float dy1 = GGML_CPU_FP16_TO_FP32(y[2*ib+1].d);
|
||||
|
||||
const float s0 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[0]) * dy0;
|
||||
const float s1 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[1]) * dy0;
|
||||
const float s2 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[2]) * dy1;
|
||||
const float s3 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[3]) * dy1;
|
||||
|
||||
const __m256 scales01 = _mm256_set_m128(_mm_set1_ps(s1), _mm_set1_ps(s0));
|
||||
const __m256 scales23 = _mm256_set_m128(_mm_set1_ps(s3), _mm_set1_ps(s2));
|
||||
|
||||
accum = _mm256_add_ps(accum, _mm256_mul_ps(p01, scales01));
|
||||
accum = _mm256_add_ps(accum, _mm256_mul_ps(p23, scales23));
|
||||
}
|
||||
sumf = hsum_float_8(accum);
|
||||
|
||||
#endif
|
||||
|
||||
for (;ib < nb; ++ib) {
|
||||
for (int s_idx = 0; s_idx < 4; ++s_idx) {
|
||||
const float d = GGML_CPU_UE4M3_TO_FP32(x[ib].d[s_idx]);
|
||||
const int q8_block = s_idx / 2;
|
||||
const int q8_off = (s_idx % 2) * QK_NVFP4_SUB;
|
||||
const float dy = GGML_CPU_FP16_TO_FP32(y[2*ib + q8_block].d);
|
||||
|
||||
int sumi_lo = 0, sumi_hi = 0;
|
||||
for (int j = 0; j < QK_NVFP4_SUB/2; ++j) {
|
||||
const uint8_t qv = x[ib].qs[s_idx*(QK_NVFP4_SUB/2) + j];
|
||||
sumi_lo += y[2*ib + q8_block].qs[q8_off + j + 0] * kvalues_fp4[qv & 0xf];
|
||||
sumi_hi += y[2*ib + q8_block].qs[q8_off + j + QK_NVFP4_SUB/2] * kvalues_fp4[qv >> 4];
|
||||
}
|
||||
|
||||
sumf += dy * d * (sumi_lo + sumi_hi);
|
||||
}
|
||||
}
|
||||
*s = sumf;
|
||||
}
|
||||
|
||||
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
||||
@@ -82,6 +82,9 @@ float ggml_table_f32_f16[1 << 16];
|
||||
// precomputed f32 table for e8m0 half (1 KB) (simd-mappings.h)
|
||||
float ggml_table_f32_e8m0_half[1 << 8];
|
||||
|
||||
// precomputed f32 table for ue4m3 (1 KB) (simd-mappings.h)
|
||||
float ggml_table_f32_ue4m3[1 << 8];
|
||||
|
||||
#if defined(__ARM_ARCH)
|
||||
struct ggml_arm_arch_features_type {
|
||||
int sve_cnt;
|
||||
@@ -3798,6 +3801,11 @@ void ggml_cpu_init(void) {
|
||||
ggml_table_f32_e8m0_half[i] = GGML_E8M0_TO_FP32_HALF(i);
|
||||
}
|
||||
|
||||
// initialize UE4M3 table (256 entries)
|
||||
for (int i = 0; i < (1 << 8); ++i) {
|
||||
ggml_table_f32_ue4m3[i] = ggml_ue4m3_to_fp32(i);
|
||||
}
|
||||
|
||||
const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
|
||||
|
||||
GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0);
|
||||
|
||||
@@ -120,6 +120,10 @@ extern float ggml_table_f32_f16[1 << 16];
|
||||
// defined in ggml-cpu.c, initialized in ggml_cpu_init()
|
||||
extern float ggml_table_f32_e8m0_half[1 << 8];
|
||||
|
||||
// precomputed f32 table for ue4m3 (1 KB)
|
||||
// defined in ggml-cpu.c, initialized in ggml_cpu_init()
|
||||
extern float ggml_table_f32_ue4m3[1 << 8];
|
||||
|
||||
// Use lookup table for E8M0 on x86 (faster than bit manipulation)
|
||||
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
||||
#define GGML_CPU_E8M0_TO_FP32_HALF(x) ggml_table_f32_e8m0_half[(uint8_t)(x)]
|
||||
@@ -127,6 +131,13 @@ extern float ggml_table_f32_e8m0_half[1 << 8];
|
||||
#define GGML_CPU_E8M0_TO_FP32_HALF(x) GGML_E8M0_TO_FP32_HALF(x)
|
||||
#endif
|
||||
|
||||
// Use lookup table for UE4M3 on x86 (faster than bit manipulation)
|
||||
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
||||
#define GGML_CPU_UE4M3_TO_FP32(x) ggml_table_f32_ue4m3[(uint8_t)(x)]
|
||||
#else
|
||||
#define GGML_CPU_UE4M3_TO_FP32(x) ggml_ue4m3_to_fp32(x)
|
||||
#endif
|
||||
|
||||
// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
|
||||
// so we define GGML_CPU_FP16_TO_FP32 and GGML_CPU_FP32_TO_FP16 elsewhere for NEON.
|
||||
// This is also true for POWER9.
|
||||
|
||||
@@ -664,7 +664,10 @@ constexpr __device__ dequantize_V_t get_dequantize_V() {
|
||||
template <int ncols1>
|
||||
__launch_bounds__(FATTN_KQ_STRIDE/2, 1)
|
||||
static __global__ void flash_attn_mask_to_KV_max(
|
||||
const half2 * __restrict__ mask, int * __restrict__ KV_max, const int ne30, const int64_t s31, const int64_t s33) {
|
||||
const half2 * mask_ptr, int * KV_max_ptr, const int ne30, const int64_t s31, const int64_t s33) {
|
||||
const half2 * GGML_CUDA_RESTRICT mask = mask_ptr;
|
||||
int * GGML_CUDA_RESTRICT KV_max = KV_max_ptr;
|
||||
|
||||
const int ne31 = gridDim.x;
|
||||
const int tid = threadIdx.x;
|
||||
const int sequence = blockIdx.y;
|
||||
@@ -1099,8 +1102,9 @@ void launch_fattn(
|
||||
const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
|
||||
|
||||
KV_max.alloc(ne_KV_max);
|
||||
flash_attn_mask_to_KV_max<ncols1><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
|
||||
((const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33);
|
||||
ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num_KV_max, block_dim_KV_max, 0, main_stream);
|
||||
ggml_cuda_kernel_launch(flash_attn_mask_to_KV_max<ncols1>, launch_params,
|
||||
(const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ gated_delta_net_cuda(const float * q,
|
||||
const float * beta,
|
||||
const float * curr_state,
|
||||
float * dst,
|
||||
float * state,
|
||||
int64_t H,
|
||||
int64_t n_tokens,
|
||||
int64_t n_seqs,
|
||||
@@ -25,6 +26,7 @@ gated_delta_net_cuda(const float * q,
|
||||
const uint3 neqk1_magic,
|
||||
const uint3 rq3_magic,
|
||||
float scale,
|
||||
int64_t state_slot_stride,
|
||||
int K) {
|
||||
const uint32_t h_idx = blockIdx.x;
|
||||
const uint32_t sequence = blockIdx.y;
|
||||
@@ -35,9 +37,7 @@ gated_delta_net_cuda(const float * q,
|
||||
const uint32_t iq1 = fastmodulo(h_idx, neqk1_magic);
|
||||
const uint32_t iq3 = fastdiv(sequence, rq3_magic);
|
||||
|
||||
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
|
||||
float * attn_data = dst;
|
||||
float * state = dst + attn_score_elems;
|
||||
|
||||
// input state holds s0 only: [S_v, S_v, H, n_seqs] — seq stride is D = H * S_v * S_v.
|
||||
// output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before.
|
||||
@@ -145,10 +145,9 @@ gated_delta_net_cuda(const float * q,
|
||||
if constexpr (keep_rs_t) {
|
||||
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
|
||||
// When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned.
|
||||
const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output
|
||||
const int target_slot = (int) n_tokens - 1 - t;
|
||||
if (target_slot >= 0 && target_slot < K) {
|
||||
float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset;
|
||||
float * curr_state = state + target_slot * state_slot_stride;
|
||||
#pragma unroll
|
||||
for (int r = 0; r < rows_per_lane; r++) {
|
||||
const int i = r * warp_size + lane;
|
||||
@@ -171,13 +170,13 @@ template <bool KDA, bool keep_rs_t>
|
||||
static void launch_gated_delta_net(
|
||||
const float * q_d, const float * k_d, const float * v_d,
|
||||
const float * g_d, const float * b_d, const float * s_d,
|
||||
float * dst_d,
|
||||
float * dst_d, float * state_d,
|
||||
int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs,
|
||||
int64_t sq1, int64_t sq2, int64_t sq3,
|
||||
int64_t sv1, int64_t sv2, int64_t sv3,
|
||||
int64_t sb1, int64_t sb2, int64_t sb3,
|
||||
int64_t neqk1, int64_t rq3,
|
||||
float scale, int K, cudaStream_t stream) {
|
||||
float scale, int64_t state_slot_stride, int K, cudaStream_t stream) {
|
||||
//TODO: Add chunked kernel for even faster pre-fill
|
||||
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
||||
const int num_warps = 4;
|
||||
@@ -187,34 +186,32 @@ static void launch_gated_delta_net(
|
||||
const uint3 neqk1_magic = init_fastdiv_values(neqk1);
|
||||
const uint3 rq3_magic = init_fastdiv_values(rq3);
|
||||
|
||||
int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(grid_dims, block_dims, 0, stream);
|
||||
switch (S_v) {
|
||||
case 16:
|
||||
ggml_cuda_kernel_launch(gated_delta_net_cuda<16, KDA, keep_rs_t>, launch_params,
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, state_slot_stride, K);
|
||||
break;
|
||||
case 32:
|
||||
ggml_cuda_kernel_launch(gated_delta_net_cuda<32, KDA, keep_rs_t>, launch_params,
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, state_slot_stride, K);
|
||||
break;
|
||||
case 64: {
|
||||
ggml_cuda_kernel_launch(gated_delta_net_cuda<64, KDA, keep_rs_t>, launch_params,
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, state_slot_stride, K);
|
||||
break;
|
||||
}
|
||||
case 128: {
|
||||
ggml_cuda_kernel_launch(gated_delta_net_cuda<128, KDA, keep_rs_t>, launch_params,
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, state_slot_stride, K);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
@@ -223,7 +220,8 @@ static void launch_gated_delta_net(
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
static void ggml_cuda_op_gated_delta_net_impl(
|
||||
ggml_backend_cuda_context & ctx, ggml_tensor * dst, const ggml_cuda_gated_delta_net_fused_cache * cache) {
|
||||
ggml_tensor * src_q = dst->src[0];
|
||||
ggml_tensor * src_k = dst->src[1];
|
||||
ggml_tensor * src_v = dst->src[2];
|
||||
@@ -288,25 +286,42 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
const int K = ggml_get_op_params_i32(dst, 0);
|
||||
const bool keep_rs = K > 1;
|
||||
|
||||
// recurrent state -> gdn_out tail (after attention scores), or the cache when fusing
|
||||
float * state_d = dst_d + S_v * H * n_tokens * n_seqs;
|
||||
int64_t state_slot_stride = S_v * S_v * H * n_seqs;
|
||||
if (cache != nullptr) {
|
||||
state_d = cache->data;
|
||||
state_slot_stride = cache->slot_stride;
|
||||
}
|
||||
|
||||
if (kda) {
|
||||
if (keep_rs) {
|
||||
launch_gated_delta_net<true, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
|
||||
launch_gated_delta_net<true, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d,
|
||||
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, state_slot_stride, K, stream);
|
||||
} else {
|
||||
launch_gated_delta_net<true, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
|
||||
launch_gated_delta_net<true, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d,
|
||||
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, state_slot_stride, K, stream);
|
||||
}
|
||||
} else {
|
||||
if (keep_rs) {
|
||||
launch_gated_delta_net<false, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
|
||||
launch_gated_delta_net<false, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d,
|
||||
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, state_slot_stride, K, stream);
|
||||
} else {
|
||||
launch_gated_delta_net<false, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
|
||||
launch_gated_delta_net<false, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d,
|
||||
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, state_slot_stride, K, stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_gated_delta_net_impl(ctx, dst, nullptr);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_gated_delta_net_fused_cache(
|
||||
ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_cuda_gated_delta_net_fused_cache cache) {
|
||||
ggml_cuda_op_gated_delta_net_impl(ctx, dst, &cache);
|
||||
}
|
||||
|
||||
@@ -1,4 +1,14 @@
|
||||
#include "common.cuh"
|
||||
#include "ggml.h"
|
||||
|
||||
// fused-kernel recurrent-state output; strides in elements (per-seq stride is always D, set in-kernel)
|
||||
struct ggml_cuda_gated_delta_net_fused_cache {
|
||||
float * data; // rollback slot 0
|
||||
int64_t slot_stride; // between rollback slots (0 when K==1)
|
||||
};
|
||||
|
||||
void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
// same op, but writes the snapshot(s) into the cache instead of dst (see ggml_cuda_try_gdn_cache_fusion)
|
||||
void ggml_cuda_op_gated_delta_net_fused_cache(ggml_backend_cuda_context & ctx, ggml_tensor * dst,
|
||||
ggml_cuda_gated_delta_net_fused_cache cache);
|
||||
|
||||
@@ -3251,6 +3251,11 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
|
||||
GGML_UNUSED(backend);
|
||||
}
|
||||
|
||||
static bool ggml_cuda_is_view_or_noop(const ggml_tensor * t) {
|
||||
return ggml_is_empty(t) || t->op == GGML_OP_RESHAPE || t->op == GGML_OP_TRANSPOSE ||
|
||||
t->op == GGML_OP_VIEW || t->op == GGML_OP_PERMUTE || t->op == GGML_OP_NONE;
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
|
||||
|
||||
@@ -3260,7 +3265,7 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
||||
if (ggml_cuda_is_view_or_noop(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -3403,6 +3408,70 @@ static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
|
||||
return true;
|
||||
}
|
||||
|
||||
// match gated_delta_net + the strided cpy that scatters its state snapshots into the cache
|
||||
// (slot i -> rollback group i, slot 0 newest), so the kernel can write them and skip the cpy.
|
||||
static int ggml_cuda_try_gdn_cache_fusion(
|
||||
const ggml_cgraph * cgraph, int node_idx, ggml_cuda_gated_delta_net_fused_cache & fused_state_cpy) {
|
||||
const ggml_tensor * gdn = cgraph->nodes[node_idx];
|
||||
// the kernel skips the snapshot tail, so the gdn output must not be a graph output
|
||||
if (gdn->op != GGML_OP_GATED_DELTA_NET || gdn->type != GGML_TYPE_F32 ||
|
||||
(gdn->flags & GGML_TENSOR_FLAG_OUTPUT)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const ggml_tensor * src_v = gdn->src[2];
|
||||
const int64_t S_v = src_v->ne[0];
|
||||
const int64_t H = src_v->ne[1];
|
||||
const int64_t n_tokens = src_v->ne[2];
|
||||
const int64_t n_seqs = src_v->ne[3];
|
||||
const int64_t D = S_v * S_v * H;
|
||||
const int64_t K = ggml_get_op_params_i32(gdn, 0); // snapshot slot count
|
||||
const int64_t n_written = std::min<int64_t>(n_tokens, K); // newest n_written slots are written
|
||||
|
||||
// snapshot tail starts right after the attention scores
|
||||
const size_t tail_off = ggml_row_size(GGML_TYPE_F32, S_v * H * n_tokens * n_seqs);
|
||||
|
||||
// snapshot cpy is the first real node after the gdn (skip views/no-ops)
|
||||
const ggml_tensor * cpy = nullptr;
|
||||
int skip = 0;
|
||||
for (int j = node_idx + 1; j < cgraph->n_nodes && cpy == nullptr; ++j) {
|
||||
const ggml_tensor * n = cgraph->nodes[j];
|
||||
if (ggml_cuda_is_view_or_noop(n)) {
|
||||
continue;
|
||||
}
|
||||
if (n->op != GGML_OP_CPY || (n->flags & GGML_TENSOR_FLAG_OUTPUT)) {
|
||||
return 0;
|
||||
}
|
||||
cpy = n;
|
||||
skip = j - node_idx;
|
||||
}
|
||||
if (cpy == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const ggml_tensor * src = cpy->src[0]; // view of the gdn snapshot tail
|
||||
const ggml_tensor * dst = cpy->src[1]; // cache view the kernel writes to
|
||||
|
||||
// src must be this gdn's snapshot tail (contiguous, at the tail offset)
|
||||
if (src->op != GGML_OP_VIEW || src->view_src != gdn || src->view_offs != tail_off ||
|
||||
!ggml_is_contiguous(src)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// dst is the [D, n_seqs, n_written] cache view; require nb[1] == D (the per-seq stride the kernel
|
||||
// assumes). ggml_cpy pins src to the same element count.
|
||||
const std::array<int64_t, GGML_MAX_DIMS> expected_ne = { D, n_seqs, n_written, 1 };
|
||||
if (dst->op != GGML_OP_VIEW || dst->type != GGML_TYPE_F32 || dst->data == nullptr ||
|
||||
!std::equal(expected_ne.begin(), expected_ne.end(), dst->ne) ||
|
||||
dst->nb[0] != ggml_type_size(GGML_TYPE_F32) || dst->nb[1] != (size_t) ggml_row_size(GGML_TYPE_F32, D)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
fused_state_cpy.data = (float *) dst->data; // rollback group 0 (newest)
|
||||
fused_state_cpy.slot_stride = K > 1 ? (int64_t) (dst->nb[2] / sizeof(float)) : 0;
|
||||
return skip;
|
||||
}
|
||||
|
||||
static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int node_idx, ggml_cuda_topk_moe_args & args) {
|
||||
args.sigmoid = false;
|
||||
args.softmax = false;
|
||||
@@ -3844,6 +3913,20 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph
|
||||
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
// gated_delta_net -> cpy: scatter recurrent-state snapshots into the cache
|
||||
if (node->op == GGML_OP_GATED_DELTA_NET) {
|
||||
ggml_cuda_gated_delta_net_fused_cache fused_state_cpy;
|
||||
const int nodes_to_skip = ggml_cuda_try_gdn_cache_fusion(cgraph, i, fused_state_cpy);
|
||||
if (nodes_to_skip > 0) {
|
||||
#ifdef GGML_CUDA_DEBUG
|
||||
GGML_LOG_INFO("%s: fused gated_delta_net snapshot copies for %s (skipped %d nodes)\n",
|
||||
__func__, node->name, nodes_to_skip);
|
||||
#endif
|
||||
ggml_cuda_op_gated_delta_net_fused_cache(*cuda_ctx, node, fused_state_cpy);
|
||||
return nodes_to_skip;
|
||||
}
|
||||
}
|
||||
|
||||
//topk-moe
|
||||
if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX ||
|
||||
cgraph->nodes[i]->op == GGML_OP_ARGSORT) {
|
||||
@@ -4372,7 +4455,7 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
|
||||
#endif
|
||||
prev_i = i;
|
||||
|
||||
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
||||
if (ggml_cuda_is_view_or_noop(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
|
||||
include(ExternalProject)
|
||||
|
||||
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
|
||||
option(GGML_HEXAGON_FA_EXP2_HF "ggml-hexagon: use FP16 exp2 polynomial in FA softmax instead of F32 exp round-trip" OFF)
|
||||
set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate")
|
||||
|
||||
add_library(htp_iface OBJECT
|
||||
|
||||
@@ -43,6 +43,7 @@
|
||||
#include "htp-opnode.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp/matmul-ops.h"
|
||||
#include "htp/flash-attn-ops.h"
|
||||
#include "htp_iface.h"
|
||||
#include "htp-drv.h"
|
||||
|
||||
@@ -62,6 +63,7 @@ static int opt_profile = 0; // profiling mode (0-disabled, 1-basic, 2-pmu)
|
||||
static int opt_hostbuf = 1; // hostbuf ON by default
|
||||
|
||||
static int opt_mm_select = 3; // 3 = HMX -> Tiled -> Flat -> CPU, 2 = Tiled -> Flat -> CPU, 1 = Flat -> CPU
|
||||
static int opt_fa_select = 2; // 2 = HMX -> HVX -> CPU, 1 = HVX -> CPU, 0 = CPU (unsupported)
|
||||
|
||||
// Default PMU events, if profiling with PMU (mode=2) is enabled
|
||||
// See https://docs.qualcomm.com/doc/80-N2040-60/topic/pmu-events.html
|
||||
@@ -125,6 +127,11 @@ static const char * htp_event_name(uint16_t id) {
|
||||
case HTP_TRACE_EVT_HVX_W_DEQUANT: return "HVX_W_DEQUANT";
|
||||
case HTP_TRACE_EVT_HVX_W_PREP: return "HVX_W_PREP";
|
||||
case HTP_TRACE_EVT_HVX_O_PROC: return "HVX_O_PROC";
|
||||
case HTP_TRACE_EVT_HVX_FA_QK: return "HVX_QK_FA";
|
||||
case HTP_TRACE_EVT_HVX_FA_SFM: return "HVX_SFM_FA";
|
||||
case HTP_TRACE_EVT_HVX_FA_Q_PREP: return "HVX_Q_PREP";
|
||||
case HTP_TRACE_EVT_HVX_FA_K_PREP: return "HVX_K_PREP";
|
||||
case HTP_TRACE_EVT_HVX_FA_V_PREP: return "HVX_V_PREP";
|
||||
case HTP_TRACE_EVT_HMX_COMP: return "HMX_COMP";
|
||||
default: return "UNKNOWN";
|
||||
}
|
||||
@@ -1879,6 +1886,162 @@ ggml_hexagon_session::~ggml_hexagon_session() noexcept(true) {
|
||||
|
||||
// ** backend interface
|
||||
|
||||
static bool ggml_hexagon_flash_attn_is_hmx_eligible(
|
||||
const struct ggml_hexagon_session * sess,
|
||||
const struct ggml_tensor * q,
|
||||
const struct ggml_tensor * k,
|
||||
const struct ggml_tensor * v,
|
||||
const struct ggml_tensor * sinks
|
||||
) {
|
||||
if (sess->n_hmx == 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (opt_fa_select < 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (k->type != GGML_TYPE_F16 || v->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const uint32_t DK = q->ne[0];
|
||||
const uint32_t DV = v->ne[0];
|
||||
|
||||
if (DK % 64 != 0 || DV % 64 != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Fall back to HVX for small token counts if head dimension is small (DK <= 128)
|
||||
const uint32_t neq1 = q->ne[1];
|
||||
if (DK <= 128 && neq1 < 5) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_precompute_flash_attn_params(
|
||||
const struct ggml_hexagon_session * sess,
|
||||
const struct ggml_tensor * op,
|
||||
struct htp_fa_kernel_params * kparams
|
||||
) {
|
||||
if (opt_fa_select < 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
memset(kparams, 0, sizeof(*kparams));
|
||||
|
||||
const struct ggml_tensor * q = op->src[0];
|
||||
const struct ggml_tensor * k = op->src[1];
|
||||
const struct ggml_tensor * v = op->src[2];
|
||||
const struct ggml_tensor * mask = op->src[3];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
const uint32_t neq0 = q->ne[0]; // head_dim (DK)
|
||||
const uint32_t neq1 = q->ne[1]; // n_tokens
|
||||
const uint32_t neq2 = q->ne[2]; // n_heads
|
||||
|
||||
const uint32_t nek1 = k->ne[1]; // kv_len
|
||||
|
||||
const uint32_t nev0 = v->ne[0]; // head_dim (DV)
|
||||
|
||||
const uint32_t DK = neq0;
|
||||
const uint32_t DV = nev0;
|
||||
|
||||
const uint32_t n_kv_heads = k->ne[2];
|
||||
const uint32_t G = neq2 / n_kv_heads;
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float logit_softcap = 0.0f;
|
||||
memcpy(&scale, &op->op_params[0], sizeof(float));
|
||||
memcpy(&max_bias, &op->op_params[1], sizeof(float));
|
||||
memcpy(&logit_softcap, &op->op_params[2], sizeof(float));
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
scale /= logit_softcap;
|
||||
}
|
||||
|
||||
kparams->scale = scale;
|
||||
kparams->max_bias = max_bias;
|
||||
kparams->logit_softcap = logit_softcap;
|
||||
|
||||
kparams->is_q_fp32 = (q->type == GGML_TYPE_F32) ? 1 : 0;
|
||||
kparams->is_dst_fp32 = (dst->type == GGML_TYPE_F32) ? 1 : 0;
|
||||
kparams->G = G;
|
||||
|
||||
const uint32_t n_head = q->ne[2];
|
||||
kparams->n_head_log2 = 1u << (uint32_t) std::floor(std::log2(n_head));
|
||||
kparams->m0 = std::pow(2.0f, -(max_bias) / kparams->n_head_log2);
|
||||
kparams->m1 = std::pow(2.0f, -(max_bias / 2.0f) / kparams->n_head_log2);
|
||||
|
||||
// Check HMX eligibility
|
||||
const struct ggml_tensor * sinks = op->src[4];
|
||||
if (ggml_hexagon_flash_attn_is_hmx_eligible(sess, q, k, v, sinks)) {
|
||||
size_t Br = 0, Bc = 0;
|
||||
int ret = hmx_fa_find_chunk_size(&Br, &Bc, G, DK, DV, neq1, nek1, sess->vtcm_size, sess->n_threads);
|
||||
if (ret == 0) {
|
||||
kparams->kernel_type = HTP_FA_KERNEL_HMX;
|
||||
kparams->Br = Br;
|
||||
kparams->Bc = Bc;
|
||||
kparams->n_kv_blocks = (nek1 + Bc - 1) / Bc;
|
||||
kparams->n_threads = (kparams->n_kv_blocks >= 3 && sess->n_threads >= 2) ? sess->n_threads : 1;
|
||||
|
||||
kparams->u.hmx.g_br = hex_align_up(G * Br, 32);
|
||||
kparams->u.hmx.pipeline = (kparams->n_kv_blocks >= 3 && sess->n_threads >= 2) ? 1 : 0;
|
||||
kparams->vtcm_size = hmx_fa_compute_vtcm_usage(G, DK, DV, Br, Bc, kparams->n_threads, kparams->u.hmx.pipeline != 0);
|
||||
|
||||
const size_t row_vec_bytes = hex_align_up(Bc * sizeof(uint16_t), 256);
|
||||
kparams->u.hmx.row_buf_stride = row_vec_bytes / 128; // HVX vector is 128 bytes
|
||||
|
||||
const size_t m_line_bytes = hex_align_up(Bc * sizeof(uint16_t), 128);
|
||||
kparams->u.hmx.mask_buf_row_stride = m_line_bytes / sizeof(uint16_t);
|
||||
kparams->u.hmx.mask_broadcast = (mask != nullptr && mask->ne[2] == 1) ? 1 : 0;
|
||||
kparams->u.hmx.div_G = init_fastdiv_values(G);
|
||||
if (mask) {
|
||||
kparams->src3_div2 = init_fastdiv_values(mask->ne[2]);
|
||||
kparams->src3_div3 = init_fastdiv_values(mask->ne[3]);
|
||||
}
|
||||
|
||||
kparams->qrows = 0;
|
||||
kparams->qrows_per_thread = 0;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to HVX
|
||||
kparams->kernel_type = HTP_FA_KERNEL_HVX;
|
||||
kparams->Br = 1;
|
||||
kparams->Bc = 64; // FLASH_ATTN_BLOCK_SIZE
|
||||
kparams->n_kv_blocks = (k->ne[1] + 64 - 1) / 64;
|
||||
kparams->n_threads = sess->n_threads;
|
||||
|
||||
const size_t size_q_row_padded = hex_round_up(q->ne[0] * (kparams->is_q_fp32 ? 4 : 2), 128);
|
||||
const size_t size_k_row_padded = hex_round_up(k->ne[0] * 2, 128);
|
||||
const size_t size_v_row_padded = hex_round_up(v->ne[0] * 2, 128);
|
||||
|
||||
kparams->vtcm_size = hvx_fa_compute_vtcm_usage(DK, DV, kparams->is_q_fp32 != 0, mask != nullptr, sess->n_threads);
|
||||
|
||||
kparams->u.hvx.size_q_row_padded = size_q_row_padded;
|
||||
kparams->u.hvx.size_k_row_padded = size_k_row_padded;
|
||||
kparams->u.hvx.size_v_row_padded = size_v_row_padded;
|
||||
kparams->u.hvx.src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
|
||||
kparams->u.hvx.src0_div1 = init_fastdiv_values(q->ne[1]);
|
||||
kparams->u.hvx.broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
|
||||
kparams->u.hvx.broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
|
||||
kparams->u.hvx.broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
|
||||
kparams->u.hvx.broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
|
||||
if (mask) {
|
||||
kparams->src3_div2 = init_fastdiv_values(mask->ne[2]);
|
||||
kparams->src3_div3 = init_fastdiv_values(mask->ne[3]);
|
||||
}
|
||||
|
||||
kparams->qrows = q->ne[1] * q->ne[2] * q->ne[3];
|
||||
kparams->qrows_per_thread = (kparams->qrows + sess->n_threads - 1) / sess->n_threads;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
@@ -1912,6 +2075,17 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess
|
||||
return false;
|
||||
}
|
||||
|
||||
struct htp_fa_kernel_params kparams;
|
||||
if (!ggml_hexagon_precompute_flash_attn_params(sess, op, &kparams)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((size_t) kparams.vtcm_size > sess->vtcm_size) {
|
||||
HEX_VERBOSE("ggml-hex: skip flash_attn_ext because VTCM needed (%d) > budget (%zu)\n",
|
||||
kparams.vtcm_size, sess->vtcm_size);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -2211,14 +2385,14 @@ static void ggml_hexagon_precompute_hvx_mm_params(
|
||||
kparams->kernel_type = (src1_nrows < (int) sess->n_threads) ? HTP_MM_KERNEL_HVX_QUANT_BLOCK : HTP_MM_KERNEL_HVX_QUANT_ROW;
|
||||
kparams->src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10);
|
||||
|
||||
size_t vtcm_src0_size = 0, vtcm_src1_size = 0;
|
||||
size_t vtcm_src0_size = 0, vtcm_src1_size = 0, vtcm_dst_size = 0;
|
||||
uint32_t max_prefetch = (src1_nrows > HTP_MM_HMX_MIN_NROWS) ? 2 : 16;
|
||||
uint32_t best_n_prefetch = 2;
|
||||
size_t total_size = 0;
|
||||
for (uint32_t d = max_prefetch; d >= 2; d /= 2) {
|
||||
total_size = htp_mm_hvx_id_get_vtcm_sizes(
|
||||
wtype, ne10, src1_nrows, sess->n_threads, src0->nb[1], d,
|
||||
&vtcm_src0_size, &vtcm_src1_size
|
||||
&vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size
|
||||
);
|
||||
if (total_size <= vtcm_budget) {
|
||||
best_n_prefetch = d;
|
||||
@@ -2228,14 +2402,14 @@ static void ggml_hexagon_precompute_hvx_mm_params(
|
||||
if (best_n_prefetch == 2 && total_size > vtcm_budget) {
|
||||
total_size = htp_mm_hvx_id_get_vtcm_sizes(
|
||||
wtype, ne10, src1_nrows, sess->n_threads, src0->nb[1], 2,
|
||||
&vtcm_src0_size, &vtcm_src1_size
|
||||
&vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size
|
||||
);
|
||||
}
|
||||
kparams->n_prefetch = best_n_prefetch;
|
||||
kparams->vtcm_size = total_size;
|
||||
kparams->vtcm_src0_size = vtcm_src0_size;
|
||||
kparams->vtcm_src1_size = vtcm_src1_size;
|
||||
kparams->vtcm_dst_size = 0;
|
||||
kparams->vtcm_dst_size = vtcm_dst_size;
|
||||
} else {
|
||||
bool try_tiled = (k_align && opt_mm_select >= 2);
|
||||
if (try_tiled) {
|
||||
@@ -2441,11 +2615,12 @@ static void ggml_hexagon_precompute_fused_qkv_params(
|
||||
size_t src3_sz_per_thread = 0;
|
||||
uint32_t best_n_prefetch = 16;
|
||||
|
||||
size_t quant_scratch_size = hex_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float)) * sess->n_threads;
|
||||
|
||||
if (is_repack) {
|
||||
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
|
||||
uint32_t n_k_tiles = hex_round_up(ne10, 32) / 32;
|
||||
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
|
||||
size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float));
|
||||
size_t src1_sz_per_thread = hex_round_up(src1_row_size * src1_nrows, 128);
|
||||
size_t src1_sz = src1_sz_per_thread;
|
||||
|
||||
@@ -2453,13 +2628,10 @@ static void ggml_hexagon_precompute_fused_qkv_params(
|
||||
best_n_prefetch = 2;
|
||||
for (uint32_t d = max_prefetch; d >= 2; d /= 2) {
|
||||
size_t repacked_vtcm_size = hex_round_up(d * tile_row_size, 128);
|
||||
if (repacked_vtcm_size < src1_row_size_padded) {
|
||||
repacked_vtcm_size = src1_row_size_padded;
|
||||
}
|
||||
size_t src0_sz = repacked_vtcm_size * sess->n_threads;
|
||||
size_t src2_sz = hex_round_up(d * tile_row_size, 128) * sess->n_threads;
|
||||
size_t src3_sz = hex_round_up(d * tile_row_size, 128) * sess->n_threads;
|
||||
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz;
|
||||
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz + quant_scratch_size;
|
||||
|
||||
if (tiled_vtcm_size <= sess->vtcm_size) {
|
||||
best_n_prefetch = d;
|
||||
@@ -2471,9 +2643,6 @@ static void ggml_hexagon_precompute_fused_qkv_params(
|
||||
}
|
||||
if (best_n_prefetch == 2 && src0_sz_per_thread == 0) {
|
||||
size_t repacked_vtcm_size = hex_round_up(2 * tile_row_size, 128);
|
||||
if (repacked_vtcm_size < src1_row_size_padded) {
|
||||
repacked_vtcm_size = src1_row_size_padded;
|
||||
}
|
||||
src0_sz_per_thread = repacked_vtcm_size;
|
||||
src2_sz_per_thread = hex_round_up(2 * tile_row_size, 128);
|
||||
src3_sz_per_thread = hex_round_up(2 * tile_row_size, 128);
|
||||
@@ -2492,7 +2661,7 @@ static void ggml_hexagon_precompute_fused_qkv_params(
|
||||
size_t src2_sz = src2_sz_per_thread * sess->n_threads;
|
||||
size_t src3_sz = src3_sz_per_thread * sess->n_threads;
|
||||
|
||||
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz;
|
||||
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz + quant_scratch_size;
|
||||
bool try_tiled = (opt_mm_select >= 2);
|
||||
if (try_tiled && tiled_vtcm_size <= sess->vtcm_size) {
|
||||
kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW;
|
||||
@@ -2500,6 +2669,7 @@ static void ggml_hexagon_precompute_fused_qkv_params(
|
||||
kparams->vtcm_src1_size = src1_sz;
|
||||
kparams->vtcm_src2_size = src2_sz;
|
||||
kparams->vtcm_src3_size = src3_sz;
|
||||
kparams->vtcm_dst_size = quant_scratch_size;
|
||||
kparams->vtcm_size = tiled_vtcm_size;
|
||||
kparams->n_prefetch = best_n_prefetch;
|
||||
} else {
|
||||
@@ -2510,7 +2680,8 @@ static void ggml_hexagon_precompute_fused_qkv_params(
|
||||
kparams->vtcm_src1_size = flat_src1_sz;
|
||||
kparams->vtcm_src2_size = src2_sz;
|
||||
kparams->vtcm_src3_size = src3_sz;
|
||||
kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz + src3_sz;
|
||||
kparams->vtcm_dst_size = quant_scratch_size;
|
||||
kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz + src3_sz + quant_scratch_size;
|
||||
kparams->n_prefetch = best_n_prefetch;
|
||||
}
|
||||
}
|
||||
@@ -2536,11 +2707,12 @@ static void ggml_hexagon_precompute_fused_ffn_params(
|
||||
size_t src2_sz_per_thread = 0;
|
||||
uint32_t best_n_prefetch = 16;
|
||||
|
||||
size_t quant_scratch_size = hex_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float)) * sess->n_threads;
|
||||
|
||||
if (is_repack) {
|
||||
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
|
||||
uint32_t n_k_tiles = hex_round_up(ne10, 32) / 32;
|
||||
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
|
||||
size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float));
|
||||
size_t src1_sz_per_thread = hex_round_up(src1_row_size * src1_nrows, 128);
|
||||
size_t src1_sz = src1_sz_per_thread;
|
||||
|
||||
@@ -2548,12 +2720,9 @@ static void ggml_hexagon_precompute_fused_ffn_params(
|
||||
best_n_prefetch = 2;
|
||||
for (uint32_t d = max_prefetch; d >= 2; d /= 2) {
|
||||
size_t repacked_vtcm_size = hex_round_up(d * tile_row_size, 128);
|
||||
if (repacked_vtcm_size < src1_row_size_padded) {
|
||||
repacked_vtcm_size = src1_row_size_padded;
|
||||
}
|
||||
size_t src0_sz = repacked_vtcm_size * sess->n_threads;
|
||||
size_t src2_sz = hex_round_up(d * tile_row_size, 128) * sess->n_threads;
|
||||
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz;
|
||||
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + quant_scratch_size;
|
||||
|
||||
if (tiled_vtcm_size <= sess->vtcm_size) {
|
||||
best_n_prefetch = d;
|
||||
@@ -2564,9 +2733,6 @@ static void ggml_hexagon_precompute_fused_ffn_params(
|
||||
}
|
||||
if (best_n_prefetch == 2 && src0_sz_per_thread == 0) {
|
||||
size_t repacked_vtcm_size = hex_round_up(2 * tile_row_size, 128);
|
||||
if (repacked_vtcm_size < src1_row_size_padded) {
|
||||
repacked_vtcm_size = src1_row_size_padded;
|
||||
}
|
||||
src0_sz_per_thread = repacked_vtcm_size;
|
||||
src2_sz_per_thread = hex_round_up(2 * tile_row_size, 128);
|
||||
}
|
||||
@@ -2582,13 +2748,14 @@ static void ggml_hexagon_precompute_fused_ffn_params(
|
||||
size_t src1_sz = src1_sz_per_thread;
|
||||
size_t src2_sz = src2_sz_per_thread * sess->n_threads;
|
||||
|
||||
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz;
|
||||
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + quant_scratch_size;
|
||||
bool try_tiled = (opt_mm_select >= 2);
|
||||
if (try_tiled && tiled_vtcm_size <= sess->vtcm_size) {
|
||||
kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW;
|
||||
kparams->vtcm_src0_size = src0_sz;
|
||||
kparams->vtcm_src1_size = src1_sz;
|
||||
kparams->vtcm_src2_size = src2_sz;
|
||||
kparams->vtcm_dst_size = quant_scratch_size;
|
||||
kparams->vtcm_size = tiled_vtcm_size;
|
||||
kparams->n_prefetch = best_n_prefetch;
|
||||
} else {
|
||||
@@ -2598,7 +2765,8 @@ static void ggml_hexagon_precompute_fused_ffn_params(
|
||||
kparams->vtcm_src0_size = src0_sz;
|
||||
kparams->vtcm_src1_size = flat_src1_sz;
|
||||
kparams->vtcm_src2_size = src2_sz;
|
||||
kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz;
|
||||
kparams->vtcm_dst_size = quant_scratch_size;
|
||||
kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz + quant_scratch_size;
|
||||
kparams->n_prefetch = best_n_prefetch;
|
||||
}
|
||||
}
|
||||
@@ -3243,7 +3411,7 @@ static inline bool op_is_compute(ggml_tensor *node)
|
||||
return !ggml_op_is_empty(node->op) && !ggml_is_empty(node) && (node->flags & GGML_TENSOR_FLAG_COMPUTE);
|
||||
}
|
||||
|
||||
static bool is_hmx_eligible(const ggml_tensor * t) {
|
||||
static bool mm_is_hmx_eligible(const ggml_tensor * t) {
|
||||
if (opt_nhmx == 0) { return false; }
|
||||
|
||||
const ggml_tensor * src0 = t->src[0];
|
||||
@@ -3262,7 +3430,7 @@ static bool is_hmx_eligible(const ggml_tensor * t) {
|
||||
static bool is_mergeable_mul_mat(const ggml_tensor * t) {
|
||||
if (!t || t->op != GGML_OP_MUL_MAT) return false;
|
||||
if (t->src[1]->type != GGML_TYPE_F32) return false;
|
||||
return ggml_is_quantized(t->src[0]->type) && !is_hmx_eligible(t);
|
||||
return ggml_is_quantized(t->src[0]->type) && !mm_is_hmx_eligible(t);
|
||||
}
|
||||
|
||||
static bool is_mergeable_mul_mat_pair(const ggml_tensor * n1, const ggml_tensor * n2) {
|
||||
@@ -3357,6 +3525,26 @@ static bool try_fuse_node(const ggml_hexagon_session * sess, const ggml_cgraph *
|
||||
}
|
||||
}
|
||||
|
||||
if (n->op == GGML_OP_MUL_MAT && next_node) {
|
||||
if (next_node->op == GGML_OP_ADD && op_is_compute(next_node) && ggml_can_fuse(graph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) {
|
||||
if (next_node->src[0] == n || next_node->src[1] == n) {
|
||||
struct htp_mm_kernel_params kparams;
|
||||
ggml_hexagon_precompute_matmul_params(sess, n->src[0], n->src[1], next_node, &kparams);
|
||||
if ((size_t)kparams.vtcm_size <= sess->vtcm_size) {
|
||||
htp_opnode node(n, {}, HTP_OP_MUL_MAT_ADD);
|
||||
node.add_fused(next_node);
|
||||
memcpy(node.kernel_params, &kparams, sizeof(kparams));
|
||||
nodes.push_back(std::move(node));
|
||||
i += 1;
|
||||
return true;
|
||||
} else {
|
||||
HEX_VERBOSE("ggml-hex: skip MUL_MAT_ADD fusion because VTCM needed (%d) > budget (%zu)\n",
|
||||
kparams.vtcm_size, sess->vtcm_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -3393,6 +3581,11 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
node.node->src[0], node.node->src[1], node.node,
|
||||
(struct htp_mm_kernel_params *)node.kernel_params
|
||||
);
|
||||
} else if (node.opcode == HTP_OP_FLASH_ATTN_EXT) {
|
||||
ggml_hexagon_precompute_flash_attn_params(sess,
|
||||
node.node,
|
||||
(struct htp_fa_kernel_params *)node.kernel_params
|
||||
);
|
||||
}
|
||||
computed_nodes.push_back(std::move(node));
|
||||
}
|
||||
@@ -4079,6 +4272,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
|
||||
const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX");
|
||||
const char * str_nhmx = getenv("GGML_HEXAGON_NHMX");
|
||||
const char * str_mm_select = getenv("GGML_HEXAGON_MM_SELECT");
|
||||
const char * str_fa_select = getenv("GGML_HEXAGON_FA_SELECT");
|
||||
const char * str_ndev = getenv("GGML_HEXAGON_NDEV");
|
||||
const char * str_arch = getenv("GGML_HEXAGON_ARCH");
|
||||
const char * str_vmem = getenv("GGML_HEXAGON_VMEM");
|
||||
@@ -4120,6 +4314,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
|
||||
opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;
|
||||
opt_nhmx = str_nhmx ? atoi(str_nhmx) : (str_use_hmx ? atoi(str_use_hmx) : opt_nhmx);
|
||||
opt_mm_select = str_mm_select ? atoi(str_mm_select) : opt_mm_select;
|
||||
opt_fa_select = str_fa_select ? atoi(str_fa_select) : opt_fa_select;
|
||||
opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev;
|
||||
opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf;
|
||||
opt_mbuf = str_mbuf ? strtoul(str_mbuf, NULL, 0) * MiB : opt_mbuf;
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <stdio.h>
|
||||
#include "htp-ops.h"
|
||||
#include "htp/matmul-ops.h"
|
||||
#include "htp/flash-attn-ops.h"
|
||||
|
||||
struct htp_opnode {
|
||||
ggml_tensor * node = nullptr;
|
||||
@@ -335,7 +336,8 @@ struct htp_opformat {
|
||||
}
|
||||
void format_kernel_params(char * str, size_t max_size, const htp_opnode & node) {
|
||||
if (node.opcode == HTP_OP_MUL_MAT || node.opcode == HTP_OP_MUL_MAT_ID ||
|
||||
node.opcode == HTP_OP_MUL_MAT_QKV || node.opcode == HTP_OP_MUL_MAT_FFN) {
|
||||
node.opcode == HTP_OP_MUL_MAT_QKV || node.opcode == HTP_OP_MUL_MAT_FFN ||
|
||||
node.opcode == HTP_OP_MUL_MAT_ADD) {
|
||||
const auto * kparams = (const struct htp_mm_kernel_params *) node.kernel_params;
|
||||
const char * path = "unknown";
|
||||
int32_t type = kparams->kernel_type;
|
||||
@@ -350,6 +352,16 @@ struct htp_opformat {
|
||||
path = "hvx-flat";
|
||||
}
|
||||
snprintf(str, max_size, "%s vtcm %d", path, (int) kparams->vtcm_size);
|
||||
} else if (node.opcode == HTP_OP_FLASH_ATTN_EXT) {
|
||||
const auto * kparams = (const struct htp_fa_kernel_params *) node.kernel_params;
|
||||
const char * path = "unknown";
|
||||
int32_t type = kparams->kernel_type;
|
||||
if (type == HTP_FA_KERNEL_HMX) {
|
||||
path = kparams->u.hmx.pipeline ? "hmx-pipe" : "hmx-seq";
|
||||
} else if (type == HTP_FA_KERNEL_HVX) {
|
||||
path = "hvx";
|
||||
}
|
||||
snprintf(str, max_size, "%s vtcm %d", path, (int) kparams->vtcm_size);
|
||||
} else {
|
||||
snprintf(str, max_size, "----");
|
||||
}
|
||||
|
||||
@@ -20,9 +20,6 @@ add_library(${HTP_LIB} SHARED
|
||||
worker-pool.c
|
||||
hex-dma.c
|
||||
hmx-queue.c
|
||||
flash-attn-ops.c
|
||||
hmx-flash-attn-ops.c
|
||||
matmul-ops.c
|
||||
binary-ops.c
|
||||
unary-ops.c
|
||||
sum-rows-ops.c
|
||||
@@ -42,16 +39,14 @@ add_library(${HTP_LIB} SHARED
|
||||
solve-tri-ops.c
|
||||
gated-delta-net-ops.c
|
||||
pad-ops.c
|
||||
matmul-ops.c
|
||||
flash-attn-ops.c
|
||||
)
|
||||
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE
|
||||
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>
|
||||
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>)
|
||||
|
||||
if (GGML_HEXAGON_FA_EXP2_HF)
|
||||
message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)")
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1)
|
||||
endif()
|
||||
|
||||
build_idl(htp_iface.idl ${HTP_LIB})
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,253 @@
|
||||
#ifndef HTP_FLASH_ATTN_OPS_H
|
||||
#define HTP_FLASH_ATTN_OPS_H
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#include "hex-fastdiv.h"
|
||||
#include "hex-common.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Tile constants (mirrored from hmx-utils.h for use on host side if needed)
|
||||
#define HMX_FP16_TILE_N_ROWS 32
|
||||
#define HMX_FP16_TILE_N_COLS 32
|
||||
#define HMX_FP16_TILE_N_ELMS 1024
|
||||
#define HMX_FP16_TILE_SIZE 2048
|
||||
#define HVX_FA_DMA_CACHE_SIZE 128
|
||||
#define HMX_FA_DMA_CACHE_SIZE 4
|
||||
|
||||
#define HTP_FA_M_INITIAL_VAL -10000.0f
|
||||
|
||||
enum htp_fa_kernel_type {
|
||||
HTP_FA_KERNEL_UNSUPPORTED = 0,
|
||||
HTP_FA_KERNEL_HVX,
|
||||
HTP_FA_KERNEL_HMX
|
||||
};
|
||||
|
||||
struct htp_fa_kernel_params {
|
||||
uint8_t kernel_type; // enum htp_fa_kernel_type
|
||||
uint8_t is_q_fp32; // 1 = Q type is F32, 0 = F16
|
||||
uint8_t is_dst_fp32; // 1 = dst type is F32, 0 = F16
|
||||
uint8_t n_threads; // Number of threads to run
|
||||
|
||||
// Common parameters
|
||||
uint16_t Br;
|
||||
uint16_t Bc;
|
||||
uint16_t n_kv_blocks; // also HVX's n_blocks
|
||||
uint16_t G; // GQA factor (n_heads / n_kv_heads)
|
||||
|
||||
float scale;
|
||||
float max_bias;
|
||||
float logit_softcap;
|
||||
uint32_t vtcm_size;
|
||||
|
||||
uint32_t qrows;
|
||||
uint32_t qrows_per_thread;
|
||||
float m0;
|
||||
float m1;
|
||||
uint32_t n_head_log2;
|
||||
|
||||
struct fastdiv_values src3_div2;
|
||||
struct fastdiv_values src3_div3;
|
||||
|
||||
union {
|
||||
struct {
|
||||
uint32_t g_br;
|
||||
uint32_t row_buf_stride;
|
||||
uint32_t mask_buf_row_stride;
|
||||
int32_t mask_broadcast;
|
||||
int32_t pipeline;
|
||||
struct fastdiv_values div_G;
|
||||
} hmx;
|
||||
struct {
|
||||
uint32_t size_q_row_padded;
|
||||
uint32_t size_k_row_padded;
|
||||
uint32_t size_v_row_padded;
|
||||
struct fastdiv_values src0_div21;
|
||||
struct fastdiv_values src0_div1;
|
||||
struct fastdiv_values broadcast_rk2;
|
||||
struct fastdiv_values broadcast_rk3;
|
||||
struct fastdiv_values broadcast_rv2;
|
||||
struct fastdiv_values broadcast_rv3;
|
||||
} hvx;
|
||||
} u;
|
||||
};
|
||||
|
||||
#if defined(__cplusplus)
|
||||
static_assert(sizeof(struct htp_fa_kernel_params) <= 128, "htp_fa_kernel_params is too large for kernel_params blob");
|
||||
#endif
|
||||
|
||||
// Exact VTCM usage for a given (gqa_factor, DK, DV, Br, Bc) configuration.
|
||||
// g_br = hex_align_up(gqa_factor * Br, 32) replaces Br for all Q/O/S/P/D dimensions.
|
||||
// Layout: Q + O_ping + O_pong + K_dma*2 + V_dma*2 + K_tile + V_tile + S + P + D + vectors + scales
|
||||
// Mask is DMA'd into a VTCM buffer (Br rows per KV block) to avoid DDR reads in softmax.
|
||||
static inline size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool pipeline) {
|
||||
const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS);
|
||||
const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK]
|
||||
const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong
|
||||
const size_t k_dma_size = hex_align_up(Bc * hex_round_up(DK * sizeof(__fp16), 128), 4096); // K DMA: [Bc, DK] x2 double-buf
|
||||
const size_t v_dma_size = hex_align_up(Bc * hex_round_up(DV * sizeof(__fp16), 128), 4096); // V DMA: [Bc, DV] x2 double-buf
|
||||
const size_t k_tile_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); // K tiles: [Bc, DK] interleaved
|
||||
const size_t v_tile_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); // V tiles: [Bc, DV] interleaved
|
||||
const size_t s_tile_size = hex_align_up(g_br * Bc * sizeof(__fp16), 4096); // S/P:[g_br, Bc]
|
||||
const size_t d_tile_size = hex_align_up(g_br * g_br * sizeof(__fp16), 4096); // D: [g_br, g_br]
|
||||
const size_t col_vec_size = hex_align_up(g_br * sizeof(float), 256); // m, l, etc.
|
||||
const size_t row_vec_size = hex_align_up(Bc * sizeof(__fp16), 256);
|
||||
const size_t m_line_size = hex_align_up(Bc * sizeof(__fp16), 128);
|
||||
const size_t m_buf_size = hex_align_up(Br * m_line_size, 4096) * HMX_FA_DMA_CACHE_SIZE;
|
||||
const size_t slopes_size = hex_align_up(g_br * sizeof(__fp16), 128);
|
||||
|
||||
return q_tile_size * 1 // Q tiles
|
||||
+ o_tile_size * 2 // O ping-pong
|
||||
+ k_dma_size * 2 // K DMA x2
|
||||
+ v_dma_size * 2 // V DMA x2
|
||||
+ k_tile_size * 1 // K tiles
|
||||
+ v_tile_size * (pipeline ? 2 : 1) // V tiles (double-buffered if pipelining)
|
||||
+ s_tile_size * 2 // S + P
|
||||
+ d_tile_size * 1 // D (diagonal matrix)
|
||||
+ col_vec_size * 4 // m_vec, l_vec, s_rowmax, p_rowsum
|
||||
+ row_vec_size * 2 * n_threads // per-thread softmax row scratch
|
||||
+ m_buf_size * 1 // mask VTCM buffer [Br rows]
|
||||
+ slopes_size // Slopes
|
||||
+ 256 * 2; // HMX scales (id + qk)
|
||||
}
|
||||
|
||||
#define FA_HVX_BLOCK_SIZE 64
|
||||
|
||||
static inline size_t hvx_fa_compute_vtcm_usage(size_t DK, size_t DV, bool is_q_fp32, bool has_mask, size_t n_threads) {
|
||||
const size_t size_q_row_padded = hex_round_up(DK * (is_q_fp32 ? 4 : 2), 128);
|
||||
const size_t size_k_row_padded = hex_round_up(DK * sizeof(__fp16), 128);
|
||||
const size_t size_v_row_padded = hex_round_up(DV * sizeof(__fp16), 128);
|
||||
|
||||
const size_t size_q_block = size_q_row_padded * 1;
|
||||
const size_t size_k_block = size_k_row_padded * FA_HVX_BLOCK_SIZE;
|
||||
const size_t size_v_block = size_v_row_padded * FA_HVX_BLOCK_SIZE;
|
||||
const size_t size_m_block = hex_round_up(FA_HVX_BLOCK_SIZE * sizeof(__fp16), 128);
|
||||
const size_t size_vkq_acc = hex_round_up(DV * sizeof(float), 128);
|
||||
|
||||
const size_t size_per_thread = size_q_block * 1
|
||||
+ size_k_block * 2
|
||||
+ size_v_block * 2
|
||||
+ (has_mask ? size_m_block * HVX_FA_DMA_CACHE_SIZE : 0)
|
||||
+ size_vkq_acc;
|
||||
|
||||
return size_per_thread * n_threads;
|
||||
}
|
||||
|
||||
#define FA_MIN_KV_BLOCKS 3
|
||||
|
||||
// Cost-based (Br, Bc) search for flash attention with pipeline constraint.
|
||||
static inline int hmx_fa_find_chunk_size(size_t * Br_out,
|
||||
size_t * Bc_out,
|
||||
size_t gqa_factor,
|
||||
size_t DK,
|
||||
size_t DV,
|
||||
size_t qo_len,
|
||||
size_t kv_len,
|
||||
size_t vtcm_budget,
|
||||
size_t n_threads) {
|
||||
const size_t T = HMX_FP16_TILE_N_ROWS; // 32
|
||||
const size_t br_unit = hmx_ceil_div(T, gqa_factor);
|
||||
const size_t bc_unit = HMX_FP16_TILE_N_COLS * 2; // 64
|
||||
const size_t fp16 = sizeof(__fp16);
|
||||
const bool can_pipeline = (kv_len >= FA_MIN_KV_BLOCKS * bc_unit && n_threads >= 2);
|
||||
|
||||
// Approximate per-unit VTCM costs (without per-buffer alignment padding).
|
||||
const size_t per_gbr = (DK + 2 * DV) * fp16 + 4 * sizeof(float); // Q + O*2 + 4 col vectors
|
||||
const size_t per_gbr2 = fp16; // D diagonal matrix
|
||||
const size_t per_bc =
|
||||
3 * DK * fp16 + (can_pipeline ? 4 : 3) * DV * fp16 + 2 * n_threads * fp16; // K/V DMA x2 + tiles + row bufs
|
||||
const size_t per_gbr_bc = 2 * fp16; // S + P
|
||||
|
||||
const size_t overhead = 256 * 2 + 13 * 4096;
|
||||
|
||||
if (vtcm_budget <= overhead) {
|
||||
return -1;
|
||||
}
|
||||
const size_t usable = vtcm_budget - overhead;
|
||||
|
||||
// Br_max: largest Br aligned to br_unit that does not exceed qo_len.
|
||||
const size_t Br_max = qo_len >= br_unit ? hex_align_down(qo_len, br_unit) : br_unit;
|
||||
|
||||
// Pipeline constraint: cap Bc so n_kv_blocks >= FA_MIN_KV_BLOCKS.
|
||||
// Only relax when kv_len is too short to form enough blocks.
|
||||
const size_t Bc_limit = can_pipeline ? hex_align_down(kv_len / FA_MIN_KV_BLOCKS, bc_unit) :
|
||||
(kv_len >= bc_unit ? hex_align_down(kv_len, bc_unit) : bc_unit);
|
||||
// Cost coefficients calibrated from profiling
|
||||
const size_t c_q_fixed = 1400; // per-Q-block: q_load + epilogue o_update + o_norm + o_store
|
||||
const size_t c_iter_fixed = 200; // per-KV-iter: HMX queue push/pop + DMA pop + barriers
|
||||
|
||||
size_t best_cost = SIZE_MAX, best_mn = 0;
|
||||
size_t best_Br = 0, best_Bc = 0;
|
||||
|
||||
for (size_t Br = Br_max; Br >= br_unit; Br -= br_unit) {
|
||||
const size_t g_br = hex_align_up(gqa_factor * Br, T);
|
||||
|
||||
// g_br-dependent VTCM cost: g_br * per_gbr + g_br*g_br * per_gbr2
|
||||
const size_t gbr_cost = g_br * per_gbr + g_br * g_br * per_gbr2;
|
||||
if (gbr_cost >= usable) {
|
||||
if (Br == br_unit) {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Analytically solve for max Bc:
|
||||
// remain >= Bc * (per_bc + g_br * per_gbr_bc + Br * fp16 * HMX_FA_DMA_CACHE_SIZE)
|
||||
// The Br * fp16 term accounts for the VTCM mask buffer [Br * Bc].
|
||||
const size_t remain = usable - gbr_cost;
|
||||
const size_t bc_denom = per_bc + g_br * per_gbr_bc + Br * fp16 * HMX_FA_DMA_CACHE_SIZE;
|
||||
size_t Bc = hex_smin(hex_align_down(remain / bc_denom, bc_unit), Bc_limit);
|
||||
if (Bc < bc_unit) {
|
||||
if (Br == br_unit) {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Exact VTCM verification (alignment padding may push over budget)
|
||||
while (Bc >= bc_unit && hmx_fa_compute_vtcm_usage(gqa_factor, DK, DV, Br, Bc, n_threads, can_pipeline) > vtcm_budget) {
|
||||
Bc -= bc_unit;
|
||||
}
|
||||
if (Bc < bc_unit) {
|
||||
if (Br == br_unit) {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const size_t q_blocks = (qo_len + Br - 1) / Br;
|
||||
const size_t kv_blocks = (kv_len + Bc - 1) / Bc;
|
||||
const size_t cost = q_blocks * (c_q_fixed + kv_blocks * c_iter_fixed);
|
||||
const size_t mn = Br * Bc;
|
||||
|
||||
if (cost < best_cost || (cost == best_cost && mn > best_mn)) {
|
||||
best_cost = cost;
|
||||
best_mn = mn;
|
||||
best_Br = Br;
|
||||
best_Bc = Bc;
|
||||
}
|
||||
|
||||
if (Br == br_unit) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (best_Br == 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
*Br_out = best_Br;
|
||||
*Bc_out = best_Bc;
|
||||
return 0;
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif /* HTP_FLASH_ATTN_OPS_H */
|
||||
@@ -138,27 +138,28 @@ static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t
|
||||
}
|
||||
|
||||
dma_descriptor_1d * desc = (dma_descriptor_1d *) &q->desc[q->push_idx];
|
||||
desc->next = NULL;
|
||||
desc->desc_size = 0; // 1D mode
|
||||
desc->src_bypass = dma_src_l2_bypass_on;
|
||||
desc->dst_bypass = dma_dst_l2_bypass_on;
|
||||
desc->order = 0;
|
||||
desc->done = 0;
|
||||
desc->src = (void *) dptr.src;
|
||||
desc->dst = (void *) dptr.dst;
|
||||
desc->size = size;
|
||||
desc->src = (void *) dptr.src;
|
||||
desc->dst = (void *) dptr.dst;
|
||||
desc->size = size;
|
||||
|
||||
q->dptr[q->push_idx] = dptr;
|
||||
|
||||
if (size) {
|
||||
desc->next = NULL;
|
||||
desc->desc_size = 0; // 1D mode
|
||||
desc->src_bypass = dma_src_l2_bypass_on;
|
||||
desc->dst_bypass = dma_dst_l2_bypass_on;
|
||||
desc->order = 0;
|
||||
desc->done = 0;
|
||||
|
||||
htp_trace_event_start(q->trace, HTP_TRACE_EVT_DMA, q->push_idx);
|
||||
dmlink(q->tail, desc);
|
||||
q->tail = (dma_descriptor_2d *) desc;
|
||||
} else {
|
||||
desc->done = 1;
|
||||
desc->desc_size = 0;
|
||||
desc->done = 1;
|
||||
}
|
||||
|
||||
// FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src);
|
||||
q->push_idx = (q->push_idx + 1) & q->idx_mask;
|
||||
return true;
|
||||
}
|
||||
@@ -320,7 +321,7 @@ static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q, dma_ptr dptr, size_
|
||||
return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows);
|
||||
}
|
||||
|
||||
#define DMA_CACHE_MAX_SIZE 64U
|
||||
#define DMA_CACHE_MAX_SIZE 256U
|
||||
|
||||
typedef struct {
|
||||
uint8_t *base;
|
||||
@@ -352,20 +353,19 @@ static inline bool dma_cache_push(dma_queue *q, dma_cache *c, const uint8_t * sr
|
||||
if (c->src[i] == (uint32_t) src) {
|
||||
c->age[i] = 0;
|
||||
dst = c->base + (i * c->line_size); nrows = 0; // dummy dma
|
||||
// FARF(ERROR, "dma-cache: found %p", src);
|
||||
} else {
|
||||
c->age[i]++;
|
||||
if (c->age[i] > o_age) { o_age = c->age[i]; o_idx = i; }
|
||||
}
|
||||
}
|
||||
if (!dst) {
|
||||
// FARF(ERROR, "dma-cache: replacing #%u : age %u %p -> %p", o_idx, c->age[o_idx], (void *) c->src[o_idx], src);
|
||||
c->age[o_idx] = 0;
|
||||
c->src[o_idx] = (uint32_t) src;
|
||||
dst = c->base + o_idx * c->line_size; // normal nrows dma
|
||||
return dma_queue_push(q, dma_make_ptr(dst, src), dst_stride, src_stride, row_size, nrows);
|
||||
}
|
||||
|
||||
return dma_queue_push(q, dma_make_ptr(dst, src), dst_stride, src_stride, row_size, nrows);
|
||||
return dma_queue_push_single_1d(q, dma_make_ptr(dst, src), 0);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
#ifndef HMX_FA_KERNELS_H
|
||||
#define HMX_FA_KERNELS_H
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdbool.h>
|
||||
#include "hvx-utils.h"
|
||||
#include "hmx-utils.h"
|
||||
|
||||
// HMX-specific parameters, offsets and inner kernels for Flash Attention
|
||||
|
||||
// Scatter offsets for diagonal tile: entry[2i] = i*136, entry[2i+1] = i*136+6
|
||||
// 136 = 4 * 32 + 8 = byte offset to diagonal in a 32x32 fp16 interleaved tile
|
||||
static const int16_t d_tile_scatter_offsets[64] __attribute__((aligned(128))) = {
|
||||
0 * 136, 0 * 136 + 6,
|
||||
1 * 136, 1 * 136 + 6,
|
||||
2 * 136, 2 * 136 + 6,
|
||||
3 * 136, 3 * 136 + 6,
|
||||
4 * 136, 4 * 136 + 6,
|
||||
5 * 136, 5 * 136 + 6,
|
||||
6 * 136, 6 * 136 + 6,
|
||||
7 * 136, 7 * 136 + 6,
|
||||
8 * 136, 8 * 136 + 6,
|
||||
9 * 136, 9 * 136 + 6,
|
||||
10 * 136, 10 * 136 + 6,
|
||||
11 * 136, 11 * 136 + 6,
|
||||
12 * 136, 12 * 136 + 6,
|
||||
13 * 136, 13 * 136 + 6,
|
||||
14 * 136, 14 * 136 + 6,
|
||||
15 * 136, 15 * 136 + 6,
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
};
|
||||
// Inner HMX tile computation kernels
|
||||
|
||||
static inline void hmx_fa_qk_dot_tile(
|
||||
const __fp16 * row_tiles,
|
||||
const __fp16 * col_tiles,
|
||||
__fp16 * out_tile,
|
||||
size_t n_dot_tiles
|
||||
) {
|
||||
for (size_t k = 0; k < n_dot_tiles; ++k) {
|
||||
Q6_activation_hf_mxmem_RR((unsigned int) row_tiles, 2047);
|
||||
Q6_weight_hf_mxmem_RR((unsigned int) col_tiles, 2047);
|
||||
row_tiles += HMX_FP16_TILE_N_ELMS;
|
||||
col_tiles += HMX_FP16_TILE_N_ELMS;
|
||||
}
|
||||
Q6_mxmem_AR_after_hf(out_tile, 0);
|
||||
}
|
||||
|
||||
static inline void hmx_fa_o_update_tile(
|
||||
const __fp16 * d_diag,
|
||||
const __fp16 * o_rc,
|
||||
const __fp16 * p_tile_in,
|
||||
const __fp16 * v_tile_in,
|
||||
__fp16 * o_tile_out,
|
||||
size_t n_col_tiles
|
||||
) {
|
||||
Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047);
|
||||
Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047);
|
||||
|
||||
for (size_t k = 0; k < n_col_tiles; ++k) {
|
||||
Q6_activation_hf_mxmem_RR((unsigned int) p_tile_in, 2047);
|
||||
Q6_weight_hf_mxmem_RR((unsigned int) v_tile_in, 2047);
|
||||
p_tile_in += HMX_FP16_TILE_N_ELMS;
|
||||
v_tile_in += HMX_FP16_TILE_N_ELMS;
|
||||
}
|
||||
|
||||
Q6_mxmem_AR_after_hf(o_tile_out, 0);
|
||||
}
|
||||
|
||||
static inline void hmx_fa_o_norm_tile(
|
||||
const __fp16 * d_diag,
|
||||
const __fp16 * o_rc,
|
||||
__fp16 * o_out
|
||||
) {
|
||||
Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047);
|
||||
Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047);
|
||||
Q6_mxmem_AR_after_hf(o_out, 0);
|
||||
}
|
||||
|
||||
#endif /* HMX_FA_KERNELS_H */
|
||||
File diff suppressed because it is too large
Load Diff
@@ -712,7 +712,17 @@ static inline void hmx_matmul_job_init(hmx_matmul_job_t * job,
|
||||
|
||||
// output : fp16 -> f32p
|
||||
|
||||
static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, uint32_t start_row, uint32_t n_rows, uint32_t n_cols, uint32_t dst_stride, uint32_t dst_cols) {
|
||||
static void transfer_output_chunk_fp16_to_fp32(
|
||||
float *restrict dst,
|
||||
const float *restrict src2,
|
||||
const __fp16 *restrict vtcm_src,
|
||||
uint32_t start_row,
|
||||
uint32_t n_rows,
|
||||
uint32_t n_cols,
|
||||
uint32_t dst_stride,
|
||||
uint32_t src2_stride,
|
||||
uint32_t dst_cols
|
||||
) {
|
||||
assert(n_cols % HTP_MM_HMX_TILE_N_COLS == 0);
|
||||
const size_t tile_row_stride = (n_cols / HTP_MM_HMX_TILE_N_COLS) * HTP_MM_HMX_TILE_N_ELMS;
|
||||
|
||||
@@ -727,6 +737,7 @@ static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16
|
||||
const size_t r1 = (r_idx0 % HTP_MM_HMX_TILE_N_ROWS) / 2; // index of the row pair within the tile
|
||||
const __fp16 *row_base = vtcm_src + r0 * tile_row_stride;
|
||||
float *output_row_base = dst + r * dst_stride; // global memory row base for row r (and r+1)
|
||||
const float *src2_row_base = src2 ? (src2 + r * src2_stride) : NULL;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (size_t c = 0; c < limit_c_aligned; c += HTP_MM_HMX_TILE_N_COLS) {
|
||||
@@ -738,9 +749,20 @@ static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16
|
||||
HVX_Vector *pv_out0 = (HVX_Vector *) (output_row_base + c + 0);
|
||||
HVX_Vector *pv_out1 = (HVX_Vector *) (output_row_base + c + dst_stride);
|
||||
|
||||
*pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp));
|
||||
HVX_Vector v_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp));
|
||||
if (src2_row_base) {
|
||||
HVX_Vector v_src2_0 = hvx_vmemu(src2_row_base + c + 0);
|
||||
v_out0 = hvx_vec_add_f32_f32(v_out0, v_src2_0);
|
||||
}
|
||||
*pv_out0 = v_out0;
|
||||
|
||||
if (r + 1 < n_rows) {
|
||||
*pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp));
|
||||
HVX_Vector v_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp));
|
||||
if (src2_row_base) {
|
||||
HVX_Vector v_src2_1 = hvx_vmemu(src2_row_base + c + src2_stride);
|
||||
v_out1 = hvx_vec_add_f32_f32(v_out1, v_src2_1);
|
||||
}
|
||||
*pv_out1 = v_out1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -752,9 +774,20 @@ static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16
|
||||
HVX_Vector v = ((const HVX_Vector *) tile)[r1];
|
||||
HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one);
|
||||
|
||||
hvx_vec_store_u(output_row_base + c, valid_c * sizeof(float), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)));
|
||||
HVX_Vector v_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp));
|
||||
if (src2_row_base) {
|
||||
HVX_Vector v_src2_0 = hvx_vmemu(src2_row_base + c + 0);
|
||||
v_out0 = hvx_vec_add_f32_f32(v_out0, v_src2_0);
|
||||
}
|
||||
hvx_vec_store_u(output_row_base + c, valid_c * sizeof(float), v_out0);
|
||||
|
||||
if (r + 1 < n_rows) {
|
||||
hvx_vec_store_u(output_row_base + c + dst_stride, valid_c * sizeof(float), Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp)));
|
||||
HVX_Vector v_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp));
|
||||
if (src2_row_base) {
|
||||
HVX_Vector v_src2_1 = hvx_vmemu(src2_row_base + c + src2_stride);
|
||||
v_out1 = hvx_vec_add_f32_f32(v_out1, v_src2_1);
|
||||
}
|
||||
hvx_vec_store_u(output_row_base + c + dst_stride, valid_c * sizeof(float), v_out1);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -763,11 +796,13 @@ static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16
|
||||
typedef struct {
|
||||
const __fp16 *vtcm_src;
|
||||
float *dst;
|
||||
const float *src2;
|
||||
uint32_t n_tasks;
|
||||
uint32_t n_tot_chunks;
|
||||
uint32_t n_chunks_per_task;
|
||||
uint32_t n_cols;
|
||||
uint32_t dst_stride; // DDR row stride
|
||||
uint32_t src2_stride; // DDR row stride for residual
|
||||
uint32_t dst_cols; // Actual output columns
|
||||
struct htp_thread_trace * traces;
|
||||
} output_transfer_task_state_t;
|
||||
|
||||
@@ -42,14 +42,14 @@ static const int32_t hmx_transpose_scatter_offsets[32] __attribute__((aligned(VL
|
||||
// Full range: start_row=0, end_row=n_cols.
|
||||
static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
|
||||
const __fp16 * restrict vtcm_src,
|
||||
int n_cols,
|
||||
int k,
|
||||
int src_stride,
|
||||
int start_row,
|
||||
int end_row) {
|
||||
uint32_t n_cols,
|
||||
uint32_t k,
|
||||
size_t src_stride,
|
||||
uint32_t start_row,
|
||||
uint32_t end_row) {
|
||||
assert(k % HMX_FP16_TILE_N_COLS == 0);
|
||||
|
||||
const int n_k_tiles = k / HMX_FP16_TILE_N_COLS;
|
||||
const uint32_t n_k_tiles = k / HMX_FP16_TILE_N_COLS;
|
||||
const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets);
|
||||
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4);
|
||||
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64);
|
||||
@@ -65,14 +65,14 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
|
||||
|
||||
if (pair_scatter) {
|
||||
// Step c by 64 fp16 (two K-tiles per scatter), advance dst by 2 tiles per iter.
|
||||
const int c_step = 2 * HMX_FP16_TILE_N_COLS;
|
||||
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
|
||||
const size_t dst_step = 2 * (size_t) HMX_FP16_TILE_N_ELMS;
|
||||
const int n_c_iters = k / c_step;
|
||||
const uint32_t c_step = 2 * HMX_FP16_TILE_N_COLS;
|
||||
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
|
||||
const size_t dst_step = 2 * (size_t) HMX_FP16_TILE_N_ELMS;
|
||||
const uint32_t n_c_iters = k / c_step;
|
||||
|
||||
for (int r = start_row; r < end_row; r += 2) {
|
||||
const int ct = r / HMX_FP16_TILE_N_ROWS;
|
||||
const int local_r = r % HMX_FP16_TILE_N_ROWS;
|
||||
for (uint32_t r = start_row; r < end_row; r += 2) {
|
||||
const uint32_t ct = r / HMX_FP16_TILE_N_ROWS;
|
||||
const uint32_t local_r = r % HMX_FP16_TILE_N_ROWS;
|
||||
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols;
|
||||
const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
|
||||
const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
|
||||
@@ -86,7 +86,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
|
||||
assert(c_byte_step % 128 == 0);
|
||||
|
||||
if (p1) {
|
||||
for (int i = 0; i < n_c_iters; ++i) {
|
||||
for (uint32_t i = 0; i < n_c_iters; ++i) {
|
||||
HVX_Vector v0 = hvx_vmem(p0); p0 += c_byte_step;
|
||||
HVX_Vector v1 = hvx_vmem(p1); p1 += c_byte_step;
|
||||
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0);
|
||||
@@ -95,7 +95,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
|
||||
}
|
||||
} else {
|
||||
const HVX_Vector vzero = Q6_V_vzero();
|
||||
for (int i = 0; i < n_c_iters; ++i) {
|
||||
for (uint32_t i = 0; i < n_c_iters; ++i) {
|
||||
HVX_Vector v0 = hvx_vmem(p0); p0 += c_byte_step;
|
||||
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0);
|
||||
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, vzero);
|
||||
@@ -105,14 +105,14 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
|
||||
}
|
||||
} else {
|
||||
// Fallback: scatter one K-tile per call (region 2047, masked).
|
||||
const int c_step = HMX_FP16_TILE_N_COLS;
|
||||
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
|
||||
const size_t dst_step = (size_t) HMX_FP16_TILE_N_ELMS;
|
||||
const int n_c_iters = k / c_step;
|
||||
const uint32_t c_step = HMX_FP16_TILE_N_COLS;
|
||||
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
|
||||
const size_t dst_step = (size_t) HMX_FP16_TILE_N_ELMS;
|
||||
const uint32_t n_c_iters = k / c_step;
|
||||
|
||||
for (int r = start_row; r < end_row; r += 2) {
|
||||
const int ct = r / HMX_FP16_TILE_N_ROWS;
|
||||
const int local_r = r % HMX_FP16_TILE_N_ROWS;
|
||||
for (uint32_t r = start_row; r < end_row; r += 2) {
|
||||
const uint32_t ct = r / HMX_FP16_TILE_N_ROWS;
|
||||
const uint32_t local_r = r % HMX_FP16_TILE_N_ROWS;
|
||||
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols;
|
||||
const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
|
||||
const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
|
||||
@@ -122,7 +122,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
|
||||
const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL;
|
||||
|
||||
if (p1) {
|
||||
for (int i = 0; i < n_c_iters; ++i) {
|
||||
for (uint32_t i = 0; i < n_c_iters; ++i) {
|
||||
HVX_Vector v0 = hvx_vmemu(p0); p0 += c_byte_step;
|
||||
HVX_Vector v1 = hvx_vmemu(p1); p1 += c_byte_step;
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0);
|
||||
@@ -131,7 +131,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
|
||||
}
|
||||
} else {
|
||||
const HVX_Vector vzero = Q6_V_vzero();
|
||||
for (int i = 0; i < n_c_iters; ++i) {
|
||||
for (uint32_t i = 0; i < n_c_iters; ++i) {
|
||||
HVX_Vector v0 = hvx_vmemu(p0); p0 += c_byte_step;
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0);
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, vzero);
|
||||
@@ -148,24 +148,24 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
|
||||
// Full range: start_row=0, end_row=n_rows.
|
||||
static inline void hmx_interleave_cols_to_tiles(__fp16 * restrict tiles_out,
|
||||
const __fp16 * restrict src,
|
||||
int n_rows,
|
||||
int head_dim,
|
||||
int src_stride,
|
||||
int n_row_tiles,
|
||||
int start_row,
|
||||
int end_row) {
|
||||
uint32_t n_rows,
|
||||
uint32_t head_dim,
|
||||
size_t src_stride,
|
||||
uint32_t n_row_tiles,
|
||||
uint32_t start_row,
|
||||
uint32_t end_row) {
|
||||
__builtin_assume(head_dim > 0);
|
||||
const size_t tile_stride_elms = (size_t) n_row_tiles * HMX_FP16_TILE_N_ELMS;
|
||||
|
||||
for (int r = start_row; r < end_row; r += 2) {
|
||||
for (uint32_t r = start_row; r < end_row; r += 2) {
|
||||
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_rows;
|
||||
|
||||
const HVX_Vector * pv_in0 = (const HVX_Vector *) (src + r * src_stride);
|
||||
const HVX_Vector * pv_in1 = next_row_valid ? (const HVX_Vector *) (src + (r + 1) * src_stride) : NULL;
|
||||
|
||||
// Row-pair invariants hoisted out of the c loop.
|
||||
const int r0 = r / HMX_FP16_TILE_N_ROWS;
|
||||
const int r1_half = (r % HMX_FP16_TILE_N_ROWS) / 2;
|
||||
const uint32_t r0 = r / HMX_FP16_TILE_N_ROWS;
|
||||
const uint32_t r1_half = (r % HMX_FP16_TILE_N_ROWS) / 2;
|
||||
|
||||
// tb0 starts at tile (c0=0, r0); tb1 at the adjacent dim-tile (c0=1, r0).
|
||||
// Each c step (+= 64) advances both by 2 dim-tiles worth of fp16.
|
||||
@@ -174,7 +174,7 @@ static inline void hmx_interleave_cols_to_tiles(__fp16 * restrict tiles_out,
|
||||
const size_t tb_step = 2 * tile_stride_elms;
|
||||
|
||||
if (pv_in1) {
|
||||
for (int c = 0; c < head_dim; c += 64) {
|
||||
for (uint32_t c = 0; c < head_dim; c += 64) {
|
||||
HVX_Vector v0 = *pv_in0++;
|
||||
HVX_Vector v1 = *pv_in1++;
|
||||
HVX_VectorPair vp = Q6_W_vshuff_VVR(v1, v0, -2);
|
||||
@@ -185,7 +185,7 @@ static inline void hmx_interleave_cols_to_tiles(__fp16 * restrict tiles_out,
|
||||
}
|
||||
} else {
|
||||
const HVX_Vector vzero = Q6_V_vzero();
|
||||
for (int c = 0; c < head_dim; c += 64) {
|
||||
for (uint32_t c = 0; c < head_dim; c += 64) {
|
||||
HVX_Vector v0 = *pv_in0++;
|
||||
HVX_VectorPair vp = Q6_W_vshuff_VVR(vzero, v0, -2);
|
||||
((HVX_Vector *) tb0)[r1_half] = Q6_V_lo_W(vp);
|
||||
|
||||
@@ -60,6 +60,7 @@ enum htp_op_code {
|
||||
HTP_OP_MUL_MAT_ID,
|
||||
HTP_OP_MUL_MAT_QKV,
|
||||
HTP_OP_MUL_MAT_FFN,
|
||||
HTP_OP_MUL_MAT_ADD,
|
||||
HTP_OP_RMS_NORM,
|
||||
HTP_OP_RMS_NORM_MUL,
|
||||
HTP_OP_UNARY_SILU,
|
||||
@@ -175,6 +176,11 @@ enum htp_trace_event_id {
|
||||
HTP_TRACE_EVT_HVX_W_DEQUANT = 23,
|
||||
HTP_TRACE_EVT_HVX_W_PREP = 24,
|
||||
HTP_TRACE_EVT_HVX_O_PROC = 25,
|
||||
HTP_TRACE_EVT_HVX_FA_QK = 26,
|
||||
HTP_TRACE_EVT_HVX_FA_SFM = 27,
|
||||
HTP_TRACE_EVT_HVX_FA_Q_PREP = 28,
|
||||
HTP_TRACE_EVT_HVX_FA_K_PREP = 29,
|
||||
HTP_TRACE_EVT_HVX_FA_V_PREP = 30,
|
||||
|
||||
HTP_TRACE_EVT_HMX_COMP = 40,
|
||||
};
|
||||
|
||||
@@ -134,16 +134,7 @@ static inline HVX_Vector hvx_vec_f32_to_f16_shuff(HVX_Vector v0, HVX_Vector v1)
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) {
|
||||
HVX_Vector v = Q6_Vh_vdeal_Vh(hvx_vec_f32_to_f16_shuff(v0, v1));
|
||||
|
||||
#if __HVX_ARCH__ < 79
|
||||
// replace NaNs with -INF, older arches produce NaNs for (-INF + 0.0)
|
||||
const HVX_Vector neg_inf = hvx_vec_splat_f16(-INFINITY);
|
||||
HVX_VectorPred nan = hvx_vec_is_nan_f16(v);
|
||||
v = Q6_V_vmux_QVV(nan, neg_inf, v);
|
||||
#endif
|
||||
|
||||
return v;
|
||||
return Q6_Vh_vdeal_Vh(hvx_vec_f32_to_f16_shuff(v0, v1));
|
||||
}
|
||||
|
||||
#if __HVX_ARCH__ >= 79
|
||||
@@ -170,8 +161,6 @@ static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) {
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) {
|
||||
// This looks complicated.
|
||||
// Ideally should just be Q6_Vh_equals_Vhf(vin)
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#define EXP_COEFF_0 (0x3F000000) // 0.5 = 1/(2!)
|
||||
#define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805
|
||||
#define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408
|
||||
#define EXP_LOG2E_F 1.44269504f
|
||||
#define EXP_ONE (0x3f800000) // 1.0
|
||||
#define EXP_RANGE_R (0x42B17218) // ln(FLT_MAX) approx = 88.7228
|
||||
#define EXP_RANGE_L (0xC2B00000) // -88.0 (approx log(FLT_MIN))
|
||||
@@ -213,4 +214,42 @@ static inline void hvx_exp_f32(uint8_t * restrict dst, const uint8_t * restrict
|
||||
}
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_exp2_f16(HVX_Vector x_v) {
|
||||
const HVX_Vector zero_v = Q6_V_vzero();
|
||||
const HVX_Vector half_hf_v = Q6_Vh_vsplat_R(0x3800); // fp16 0.5
|
||||
|
||||
// Clamp input to prevent integer underflow in FP16-to-INT16 conversion
|
||||
const HVX_Vector v_clamp_min = hvx_vec_splat_f16(-24.0f);
|
||||
x_v = Q6_Vhf_vmax_VhfVhf(v_clamp_min, x_v);
|
||||
|
||||
// k = round_toward_neg_inf(x); f = (float)k; frac = x - f
|
||||
HVX_Vector x_minus_half = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vsub_VhfVhf(x_v, half_hf_v));
|
||||
HVX_Vector k_v = Q6_Vh_equals_Vhf(x_minus_half); // truncate to int16
|
||||
HVX_Vector f_v = Q6_Vhf_equals_Vh(k_v); // back to fp16
|
||||
|
||||
HVX_Vector x_qf16 = Q6_Vqf16_vsub_VhfVhf(x_v, f_v); // fractional part in qf16
|
||||
|
||||
// Horner: y = ((((E5*x + E4)*x + E3)*x + E2)*x + E1)*x + E0
|
||||
HVX_Vector y = Q6_Vqf16_vmpy_Vqf16Vqf16(Q6_Vh_vsplat_R(0x5082), x_qf16); // E5*x
|
||||
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x157d)); // + E4
|
||||
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16);
|
||||
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x20ed)); // + E3
|
||||
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16);
|
||||
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x2b1b)); // + E2
|
||||
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16);
|
||||
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x33b0)); // + E1
|
||||
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16);
|
||||
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x398c)); // + E0
|
||||
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); // y = y * x
|
||||
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x3c00)); // + 1.0
|
||||
|
||||
// Combine polynomial (mantissa) with integer part (exponent): result = y * 2^k
|
||||
y = Q6_Vhf_equals_Vqf16(y);
|
||||
HVX_Vector y_exp = Q6_Vuh_vlsr_VuhR(Q6_Vh_vasl_VhR(y, 1), 11);
|
||||
y_exp = Q6_Vh_vadd_VhVh(k_v, y_exp);
|
||||
HVX_VectorPred q_underflow = Q6_Q_vcmp_gt_VhVh(zero_v, y_exp);
|
||||
y = Q6_Vh_vaslacc_VhVhR(y, k_v, 10);
|
||||
return Q6_V_vmux_QVV(q_underflow, zero_v, y);
|
||||
}
|
||||
|
||||
#endif /* HVX_EXP_H */
|
||||
|
||||
@@ -0,0 +1,232 @@
|
||||
#ifndef HVX_FA_KERNELS_H
|
||||
#define HVX_FA_KERNELS_H
|
||||
|
||||
#include <assert.h>
|
||||
#include <math.h>
|
||||
#include "hvx-utils.h"
|
||||
|
||||
// Little inner kernels for HVX
|
||||
|
||||
#if __HVX_ARCH__ < 79
|
||||
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
|
||||
#define HVX_OP_SUB_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b))
|
||||
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
|
||||
#else
|
||||
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
|
||||
#define HVX_OP_SUB_F32(a, b) Q6_Vsf_vsub_VsfVsf(a, b)
|
||||
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
|
||||
#endif
|
||||
|
||||
// This is a bit of a hack because the compiler is struggling to properly inline
|
||||
// the default hvx_vec_f32_to_f16 with output into the local array.
|
||||
static __attribute__((unused)) __attribute__((noinline)) void hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1)
|
||||
{
|
||||
*(HVX_Vector *) ptr = hvx_vec_f32_to_f16(v0, v1);
|
||||
}
|
||||
|
||||
// Dot product of two F16 vectors, accumulating to float
|
||||
static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
|
||||
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
|
||||
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
HVX_VectorPair rsum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, vx[i], vy[i]);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
|
||||
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
|
||||
|
||||
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf);
|
||||
}
|
||||
|
||||
HVX_Vector rsum = HVX_OP_ADD_F32(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p));
|
||||
rsum = HVX_OP_MUL_F32(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
|
||||
hvx_vec_store_u(r, 4, rsum);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_dot_f16_f16_aa_rx4(const void * restrict y,
|
||||
const uint8_t * restrict x,
|
||||
const size_t stride_x,
|
||||
const size_t nvec,
|
||||
const size_t nloe) {
|
||||
const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x; // fp16
|
||||
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) (x + stride_x); // fp16
|
||||
const HVX_Vector * restrict vx2 = (const HVX_Vector * restrict) (x + stride_x * 2); // fp16
|
||||
const HVX_Vector * restrict vx3 = (const HVX_Vector * restrict) (x + stride_x * 3); // fp16
|
||||
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
|
||||
|
||||
HVX_VectorPair rsum0_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
|
||||
HVX_VectorPair rsum1_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
|
||||
HVX_VectorPair rsum2_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
|
||||
HVX_VectorPair rsum3_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
for (i = 0; i < nvec; i++) {
|
||||
HVX_Vector y_hf = vy[i];
|
||||
HVX_Vector x0_hf = vx0[i];
|
||||
HVX_Vector x1_hf = vx1[i];
|
||||
HVX_Vector x2_hf = vx2[i];
|
||||
HVX_Vector x3_hf = vx3[i];
|
||||
|
||||
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
|
||||
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
|
||||
rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf);
|
||||
rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
// Load x (fp16) and zero-out unused elements
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
|
||||
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
|
||||
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
|
||||
HVX_Vector x2_hf = Q6_V_vand_QV(bmask, vx2[i]);
|
||||
HVX_Vector x3_hf = Q6_V_vand_QV(bmask, vx3[i]);
|
||||
|
||||
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
|
||||
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
|
||||
rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf);
|
||||
rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf);
|
||||
}
|
||||
|
||||
HVX_Vector rsum0 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p));
|
||||
HVX_Vector rsum1 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p));
|
||||
HVX_Vector rsum2 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p));
|
||||
HVX_Vector rsum3 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p));
|
||||
|
||||
HVX_Vector_x4 rsum0123 = { .v = { rsum0, rsum1, rsum2, rsum3 } };
|
||||
return hvx_vec_reduce_sum_f32x4(rsum0123);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y,
|
||||
const uint8_t * restrict x,
|
||||
const size_t stride_x,
|
||||
const size_t n,
|
||||
float s) {
|
||||
|
||||
const size_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
const size_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
HVX_Vector sums = Q6_V_vzero();
|
||||
const size_t stride_x_4 = stride_x * 4;
|
||||
for (uint32_t j = 0; j < VLEN_FP32; j += 4) {
|
||||
HVX_Vector sums_x4 = hvx_dot_f16_f16_aa_rx4(y, x, stride_x, nvec, nloe);
|
||||
HVX_VectorPred pred = Q6_Q_vsetq_R(j * SIZEOF_FP32);
|
||||
sums = Q6_V_vmux_QVV(pred, sums, sums_x4);
|
||||
x += stride_x_4;
|
||||
}
|
||||
|
||||
return HVX_OP_MUL_F32(hvx_vec_splat_f32(s), sums);
|
||||
}
|
||||
|
||||
// MAD: y (F32) += x (F16) * s (F16)
|
||||
static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, const __fp16 * restrict s, uint32_t n) {
|
||||
const HVX_Vector * restrict vx0 = (const HVX_Vector *) x;
|
||||
|
||||
HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y;
|
||||
HVX_Vector * restrict vy = (HVX_Vector *) y;
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
HVX_Vector S0 = hvx_vec_splat_f16(*s);
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(2)
|
||||
for (i = 0; i < nvec; ++i) {
|
||||
vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_VectorPair xy_p = vy_p[i];
|
||||
xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0);
|
||||
|
||||
HVX_Vector xy = Q6_V_lo_W(xy_p);
|
||||
i = 2 * i; // index for vy
|
||||
|
||||
if (nloe >= VLEN_FP32) {
|
||||
vy[i] = xy;
|
||||
nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
hvx_vec_store_a(&vy[i], nloe * 4, xy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MAD: y (F32) += x0 (F16) * s0 (F16) + x1 (F16) * s1 (F16)
|
||||
static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y, const void * restrict x0, const void * restrict x1,
|
||||
const __fp16 * restrict s0, const __fp16 * restrict s1, uint32_t n) {
|
||||
const HVX_Vector * restrict vx0 = (const HVX_Vector *) x0;
|
||||
const HVX_Vector * restrict vx1 = (const HVX_Vector *) x1;
|
||||
|
||||
HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y;
|
||||
HVX_Vector * restrict vy = (HVX_Vector *) y;
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
HVX_Vector S0 = hvx_vec_splat_f16(*s0);
|
||||
HVX_Vector S1 = hvx_vec_splat_f16(*s1);
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(2)
|
||||
for (i = 0; i < nvec; ++i) {
|
||||
vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0);
|
||||
vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx1[i]), S1);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_VectorPair xy_p = vy_p[i];
|
||||
xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0);
|
||||
xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx1[i]), S1);
|
||||
|
||||
HVX_Vector xy = Q6_V_lo_W(xy_p);
|
||||
i = 2 * i; // index for vy
|
||||
|
||||
if (nloe >= VLEN_FP32) {
|
||||
vy[i] = xy;
|
||||
nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
hvx_vec_store_a(&vy[i], nloe * 4, xy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t n, HVX_Vector vs) {
|
||||
assert((size_t) dst % 128 == 0);
|
||||
assert((size_t) src % 128 == 0);
|
||||
|
||||
const HVX_Vector * restrict vsrc = (const HVX_Vector * restrict) src;
|
||||
HVX_Vector * restrict vdst = (HVX_Vector * restrict) dst;
|
||||
|
||||
const uint32_t nvec = n / VLEN_FP32;
|
||||
const uint32_t nloe = n % VLEN_FP32;
|
||||
|
||||
uint32_t i = 0;
|
||||
#pragma unroll(4)
|
||||
for (; i < nvec; ++i) {
|
||||
vdst[i] = HVX_OP_MUL_F32(vsrc[i], vs);
|
||||
}
|
||||
if (nloe) {
|
||||
hvx_vec_store_a(&vdst[i], nloe * sizeof(float), HVX_OP_MUL_F32(vsrc[i], vs));
|
||||
}
|
||||
}
|
||||
|
||||
#endif /* HVX_FA_KERNELS_H */
|
||||
@@ -256,7 +256,7 @@ static inline void quantize_f16_f16_flat_kernel(
|
||||
|
||||
// Dot kernels that consume flat (non-tiled) activations
|
||||
|
||||
static void flat_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
|
||||
static void flat_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y_q = vy;
|
||||
|
||||
@@ -312,10 +312,14 @@ static void flat_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const v
|
||||
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
if (sz) {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
|
||||
} else {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
}
|
||||
}
|
||||
|
||||
static void flat_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
|
||||
static void flat_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y0_q = vy0;
|
||||
const uint8_t * restrict y1_q = vy1;
|
||||
@@ -397,11 +401,19 @@ static void flat_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float
|
||||
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
if (sz0) {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
|
||||
} else {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
}
|
||||
if (sz1) {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
|
||||
} else {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
}
|
||||
}
|
||||
|
||||
static void flat_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
|
||||
static void flat_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y_q = vy;
|
||||
|
||||
@@ -464,10 +476,14 @@ static void flat_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const v
|
||||
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
if (sz) {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
|
||||
} else {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
}
|
||||
}
|
||||
|
||||
static void flat_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
|
||||
static void flat_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y0_q = vy0;
|
||||
const uint8_t * restrict y1_q = vy1;
|
||||
@@ -561,11 +577,19 @@ static void flat_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float
|
||||
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
if (sz0) {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
|
||||
} else {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
}
|
||||
if (sz1) {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
|
||||
} else {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
}
|
||||
}
|
||||
|
||||
static void flat_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
|
||||
static void flat_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y_q = vy;
|
||||
|
||||
@@ -620,10 +644,14 @@ static void flat_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const v
|
||||
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
if (sz) {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
|
||||
} else {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
}
|
||||
}
|
||||
|
||||
static void flat_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
|
||||
static void flat_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y0_q = vy0;
|
||||
const uint8_t * restrict y1_q = vy1;
|
||||
@@ -704,11 +732,19 @@ static void flat_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float
|
||||
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
if (sz0) {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
|
||||
} else {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
}
|
||||
if (sz1) {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
|
||||
} else {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
}
|
||||
}
|
||||
|
||||
static void flat_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
|
||||
static void flat_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y_q = vy;
|
||||
|
||||
@@ -765,10 +801,14 @@ static void flat_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const
|
||||
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
if (sz) {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
|
||||
} else {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
}
|
||||
}
|
||||
|
||||
static void flat_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
|
||||
static void flat_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y0_q = vy0;
|
||||
const uint8_t * restrict y1_q = vy1;
|
||||
@@ -851,11 +891,19 @@ static void flat_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float
|
||||
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
if (sz0) {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
|
||||
} else {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
}
|
||||
if (sz1) {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
|
||||
} else {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
}
|
||||
}
|
||||
|
||||
static void flat_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
|
||||
static void flat_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y_q = vy;
|
||||
|
||||
@@ -921,10 +969,14 @@ static void flat_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const
|
||||
|
||||
v_sum_float = hvx_vec_mul_f32_f32(v_sum_float, hvx_vec_splat_f32(0.5f));
|
||||
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
if (sz) {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
|
||||
} else {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
}
|
||||
}
|
||||
|
||||
static void flat_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
|
||||
static void flat_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y0_q = vy0;
|
||||
const uint8_t * restrict y1_q = vy1;
|
||||
@@ -1019,6 +1071,441 @@ static void flat_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float
|
||||
v_sum_float_c0 = hvx_vec_mul_f32_f32(v_sum_float_c0, hvx_vec_splat_f32(0.5f));
|
||||
v_sum_float_c1 = hvx_vec_mul_f32_f32(v_sum_float_c1, hvx_vec_splat_f32(0.5f));
|
||||
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
if (sz0) {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
|
||||
} else {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
}
|
||||
if (sz1) {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
|
||||
} else {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
}
|
||||
}
|
||||
|
||||
#if __HVX_ARCH__ < 79
|
||||
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
|
||||
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
|
||||
#else
|
||||
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
|
||||
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
|
||||
#endif
|
||||
|
||||
static inline void vec_dot_f32_f32_aa_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||
const HVX_Vector * restrict x = (const HVX_Vector *) vx;
|
||||
const HVX_Vector * restrict y = (const HVX_Vector *) vy;
|
||||
|
||||
uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP32; // leftover elements
|
||||
|
||||
HVX_Vector rsum = Q6_V_vzero();
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
HVX_Vector prod = HVX_OP_MUL_F32(x[i], y[i]);
|
||||
rsum = HVX_OP_ADD_F32(rsum, prod);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
||||
HVX_Vector x_sf = Q6_V_vand_QV(bmask, x[i]);
|
||||
HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]);
|
||||
HVX_Vector prod = HVX_OP_MUL_F32(x_sf, y_sf);
|
||||
rsum = HVX_OP_ADD_F32(rsum, prod);
|
||||
}
|
||||
|
||||
*s = hvx_vec_get_f32(hvx_vec_reduce_sum_f32(rsum));
|
||||
}
|
||||
|
||||
static inline void vec_dot_f32_f32_aa_2x1(const uint32_t n, float * restrict s0,
|
||||
const void * restrict vx0, const void * restrict vx1,
|
||||
const void * restrict vy0) {
|
||||
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
|
||||
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
|
||||
const HVX_Vector * restrict y = (const HVX_Vector *) vy0;
|
||||
|
||||
uint32_t nvec = n / VLEN_FP32;
|
||||
uint32_t nloe = n % VLEN_FP32;
|
||||
|
||||
HVX_Vector rsum0 = Q6_V_vzero();
|
||||
HVX_Vector rsum1 = Q6_V_vzero();
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(2)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
HVX_Vector y_sf = y[i];
|
||||
HVX_Vector prod0 = HVX_OP_MUL_F32(x0[i], y_sf);
|
||||
HVX_Vector prod1 = HVX_OP_MUL_F32(x1[i], y_sf);
|
||||
rsum0 = HVX_OP_ADD_F32(rsum0, prod0);
|
||||
rsum1 = HVX_OP_ADD_F32(rsum1, prod1);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
||||
HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]);
|
||||
HVX_Vector x0_sf = Q6_V_vand_QV(bmask, x0[i]);
|
||||
HVX_Vector x1_sf = Q6_V_vand_QV(bmask, x1[i]);
|
||||
HVX_Vector prod0 = HVX_OP_MUL_F32(x0_sf, y_sf);
|
||||
HVX_Vector prod1 = HVX_OP_MUL_F32(x1_sf, y_sf);
|
||||
rsum0 = HVX_OP_ADD_F32(rsum0, prod0);
|
||||
rsum1 = HVX_OP_ADD_F32(rsum1, prod1);
|
||||
}
|
||||
|
||||
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1);
|
||||
hvx_vec_store_u(s0, 8, rsum);
|
||||
}
|
||||
|
||||
static inline void vec_dot_f32_f32_aa_2x2(const uint32_t n, float * restrict s0, float * restrict s1,
|
||||
const void * restrict vx0, const void * restrict vx1,
|
||||
const void * restrict vy0, const void * restrict vy1) {
|
||||
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
|
||||
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
|
||||
const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0;
|
||||
const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1;
|
||||
|
||||
uint32_t nvec = n / VLEN_FP32;
|
||||
uint32_t nloe = n % VLEN_FP32;
|
||||
|
||||
HVX_Vector r0_c0_sum = Q6_V_vzero();
|
||||
HVX_Vector r0_c1_sum = Q6_V_vzero();
|
||||
HVX_Vector r1_c0_sum = Q6_V_vzero();
|
||||
HVX_Vector r1_c1_sum = Q6_V_vzero();
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(2)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
HVX_Vector r0_sf = x0[i];
|
||||
HVX_Vector r1_sf = x1[i];
|
||||
HVX_Vector c0_sf = y0[i];
|
||||
HVX_Vector c1_sf = y1[i];
|
||||
|
||||
r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf));
|
||||
r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf));
|
||||
r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf));
|
||||
r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
||||
|
||||
HVX_Vector r0_sf = Q6_V_vand_QV(bmask, x0[i]);
|
||||
HVX_Vector r1_sf = Q6_V_vand_QV(bmask, x1[i]);
|
||||
HVX_Vector c0_sf = Q6_V_vand_QV(bmask, y0[i]);
|
||||
HVX_Vector c1_sf = Q6_V_vand_QV(bmask, y1[i]);
|
||||
|
||||
r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf));
|
||||
r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf));
|
||||
r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf));
|
||||
r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf));
|
||||
}
|
||||
|
||||
// Reduce and store results
|
||||
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
|
||||
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
|
||||
|
||||
hvx_vec_store_u(s0, 8, r0_r1_c0_sum);
|
||||
hvx_vec_store_u(s1, 8, r0_r1_c1_sum);
|
||||
}
|
||||
|
||||
static inline void vec_dot_f32_f32_uu_1x1(const uint32_t n, float * restrict s, const void * restrict x, const void * restrict y) {
|
||||
const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
|
||||
const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
|
||||
|
||||
uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP32; // leftover elements
|
||||
|
||||
HVX_Vector rsum = Q6_V_vzero();
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(2)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
HVX_Vector x_sf = vx[i];
|
||||
HVX_Vector y_sf = vy[i];
|
||||
|
||||
rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector x_sf = vx[i];
|
||||
HVX_Vector y_sf = vy[i];
|
||||
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
||||
x_sf = Q6_V_vand_QV(bmask, x_sf);
|
||||
y_sf = Q6_V_vand_QV(bmask, y_sf);
|
||||
|
||||
rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf));
|
||||
}
|
||||
|
||||
rsum = hvx_vec_reduce_sum_f32(rsum);
|
||||
hvx_vec_store_u(&s[0], 4, rsum);
|
||||
}
|
||||
|
||||
#undef HVX_OP_ADD_F32
|
||||
#undef HVX_OP_MUL_F32
|
||||
|
||||
static inline void vec_dot_f16_f16_aa_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||
const HVX_Vector * restrict x = (const HVX_Vector *) vx;
|
||||
const HVX_Vector * restrict y = (const HVX_Vector *) vy;
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
HVX_VectorPair rsum_p = Q6_W_vzero();
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x[i], y[i]);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
|
||||
HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
|
||||
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf);
|
||||
}
|
||||
|
||||
HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p)));
|
||||
hvx_vec_store_u(s, 4, hvx_vec_reduce_sum_f32(rsum));
|
||||
}
|
||||
|
||||
static inline void vec_dot_f16_f16_aa_2x1(const uint32_t n, float * restrict s0,
|
||||
const void * restrict vx0, const void * restrict vx1,
|
||||
const void * restrict vy0) {
|
||||
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
|
||||
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
|
||||
const HVX_Vector * restrict y = (const HVX_Vector *) vy0;
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16;
|
||||
uint32_t nloe = n % VLEN_FP16;
|
||||
|
||||
HVX_VectorPair rsum0_p = Q6_W_vzero();
|
||||
HVX_VectorPair rsum1_p = Q6_W_vzero();
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(2)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
HVX_Vector y_hf = y[i];
|
||||
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0[i], y_hf);
|
||||
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1[i], y_hf);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
|
||||
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, x0[i]);
|
||||
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, x1[i]);
|
||||
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
|
||||
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
|
||||
}
|
||||
|
||||
HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p)));
|
||||
HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p)));
|
||||
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1);
|
||||
hvx_vec_store_u(s0, 8, rsum);
|
||||
}
|
||||
|
||||
static inline void vec_dot_f16_f16_aa_2x2(const uint32_t n, float * restrict s0, float * restrict s1,
|
||||
const void * restrict vx0, const void * restrict vx1,
|
||||
const void * restrict vy0, const void * restrict vy1) {
|
||||
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
|
||||
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
|
||||
const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0;
|
||||
const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1;
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16;
|
||||
uint32_t nloe = n % VLEN_FP16;
|
||||
|
||||
// Row sums (sf) - 4 accumulators for 2x2 tile
|
||||
HVX_VectorPair r0_c0_sum_p = Q6_W_vzero();
|
||||
HVX_VectorPair r0_c1_sum_p = Q6_W_vzero();
|
||||
HVX_VectorPair r1_c0_sum_p = Q6_W_vzero();
|
||||
HVX_VectorPair r1_c1_sum_p = Q6_W_vzero();
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(2)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
HVX_Vector r0_hf = x0[i];
|
||||
HVX_Vector r1_hf = x1[i];
|
||||
HVX_Vector c0_hf = y0[i];
|
||||
HVX_Vector c1_hf = y1[i];
|
||||
|
||||
// Compute 4 dot products: r0xc0, r0xc1, r1xc0, r1xc1
|
||||
r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf);
|
||||
r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf);
|
||||
r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf);
|
||||
r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
|
||||
HVX_Vector r0_hf = Q6_V_vand_QV(bmask, x0[i]);
|
||||
HVX_Vector r1_hf = Q6_V_vand_QV(bmask, x1[i]);
|
||||
HVX_Vector c0_hf = Q6_V_vand_QV(bmask, y0[i]);
|
||||
HVX_Vector c1_hf = Q6_V_vand_QV(bmask, y1[i]);
|
||||
|
||||
r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf);
|
||||
r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf);
|
||||
r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf);
|
||||
r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf);
|
||||
}
|
||||
|
||||
HVX_Vector r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c0_sum_p), Q6_V_hi_W(r0_c0_sum_p)));
|
||||
HVX_Vector r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c1_sum_p), Q6_V_hi_W(r0_c1_sum_p)));
|
||||
HVX_Vector r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c0_sum_p), Q6_V_hi_W(r1_c0_sum_p)));
|
||||
HVX_Vector r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c1_sum_p), Q6_V_hi_W(r1_c1_sum_p)));
|
||||
|
||||
// Reduce and store results
|
||||
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
|
||||
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
|
||||
|
||||
hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0
|
||||
hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
|
||||
}
|
||||
|
||||
static inline void vec_dot_f16_f16_uu_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||
const HVX_UVector * restrict x = (const HVX_UVector *) vx;
|
||||
const HVX_UVector * restrict y = (const HVX_UVector *) vy;
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
HVX_Vector rsum = Q6_V_vzero();
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]);
|
||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
|
||||
HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
|
||||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
}
|
||||
|
||||
rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
|
||||
hvx_vec_store_u(&s[0], 4, rsum);
|
||||
}
|
||||
|
||||
static inline void vec_dot_f16_f32_uu_1x1(const uint32_t n, float * restrict s, const void * restrict x, const void * restrict y) {
|
||||
const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
|
||||
const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
const HVX_Vector zero = Q6_V_vzero();
|
||||
|
||||
HVX_Vector rsum = Q6_V_vzero();
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(2)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
|
||||
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
|
||||
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
|
||||
|
||||
// Load x (fp16)
|
||||
HVX_Vector x_hf = vx[i];
|
||||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
|
||||
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
|
||||
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
|
||||
|
||||
// Load x (fp16)
|
||||
HVX_Vector x_hf = vx[i];
|
||||
|
||||
// Zero-out unused elements
|
||||
// Note that we need to clear both x and y because they may contain NANs
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
x_hf = Q6_V_vand_QV(bmask, x_hf);
|
||||
y_hf = Q6_V_vand_QV(bmask, y_hf);
|
||||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
}
|
||||
|
||||
// Convert into fp32 and reduce
|
||||
rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
|
||||
hvx_vec_store_u(&s[0], 4, rsum);
|
||||
}
|
||||
|
||||
static inline void hvx_tensor_add_f32_grid(
|
||||
const struct htp_tensor * restrict dst,
|
||||
const struct htp_tensor * restrict src2,
|
||||
uint32_t start_row,
|
||||
uint32_t end_row,
|
||||
uint32_t start_col,
|
||||
uint32_t end_col,
|
||||
const struct fastdiv_values * div_ne11_12,
|
||||
const struct fastdiv_values * div_ne11
|
||||
) {
|
||||
if (start_row >= end_row || start_col >= end_col) return;
|
||||
const uint32_t nb1 = dst->nb[1]; // row stride in bytes
|
||||
|
||||
const uint32_t ne11 = dst->ne[1];
|
||||
const uint32_t ne12 = dst->ne[2];
|
||||
const uint32_t ne11_12 = ne11 * ne12;
|
||||
|
||||
const bool is_broadcast1 = (src2->ne[1] == 1);
|
||||
const bool is_broadcast2 = (src2->ne[2] == 1);
|
||||
const bool is_broadcast3 = (src2->ne[3] == 1);
|
||||
|
||||
for (uint32_t r = start_row; r < end_row; r++) {
|
||||
float * dst_row = (float *) ((uint8_t *) dst->data + r * nb1);
|
||||
|
||||
uint32_t i13 = fastdiv(r, div_ne11_12);
|
||||
uint32_t i12 = fastdiv(r - i13 * ne11_12, div_ne11);
|
||||
uint32_t i11 = r - i13 * ne11_12 - i12 * ne11;
|
||||
|
||||
uint32_t i23 = is_broadcast3 ? 0 : i13;
|
||||
uint32_t i22 = is_broadcast2 ? 0 : i12;
|
||||
uint32_t i21 = is_broadcast1 ? 0 : i11;
|
||||
|
||||
const float * src2_row = (const float *) ((const uint8_t *) src2->data +
|
||||
i21 * src2->nb[1] + i22 * src2->nb[2] + i23 * src2->nb[3]);
|
||||
|
||||
float * dst_ptr = &dst_row[start_col];
|
||||
const float * src2_ptr = &src2_row[start_col];
|
||||
int remaining = end_col - start_col;
|
||||
while (remaining >= 32) {
|
||||
HVX_Vector v_out = hvx_vmemu(dst_ptr);
|
||||
HVX_Vector v_z = hvx_vmemu(src2_ptr);
|
||||
hvx_vmemu(dst_ptr) = hvx_vec_add_f32_f32(v_out, v_z);
|
||||
dst_ptr += 32;
|
||||
src2_ptr += 32;
|
||||
remaining -= 32;
|
||||
}
|
||||
if (remaining > 0) {
|
||||
HVX_Vector v_out = hvx_vmemu(dst_ptr);
|
||||
HVX_Vector v_z = hvx_vmemu(src2_ptr);
|
||||
hvx_vec_store_u(dst_ptr, remaining * sizeof(float), hvx_vec_add_f32_f32(v_out, v_z));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -378,7 +378,7 @@ static inline HVX_VectorPair accum_q8_0_32x2(
|
||||
return Q6_W_vcombine_VV(v_sum1, v_sum0);
|
||||
}
|
||||
|
||||
static void tiled_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
|
||||
static void tiled_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y_q = vy;
|
||||
|
||||
@@ -401,10 +401,14 @@ static void tiled_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const
|
||||
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
if (sz) {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
|
||||
} else {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
}
|
||||
}
|
||||
|
||||
static void tiled_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
|
||||
static void tiled_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y0_q = vy0;
|
||||
const uint8_t * restrict y1_q = vy1;
|
||||
@@ -484,11 +488,19 @@ static void tiled_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float
|
||||
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
if (sz0) {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
|
||||
} else {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
}
|
||||
if (sz1) {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
|
||||
} else {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
}
|
||||
}
|
||||
|
||||
static void tiled_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
|
||||
static void tiled_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y_q = vy;
|
||||
|
||||
@@ -519,10 +531,14 @@ static void tiled_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const
|
||||
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
if (sz) {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
|
||||
} else {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
}
|
||||
}
|
||||
|
||||
static void tiled_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
|
||||
static void tiled_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y0_q = vy0;
|
||||
const uint8_t * restrict y1_q = vy1;
|
||||
@@ -637,11 +653,19 @@ static void tiled_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float
|
||||
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
if (sz0) {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
|
||||
} else {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
}
|
||||
if (sz1) {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
|
||||
} else {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
}
|
||||
}
|
||||
|
||||
static void tiled_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
|
||||
static void tiled_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y_q = vy;
|
||||
|
||||
@@ -663,10 +687,14 @@ static void tiled_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const
|
||||
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
if (sz) {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
|
||||
} else {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
}
|
||||
}
|
||||
|
||||
static void tiled_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
|
||||
static void tiled_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y0_q = vy0;
|
||||
const uint8_t * restrict y1_q = vy1;
|
||||
@@ -745,11 +773,19 @@ static void tiled_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float
|
||||
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
if (sz0) {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
|
||||
} else {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
}
|
||||
if (sz1) {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
|
||||
} else {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
}
|
||||
}
|
||||
|
||||
static void tiled_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
|
||||
static void tiled_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y_q = vy;
|
||||
|
||||
@@ -773,10 +809,14 @@ static void tiled_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const
|
||||
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
if (sz) {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
|
||||
} else {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
}
|
||||
}
|
||||
|
||||
static void tiled_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
|
||||
static void tiled_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y0_q = vy0;
|
||||
const uint8_t * restrict y1_q = vy1;
|
||||
@@ -857,11 +897,19 @@ static void tiled_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, floa
|
||||
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
if (sz0) {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
|
||||
} else {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
}
|
||||
if (sz1) {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
|
||||
} else {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
}
|
||||
}
|
||||
|
||||
static void tiled_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
|
||||
static void tiled_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y_q = vy;
|
||||
|
||||
@@ -896,10 +944,14 @@ static void tiled_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const
|
||||
|
||||
v_sum_float = hvx_vec_mul_f32_f32(v_sum_float, hvx_vec_splat_f32(0.5f));
|
||||
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
if (sz) {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
|
||||
} else {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
}
|
||||
}
|
||||
|
||||
static void tiled_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
|
||||
static void tiled_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y0_q = vy0;
|
||||
const uint8_t * restrict y1_q = vy1;
|
||||
@@ -1013,8 +1065,16 @@ static void tiled_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, floa
|
||||
v_sum_float_c0 = hvx_vec_mul_f32_f32(v_sum_float_c0, hvx_vec_splat_f32(0.5f));
|
||||
v_sum_float_c1 = hvx_vec_mul_f32_f32(v_sum_float_c1, hvx_vec_splat_f32(0.5f));
|
||||
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
if (sz0) {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
|
||||
} else {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
}
|
||||
if (sz1) {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
|
||||
} else {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
}
|
||||
}
|
||||
|
||||
static inline void quantize_f32_q8_0_tiled_kernel(
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "hvx-base.h"
|
||||
#include "hvx-inverse.h"
|
||||
#include "hvx-exp.h"
|
||||
|
||||
#define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022
|
||||
#define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777
|
||||
@@ -139,4 +140,42 @@ static inline void hvx_tanh_f32_aa(uint8_t * restrict dst, const uint8_t * restr
|
||||
hvx_tanh_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_fast_sigmoid_f16(HVX_Vector x_v) {
|
||||
const HVX_Vector v_one = hvx_vec_splat_f16(1.0f);
|
||||
const HVX_Vector v_neg_log2e = hvx_vec_splat_f16(-EXP_LOG2E_F);
|
||||
const HVX_Vector em_mask = Q6_Vh_vsplat_R(0x7FFF);
|
||||
|
||||
// Compute absolute value of x_v
|
||||
HVX_Vector abs_x = Q6_V_vand_VV(x_v, em_mask);
|
||||
|
||||
// Compute u = -abs_x * log2(e) <= 0.
|
||||
HVX_Vector u = hvx_vec_mul_f16_f16(abs_x, v_neg_log2e);
|
||||
|
||||
// Clamp input to prevent underflow in exp2
|
||||
const HVX_Vector v_clamp_min = hvx_vec_splat_f16(-24.0f);
|
||||
u = Q6_Vhf_vmax_VhfVhf(v_clamp_min, u);
|
||||
|
||||
HVX_Vector exp_val = hvx_vec_exp2_f16(u);
|
||||
HVX_Vector denom = hvx_vec_add_f16_f16(v_one, exp_val);
|
||||
HVX_Vector sig_abs = hvx_vec_inverse_f16(denom);
|
||||
|
||||
// check if x_v < 0 (using integer comparison on absolute value)
|
||||
HVX_VectorPred is_neg = Q6_Q_vcmp_gt_VhVh(abs_x, x_v);
|
||||
|
||||
// If x_v < 0, return 1.0f - sig_abs
|
||||
HVX_Vector sig_neg = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vsub_VhfVhf(v_one, sig_abs));
|
||||
return Q6_V_vmux_QVV(is_neg, sig_neg, sig_abs);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_tanh_f16(HVX_Vector x) {
|
||||
// tanh(x) = 2 * sigmoid(2x) - 1
|
||||
const HVX_Vector v_two = hvx_vec_splat_f16(2.0f);
|
||||
|
||||
HVX_Vector x2 = hvx_vec_mul_f16_f16(x, v_two);
|
||||
HVX_Vector sig2x = hvx_vec_fast_sigmoid_f16(x2);
|
||||
|
||||
const HVX_Vector v_neg_one = hvx_vec_splat_f16(-1.0f);
|
||||
return hvx_vec_add_f16_f16(hvx_vec_mul_f16_f16(sig2x, v_two), v_neg_one);
|
||||
}
|
||||
|
||||
#endif /* HVX_SIGMOID_H */
|
||||
|
||||
@@ -575,6 +575,7 @@ static inline void profile_stop(uint32_t mode, struct profile_data * d) {
|
||||
static int execute_op(struct htp_ops_context * octx) {
|
||||
switch (octx->op) {
|
||||
case HTP_OP_MUL_MAT:
|
||||
case HTP_OP_MUL_MAT_ADD:
|
||||
return op_matmul(octx);
|
||||
|
||||
case HTP_OP_MUL_MAT_ID:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -392,56 +392,49 @@ static inline size_t htp_mm_hvx_get_vtcm_sizes(
|
||||
case HTP_MM_KERNEL_HVX_QUANT_ROW: {
|
||||
size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10);
|
||||
|
||||
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
|
||||
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
|
||||
vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256);
|
||||
|
||||
// src0 spad is also used in dynamic quantizer to store padded src1 rows
|
||||
size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, QK_Q8_0_TILED * sizeof(float));
|
||||
if (vtcm_src0_size < src1_row_size_padded) {
|
||||
vtcm_src0_size = src1_row_size_padded;
|
||||
}
|
||||
|
||||
vtcm_src0_size = vtcm_src0_size * n_threads;
|
||||
vtcm_dst_size = vtcm_dst_size * n_threads;
|
||||
|
||||
if (is_repack) {
|
||||
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
|
||||
uint32_t n_k_tiles = ne10 / 32;
|
||||
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
|
||||
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
|
||||
if (repacked_vtcm_size < src1_row_size_padded) {
|
||||
repacked_vtcm_size = src1_row_size_padded;
|
||||
}
|
||||
vtcm_src0_size = repacked_vtcm_size * n_threads;
|
||||
}
|
||||
|
||||
size_t quant_scratch_size_per_thread = htp_mm_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float));
|
||||
size_t dst_size_per_thread = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
|
||||
if (dst_size_per_thread < quant_scratch_size_per_thread) {
|
||||
dst_size_per_thread = quant_scratch_size_per_thread;
|
||||
}
|
||||
vtcm_dst_size = dst_size_per_thread * n_threads;
|
||||
break;
|
||||
}
|
||||
case HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT: {
|
||||
size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(ne10) : htp_mm_q8_0_flat_row_size(ne10);
|
||||
|
||||
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
|
||||
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
|
||||
vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256);
|
||||
|
||||
size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, 256);
|
||||
if (vtcm_src0_size < src1_row_size_padded) {
|
||||
vtcm_src0_size = src1_row_size_padded;
|
||||
}
|
||||
|
||||
vtcm_src0_size = vtcm_src0_size * n_threads;
|
||||
vtcm_dst_size = vtcm_dst_size * n_threads;
|
||||
|
||||
if (is_repack) {
|
||||
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
|
||||
uint32_t n_k_tiles = ne10 / 32;
|
||||
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
|
||||
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
|
||||
if (repacked_vtcm_size < src1_row_size_padded) {
|
||||
repacked_vtcm_size = src1_row_size_padded;
|
||||
}
|
||||
vtcm_src0_size = repacked_vtcm_size * n_threads;
|
||||
}
|
||||
|
||||
size_t quant_scratch_size_per_thread = htp_mm_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float));
|
||||
size_t dst_size_per_thread = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
|
||||
if (dst_size_per_thread < quant_scratch_size_per_thread) {
|
||||
dst_size_per_thread = quant_scratch_size_per_thread;
|
||||
}
|
||||
vtcm_dst_size = dst_size_per_thread * n_threads;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
@@ -463,7 +456,8 @@ static inline size_t htp_mm_hvx_id_get_vtcm_sizes(
|
||||
size_t src0_row_size, // nb01
|
||||
uint32_t n_prefetch,
|
||||
size_t * vtcm_src0_size_out,
|
||||
size_t * vtcm_src1_size_out
|
||||
size_t * vtcm_src1_size_out,
|
||||
size_t * vtcm_dst_size_out
|
||||
) {
|
||||
const bool is_repack = (wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 ||
|
||||
wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL ||
|
||||
@@ -476,29 +470,22 @@ static inline size_t htp_mm_hvx_id_get_vtcm_sizes(
|
||||
size_t src0_sz_per_thread = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
|
||||
size_t src1_sz = htp_mm_round_up(src1_row_size * src1_nrows, 256);
|
||||
|
||||
// src0 spad also holds temporary transposed src1 columns during dynamic quantization.
|
||||
const size_t src1_row_size_padded = htp_mm_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float));
|
||||
if (src0_sz_per_thread < src1_row_size_padded) {
|
||||
src0_sz_per_thread = src1_row_size_padded;
|
||||
}
|
||||
|
||||
if (is_repack) {
|
||||
const uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
|
||||
const uint32_t n_k_tiles = ne10 / 32;
|
||||
const uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
|
||||
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
|
||||
if (repacked_vtcm_size < src1_row_size_padded) {
|
||||
repacked_vtcm_size = src1_row_size_padded;
|
||||
}
|
||||
src0_sz_per_thread = repacked_vtcm_size;
|
||||
}
|
||||
|
||||
const size_t vtcm_src0_size = src0_sz_per_thread * n_threads;
|
||||
const size_t vtcm_dst_size = htp_mm_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float)) * n_threads;
|
||||
|
||||
*vtcm_src0_size_out = vtcm_src0_size;
|
||||
*vtcm_src1_size_out = src1_sz;
|
||||
*vtcm_dst_size_out = vtcm_dst_size;
|
||||
|
||||
return vtcm_src0_size + src1_sz;
|
||||
return vtcm_src0_size + src1_sz + vtcm_dst_size;
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
@@ -31,6 +31,11 @@ if (GGML_OPENCL_EMBED_KERNELS)
|
||||
target_include_directories(${TARGET_NAME} PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/autogenerated")
|
||||
endif ()
|
||||
|
||||
if (GGML_OPENCL_USE_ADRENO_BIN_KERNELS)
|
||||
message(STATUS "OpenCL will use precompiled binary kernels for Adreno (improved performance on some platforms)")
|
||||
add_compile_definitions(GGML_OPENCL_USE_ADRENO_BIN_KERNELS)
|
||||
endif ()
|
||||
|
||||
function(ggml_opencl_add_kernel KNAME)
|
||||
set(KERN_HDR ${CMAKE_CURRENT_BINARY_DIR}/autogenerated/${KNAME}.cl.h)
|
||||
set(KERN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/kernels/${KNAME}.cl)
|
||||
|
||||
@@ -13,6 +13,22 @@
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_BIN_KERNELS
|
||||
#include "libdl.h"
|
||||
#ifdef _WIN32
|
||||
#define KERNEL_LIB_NAME "adreno-opencl-kernels.dll"
|
||||
#else
|
||||
#define KERNEL_LIB_NAME "libadreno-opencl-kernels.so"
|
||||
#endif // _WIN32
|
||||
#endif // GGML_OPENCL_USE_ADRENO_BIN_KERNELS
|
||||
|
||||
typedef const void * (*get_adreno_bin_kernel_func_t)(
|
||||
const char * name,
|
||||
const char * gpu_name,
|
||||
const char * compiler_ver,
|
||||
size_t * out_size
|
||||
);
|
||||
|
||||
#include <CL/cl.h>
|
||||
|
||||
#include <inttypes.h>
|
||||
@@ -476,6 +492,8 @@ struct ggml_backend_opencl_context {
|
||||
|
||||
bool adreno_has_large_buffer;
|
||||
bool adreno_use_large_buffer;
|
||||
bool adreno_use_bin_kernels;
|
||||
get_adreno_bin_kernel_func_t get_adreno_bin_kernel_func = nullptr;
|
||||
ggml_cl_compiler_version adreno_cl_compiler_version;
|
||||
|
||||
std::string kernel_compile_opts; // cached for lazy-compiled kernels.
|
||||
@@ -718,15 +736,15 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_gated_delta_net_f32[4][2][2] = {};
|
||||
|
||||
cl_kernel kernel_timestep_embedding;
|
||||
cl_kernel kernel_gemv_moe_q4_0_f32_ns, kernel_gemm_moe_q4_0_f32_ns;
|
||||
cl_kernel kernel_gemv_moe_q4_1_f32_ns, kernel_gemm_moe_q4_1_f32_ns;
|
||||
cl_kernel kernel_gemv_moe_q4_0_f32_ns, kernel_gemm_moe_q4_0_f32_ns, kernel_gemm_moe_q4_0_f32_ns_bin;
|
||||
cl_kernel kernel_gemv_moe_q4_1_f32_ns, kernel_gemm_moe_q4_1_f32_ns, kernel_gemm_moe_q4_1_f32_ns_bin;
|
||||
cl_kernel kernel_gemv_moe_q5_0_f32_ns, kernel_gemm_moe_q5_0_f32_ns;
|
||||
cl_kernel kernel_gemv_moe_q5_1_f32_ns, kernel_gemm_moe_q5_1_f32_ns;
|
||||
cl_kernel kernel_gemv_moe_q4_k_f32_ns, kernel_gemm_moe_q4_k_f32_ns;
|
||||
cl_kernel kernel_gemv_moe_q4_k_f32_ns, kernel_gemm_moe_q4_k_f32_ns, kernel_gemm_moe_q4_k_f32_ns_bin;
|
||||
cl_kernel kernel_gemv_moe_q5_k_f32_ns, kernel_gemm_moe_q5_k_f32_ns;
|
||||
cl_kernel kernel_gemv_moe_q6_k_f32_ns, kernel_gemm_moe_q6_k_f32_ns;
|
||||
cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32;
|
||||
cl_kernel kernel_gemv_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns;
|
||||
cl_kernel kernel_gemv_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns_bin;
|
||||
cl_kernel kernel_moe_reorder_b;
|
||||
cl_kernel kernel_moe_histogram, kernel_moe_scan, kernel_moe_fill, kernel_moe_scatter;
|
||||
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
|
||||
@@ -870,6 +888,20 @@ struct ggml_backend_opencl_context {
|
||||
#endif
|
||||
}
|
||||
|
||||
const void * get_adreno_bin_kernel(const std::string &kernel_name, size_t *bin_size) const {
|
||||
if (!get_adreno_bin_kernel_func) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
size_t sz;
|
||||
const void * kernel_bin = get_adreno_bin_kernel_func(
|
||||
kernel_name.c_str(), device_name.c_str(), driver_version.c_str(), &sz);
|
||||
if (bin_size) {
|
||||
*bin_size = sz;
|
||||
}
|
||||
return kernel_bin;
|
||||
}
|
||||
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
// Transpose kernels
|
||||
cl_program program_transpose;
|
||||
@@ -891,7 +923,7 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_gemv_noshuffle_q4_0_f32_32000_1_4096;
|
||||
cl_kernel kernel_gemv_noshuffle_q4_1_f32;
|
||||
cl_kernel kernel_gemm_noshuffle_q4_1_f32;
|
||||
cl_kernel kernel_gemm_noshuffle_q8_0_f32;
|
||||
cl_kernel kernel_gemm_noshuffle_q8_0_f32, kernel_gemm_noshuffle_q8_0_f32_bin;
|
||||
cl_kernel kernel_gemv_noshuffle_q8_0_f32;
|
||||
cl_kernel kernel_gemm_noshuffle_q1_0_f32;
|
||||
cl_kernel kernel_gemv_noshuffle_q1_0_f32;
|
||||
@@ -988,6 +1020,32 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co
|
||||
return build_program_from_source_ex(ctx, dev, program_buffer, compile_opts, /*fatal=*/true);
|
||||
}
|
||||
|
||||
static cl_program build_program_from_binary(cl_context ctx, cl_device_id dev, const char* program_buffer, const std::string &compile_opts, size_t bin_size = 0) {
|
||||
cl_program p;
|
||||
char *program_log;
|
||||
size_t log_size;
|
||||
int err;
|
||||
|
||||
p = clCreateProgramWithBinary(ctx, 1, &dev, &bin_size, (const unsigned char**)&program_buffer, NULL, &err);
|
||||
if(err < 0) {
|
||||
GGML_LOG_ERROR("OpenCL error creating program from binary");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
err = clBuildProgram(p, 0, NULL, compile_opts.c_str(), NULL, NULL);
|
||||
if(err < 0) {
|
||||
clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);
|
||||
program_log = (char*) malloc(log_size + 1);
|
||||
program_log[log_size] = '\0';
|
||||
clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, log_size + 1, program_log, NULL);
|
||||
GGML_LOG_ERROR("ggml_opencl: kernel compile error:\n\n%s\n", program_log);
|
||||
free(program_log);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
return p;
|
||||
}
|
||||
|
||||
static void load_cl_kernels_argsort(ggml_backend_opencl_context *backend_ctx) {
|
||||
// compiler options for general kernels
|
||||
auto opencl_c_std =
|
||||
@@ -1014,6 +1072,17 @@ static void load_cl_kernels_argsort(ggml_backend_opencl_context *backend_ctx) {
|
||||
}
|
||||
}
|
||||
|
||||
static bool use_adreno_bin_kernels(ggml_backend_opencl_context * backend_ctx) {
|
||||
#ifndef GGML_OPENCL_USE_ADRENO_BIN_KERNELS
|
||||
return false;
|
||||
#else
|
||||
if (backend_ctx->gpu_family != GPU_FAMILY::ADRENO) {
|
||||
return false;
|
||||
}
|
||||
return backend_ctx->adreno_use_bin_kernels;
|
||||
#endif // GGML_OPENCL_USE_ADRENO_BIN_KERNELS
|
||||
}
|
||||
|
||||
static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) {
|
||||
if (backend_ctx->kernels_loaded) {
|
||||
return;
|
||||
@@ -3323,6 +3392,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) {
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// gemm_noshuffle_q8_0_f32_bin
|
||||
{
|
||||
size_t bin_size = 0;
|
||||
backend_ctx->kernel_gemm_noshuffle_q8_0_f32_bin = nullptr;
|
||||
|
||||
if (use_adreno_bin_kernels(backend_ctx)) {
|
||||
const char * kernel_bin = (const char *)backend_ctx->get_adreno_bin_kernel("gemm_noshuffle_q8_0_f32_ila", &bin_size);
|
||||
if (kernel_bin && bin_size > 0) {
|
||||
cl_program prog =
|
||||
build_program_from_binary(backend_ctx->context, backend_ctx->device, kernel_bin, compile_opts, bin_size);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q8_0_f32_bin = clCreateKernel(prog, "kernel_gemm_noshuffle_q8_0_f32_ila", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// gemv_noshuffle_general_q8_0_f32
|
||||
{
|
||||
std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std +
|
||||
@@ -3424,6 +3511,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) {
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// gemm_moe_q4_1_f32_ns_bin
|
||||
{
|
||||
size_t bin_size = 0;
|
||||
backend_ctx->kernel_gemm_moe_q4_1_f32_ns_bin = nullptr;
|
||||
|
||||
if (use_adreno_bin_kernels(backend_ctx)) {
|
||||
const char * kernel_bin = (const char *)backend_ctx->get_adreno_bin_kernel("gemm_moe_q4_1_f32_ns_ila", &bin_size);
|
||||
if (kernel_bin && bin_size > 0) {
|
||||
cl_program prog =
|
||||
build_program_from_binary(backend_ctx->context, backend_ctx->device, kernel_bin, CL_moe_compile_opts, bin_size);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_gemm_moe_q4_1_f32_ns_bin = clCreateKernel(prog, "kernel_gemm_moe_q4_1_f32_ns_ila", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// gemv_moe_mxfp4_f32
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
@@ -3490,6 +3595,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) {
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// gemm_moe_q4_0_f32_ns_bin
|
||||
{
|
||||
size_t bin_size = 0;
|
||||
backend_ctx->kernel_gemm_moe_q4_0_f32_ns_bin = nullptr;
|
||||
|
||||
if (use_adreno_bin_kernels(backend_ctx)) {
|
||||
const char * kernel_bin = (const char *)backend_ctx->get_adreno_bin_kernel("gemm_moe_q4_0_f32_ns_ila", &bin_size);
|
||||
if (kernel_bin && bin_size > 0) {
|
||||
cl_program prog =
|
||||
build_program_from_binary(backend_ctx->context, backend_ctx->device, kernel_bin, CL_moe_compile_opts, bin_size);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_gemm_moe_q4_0_f32_ns_bin = clCreateKernel(prog, "kernel_gemm_moe_q4_0_f32_ns_ila", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// gemv_moe_q5_0_f32_ns
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
@@ -3592,6 +3715,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) {
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// gemm_moe_q4_k_f32_ns_bin
|
||||
{
|
||||
size_t bin_size = 0;
|
||||
backend_ctx->kernel_gemm_moe_q4_k_f32_ns_bin = nullptr;
|
||||
|
||||
if (use_adreno_bin_kernels(backend_ctx)) {
|
||||
const char * kernel_bin = (const char *)backend_ctx->get_adreno_bin_kernel("gemm_moe_q4_k_f32_ns_ila", &bin_size);
|
||||
if (kernel_bin && bin_size > 0) {
|
||||
cl_program prog =
|
||||
build_program_from_binary(backend_ctx->context, backend_ctx->device, kernel_bin, CL_moe_compile_opts, bin_size);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_gemm_moe_q4_k_f32_ns_bin = clCreateKernel(prog, "kernel_gemm_moe_q4_k_f32_ns_ila", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// gemv_moe_q5_k_f32_ns
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
@@ -3689,9 +3830,27 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) {
|
||||
cl_program prog =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_mxfp4_f32_ns", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_mxfp4_f32_ns", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// gemm_moe_mxfp4_f32_ns_bin
|
||||
{
|
||||
size_t bin_size = 0;
|
||||
backend_ctx->kernel_gemm_moe_mxfp4_f32_ns_bin = nullptr;
|
||||
|
||||
if (use_adreno_bin_kernels(backend_ctx)) {
|
||||
const char * kernel_bin = (const char *)backend_ctx->get_adreno_bin_kernel("gemm_moe_mxfp4_f32_ns_ila", &bin_size);
|
||||
if (kernel_bin && bin_size > 0) {
|
||||
cl_program prog =
|
||||
build_program_from_binary(backend_ctx->context, backend_ctx->device, kernel_bin, CL_moe_compile_opts, bin_size);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32_ns_bin = clCreateKernel(prog, "kernel_gemm_moe_mxfp4_f32_ns_ila", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// moe_reorder_b
|
||||
@@ -4770,6 +4929,27 @@ static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev) {
|
||||
backend_ctx->adreno_use_large_buffer = getenv("GGML_OPENCL_ADRENO_USE_LARGE_BUFFER") != nullptr &&
|
||||
backend_ctx->gpu_family == GPU_FAMILY::ADRENO;
|
||||
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_BIN_KERNELS
|
||||
// try loading adreno binary kernels if enabled
|
||||
// if fails to load, builtin kernels will be used
|
||||
{
|
||||
dl_handle * kernel_lib_handle = dl_load_library(KERNEL_LIB_NAME);
|
||||
backend_ctx->adreno_use_bin_kernels = false;
|
||||
|
||||
if (kernel_lib_handle) {
|
||||
backend_ctx->get_adreno_bin_kernel_func = (get_adreno_bin_kernel_func_t)dl_get_sym(kernel_lib_handle, "get_adreno_kernels");
|
||||
if (backend_ctx->get_adreno_bin_kernel_func) {
|
||||
GGML_LOG_INFO("ggml_opencl: loaded bin kernel library %s\n", KERNEL_LIB_NAME);
|
||||
backend_ctx->adreno_use_bin_kernels = true;
|
||||
} else {
|
||||
GGML_LOG_INFO("ggml_opencl: bin kernel library %s is invalid, will use builtin kernels\n", KERNEL_LIB_NAME);
|
||||
}
|
||||
} else {
|
||||
GGML_LOG_INFO("ggml_opencl: failed to load %s, will use builtin kernels\n", KERNEL_LIB_NAME);
|
||||
}
|
||||
}
|
||||
#endif // GGML_OPENCL_USE_ADRENO_BIN_KERNELS
|
||||
|
||||
cl_int err;
|
||||
|
||||
// A local ref of cl_context for convenience
|
||||
@@ -14972,6 +15152,99 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t
|
||||
CL_CHECK(clReleaseMemObject(b_img));
|
||||
CL_CHECK(clReleaseMemObject(b_sub_buf));
|
||||
} else {
|
||||
// use bin kernel if available
|
||||
if (backend_ctx->kernel_gemm_noshuffle_q8_0_f32_bin) {
|
||||
int K_pad = K;
|
||||
|
||||
cl_mem b_sub_buf = nullptr;
|
||||
cl_mem d_sub_buf = nullptr;
|
||||
|
||||
cl_mem a_img = nullptr;
|
||||
cl_mem s_img = nullptr;
|
||||
cl_mem b_img = nullptr;
|
||||
cl_mem d_img = nullptr;
|
||||
|
||||
// subbuffer for activations
|
||||
region.origin = offset1;
|
||||
region.size = K_pad * N * sizeof(float);
|
||||
CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err));
|
||||
|
||||
// Create subbuffer and image1d_buffer for dst
|
||||
region.origin = (extrad->offset); // + dst->view_offs;
|
||||
region.size = M * N * sizeof(float);
|
||||
CL_CHECK((d_sub_buf = clCreateSubBuffer((extrad->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err));
|
||||
|
||||
// create an image for A
|
||||
img_fmt = { CL_R, CL_FLOAT};
|
||||
memset(&img_desc, 0, sizeof(img_desc));
|
||||
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc.image_width = M * K / 4; // Divide by 4 for char -> float
|
||||
img_desc.buffer = extra0_q8_0->q;
|
||||
CL_CHECK((a_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
|
||||
|
||||
// create an image for Scale
|
||||
img_fmt = { CL_R, CL_HALF_FLOAT};
|
||||
memset(&img_desc, 0, sizeof(img_desc));
|
||||
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc.image_width = M * K / 32; // Block size is 32
|
||||
img_desc.buffer = extra0_q8_0->d;
|
||||
CL_CHECK((s_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
|
||||
|
||||
// create an image for B from sub_buffer
|
||||
img_fmt = {CL_R, CL_FLOAT};
|
||||
memset(&img_desc, 0, sizeof(img_desc));
|
||||
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc.image_width = K_pad * N;
|
||||
img_desc.buffer = b_sub_buf;
|
||||
CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
|
||||
|
||||
// img for d
|
||||
img_fmt = {CL_R, CL_FLOAT};
|
||||
memset(&img_desc, 0, sizeof(img_desc));
|
||||
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc.image_width = M * N;
|
||||
img_desc.buffer = d_sub_buf;
|
||||
CL_CHECK((d_img = clCreateImage(context, CL_MEM_WRITE_ONLY, &img_fmt, &img_desc, NULL, &err), err));
|
||||
|
||||
// gemm
|
||||
kernel = backend_ctx->kernel_gemm_noshuffle_q8_0_f32_bin;
|
||||
|
||||
bool layoutA_Mfirst = true;
|
||||
bool layoutS_Mfirst = true;
|
||||
bool layoutB_Nfirst = false;
|
||||
bool layoutC_Mfirst = true;
|
||||
|
||||
cl_uint lineStrideMatrixAinBytes = layoutA_Mfirst ? M * 4 : K; // int8
|
||||
cl_uint lineStrideMatrixSinBytes = layoutS_Mfirst ? M * 2 : (K / 32) * 2; // fp16
|
||||
cl_uint lineStrideMatrixBinBytes = layoutB_Nfirst ? N * 4 : K_pad * 4; // fp32
|
||||
cl_uint lineStrideMatrixCinBytes = layoutC_Mfirst ? M * 4 : N * 4; // fp32
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &a_img));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &s_img));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &extra1->offset));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &d_img));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &extrad->offset));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &K));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &lineStrideMatrixAinBytes));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &lineStrideMatrixSinBytes));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &lineStrideMatrixBinBytes));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &lineStrideMatrixCinBytes));
|
||||
|
||||
size_t global_work_size[] = { 64, (size_t)CEIL_DIV(M, 64), (size_t)CEIL_DIV(N, 64)};
|
||||
size_t local_work_size[] = { 64, 2, 2 };
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
|
||||
CL_CHECK(clReleaseMemObject(b_sub_buf));
|
||||
CL_CHECK(clReleaseMemObject(d_sub_buf));
|
||||
CL_CHECK(clReleaseMemObject(a_img));
|
||||
CL_CHECK(clReleaseMemObject(s_img));
|
||||
CL_CHECK(clReleaseMemObject(b_img));
|
||||
CL_CHECK(clReleaseMemObject(d_img));
|
||||
return;
|
||||
}
|
||||
|
||||
cl_mem b_sub_buf = nullptr;
|
||||
cl_mem b_sub_buf_trans = nullptr;
|
||||
cl_mem b_img = nullptr;
|
||||
@@ -17825,6 +18098,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
|
||||
|
||||
} else { // for gemm
|
||||
kernel = backend_ctx->kernel_gemm_moe_q4_0_f32_ns;
|
||||
if (backend_ctx->kernel_gemm_moe_q4_0_f32_ns_bin) {
|
||||
kernel = backend_ctx->kernel_gemm_moe_q4_0_f32_ns_bin;
|
||||
}
|
||||
|
||||
// Reorder router if called from test-backend-ops or when new router is generated.
|
||||
// Otherwise reuse the reordered result from previous mul_mat_id call.
|
||||
@@ -17870,6 +18146,11 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
|
||||
cl_image_desc image_desc_buf_src1;
|
||||
image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
|
||||
image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}};
|
||||
if (backend_ctx->kernel_gemm_moe_q4_0_f32_ns_bin) {
|
||||
// bin kernel uses slightly different image format
|
||||
image_format_buf_src1 = {CL_R, CL_FLOAT};
|
||||
image_desc_buf_src1.image_width = static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size);
|
||||
}
|
||||
image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
@@ -18042,6 +18323,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
|
||||
|
||||
} else { // for gemm
|
||||
kernel = backend_ctx->kernel_gemm_moe_q4_1_f32_ns;
|
||||
if (backend_ctx->kernel_gemm_moe_q4_1_f32_ns_bin) {
|
||||
kernel = backend_ctx->kernel_gemm_moe_q4_1_f32_ns_bin;
|
||||
}
|
||||
|
||||
// Reorder router if called from test-backend-ops or when new router is generated.
|
||||
// Otherwise reuse the reordered result from previous mul_mat_id call.
|
||||
@@ -18087,6 +18371,11 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
|
||||
cl_image_desc image_desc_buf_src1;
|
||||
image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
|
||||
image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}};
|
||||
if (backend_ctx->kernel_gemm_moe_q4_1_f32_ns_bin) {
|
||||
// bin kernel uses slightly different image format
|
||||
image_format_buf_src1 = {CL_R, CL_FLOAT};
|
||||
image_desc_buf_src1.image_width = static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size);
|
||||
}
|
||||
image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
@@ -18648,6 +18937,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
|
||||
|
||||
} else { // for gemm
|
||||
kernel = backend_ctx->kernel_gemm_moe_q4_k_f32_ns;
|
||||
if (backend_ctx->kernel_gemm_moe_q4_k_f32_ns_bin) {
|
||||
kernel = backend_ctx->kernel_gemm_moe_q4_k_f32_ns_bin;
|
||||
}
|
||||
|
||||
// Reorder router if called from test-backend-ops or when new router is generated.
|
||||
// Otherwise reuse the reordered result from previous mul_mat_id call.
|
||||
@@ -18689,6 +18981,11 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
|
||||
CL_CHECK(status);
|
||||
cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
|
||||
cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}};
|
||||
if (backend_ctx->kernel_gemm_moe_q4_k_f32_ns_bin) {
|
||||
// bin kernel uses slightly different image format
|
||||
image_format_buf_src1 = {CL_R, CL_FLOAT};
|
||||
image_desc_buf_src1.image_width = static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size);
|
||||
}
|
||||
image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
@@ -19172,6 +19469,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
|
||||
|
||||
} else { // for gemm
|
||||
kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32_ns;
|
||||
if (backend_ctx->kernel_gemm_moe_mxfp4_f32_ns_bin) {
|
||||
kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32_ns_bin;
|
||||
}
|
||||
|
||||
// Reorder router if called from test-backend-ops or when new router is generated.
|
||||
// Otherwise reuse the reordered result from previous mul_mat_id call.
|
||||
@@ -19218,6 +19518,11 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
|
||||
cl_image_desc image_desc_buf_src1;
|
||||
image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
|
||||
image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}};
|
||||
if (backend_ctx->kernel_gemm_moe_mxfp4_f32_ns_bin) {
|
||||
// bin kernel uses slightly different image format
|
||||
image_format_buf_src1 = {CL_R, CL_FLOAT};
|
||||
image_desc_buf_src1.image_width = static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size);
|
||||
}
|
||||
image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef _WIN32
|
||||
# define WIN32_LEAN_AND_MEAN
|
||||
# ifndef NOMINMAX
|
||||
# define NOMINMAX
|
||||
# endif
|
||||
# include <windows.h>
|
||||
# include <winevt.h>
|
||||
#else
|
||||
# include <dlfcn.h>
|
||||
# include <unistd.h>
|
||||
#endif
|
||||
#include <filesystem>
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
#ifdef _WIN32
|
||||
|
||||
using dl_handle = std::remove_pointer_t<HMODULE>;
|
||||
|
||||
struct dl_handle_deleter {
|
||||
void operator()(HMODULE handle) {
|
||||
FreeLibrary(handle);
|
||||
}
|
||||
};
|
||||
|
||||
static inline dl_handle * dl_load_library(const fs::path & path) {
|
||||
// suppress error dialogs for missing DLLs
|
||||
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
|
||||
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
|
||||
|
||||
HMODULE handle = LoadLibraryW(path.wstring().c_str());
|
||||
|
||||
SetErrorMode(old_mode);
|
||||
|
||||
return handle;
|
||||
}
|
||||
|
||||
static inline void * dl_get_sym(dl_handle * handle, const char * name) {
|
||||
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
|
||||
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
|
||||
|
||||
void * p = (void *) GetProcAddress(handle, name);
|
||||
|
||||
SetErrorMode(old_mode);
|
||||
|
||||
return p;
|
||||
}
|
||||
|
||||
static inline const char * dl_error() {
|
||||
return "";
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
using dl_handle = void;
|
||||
|
||||
struct dl_handle_deleter {
|
||||
void operator()(void * handle) {
|
||||
dlclose(handle);
|
||||
}
|
||||
};
|
||||
|
||||
static inline dl_handle * dl_load_library(const fs::path & path) {
|
||||
dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
|
||||
return handle;
|
||||
}
|
||||
|
||||
static inline void * dl_get_sym(dl_handle * handle, const char * name) {
|
||||
return dlsym(handle, name);
|
||||
}
|
||||
|
||||
static inline const char * dl_error() {
|
||||
const char *rslt = dlerror();
|
||||
return rslt != nullptr ? rslt : "";
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -159,6 +159,9 @@ extern "C" {
|
||||
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
||||
};
|
||||
|
||||
// Get the model file type (quantization) as a string, e.g. "Q8_0" or "Q4_K - Medium"
|
||||
LLAMA_API const char * llama_ftype_name(enum llama_ftype ftype);
|
||||
|
||||
enum llama_rope_scaling_type {
|
||||
LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED = -1,
|
||||
LLAMA_ROPE_SCALING_TYPE_NONE = 0,
|
||||
@@ -606,6 +609,9 @@ extern "C" {
|
||||
// Get a string describing the model type
|
||||
LLAMA_API int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);
|
||||
|
||||
// Get the model file type (quantization), e.g. LLAMA_FTYPE_MOSTLY_Q8_0
|
||||
LLAMA_API enum llama_ftype llama_model_ftype(const struct llama_model * model);
|
||||
|
||||
// Returns the total size of all the tensors in the model in bytes
|
||||
LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
|
||||
|
||||
|
||||
@@ -69,13 +69,16 @@ mbuf=
|
||||
mmsel=
|
||||
[ "$MM" != "" ] && mmsel="GGML_HEXAGON_MM_SELECT=$MM"
|
||||
|
||||
fasel=
|
||||
[ "$FA" != "" ] && fasel="GGML_HEXAGON_FA_SELECT=$FA"
|
||||
|
||||
set -x
|
||||
|
||||
adb $adbserial $adbhost shell " \
|
||||
cd $basedir; ulimit -c unlimited; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opflt $opfuse $vmem $mbuf $mmsel \
|
||||
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opflt $opfuse $vmem $mbuf $mmsel $fasel \
|
||||
./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--ctx-size 8192 --ubatch-size 1024 -fa on \
|
||||
|
||||
@@ -57,6 +57,9 @@ opfuse=
|
||||
mmsel=
|
||||
[ "$MM" != "" ] && mmsel="GGML_HEXAGON_MM_SELECT=$MM"
|
||||
|
||||
fasel=
|
||||
[ "$FA" != "" ] && fasel="GGML_HEXAGON_FA_SELECT=$FA"
|
||||
|
||||
set -x
|
||||
|
||||
tool=$1; shift
|
||||
@@ -65,5 +68,5 @@ adb $adbserial $adbhost shell " \
|
||||
cd $basedir; ulimit -c unlimited; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opfuse $mmsel ./$branch/bin/$tool $@ \
|
||||
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opfuse $mmsel $fasel ./$branch/bin/$tool $@ \
|
||||
"
|
||||
|
||||
@@ -230,6 +230,12 @@ def print_ascii_timeline(op_name, dims, types, usec, cycles, events, evt_val=Non
|
||||
char = 'Q'
|
||||
elif norm_evt == 'A-PREP':
|
||||
char = 'A'
|
||||
elif norm_evt == 'Q-PREP':
|
||||
char = 'q'
|
||||
elif norm_evt == 'K-PREP':
|
||||
char = 'k'
|
||||
elif norm_evt == 'V-PREP':
|
||||
char = 'v'
|
||||
elif norm_evt == 'W-DEQUANT':
|
||||
char = 'D'
|
||||
elif norm_evt == 'O-PROC':
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
HTTPLIB_VERSION = "refs/tags/v0.48.0"
|
||||
HTTPLIB_VERSION = "refs/tags/v0.49.0"
|
||||
|
||||
vendor = {
|
||||
"https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp",
|
||||
|
||||
+46
-44
@@ -27,52 +27,54 @@ const char * llama_file_version_name(llama_fver version) {
|
||||
return "unknown";
|
||||
}
|
||||
|
||||
static std::string llama_model_ftype_name(llama_ftype ftype) {
|
||||
if (ftype & LLAMA_FTYPE_GUESSED) {
|
||||
return llama_model_ftype_name((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) + " (guessed)";
|
||||
}
|
||||
#define LLAMA_FTYPE_PREFIX "(guessed) "
|
||||
|
||||
switch (ftype) {
|
||||
case LLAMA_FTYPE_ALL_F32: return "all F32";
|
||||
case LLAMA_FTYPE_MOSTLY_F16: return "F16";
|
||||
case LLAMA_FTYPE_MOSTLY_BF16: return "BF16";
|
||||
case LLAMA_FTYPE_MOSTLY_Q1_0: return "Q1_0";
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0";
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1";
|
||||
case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0";
|
||||
case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1";
|
||||
case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0";
|
||||
case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return "MXFP4 MoE";
|
||||
case LLAMA_FTYPE_MOSTLY_NVFP4: return "NVFP4";
|
||||
case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium";
|
||||
case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small";
|
||||
case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small";
|
||||
case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "Q3_K - Medium";
|
||||
case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "Q3_K - Large";
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "Q4_K - Small";
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "Q4_K - Medium";
|
||||
case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small";
|
||||
case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium";
|
||||
case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K";
|
||||
case LLAMA_FTYPE_MOSTLY_TQ1_0: return "TQ1_0 - 1.69 bpw ternary";
|
||||
case LLAMA_FTYPE_MOSTLY_TQ2_0: return "TQ2_0 - 2.06 bpw ternary";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return "IQ2_XXS - 2.0625 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ2_S: return "IQ2_S - 2.5 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ2_M: return "IQ2_M - 2.7 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_XXS: return "IQ3_XXS - 3.0625 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ1_S: return "IQ1_S - 1.5625 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ1_M: return "IQ1_M - 1.75 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw";
|
||||
|
||||
default: return "unknown, may not work";
|
||||
const char * llama_ftype_name(llama_ftype ftype) {
|
||||
static constexpr size_t guessed_prefix_len = sizeof(LLAMA_FTYPE_PREFIX) - 1;
|
||||
const char * name;
|
||||
switch ((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) {
|
||||
case LLAMA_FTYPE_ALL_F32: name = LLAMA_FTYPE_PREFIX "all F32"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_F16: name = LLAMA_FTYPE_PREFIX "F16"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_BF16: name = LLAMA_FTYPE_PREFIX "BF16"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q1_0: name = LLAMA_FTYPE_PREFIX "Q1_0"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_0: name = LLAMA_FTYPE_PREFIX "Q4_0"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_1: name = LLAMA_FTYPE_PREFIX "Q4_1"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q5_0: name = LLAMA_FTYPE_PREFIX "Q5_0"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q5_1: name = LLAMA_FTYPE_PREFIX "Q5_1"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q8_0: name = LLAMA_FTYPE_PREFIX "Q8_0"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: name = LLAMA_FTYPE_PREFIX "MXFP4 MoE"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_NVFP4: name = LLAMA_FTYPE_PREFIX "NVFP4"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q2_K: name = LLAMA_FTYPE_PREFIX "Q2_K - Medium"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q2_K_S: name = LLAMA_FTYPE_PREFIX "Q2_K - Small"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q3_K_S: name = LLAMA_FTYPE_PREFIX "Q3_K - Small"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q3_K_M: name = LLAMA_FTYPE_PREFIX "Q3_K - Medium"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q3_K_L: name = LLAMA_FTYPE_PREFIX "Q3_K - Large"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_K_S: name = LLAMA_FTYPE_PREFIX "Q4_K - Small"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_K_M: name = LLAMA_FTYPE_PREFIX "Q4_K - Medium"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q5_K_S: name = LLAMA_FTYPE_PREFIX "Q5_K - Small"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q5_K_M: name = LLAMA_FTYPE_PREFIX "Q5_K - Medium"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q6_K: name = LLAMA_FTYPE_PREFIX "Q6_K"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_TQ1_0: name = LLAMA_FTYPE_PREFIX "TQ1_0 - 1.69 bpw ternary"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_TQ2_0: name = LLAMA_FTYPE_PREFIX "TQ2_0 - 2.06 bpw ternary"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ2_XXS: name = LLAMA_FTYPE_PREFIX "IQ2_XXS - 2.0625 bpw"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ2_XS: name = LLAMA_FTYPE_PREFIX "IQ2_XS - 2.3125 bpw"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ2_S: name = LLAMA_FTYPE_PREFIX "IQ2_S - 2.5 bpw"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ2_M: name = LLAMA_FTYPE_PREFIX "IQ2_M - 2.7 bpw"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_XS: name = LLAMA_FTYPE_PREFIX "IQ3_XS - 3.3 bpw"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_XXS: name = LLAMA_FTYPE_PREFIX "IQ3_XXS - 3.0625 bpw"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ1_S: name = LLAMA_FTYPE_PREFIX "IQ1_S - 1.5625 bpw"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ1_M: name = LLAMA_FTYPE_PREFIX "IQ1_M - 1.75 bpw"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ4_NL: name = LLAMA_FTYPE_PREFIX "IQ4_NL - 4.5 bpw"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ4_XS: name = LLAMA_FTYPE_PREFIX "IQ4_XS - 4.25 bpw"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_S: name = LLAMA_FTYPE_PREFIX "IQ3_S - 3.4375 bpw"; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_M: name = LLAMA_FTYPE_PREFIX "IQ3_S mix - 3.66 bpw"; break;
|
||||
default: name = LLAMA_FTYPE_PREFIX "unknown, may not work"; break;
|
||||
}
|
||||
return (ftype & LLAMA_FTYPE_GUESSED) ? name : name + guessed_prefix_len;
|
||||
}
|
||||
|
||||
#undef LLAMA_FTYPE_PREFIX
|
||||
|
||||
// return a list of splits for a given path
|
||||
// for example, given "<name>-00002-of-00004.gguf", returns list of all 4 splits
|
||||
static std::vector<std::string> llama_get_list_splits(const std::string & path, const int idx, const int n_split) {
|
||||
@@ -1693,12 +1695,12 @@ bool llama_model_loader::load_all_data(
|
||||
}
|
||||
|
||||
std::string llama_model_loader::ftype_name() const {
|
||||
return llama_model_ftype_name(ftype);
|
||||
return llama_ftype_name(ftype);
|
||||
}
|
||||
|
||||
void llama_model_loader::print_info() const {
|
||||
LLAMA_LOG_INFO("%s: file format = %s\n", __func__, llama_file_version_name(fver));
|
||||
LLAMA_LOG_INFO("%s: file type = %s\n", __func__, llama_model_ftype_name(ftype).c_str());
|
||||
LLAMA_LOG_INFO("%s: file type = %s\n", __func__, llama_ftype_name(ftype));
|
||||
if (n_bytes < GiB) {
|
||||
LLAMA_LOG_INFO("%s: file size = %.2f MiB (%.2f BPW) \n", __func__, n_bytes/1024.0/1024.0, n_bytes*8.0/n_elements);
|
||||
} else {
|
||||
|
||||
@@ -987,6 +987,8 @@ struct llama_model::impl {
|
||||
|
||||
std::string desc_str;
|
||||
|
||||
llama_ftype ftype = LLAMA_FTYPE_ALL_F32;
|
||||
|
||||
// model memory mapped files
|
||||
llama_mmaps mappings;
|
||||
|
||||
@@ -1200,6 +1202,8 @@ void llama_model_base::load_hparams(llama_model_loader & ml) {
|
||||
|
||||
pimpl->desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name();
|
||||
|
||||
pimpl->ftype = ml.ftype;
|
||||
|
||||
if (hparams.f_max_alibi_bias > 0.0f) {
|
||||
hparams.use_alibi = true;
|
||||
}
|
||||
@@ -1646,6 +1650,10 @@ std::string llama_model::desc() const {
|
||||
return pimpl->desc_str;
|
||||
}
|
||||
|
||||
llama_ftype llama_model::ftype() const {
|
||||
return pimpl->ftype;
|
||||
}
|
||||
|
||||
size_t llama_model::size() const {
|
||||
return pimpl->n_bytes;
|
||||
}
|
||||
@@ -2616,6 +2624,10 @@ int32_t llama_model_desc(const llama_model * model, char * buf, size_t buf_size)
|
||||
return snprintf(buf, buf_size, "%s", model->desc().c_str());
|
||||
}
|
||||
|
||||
llama_ftype llama_model_ftype(const llama_model * model) {
|
||||
return model->ftype();
|
||||
}
|
||||
|
||||
uint64_t llama_model_size(const llama_model * model) {
|
||||
return model->size();
|
||||
}
|
||||
|
||||
@@ -637,6 +637,8 @@ struct llama_model {
|
||||
|
||||
std::string desc() const;
|
||||
|
||||
llama_ftype ftype() const;
|
||||
|
||||
size_t size() const; // file size
|
||||
size_t n_tensors() const;
|
||||
size_t n_devices() const;
|
||||
|
||||
@@ -448,6 +448,9 @@ int llama_cli(int argc, char ** argv) {
|
||||
console::log("%s\n", LLAMA_ASCII_LOGO);
|
||||
console::log("build : %s\n", inf.build_info.c_str());
|
||||
console::log("model : %s\n", inf.model_name.c_str());
|
||||
if (!inf.model_ftype.empty()) {
|
||||
console::log("ftype : %s\n", inf.model_ftype.c_str());
|
||||
}
|
||||
console::log("modalities : %s\n", modalities.c_str());
|
||||
if (!params.system_prompt.empty()) {
|
||||
console::log("using custom system prompt\n");
|
||||
|
||||
@@ -3989,6 +3989,8 @@ server_context_meta server_context::get_meta() const {
|
||||
auto bos_token_str = bos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_tgt, bos_id, true) : "";
|
||||
auto eos_token_str = eos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_tgt, eos_id, true) : "";
|
||||
|
||||
const char * ftype_name = llama_ftype_name(llama_model_ftype(impl->model_tgt));
|
||||
|
||||
return server_context_meta {
|
||||
/* build_info */ std::string(llama_build_info()),
|
||||
/* model_name */ impl->model_name,
|
||||
@@ -4023,6 +4025,7 @@ server_context_meta server_context::get_meta() const {
|
||||
/* model_n_embd_inp */ llama_model_n_embd(impl->model_tgt),
|
||||
/* model_n_params */ llama_model_n_params(impl->model_tgt),
|
||||
/* model_size */ llama_model_size(impl->model_tgt),
|
||||
/* model_ftype */ ftype_name,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -5118,6 +5121,7 @@ json server_routes::get_model_info() const {
|
||||
{"n_embd", meta->model_n_embd_inp},
|
||||
{"n_params", meta->model_n_params},
|
||||
{"size", meta->model_size},
|
||||
{"ftype", meta->model_ftype},
|
||||
}},
|
||||
};
|
||||
}
|
||||
|
||||
@@ -50,6 +50,7 @@ struct server_context_meta {
|
||||
int32_t model_n_embd_inp;
|
||||
uint64_t model_n_params;
|
||||
uint64_t model_size;
|
||||
std::string model_ftype;
|
||||
};
|
||||
|
||||
enum server_state {
|
||||
|
||||
+1
-1
@@ -18,7 +18,7 @@
|
||||
|
||||
let mcpSearchQuery = $state('');
|
||||
let allMcpServers = $derived(mcpStore.getServersSorted());
|
||||
let mcpServers = $derived(allMcpServers.filter((s) => s.enabled));
|
||||
let mcpServers = $derived(mcpStore.visibleMcpServers);
|
||||
let hasMcpServers = $derived(mcpServers.length > 0);
|
||||
// let hasAnyMcpServers = $derived(allMcpServers.length > 0);
|
||||
let filteredMcpServers = $derived.by(() => {
|
||||
|
||||
+4
-6
@@ -74,9 +74,7 @@
|
||||
const sheetItemRowClass =
|
||||
'flex w-full items-center justify-between gap-2 rounded-md px-3 py-2 text-left text-sm transition-colors hover:bg-accent';
|
||||
|
||||
function getEnabledMcpServers() {
|
||||
return mcpStore.getServersSorted().filter((s) => s.enabled);
|
||||
}
|
||||
let visibleMcpServers = $derived(mcpStore.visibleMcpServers);
|
||||
</script>
|
||||
|
||||
<div class="flex items-center gap-1 {className}">
|
||||
@@ -153,13 +151,13 @@
|
||||
<span class="flex-1">MCP Servers</span>
|
||||
|
||||
<span class="text-xs text-muted-foreground">
|
||||
{getEnabledMcpServers().length} server{getEnabledMcpServers().length !== 1 ? 's' : ''}
|
||||
{visibleMcpServers.length} server{visibleMcpServers.length !== 1 ? 's' : ''}
|
||||
</span>
|
||||
</Collapsible.Trigger>
|
||||
|
||||
<Collapsible.Content>
|
||||
<div class="flex flex-col gap-0.5 pl-4">
|
||||
{#each getEnabledMcpServers() as server (server.id)}
|
||||
{#each visibleMcpServers as server (server.id)}
|
||||
{@const healthState = mcpStore.getHealthCheckState(server.id)}
|
||||
{@const hasError = healthState.status === HealthCheckStatus.ERROR}
|
||||
{@const displayName = mcpStore.getServerLabel(server)}
|
||||
@@ -202,7 +200,7 @@
|
||||
</button>
|
||||
{/each}
|
||||
|
||||
{#if getEnabledMcpServers().length === 0}
|
||||
{#if visibleMcpServers.length === 0}
|
||||
<div class="px-3 py-2 text-center text-sm text-muted-foreground">
|
||||
No MCP servers configured
|
||||
</div>
|
||||
|
||||
+16
-24
@@ -1,8 +1,9 @@
|
||||
<script lang="ts">
|
||||
import { ChevronDown, ShieldQuestion } from '@lucide/svelte';
|
||||
import { ChatMessageActionCard } from '$lib/components/app';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { Button, buttonVariants } from '$lib/components/ui/button';
|
||||
import * as ButtonGroup from '$lib/components/ui/button-group';
|
||||
import { cn } from '$lib/components/ui/utils';
|
||||
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
|
||||
import { ToolSource, ToolPermissionDecision } from '$lib/enums';
|
||||
import { TOOL_SERVER_LABELS } from '$lib/constants';
|
||||
@@ -19,25 +20,17 @@
|
||||
|
||||
<ChatMessageActionCard icon={ShieldQuestion}>
|
||||
{#snippet message()}
|
||||
Allow use of
|
||||
|
||||
<span class="font-semibold">{toolName}</span>
|
||||
|
||||
{#if serverLabel}
|
||||
from <span class="font-semibold">{serverLabel}</span>
|
||||
{/if}
|
||||
|
||||
?
|
||||
Allow use of <span class="font-semibold">{toolName}</span>{#if serverLabel}
|
||||
from <span class="font-semibold">{serverLabel}</span>{/if}?
|
||||
{/snippet}
|
||||
|
||||
{#snippet actions()}
|
||||
<DropdownMenu.Root>
|
||||
<ButtonGroup.Root
|
||||
class="overflow-hidden rounded-md bg-foreground text-white shadow-sm dark:bg-secondary dark:text-foreground"
|
||||
>
|
||||
<ButtonGroup.Root class="overflow-hidden rounded-md shadow-sm">
|
||||
<Button
|
||||
class="rounded-none! shadow-none!"
|
||||
variant="secondary"
|
||||
size="sm"
|
||||
class="!rounded-r-none !shadow-none"
|
||||
onclick={() => onDecision(ToolPermissionDecision.ONCE)}
|
||||
>
|
||||
Allow once
|
||||
@@ -45,10 +38,14 @@
|
||||
|
||||
<ButtonGroup.Separator />
|
||||
|
||||
<DropdownMenu.Trigger>
|
||||
<Button size="sm" class="rounded-none! !ps-2 shadow-none!">
|
||||
<ChevronDown class="h-3.5 w-3.5" />
|
||||
</Button>
|
||||
<DropdownMenu.Trigger
|
||||
class={cn(
|
||||
buttonVariants({ variant: 'secondary', size: 'sm' }),
|
||||
'inline-flex cursor-pointer items-center !rounded-l-none !shadow-none !px-2'
|
||||
)}
|
||||
aria-label="More allow options"
|
||||
>
|
||||
<ChevronDown class="h-3.5 w-3.5" />
|
||||
</DropdownMenu.Trigger>
|
||||
</ButtonGroup.Root>
|
||||
|
||||
@@ -76,12 +73,7 @@
|
||||
</DropdownMenu.Content>
|
||||
</DropdownMenu.Root>
|
||||
|
||||
<Button
|
||||
variant="destructive"
|
||||
size="sm"
|
||||
class="text-destructive hover:text-destructive"
|
||||
onclick={() => onDecision(ToolPermissionDecision.DENY)}
|
||||
>
|
||||
<Button variant="destructive" size="sm" onclick={() => onDecision(ToolPermissionDecision.DENY)}>
|
||||
Deny
|
||||
</Button>
|
||||
{/snippet}
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import { McpServerForm } from '$lib/components/app/mcp';
|
||||
import { mcpStore } from '$lib/stores/mcp.svelte';
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { uuid } from '$lib/utils';
|
||||
import { parseHeadersToArray, uuid } from '$lib/utils';
|
||||
import { MCP_SERVER_ID_PREFIX } from '$lib/constants';
|
||||
|
||||
interface Props {
|
||||
@@ -26,6 +26,10 @@
|
||||
return 'Invalid URL format';
|
||||
}
|
||||
});
|
||||
let newServerHeaderPairsValid = $derived(
|
||||
parseHeadersToArray(newServerHeaders).every((p) => p.key.trim() && p.value.trim())
|
||||
);
|
||||
let canSave = $derived(!newServerUrlError && newServerHeaderPairsValid);
|
||||
|
||||
function handleOpenChange(value: boolean) {
|
||||
if (!value) {
|
||||
@@ -37,7 +41,7 @@
|
||||
}
|
||||
|
||||
function saveNewServer() {
|
||||
if (newServerUrlError) return;
|
||||
if (!canSave) return;
|
||||
|
||||
const newServerId = uuid() ?? `${MCP_SERVER_ID_PREFIX}-${Date.now()}`;
|
||||
|
||||
@@ -52,6 +56,11 @@
|
||||
|
||||
handleOpenChange(false);
|
||||
}
|
||||
|
||||
function handleSubmit(event: SubmitEvent) {
|
||||
event.preventDefault();
|
||||
saveNewServer();
|
||||
}
|
||||
</script>
|
||||
|
||||
<Dialog.Root {open} onOpenChange={handleOpenChange}>
|
||||
@@ -60,29 +69,27 @@
|
||||
<Dialog.Title>Add New Server</Dialog.Title>
|
||||
</Dialog.Header>
|
||||
|
||||
<div class="space-y-4 py-4">
|
||||
<McpServerForm
|
||||
url={newServerUrl}
|
||||
headers={newServerHeaders}
|
||||
onUrlChange={(v) => (newServerUrl = v)}
|
||||
onHeadersChange={(v) => (newServerHeaders = v)}
|
||||
urlError={newServerUrl ? newServerUrlError : null}
|
||||
id="new-server"
|
||||
/>
|
||||
</div>
|
||||
<form onsubmit={handleSubmit} class="contents">
|
||||
<div class="space-y-4 py-4">
|
||||
<McpServerForm
|
||||
url={newServerUrl}
|
||||
headers={newServerHeaders}
|
||||
onUrlChange={(v) => (newServerUrl = v)}
|
||||
onHeadersChange={(v) => (newServerHeaders = v)}
|
||||
urlError={newServerUrl ? newServerUrlError : null}
|
||||
id="new-server"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button variant="secondary" size="sm" onclick={() => handleOpenChange(false)}>Cancel</Button>
|
||||
<Dialog.Footer>
|
||||
<Button variant="secondary" size="sm" onclick={() => handleOpenChange(false)}>
|
||||
Cancel
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
variant="default"
|
||||
size="sm"
|
||||
onclick={saveNewServer}
|
||||
disabled={!!newServerUrlError}
|
||||
aria-label="Save"
|
||||
>
|
||||
Add
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
<Button variant="default" size="sm" type="submit" disabled={!canSave} aria-label="Save">
|
||||
Add
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</Dialog.Content>
|
||||
</Dialog.Root>
|
||||
|
||||
@@ -0,0 +1,180 @@
|
||||
<script lang="ts">
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as Card from '$lib/components/ui/card';
|
||||
import * as Dialog from '$lib/components/ui/dialog';
|
||||
import { fly } from 'svelte/transition';
|
||||
import { McpServerCardCompact, McpServerForm } from '$lib/components/app/mcp';
|
||||
import { RECOMMENDED_MCP_SERVERS } from '$lib/constants';
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { mcpStore } from '$lib/stores/mcp.svelte';
|
||||
import { uuid } from '$lib/utils';
|
||||
import { MCP_SERVERS_ADDED_TO_CHAT_LOCALSTORAGE_KEY, MCP_SERVER_ID_PREFIX } from '$lib/constants';
|
||||
import type { MCPServerSettingsEntry } from '$lib/types';
|
||||
import { Plus } from '@lucide/svelte';
|
||||
|
||||
interface Props {
|
||||
open: boolean;
|
||||
onOpenChange?: (open: boolean) => void;
|
||||
}
|
||||
|
||||
let { open = $bindable(), onOpenChange }: Props = $props();
|
||||
|
||||
let selected = $state<Record<string, boolean>>(
|
||||
Object.fromEntries(RECOMMENDED_MCP_SERVERS.map((server) => [server.id, false]))
|
||||
);
|
||||
|
||||
let addedServers = $state<MCPServerSettingsEntry[]>([]);
|
||||
|
||||
let showAddForm = $state(false);
|
||||
let newServerUrl = $state('');
|
||||
let newServerHeaders = $state('');
|
||||
let newServerUrlError = $derived.by(() => {
|
||||
if (!newServerUrl.trim()) return 'URL is required';
|
||||
try {
|
||||
new URL(newServerUrl);
|
||||
|
||||
return null;
|
||||
} catch {
|
||||
return 'Invalid URL format';
|
||||
}
|
||||
});
|
||||
|
||||
function handleOpenChange(value: boolean) {
|
||||
if (!value) {
|
||||
showAddForm = false;
|
||||
newServerUrl = '';
|
||||
newServerHeaders = '';
|
||||
addedServers = [];
|
||||
|
||||
localStorage.setItem(MCP_SERVERS_ADDED_TO_CHAT_LOCALSTORAGE_KEY, 'true');
|
||||
}
|
||||
open = value;
|
||||
onOpenChange?.(value);
|
||||
}
|
||||
|
||||
function resetAddForm() {
|
||||
showAddForm = false;
|
||||
newServerUrl = '';
|
||||
newServerHeaders = '';
|
||||
}
|
||||
|
||||
function enableSelected() {
|
||||
localStorage.setItem(MCP_SERVERS_ADDED_TO_CHAT_LOCALSTORAGE_KEY, 'true');
|
||||
|
||||
for (const server of RECOMMENDED_MCP_SERVERS) {
|
||||
if (selected[server.id]) {
|
||||
const existing = mcpStore.getServerById(server.id);
|
||||
if (existing) {
|
||||
mcpStore.updateServer(server.id, { enabled: true });
|
||||
} else {
|
||||
mcpStore.addServer({
|
||||
id: server.id,
|
||||
enabled: true,
|
||||
url: server.url,
|
||||
name: server.name
|
||||
});
|
||||
}
|
||||
conversationsStore.setMcpServerOverride(server.id, true);
|
||||
}
|
||||
}
|
||||
handleOpenChange(false);
|
||||
}
|
||||
|
||||
function saveNewServer() {
|
||||
if (newServerUrlError) return;
|
||||
|
||||
const newServerId = uuid() ?? `${MCP_SERVER_ID_PREFIX}-${Date.now()}`;
|
||||
|
||||
localStorage.setItem(MCP_SERVERS_ADDED_TO_CHAT_LOCALSTORAGE_KEY, 'true');
|
||||
|
||||
const newServer = mcpStore.addServer({
|
||||
id: newServerId,
|
||||
enabled: true,
|
||||
url: newServerUrl.trim(),
|
||||
headers: newServerHeaders.trim() || undefined
|
||||
});
|
||||
|
||||
conversationsStore.setMcpServerOverride(newServerId, true);
|
||||
|
||||
if (newServer) {
|
||||
addedServers = [...addedServers, newServer];
|
||||
}
|
||||
|
||||
resetAddForm();
|
||||
}
|
||||
</script>
|
||||
|
||||
<Dialog.Root bind:open onOpenChange={handleOpenChange}>
|
||||
<Dialog.Content class="sm:max-w-lg">
|
||||
<Dialog.Header>
|
||||
<Dialog.Title>Do more with MCP</Dialog.Title>
|
||||
<Dialog.Description>
|
||||
Power-up your experience by adding tools, resources and more capabilities provided by MCP
|
||||
servers.
|
||||
</Dialog.Description>
|
||||
</Dialog.Header>
|
||||
|
||||
<div class="max-h-[60vh] space-y-4 overflow-y-auto py-4" in:fly={{ y: 16, duration: 300 }}>
|
||||
<h3 class="text-sm font-semibold">Quickly get started with</h3>
|
||||
|
||||
{#each RECOMMENDED_MCP_SERVERS as server (server.id)}
|
||||
<McpServerCardCompact
|
||||
{server}
|
||||
enabled={selected[server.id]}
|
||||
onToggle={(enabled) => (selected[server.id] = enabled)}
|
||||
/>
|
||||
{/each}
|
||||
|
||||
{#if addedServers.length > 0}
|
||||
{#each addedServers as server (server.id)}
|
||||
<McpServerCardCompact {server} enabled={true} />
|
||||
{/each}
|
||||
{/if}
|
||||
|
||||
{#if showAddForm}
|
||||
<Card.Root class="gap-3! bg-muted/30 p-4">
|
||||
<McpServerForm
|
||||
url={newServerUrl}
|
||||
headers={newServerHeaders}
|
||||
onUrlChange={(v) => (newServerUrl = v)}
|
||||
onHeadersChange={(v) => (newServerHeaders = v)}
|
||||
urlError={newServerUrl ? newServerUrlError : null}
|
||||
id="recommendation-new-server"
|
||||
/>
|
||||
|
||||
<div class="flex justify-end gap-2 pt-2">
|
||||
<Button variant="secondary" size="sm" onclick={resetAddForm}>Cancel</Button>
|
||||
|
||||
<Button
|
||||
variant="default"
|
||||
size="sm"
|
||||
onclick={saveNewServer}
|
||||
disabled={!!newServerUrlError}
|
||||
aria-label="Save"
|
||||
>
|
||||
Add
|
||||
</Button>
|
||||
</div>
|
||||
</Card.Root>
|
||||
{:else}
|
||||
<Card.Root class="gap-0 border-dashed bg-muted/30 p-0 transition-colors hover:bg-muted/50">
|
||||
<button
|
||||
type="button"
|
||||
class="flex w-full items-center justify-center gap-2 rounded-lg p-6 text-sm text-muted-foreground transition-colors hover:text-foreground"
|
||||
onclick={() => (showAddForm = true)}
|
||||
aria-label="Add your own MCP server"
|
||||
>
|
||||
<Plus class="h-4 w-4" />
|
||||
<span>Add your own server</span>
|
||||
</button>
|
||||
</Card.Root>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button variant="secondary" size="sm" onclick={() => handleOpenChange(false)}>Not now</Button>
|
||||
|
||||
<Button variant="default" size="sm" onclick={enableSelected}>Add selected</Button>
|
||||
</Dialog.Footer>
|
||||
</Dialog.Content>
|
||||
</Dialog.Root>
|
||||
@@ -18,6 +18,15 @@
|
||||
*/
|
||||
export { default as DialogMcpServerAddNew } from './DialogMcpServerAddNew.svelte';
|
||||
|
||||
/**
|
||||
* **DialogMcpServerRecommendations** - Suggested MCP servers opt-in dialog
|
||||
*
|
||||
* Prompts the user to enable pre-defined recommended MCP servers on first launch.
|
||||
* Shows one switch per suggested server and persists the choice as a per-chat
|
||||
* override so the selected servers become available in conversations.
|
||||
*/
|
||||
export { default as DialogMcpServerRecommendations } from './DialogMcpServerRecommendations.svelte';
|
||||
|
||||
/**
|
||||
* **DialogExportSettings** - Settings export dialog with sensitive data warning
|
||||
*
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
<script lang="ts">
|
||||
import { tick } from 'svelte';
|
||||
import { Plus, Trash2 } from '@lucide/svelte';
|
||||
import { Input } from '$lib/components/ui/input';
|
||||
import {
|
||||
@@ -33,8 +34,18 @@
|
||||
sectionLabelOptional = true
|
||||
}: Props = $props();
|
||||
|
||||
function addPair() {
|
||||
// Pre-allocate the ref array so `bind:ref={keyInputRefs[index]}` never reads `undefined`
|
||||
// for in-range indices; the $effect below keeps it in sync when `pairs` grows.
|
||||
// svelte-ignore state_referenced_locally
|
||||
let keyInputRefs: (HTMLInputElement | null)[] = $state(pairs.map(() => null));
|
||||
|
||||
async function addPair() {
|
||||
// Capture the target index before mutating so deletions earlier in the
|
||||
// list can't make keyInputRefs.length drift past the newly-appended row.
|
||||
const newIndex = pairs.length;
|
||||
onPairsChange([...pairs, { key: '', value: '' }]);
|
||||
await tick();
|
||||
keyInputRefs[newIndex]?.focus();
|
||||
}
|
||||
|
||||
function removePair(index: number) {
|
||||
@@ -76,6 +87,15 @@
|
||||
newPairs[index] = { ...newPairs[index], value: trimmed };
|
||||
onPairsChange(newPairs);
|
||||
}
|
||||
|
||||
// Keep keyInputRefs aligned with pairs length so bind:ref never sees `undefined`.
|
||||
// $effect.pre runs during traversal in tree order, before the {#each} block re-renders,
|
||||
// so newly-appended items always have a defined slot when their binding is set up.
|
||||
$effect.pre(() => {
|
||||
while (keyInputRefs.length < pairs.length) {
|
||||
keyInputRefs.push(null);
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class={className}>
|
||||
@@ -103,6 +123,7 @@
|
||||
{#each pairs as pair, index (index)}
|
||||
<div class="flex items-start gap-2">
|
||||
<Input
|
||||
bind:ref={keyInputRefs[index]}
|
||||
type="text"
|
||||
placeholder={keyPlaceholder}
|
||||
value={pair.key}
|
||||
|
||||
@@ -163,7 +163,7 @@
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<div class="flex justify-between gap-4">
|
||||
<div class="mt-auto flex justify-between gap-4">
|
||||
{#if showSkeleton}
|
||||
<Skeleton class="h-3 w-28" />
|
||||
{:else if protocolVersion}
|
||||
|
||||
@@ -0,0 +1,156 @@
|
||||
<script lang="ts">
|
||||
import * as Card from '$lib/components/ui/card';
|
||||
import { Badge } from '$lib/components/ui/badge';
|
||||
import { Skeleton } from '$lib/components/ui/skeleton';
|
||||
import { Switch } from '$lib/components/ui/switch';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import { McpServerIdentity } from '$lib/components/app/mcp';
|
||||
import { mcpStore } from '$lib/stores/mcp.svelte';
|
||||
import { HealthCheckStatus } from '$lib/enums';
|
||||
import type { MCPServerDisplayInfo, HealthCheckState, MCPServerSettingsEntry } from '$lib/types';
|
||||
import { onMount } from 'svelte';
|
||||
import { MCP_CARD_VISIBLE_TOOL_LIMIT, NEWLINE } from '$lib/constants';
|
||||
|
||||
interface Props {
|
||||
server: MCPServerDisplayInfo & { description?: string };
|
||||
enabled?: boolean;
|
||||
onToggle?: (enabled: boolean) => void;
|
||||
}
|
||||
|
||||
let { server, enabled = false, onToggle }: Props = $props();
|
||||
|
||||
onMount(() => {
|
||||
const state = mcpStore.getHealthCheckState(server.id);
|
||||
|
||||
if (state.status === HealthCheckStatus.IDLE) {
|
||||
mcpStore.runHealthCheck(server as MCPServerSettingsEntry).catch(() => {});
|
||||
}
|
||||
});
|
||||
|
||||
let healthState = $derived<HealthCheckState>(mcpStore.getHealthCheckState(server.id));
|
||||
let displayName = $derived(mcpStore.getServerLabel(server));
|
||||
let faviconUrl = $derived(mcpStore.getServerFavicon(server.id));
|
||||
let isIdle = $derived(healthState.status === HealthCheckStatus.IDLE);
|
||||
let isHealthChecking = $derived(healthState.status === HealthCheckStatus.CONNECTING);
|
||||
let isError = $derived(healthState.status === HealthCheckStatus.ERROR);
|
||||
let errorMessage = $derived(
|
||||
healthState.status === HealthCheckStatus.ERROR ? healthState.message : undefined
|
||||
);
|
||||
let serverInfo = $derived(
|
||||
healthState.status === HealthCheckStatus.SUCCESS ? healthState.serverInfo : undefined
|
||||
);
|
||||
let tools = $derived(healthState.status === HealthCheckStatus.SUCCESS ? healthState.tools : []);
|
||||
let instructions = $derived(
|
||||
healthState.status === HealthCheckStatus.SUCCESS ? healthState.instructions : undefined
|
||||
);
|
||||
let showSkeleton = $derived(isIdle || isHealthChecking);
|
||||
|
||||
// Curated descriptions get two lines; instructions fallback is one line so the
|
||||
// compact card stays scannable.
|
||||
let description = $derived.by(() => {
|
||||
if (server.description) {
|
||||
return { text: server.description, lines: 2 };
|
||||
}
|
||||
if (!instructions) return null;
|
||||
const firstLine = instructions.split(NEWLINE).find((line: string) => line.trim().length > 0);
|
||||
const trimmed = firstLine?.trim();
|
||||
return trimmed ? { text: trimmed, lines: 1 } : null;
|
||||
});
|
||||
|
||||
let visibleTools = $derived(tools.slice(0, MCP_CARD_VISIBLE_TOOL_LIMIT));
|
||||
let hiddenTools = $derived(tools.slice(MCP_CARD_VISIBLE_TOOL_LIMIT));
|
||||
let hiddenToolCount = $derived(hiddenTools.length);
|
||||
|
||||
function handleToggle(checked: boolean) {
|
||||
onToggle?.(checked);
|
||||
}
|
||||
</script>
|
||||
|
||||
<Card.Root class="!gap-3 bg-muted/30 p-4">
|
||||
<div class="flex items-start justify-between gap-3">
|
||||
<div class="min-w-0 flex-1">
|
||||
{#if showSkeleton}
|
||||
<span class="flex min-w-0 items-center gap-1.5">
|
||||
<Skeleton class="h-5 w-5 rounded" />
|
||||
<Skeleton class="h-4 w-32" />
|
||||
</span>
|
||||
{:else}
|
||||
<McpServerIdentity
|
||||
{displayName}
|
||||
{faviconUrl}
|
||||
{serverInfo}
|
||||
iconClass="h-5 w-5"
|
||||
iconRounded="rounded"
|
||||
nameClass="font-medium"
|
||||
/>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<Switch checked={enabled} disabled={isError || showSkeleton} onCheckedChange={handleToggle} />
|
||||
</div>
|
||||
|
||||
{#if isError && errorMessage}
|
||||
<p class="text-xs text-destructive">{errorMessage}</p>
|
||||
{/if}
|
||||
|
||||
{#if showSkeleton}
|
||||
<div class="space-y-1.5">
|
||||
<Skeleton class="h-3 w-full max-w-md" />
|
||||
</div>
|
||||
|
||||
<div class="flex flex-wrap items-center gap-1.5">
|
||||
<Skeleton class="h-5 w-16 rounded-full" />
|
||||
<Skeleton class="h-5 w-20 rounded-full" />
|
||||
<Skeleton class="h-5 w-24 rounded-full" />
|
||||
<Skeleton class="h-5 w-14 rounded-full" />
|
||||
</div>
|
||||
{:else}
|
||||
{#if description}
|
||||
{#if description.lines === 2}
|
||||
<p class="line-clamp-2 text-xs text-muted-foreground" title={description.text}>
|
||||
{description.text}
|
||||
</p>
|
||||
{:else}
|
||||
<p class="line-clamp-1 truncate text-xs text-muted-foreground" title={description.text}>
|
||||
{description.text}
|
||||
</p>
|
||||
{/if}
|
||||
{/if}
|
||||
|
||||
{#if tools.length > 0}
|
||||
<div class="flex flex-wrap items-center gap-1.5">
|
||||
{#each visibleTools as tool (tool.name)}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<Badge variant="secondary" class="h-5 max-w-40 px-2 text-[11px]">
|
||||
<span class="block min-w-0 flex-1 truncate">{tool.name}</span>
|
||||
</Badge>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content>
|
||||
<p class="max-w-xs text-xs">
|
||||
{tool.description ?? 'No description'}
|
||||
</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{/each}
|
||||
|
||||
{#if hiddenToolCount > 0}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<Badge variant="secondary" class="h-5 px-2 text-[11px] text-muted-foreground">
|
||||
+ {hiddenToolCount} more tools
|
||||
</Badge>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content class="max-w-md">
|
||||
<p class="text-xs">
|
||||
{hiddenTools.map((tool) => tool.name).join(', ')}
|
||||
</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</Card.Root>
|
||||
@@ -1,6 +1,7 @@
|
||||
<script lang="ts">
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { McpServerForm } from '$lib/components/app/mcp';
|
||||
import { parseHeadersToArray } from '$lib/utils';
|
||||
|
||||
interface Props {
|
||||
serverId: string;
|
||||
@@ -26,13 +27,21 @@
|
||||
}
|
||||
});
|
||||
|
||||
let canSave = $derived(!urlError);
|
||||
let headerPairsValid = $derived(
|
||||
parseHeadersToArray(editHeaders).every((p) => p.key.trim() && p.value.trim())
|
||||
);
|
||||
let canSave = $derived(!urlError && headerPairsValid);
|
||||
|
||||
function handleSave() {
|
||||
if (!canSave) return;
|
||||
onSave(editUrl.trim(), editHeaders.trim(), editUseProxy);
|
||||
}
|
||||
|
||||
function handleSubmit(event: SubmitEvent) {
|
||||
event.preventDefault();
|
||||
handleSave();
|
||||
}
|
||||
|
||||
export function setInitialValues(url: string, headers: string, useProxy: boolean) {
|
||||
editUrl = url;
|
||||
editHeaders = headers;
|
||||
@@ -40,25 +49,27 @@
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="space-y-4">
|
||||
<p class="font-medium">Configure Server</p>
|
||||
<form onsubmit={handleSubmit} class="contents">
|
||||
<div class="space-y-4">
|
||||
<p class="font-medium">Configure Server</p>
|
||||
|
||||
<McpServerForm
|
||||
url={editUrl}
|
||||
headers={editHeaders}
|
||||
useProxy={editUseProxy}
|
||||
onUrlChange={(v) => (editUrl = v)}
|
||||
onHeadersChange={(v) => (editHeaders = v)}
|
||||
onUseProxyChange={(v) => (editUseProxy = v)}
|
||||
urlError={editUrl ? urlError : null}
|
||||
id={serverId}
|
||||
/>
|
||||
<McpServerForm
|
||||
url={editUrl}
|
||||
headers={editHeaders}
|
||||
useProxy={editUseProxy}
|
||||
onUrlChange={(v) => (editUrl = v)}
|
||||
onHeadersChange={(v) => (editHeaders = v)}
|
||||
onUseProxyChange={(v) => (editUseProxy = v)}
|
||||
urlError={editUrl ? urlError : null}
|
||||
id={serverId}
|
||||
/>
|
||||
|
||||
<div class="flex items-center justify-end gap-2">
|
||||
<Button variant="secondary" size="sm" onclick={onCancel}>Cancel</Button>
|
||||
<div class="flex items-center justify-end gap-2">
|
||||
<Button variant="secondary" size="sm" onclick={onCancel}>Cancel</Button>
|
||||
|
||||
<Button size="sm" onclick={handleSave} disabled={!canSave}>
|
||||
{serverUrl.trim() ? 'Update' : 'Add'}
|
||||
</Button>
|
||||
<Button size="sm" type="submit" disabled={!canSave}>
|
||||
{serverUrl.trim() ? 'Update' : 'Add'}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
|
||||
@@ -38,14 +38,87 @@
|
||||
|
||||
let headerPairs = $derived<KeyValuePair[]>(parseHeadersToArray(headers));
|
||||
|
||||
const AUTHORIZATION_HEADER = 'Authorization';
|
||||
const BEARER_PREFIX = 'Bearer ';
|
||||
|
||||
// Heuristic: this dedicated UI only owns Authorization headers that already
|
||||
// carry a Bearer scheme. Anything else (e.g. Basic, raw tokens) stays in the
|
||||
// KV section so the user can still edit those values verbatim.
|
||||
const matchesAuthorizationKey = (key: string): boolean =>
|
||||
key.trim().toLowerCase() === AUTHORIZATION_HEADER.toLowerCase();
|
||||
|
||||
const isBearerScheme = (value: string): boolean =>
|
||||
value.trim().toLowerCase().startsWith(BEARER_PREFIX.toLowerCase());
|
||||
|
||||
const ownedByBearerUi = (p: KeyValuePair): boolean =>
|
||||
matchesAuthorizationKey(p.key) && isBearerScheme(p.value);
|
||||
|
||||
let hasAuthorization = $derived(headerPairs.some(ownedByBearerUi));
|
||||
|
||||
let wantsAuthorization = $state(false);
|
||||
|
||||
let showAuthorization = $derived(hasAuthorization || wantsAuthorization);
|
||||
|
||||
let urlInput: HTMLInputElement | null = $state(null);
|
||||
let bearerInput: HTMLInputElement | null = $state(null);
|
||||
|
||||
$effect(() => {
|
||||
urlInput?.focus();
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
if (wantsAuthorization && bearerInput) {
|
||||
bearerInput.focus();
|
||||
}
|
||||
});
|
||||
|
||||
let bearerToken = $derived.by(() => {
|
||||
const auth = headerPairs.find(ownedByBearerUi);
|
||||
if (!auth) return '';
|
||||
return auth.value.trim().slice(BEARER_PREFIX.length).trim();
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
if (!headers.trim()) {
|
||||
wantsAuthorization = false;
|
||||
}
|
||||
});
|
||||
|
||||
function updateHeaderPairs(newPairs: KeyValuePair[]) {
|
||||
headerPairs = newPairs;
|
||||
onHeadersChange(serializeHeaders(newPairs));
|
||||
}
|
||||
|
||||
// The dedicated UI owns the Authorization slot end-to-end when the user
|
||||
// engages it: any prior Authorization row (Bearer or otherwise) is replaced
|
||||
// by exactly one { Authorization: "Bearer <token>" } entry. JSON's last-key
|
||||
// behavior would otherwise pick one arbitrarily, so we strip first.
|
||||
function updateBearerToken(token: string) {
|
||||
const filtered = headerPairs.filter((p) => !matchesAuthorizationKey(p.key));
|
||||
|
||||
const trimmed = token.trim();
|
||||
|
||||
if (trimmed) {
|
||||
filtered.push({ key: AUTHORIZATION_HEADER, value: `${BEARER_PREFIX}${trimmed}` });
|
||||
}
|
||||
|
||||
updateHeaderPairs(filtered);
|
||||
}
|
||||
|
||||
function setUseAuthorization(checked: boolean) {
|
||||
wantsAuthorization = checked;
|
||||
|
||||
if (!checked) {
|
||||
// Only drop the entry this UI owns; a non-Bearer Authorization row
|
||||
// authored in the KV section must survive a toggle off untouched.
|
||||
const filtered = headerPairs.filter((p) => !ownedByBearerUi(p));
|
||||
updateHeaderPairs(filtered);
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="grid gap-3">
|
||||
<div>
|
||||
<div class="grid gap-2">
|
||||
<div class="mb-4">
|
||||
<label for="server-url-{id}" class="mb-2 block text-xs font-medium">
|
||||
Server URL <span class="text-destructive">*</span>
|
||||
</label>
|
||||
@@ -57,50 +130,52 @@
|
||||
value={url}
|
||||
oninput={(e) => onUrlChange(e.currentTarget.value)}
|
||||
class={urlError ? 'border-destructive' : ''}
|
||||
bind:ref={urlInput}
|
||||
/>
|
||||
|
||||
{#if urlError}
|
||||
<p class="mt-1.5 text-xs text-destructive">{urlError}</p>
|
||||
{/if}
|
||||
|
||||
{#if !isWebSocket && onUseProxyChange}
|
||||
<label
|
||||
class={[
|
||||
'mt-3 flex items-start gap-2',
|
||||
mcpStore.isProxyAvailable && 'cursor-pointer',
|
||||
!mcpStore.isProxyAvailable && 'opacity-80'
|
||||
]}
|
||||
>
|
||||
<Switch
|
||||
class="mt-1"
|
||||
id="use-proxy-{id}"
|
||||
checked={useProxy}
|
||||
disabled={!mcpStore.isProxyAvailable}
|
||||
onCheckedChange={(checked) => onUseProxyChange?.(checked)}
|
||||
/>
|
||||
|
||||
<span>
|
||||
<span class="text-xs text-muted-foreground">Use llama-server proxy</span>
|
||||
|
||||
<br />
|
||||
|
||||
{#if !mcpStore.isProxyAvailable}
|
||||
<span class="inline-flex gap-0.75 text-xs text-muted-foreground/60"
|
||||
>(Run <pre>llama-server</pre>
|
||||
with
|
||||
<pre>{CLI_FLAGS.MCP_PROXY}</pre>
|
||||
flag)</span
|
||||
>
|
||||
{/if}
|
||||
</span>
|
||||
</label>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<label class="flex items-center gap-2 cursor-pointer">
|
||||
<Switch
|
||||
id="use-authorization-{id}"
|
||||
checked={showAuthorization}
|
||||
onCheckedChange={setUseAuthorization}
|
||||
/>
|
||||
|
||||
<span class="text-xs text-muted-foreground">Authorization</span>
|
||||
</label>
|
||||
|
||||
{#if showAuthorization}
|
||||
<div class="relative mt-2">
|
||||
<Input
|
||||
id="bearer-token-{id}"
|
||||
type="password"
|
||||
autocomplete="off"
|
||||
placeholder="Paste token here"
|
||||
value={bearerToken}
|
||||
oninput={(e) => updateBearerToken(e.currentTarget.value)}
|
||||
class="pl-16"
|
||||
bind:ref={bearerInput}
|
||||
/>
|
||||
|
||||
<span
|
||||
class="pointer-events-none absolute inset-y-0 left-3 flex items-center text-sm font-medium text-foreground"
|
||||
>
|
||||
Bearer
|
||||
</span>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<KeyValuePairs
|
||||
class="mt-2"
|
||||
pairs={headerPairs}
|
||||
onPairsChange={updateHeaderPairs}
|
||||
class="mt-3"
|
||||
pairs={headerPairs.filter((p) => !ownedByBearerUi(p))}
|
||||
onPairsChange={(pairs) => {
|
||||
const auth = headerPairs.find(ownedByBearerUi);
|
||||
updateHeaderPairs(auth ? [...pairs, auth] : pairs);
|
||||
}}
|
||||
keyPlaceholder="Header name"
|
||||
valuePlaceholder="Value"
|
||||
addButtonLabel="Add"
|
||||
@@ -108,4 +183,37 @@
|
||||
sectionLabel="Custom Headers"
|
||||
sectionLabelOptional
|
||||
/>
|
||||
|
||||
{#if !isWebSocket && onUseProxyChange}
|
||||
<label
|
||||
class={[
|
||||
'mt-3 flex items-start gap-2',
|
||||
mcpStore.isProxyAvailable && 'cursor-pointer',
|
||||
!mcpStore.isProxyAvailable && 'opacity-80'
|
||||
]}
|
||||
>
|
||||
<Switch
|
||||
class="mt-1"
|
||||
id="use-proxy-{id}"
|
||||
checked={useProxy}
|
||||
disabled={!mcpStore.isProxyAvailable}
|
||||
onCheckedChange={(checked) => onUseProxyChange?.(checked)}
|
||||
/>
|
||||
|
||||
<span>
|
||||
<span class="text-xs text-muted-foreground">Use llama-server proxy</span>
|
||||
|
||||
<br />
|
||||
|
||||
{#if !mcpStore.isProxyAvailable}
|
||||
<span class="inline-flex gap-0.75 text-xs text-muted-foreground/60"
|
||||
>(Run <pre>llama-server</pre>
|
||||
with
|
||||
<pre>{CLI_FLAGS.MCP_PROXY}</pre>
|
||||
flag)</span
|
||||
>
|
||||
{/if}
|
||||
</span>
|
||||
</label>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
<script lang="ts">
|
||||
import { ExternalLink } from '@lucide/svelte';
|
||||
import { Badge } from '$lib/components/ui/badge';
|
||||
import { McpLogo } from '$lib/components/app/mcp';
|
||||
import { TruncatedText } from '$lib/components/app/misc';
|
||||
import { sanitizeExternalUrl } from '$lib/utils';
|
||||
import type { MCPServerInfo } from '$lib/types';
|
||||
@@ -34,20 +35,15 @@
|
||||
|
||||
<span class="flex min-w-0 items-center gap-1.5">
|
||||
{#if faviconUrl}
|
||||
<img
|
||||
src={faviconUrl}
|
||||
alt=""
|
||||
class={['shrink-0', iconRounded, iconClass]}
|
||||
onerror={(e) => {
|
||||
(e.currentTarget as HTMLImageElement).style.display = 'none';
|
||||
}}
|
||||
/>
|
||||
<img src={faviconUrl} alt="" class={['shrink-0 text-foreground', iconRounded, iconClass]} />
|
||||
{:else}
|
||||
<McpLogo class={['shrink-0 text-foreground', iconRounded, iconClass].join(' ')} />
|
||||
{/if}
|
||||
|
||||
<TruncatedText text={displayName ?? ''} class={nameClass ?? ''} />
|
||||
|
||||
{#if showVersion && serverInfo?.version}
|
||||
<Badge variant="secondary" class="h-4 min-w-0 shrink px-1 text-[10px]">
|
||||
<Badge variant="secondary" class="h-4 max-w-24 min-w-0 shrink px-1 text-[10px]">
|
||||
<TruncatedText text={`v${serverInfo.version}`} />
|
||||
</Badge>
|
||||
{/if}
|
||||
|
||||
@@ -180,6 +180,16 @@ export { default as McpServerCardDeleteDialog } from './McpServerCard/McpServerC
|
||||
/** Skeleton loading state for server card during health checks. */
|
||||
export { default as McpServerCardSkeleton } from './McpServerCardSkeleton.svelte';
|
||||
|
||||
/**
|
||||
* **McpServerCardCompact** - Condensed MCP server card
|
||||
*
|
||||
* Compact alternative to McpServerCard tailored for picker-style UIs.
|
||||
* Shows the server identity, status, and a flex-wrapped list of available tools.
|
||||
* Tool names are rendered as badges; hovering a badge shows its description in a tooltip.
|
||||
* Does not show connection logs or server instructions.
|
||||
*/
|
||||
export { default as McpServerCardCompact } from './McpServerCard/McpServerCardCompact.svelte';
|
||||
|
||||
/**
|
||||
* **McpServerIdentity** - Server identity display (icon, name, version)
|
||||
*
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
|
||||
let { class: className }: Props = $props();
|
||||
|
||||
let servers = $derived(mcpStore.getServersSorted());
|
||||
let servers = $derived(mcpStore.visibleMcpServers);
|
||||
|
||||
let initialLoadComplete = $state(false);
|
||||
let isAddingServer = $state(false);
|
||||
|
||||
@@ -8,6 +8,7 @@ export * from './attachment-labels';
|
||||
export * from './database';
|
||||
export * from './reasoning-effort';
|
||||
export * from './reasoning-effort-tokens';
|
||||
export * from './recommended-mcp-servers';
|
||||
export * from './storage';
|
||||
export * from './attachment-menu';
|
||||
export * from './auto-scroll';
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
export const MCP_SERVER_URL_PLACEHOLDER = 'https://mcp.example.com/sse';
|
||||
export const MIN_AUTOCOMPLETE_INPUT_LENGTH = 1;
|
||||
/** Number of tools shown on the compact MCP server card before collapsing to a "+ N more" badge */
|
||||
export const MCP_CARD_VISIBLE_TOOL_LIMIT = 4;
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
import { DEFAULT_MCP_CONFIG } from './mcp';
|
||||
import type { RecommendedMCPServer } from '$lib/types';
|
||||
|
||||
/**
|
||||
* Pre-defined recommended MCP servers.
|
||||
*
|
||||
* Servers are enabled by default, but they are not turned on for individual
|
||||
* conversations until the user explicitly enables them (so their tools are
|
||||
* disabled by default).
|
||||
*/
|
||||
export const RECOMMENDED_MCP_SERVERS: RecommendedMCPServer[] = [
|
||||
{
|
||||
id: 'exa-web-search',
|
||||
name: 'Exa Web Search',
|
||||
description: 'Search the web and retrieve relevant content.',
|
||||
url: 'https://mcp.exa.ai/mcp',
|
||||
enabled: true,
|
||||
requestTimeoutSeconds: DEFAULT_MCP_CONFIG.requestTimeoutSeconds
|
||||
},
|
||||
{
|
||||
id: 'huggingface-mcp',
|
||||
name: 'Hugging Face',
|
||||
description:
|
||||
'Browse models, datasets, spaces and machine learning papers from the Hugging Face hub.',
|
||||
url: 'https://huggingface.co/mcp',
|
||||
enabled: true,
|
||||
requestTimeoutSeconds: DEFAULT_MCP_CONFIG.requestTimeoutSeconds
|
||||
}
|
||||
];
|
||||
|
||||
export const RECOMMENDED_MCP_SERVER_IDS = new Set(
|
||||
RECOMMENDED_MCP_SERVERS.map((server) => server.id)
|
||||
);
|
||||
|
||||
export const RECOMMENDED_MCP_SERVERS_OPTIN_DIALOG_DELAY = 1000;
|
||||
@@ -59,6 +59,7 @@ export const SETTINGS_KEYS = {
|
||||
// MCP
|
||||
MCP_SERVERS: 'mcpServers',
|
||||
MCP_REQUEST_TIMEOUT_SECONDS: 'mcpRequestTimeoutSeconds',
|
||||
MCP_DEFAULT_SERVER_OVERRIDES: 'mcpDefaultServerOverrides',
|
||||
AGENTIC_MAX_TURNS: 'agenticMaxTurns',
|
||||
ALWAYS_SHOW_AGENTIC_TURNS: 'alwaysShowAgenticTurns',
|
||||
AGENTIC_MAX_TOOL_PREVIEW_LINES: 'agenticMaxToolPreviewLines',
|
||||
|
||||
@@ -28,6 +28,7 @@ import McpLogo from '$lib/components/app/mcp/McpLogo.svelte';
|
||||
import { SETTINGS_KEYS } from './settings-keys';
|
||||
import { ROUTES, SETTINGS_SECTION_SLUGS } from './routes';
|
||||
import { TITLE_GENERATION } from './title-generation';
|
||||
import { RECOMMENDED_MCP_SERVERS } from './recommended-mcp-servers';
|
||||
|
||||
export const SETTINGS_SECTION_TITLES = {
|
||||
GENERAL: 'General',
|
||||
@@ -774,9 +775,16 @@ const NON_UI_SETTINGS: SettingsEntry[] = [
|
||||
key: SETTINGS_KEYS.MCP_SERVERS,
|
||||
label: 'MCP servers',
|
||||
help: 'Configure MCP servers as a JSON list. Use the form in the MCP Client settings section to edit.',
|
||||
defaultValue: '[]',
|
||||
defaultValue: JSON.stringify(RECOMMENDED_MCP_SERVERS),
|
||||
type: SettingsFieldType.INPUT,
|
||||
sync: { serverKey: SETTINGS_KEYS.MCP_SERVERS, paramType: SyncableParameterType.STRING }
|
||||
},
|
||||
{
|
||||
key: SETTINGS_KEYS.MCP_DEFAULT_SERVER_OVERRIDES,
|
||||
label: 'MCP default server overrides',
|
||||
help: 'Per-server enable/disable defaults inherited by new chats. JSON-serialized list of {serverId, enabled} entries.',
|
||||
defaultValue: '[]',
|
||||
type: SettingsFieldType.INPUT
|
||||
}
|
||||
// {
|
||||
// key: SETTINGS_KEYS.PY_INTERPRETER_ENABLED,
|
||||
|
||||
@@ -21,9 +21,10 @@ export const DISABLED_TOOLS_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.disabledTool
|
||||
/** Disabled tools keyed by stable selection identity, no migration from the name based key */
|
||||
export const DISABLED_TOOL_KEYS_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.disabledToolKeys`;
|
||||
export const FAVORITE_MODELS_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.favoriteModels`;
|
||||
export const MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.mcpDefaultEnabled`;
|
||||
export const THINKING_ENABLED_DEFAULT_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.thinkingEnabledDefault`;
|
||||
export const REASONING_EFFORT_DEFAULT_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.reasoningEffortDefault`;
|
||||
/** Set when user has interacted with the MCP server recommendations dialog (checked servers, added custom server, or dismissed) */
|
||||
export const MCP_SERVERS_ADDED_TO_CHAT_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.mcpServersSetupDone`;
|
||||
export const USER_OVERRIDES_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.userOverrides`;
|
||||
|
||||
/** Key prefix for per-conversation resumable stream state, conversationId is appended */
|
||||
@@ -38,8 +39,6 @@ export const DEPRECATED_CONFIG_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME_DEPRECATED
|
||||
export const DEPRECATED_DISABLED_TOOLS_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME_DEPRECATED}.disabledTools`;
|
||||
/** @deprecated Use {@link FAVORITE_MODELS_LOCALSTORAGE_KEY} instead */
|
||||
export const DEPRECATED_FAVORITE_MODELS_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME_DEPRECATED}.favoriteModels`;
|
||||
/** @deprecated Use {@link MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY} instead */
|
||||
export const DEPRECATED_MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME_DEPRECATED}.mcpDefaultEnabled`;
|
||||
/** @deprecated Use {@link USER_OVERRIDES_LOCALSTORAGE_KEY} instead */
|
||||
export const DEPRECATED_USER_OVERRIDES_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME_DEPRECATED}.userOverrides`;
|
||||
|
||||
@@ -52,6 +51,5 @@ export const NEW_TO_DEPRECATED_MAP: Record<string, string> = {
|
||||
[CONFIG_LOCALSTORAGE_KEY]: DEPRECATED_CONFIG_LOCALSTORAGE_KEY,
|
||||
[DISABLED_TOOLS_LOCALSTORAGE_KEY]: DEPRECATED_DISABLED_TOOLS_LOCALSTORAGE_KEY,
|
||||
[FAVORITE_MODELS_LOCALSTORAGE_KEY]: DEPRECATED_FAVORITE_MODELS_LOCALSTORAGE_KEY,
|
||||
[MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY]: DEPRECATED_MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY,
|
||||
[USER_OVERRIDES_LOCALSTORAGE_KEY]: DEPRECATED_USER_OVERRIDES_LOCALSTORAGE_KEY
|
||||
};
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
import { browser } from '$app/environment';
|
||||
import {
|
||||
MCP_SERVERS_ADDED_TO_CHAT_LOCALSTORAGE_KEY,
|
||||
RECOMMENDED_MCP_SERVER_IDS,
|
||||
RECOMMENDED_MCP_SERVERS_OPTIN_DIALOG_DELAY
|
||||
} from '$lib/constants';
|
||||
import { mcpStore } from '$lib/stores/mcp.svelte';
|
||||
|
||||
/**
|
||||
* First-run opt-in dialog for the recommended MCP servers.
|
||||
*
|
||||
* Owns the dismissed / open / trigger-timeout state and the effect that
|
||||
* schedules the dialog. Reads opt-in status and the configured server list
|
||||
* from `mcpStore`, so callers don't need to recompute on their side.
|
||||
*/
|
||||
export function useMcpRecommendations() {
|
||||
let dismissed = $state(
|
||||
browser && localStorage.getItem(MCP_SERVERS_ADDED_TO_CHAT_LOCALSTORAGE_KEY) === 'true'
|
||||
);
|
||||
let open = $state(false);
|
||||
let checked = $state(false);
|
||||
let triggerTimeout: ReturnType<typeof setTimeout> | null = null;
|
||||
|
||||
function dismiss() {
|
||||
if (browser) {
|
||||
localStorage.setItem(MCP_SERVERS_ADDED_TO_CHAT_LOCALSTORAGE_KEY, 'true');
|
||||
}
|
||||
dismissed = true;
|
||||
open = false;
|
||||
if (triggerTimeout) {
|
||||
clearTimeout(triggerTimeout);
|
||||
triggerTimeout = null;
|
||||
}
|
||||
}
|
||||
|
||||
function handleOpenChange(next: boolean) {
|
||||
open = next;
|
||||
if (!next) dismiss();
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (!browser) return;
|
||||
|
||||
if (open || dismissed) {
|
||||
if (triggerTimeout) {
|
||||
clearTimeout(triggerTimeout);
|
||||
triggerTimeout = null;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Already evaluated once this session; leave any pending trigger alone so
|
||||
// it can still fire later. Setting `checked = true` below re-runs this
|
||||
// effect, and we must not wipe the timeout that was just scheduled.
|
||||
if (checked) return;
|
||||
|
||||
if (mcpStore.optedInRecommendationIds.size > 0) {
|
||||
checked = true;
|
||||
return;
|
||||
}
|
||||
|
||||
const hasRecommendations = mcpStore
|
||||
.getServers()
|
||||
.some((server) => RECOMMENDED_MCP_SERVER_IDS.has(server.id));
|
||||
|
||||
if (hasRecommendations) {
|
||||
triggerTimeout = setTimeout(() => {
|
||||
open = true;
|
||||
}, RECOMMENDED_MCP_SERVERS_OPTIN_DIALOG_DELAY);
|
||||
}
|
||||
|
||||
checked = true;
|
||||
});
|
||||
|
||||
return {
|
||||
get open() {
|
||||
return open;
|
||||
},
|
||||
get dismissed() {
|
||||
return dismissed;
|
||||
},
|
||||
dismiss,
|
||||
handleOpenChange
|
||||
};
|
||||
}
|
||||
@@ -20,6 +20,7 @@
|
||||
import Dexie from 'dexie';
|
||||
import {
|
||||
STORAGE_APP_NAME,
|
||||
STORAGE_APP_NAME_DEPRECATED,
|
||||
DB_APP_NAME_DEPRECATED,
|
||||
CONFIG_LOCALSTORAGE_KEY,
|
||||
IDXDB_TABLES,
|
||||
@@ -494,12 +495,69 @@ const customJsonKeyMigration: Migration = {
|
||||
}
|
||||
};
|
||||
|
||||
const MCP_DEFAULT_ENABLED_MIGRATION_ID = 'mcp-default-enabled-to-config-v1';
|
||||
|
||||
const LEGACY_MCP_DEFAULT_ENABLED_KEY = `${STORAGE_APP_NAME}.mcpDefaultEnabled`;
|
||||
const DEPRECATED_LEGACY_MCP_DEFAULT_ENABLED_KEY = `${STORAGE_APP_NAME_DEPRECATED}.mcpDefaultEnabled`;
|
||||
|
||||
const mcpDefaultEnabledMigration: Migration = {
|
||||
id: MCP_DEFAULT_ENABLED_MIGRATION_ID,
|
||||
description:
|
||||
'Copy mcpDefaultEnabled localStorage key into settings config (preserves legacy keys)',
|
||||
|
||||
async run(): Promise<void> {
|
||||
const raw =
|
||||
localStorage.getItem(LEGACY_MCP_DEFAULT_ENABLED_KEY) ??
|
||||
localStorage.getItem(DEPRECATED_LEGACY_MCP_DEFAULT_ENABLED_KEY);
|
||||
|
||||
// Legacy keys intentionally left in place so a downgrade keeps reading them.
|
||||
|
||||
if (raw === null) {
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
|
||||
console.log('[Migration] MCP default enabled: no legacy key found, skipping');
|
||||
return;
|
||||
}
|
||||
|
||||
const configRaw = localStorage.getItem(CONFIG_LOCALSTORAGE_KEY);
|
||||
const config = configRaw ? JSON.parse(configRaw) : {};
|
||||
|
||||
// Don't overwrite an existing config entry — current data wins.
|
||||
if (SETTINGS_KEYS.MCP_DEFAULT_SERVER_OVERRIDES in config) {
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
|
||||
console.log('[Migration] MCP default enabled: config already has overrides, skipping');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(raw);
|
||||
if (!Array.isArray(parsed)) return;
|
||||
const valid = parsed.every(
|
||||
(o) =>
|
||||
typeof o === 'object' &&
|
||||
o !== null &&
|
||||
typeof (o as Record<string, unknown>).serverId === 'string' &&
|
||||
typeof (o as Record<string, unknown>).enabled === 'boolean'
|
||||
);
|
||||
if (!valid) return;
|
||||
} catch {
|
||||
return;
|
||||
}
|
||||
|
||||
config[SETTINGS_KEYS.MCP_DEFAULT_SERVER_OVERRIDES] = raw;
|
||||
localStorage.setItem(CONFIG_LOCALSTORAGE_KEY, JSON.stringify(config));
|
||||
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
|
||||
console.log('[Migration] MCP default enabled: moved legacy key into config');
|
||||
}
|
||||
};
|
||||
|
||||
const migrations: Migration[] = [
|
||||
localStorageMigration,
|
||||
idxdbMigration,
|
||||
legacyMessageMigration,
|
||||
themeMigration,
|
||||
customJsonKeyMigration
|
||||
customJsonKeyMigration,
|
||||
mcpDefaultEnabledMigration
|
||||
];
|
||||
|
||||
export const MigrationService = {
|
||||
|
||||
@@ -1083,6 +1083,11 @@ class ChatStore {
|
||||
let resolvedModel: string | null = null;
|
||||
let modelPersisted = false;
|
||||
const convId = assistantMessage.convId;
|
||||
// Tracks the last message created in this flow. Used as the parent for the next
|
||||
// turn's assistant message so createAssistantMessage does not have to read
|
||||
// conversationsStore.activeMessages, which may belong to a different conversation
|
||||
// after the user navigates while the loop is still running.
|
||||
let lastCreatedInFlow = currentMessageId;
|
||||
// freeze the POST identity from t0 so a stop cancels with the exact session key,
|
||||
// never a stale or empty model resolved later
|
||||
this.setChatStreaming(convId, streamedContent, currentMessageId, effectiveModel);
|
||||
@@ -1208,8 +1213,15 @@ class ChatStore {
|
||||
};
|
||||
if (timings) uiUpdate.timings = timings;
|
||||
if (resolvedModel) uiUpdate.model = resolvedModel;
|
||||
conversationsStore.updateMessageAtIndex(idx, uiUpdate);
|
||||
await conversationsStore.updateCurrentNode(currentMessageId);
|
||||
// touch the active ui array and node pointer only when this conversation
|
||||
// is displayed; otherwise persist the node move straight to the db so a
|
||||
// foreign conv's currNode stays untouched
|
||||
if (conversationsStore.activeConversation?.id === convId) {
|
||||
conversationsStore.updateMessageAtIndex(idx, uiUpdate);
|
||||
await conversationsStore.updateCurrentNode(currentMessageId);
|
||||
} else {
|
||||
await DatabaseService.updateCurrentNode(convId, currentMessageId);
|
||||
}
|
||||
},
|
||||
createToolResultMessage: async (
|
||||
toolCallId: string,
|
||||
@@ -1230,8 +1242,16 @@ class ChatStore {
|
||||
},
|
||||
currentMessageId
|
||||
);
|
||||
conversationsStore.addMessageToActive(msg);
|
||||
await conversationsStore.updateCurrentNode(msg.id);
|
||||
// mirror into the active store and move the node pointer only when this
|
||||
// conversation is displayed; otherwise persist the node move straight to
|
||||
// the db for the owning conv so a foreign conv's currNode stays untouched
|
||||
if (conversationsStore.activeConversation?.id === convId) {
|
||||
conversationsStore.addMessageToActive(msg);
|
||||
await conversationsStore.updateCurrentNode(msg.id);
|
||||
} else {
|
||||
await DatabaseService.updateCurrentNode(convId, msg.id);
|
||||
}
|
||||
lastCreatedInFlow = msg.id;
|
||||
return msg;
|
||||
},
|
||||
createAssistantMessage: async () => {
|
||||
@@ -1239,8 +1259,6 @@ class ChatStore {
|
||||
streamedContent = '';
|
||||
streamedReasoningContent = '';
|
||||
|
||||
const lastMsg =
|
||||
conversationsStore.activeMessages[conversationsStore.activeMessages.length - 1];
|
||||
const msg = await DatabaseService.createMessageBranch(
|
||||
{
|
||||
convId,
|
||||
@@ -1252,10 +1270,13 @@ class ChatStore {
|
||||
children: [],
|
||||
model: resolvedModel
|
||||
},
|
||||
lastMsg.id
|
||||
lastCreatedInFlow
|
||||
);
|
||||
conversationsStore.addMessageToActive(msg);
|
||||
if (conversationsStore.activeConversation?.id === convId) {
|
||||
conversationsStore.addMessageToActive(msg);
|
||||
}
|
||||
currentMessageId = msg.id;
|
||||
lastCreatedInFlow = msg.id;
|
||||
return msg;
|
||||
},
|
||||
onFlowComplete: (finalTimings?: ChatMessageTimings) => {
|
||||
|
||||
@@ -23,7 +23,7 @@ import { browser } from '$app/environment';
|
||||
import { toast } from 'svelte-sonner';
|
||||
import { DatabaseService } from '$lib/services/database.service';
|
||||
import { MigrationService } from '$lib/services/migration.service';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { config, settingsStore } from '$lib/stores/settings.svelte';
|
||||
import { filterByLeafNodeId, findLeafNode, generateConversationTitle } from '$lib/utils';
|
||||
import type { McpServerOverride } from '$lib/types/database';
|
||||
import { zipSync, unzipSync, strToU8, strFromU8 } from 'fflate';
|
||||
@@ -46,7 +46,7 @@ import {
|
||||
ISO_TIME_SEPARATOR_REPLACEMENT,
|
||||
NON_ALPHANUMERIC_REGEX,
|
||||
MULTIPLE_UNDERSCORE_REGEX,
|
||||
MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY,
|
||||
SETTINGS_KEYS,
|
||||
THINKING_ENABLED_DEFAULT_LOCALSTORAGE_KEY,
|
||||
REASONING_EFFORT_DEFAULT_LOCALSTORAGE_KEY
|
||||
} from '$lib/constants';
|
||||
@@ -90,12 +90,10 @@ class ConversationsStore {
|
||||
/** Global (non-conversation-specific) reasoning effort default */
|
||||
pendingReasoningEffort = $state<ReasoningEffort>(ConversationsStore.loadReasoningEffortDefault());
|
||||
|
||||
/** Load MCP default overrides from localStorage */
|
||||
private static loadMcpDefaults(): McpServerOverride[] {
|
||||
if (typeof globalThis.localStorage === 'undefined') return [];
|
||||
const raw = config()[SETTINGS_KEYS.MCP_DEFAULT_SERVER_OVERRIDES];
|
||||
if (typeof raw !== 'string' || raw.length === 0) return [];
|
||||
try {
|
||||
const raw = localStorage.getItem(MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY);
|
||||
if (!raw) return [];
|
||||
const parsed = JSON.parse(raw);
|
||||
if (!Array.isArray(parsed)) return [];
|
||||
return parsed.filter(
|
||||
@@ -106,18 +104,12 @@ class ConversationsStore {
|
||||
}
|
||||
}
|
||||
|
||||
/** Persist MCP default overrides to localStorage */
|
||||
private saveMcpDefaults(): void {
|
||||
if (typeof globalThis.localStorage === 'undefined') return;
|
||||
const plain = this.pendingMcpServerOverrides.map((o) => ({
|
||||
serverId: o.serverId,
|
||||
enabled: o.enabled
|
||||
}));
|
||||
if (plain.length > 0) {
|
||||
localStorage.setItem(MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY, JSON.stringify(plain));
|
||||
} else {
|
||||
localStorage.removeItem(MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY);
|
||||
}
|
||||
settingsStore.updateConfig(SETTINGS_KEYS.MCP_DEFAULT_SERVER_OVERRIDES, JSON.stringify(plain));
|
||||
}
|
||||
|
||||
/** Load thinking-enabled default from localStorage */
|
||||
@@ -189,6 +181,10 @@ class ConversationsStore {
|
||||
try {
|
||||
await MigrationService.runAllMigrations();
|
||||
|
||||
// Re-read defaults after migrations: a migration may have populated
|
||||
// the settings config (e.g. moved legacy MCP overrides into it).
|
||||
this.pendingMcpServerOverrides = ConversationsStore.loadMcpDefaults();
|
||||
|
||||
await this.loadConversations();
|
||||
this.isInitialized = true;
|
||||
} catch (error) {
|
||||
|
||||
@@ -20,11 +20,13 @@
|
||||
*/
|
||||
|
||||
import { browser } from '$app/environment';
|
||||
import { SvelteSet } from 'svelte/reactivity';
|
||||
import { SETTINGS_KEYS } from '$lib/constants';
|
||||
import { MCPService } from '$lib/services/mcp.service';
|
||||
import { config, settingsStore } from '$lib/stores/settings.svelte';
|
||||
import { mcpResourceStore } from '$lib/stores/mcp-resources.svelte';
|
||||
import { serverStore } from '$lib/stores/server.svelte';
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { mode } from 'mode-watcher';
|
||||
import {
|
||||
parseMcpServerSettings,
|
||||
@@ -48,10 +50,11 @@ import {
|
||||
EXPECTED_THEMED_ICON_PAIR_COUNT,
|
||||
MCP_ALLOWED_ICON_MIME_TYPES,
|
||||
MCP_SERVER_ID_PREFIX,
|
||||
MCP_RECONNECT_INITIAL_DELAY,
|
||||
MCP_RECONNECT_BACKOFF_MULTIPLIER,
|
||||
MCP_RECONNECT_INITIAL_DELAY,
|
||||
MCP_RECONNECT_MAX_DELAY,
|
||||
MCP_RECONNECT_ATTEMPT_TIMEOUT_MS
|
||||
MCP_RECONNECT_ATTEMPT_TIMEOUT_MS,
|
||||
RECOMMENDED_MCP_SERVER_IDS
|
||||
} from '$lib/constants';
|
||||
import type {
|
||||
MCPToolCall,
|
||||
@@ -70,6 +73,7 @@ import type {
|
||||
Tool,
|
||||
HealthCheckState,
|
||||
MCPServerSettingsEntry,
|
||||
MCPServerDisplayInfo,
|
||||
MCPServerConfig,
|
||||
MCPResourceIcon,
|
||||
MCPResourceAttachment,
|
||||
@@ -365,7 +369,7 @@ class MCPStore {
|
||||
return this.connections;
|
||||
}
|
||||
|
||||
getServerLabel(server: MCPServerSettingsEntry): string {
|
||||
getServerLabel(server: MCPServerDisplayInfo): string {
|
||||
const healthState = this.getHealthCheckState(server.id);
|
||||
|
||||
if (healthState?.status === HealthCheckStatus.SUCCESS)
|
||||
@@ -527,7 +531,7 @@ class MCPStore {
|
||||
|
||||
addServer(
|
||||
serverData: Omit<MCPServerSettingsEntry, 'id' | 'requestTimeoutSeconds'> & { id?: string }
|
||||
): void {
|
||||
): MCPServerSettingsEntry {
|
||||
const servers = this.getServers();
|
||||
const newServer: MCPServerSettingsEntry = {
|
||||
id: serverData.id || (uuid() ?? `server-${Date.now()}`),
|
||||
@@ -540,6 +544,7 @@ class MCPStore {
|
||||
useProxy: serverData.useProxy
|
||||
};
|
||||
settingsStore.updateConfig(SETTINGS_KEYS.MCP_SERVERS, JSON.stringify([...servers, newServer]));
|
||||
return newServer;
|
||||
}
|
||||
|
||||
updateServer(id: string, updates: Partial<MCPServerSettingsEntry>): void {
|
||||
@@ -576,6 +581,33 @@ class MCPStore {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Recommended MCP server IDs the user opted in to via per-chat overrides.
|
||||
* Single source of truth for "which recommendations has the user accepted",
|
||||
* shared by the recommendations hook and the visible-servers getter.
|
||||
*/
|
||||
get optedInRecommendationIds(): ReadonlySet<string> {
|
||||
const ids = new SvelteSet<string>();
|
||||
for (const override of conversationsStore.pendingMcpServerOverrides) {
|
||||
if (RECOMMENDED_MCP_SERVER_IDS.has(override.serverId) && override.enabled) {
|
||||
ids.add(override.serverId);
|
||||
}
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
/**
|
||||
* MCP servers selectable in chat-add UIs and the settings page:
|
||||
* enabled in settings and either non-recommended or explicitly opted in.
|
||||
*/
|
||||
get visibleMcpServers(): MCPServerSettingsEntry[] {
|
||||
const optedIn = this.optedInRecommendationIds;
|
||||
return this.getServersSorted().filter(
|
||||
(server) =>
|
||||
server.enabled && (!RECOMMENDED_MCP_SERVER_IDS.has(server.id) || optedIn.has(server.id))
|
||||
);
|
||||
}
|
||||
|
||||
async ensureInitialized(perChatOverrides?: McpServerOverride[]): Promise<boolean> {
|
||||
if (!browser) {
|
||||
return false;
|
||||
|
||||
@@ -127,6 +127,8 @@ export type {
|
||||
MCPServerConfig,
|
||||
MCPClientConfig,
|
||||
MCPServerSettingsEntry,
|
||||
MCPServerDisplayInfo,
|
||||
RecommendedMCPServer,
|
||||
MCPToolCall,
|
||||
OpenAIToolDefinition,
|
||||
ServerStatus,
|
||||
|
||||
Vendored
+18
-3
@@ -209,17 +209,32 @@ export type MCPToolCall = {
|
||||
};
|
||||
};
|
||||
|
||||
export type MCPServerSettingsEntry = {
|
||||
/**
|
||||
* Minimum fields needed to display or identify an MCP server.
|
||||
*/
|
||||
export interface MCPServerDisplayInfo {
|
||||
id: string;
|
||||
enabled: boolean;
|
||||
name?: string;
|
||||
url: string;
|
||||
}
|
||||
|
||||
export type MCPServerSettingsEntry = MCPServerDisplayInfo & {
|
||||
enabled: boolean;
|
||||
requestTimeoutSeconds: number;
|
||||
headers?: string;
|
||||
name?: string;
|
||||
iconUrl?: string;
|
||||
useProxy?: boolean;
|
||||
};
|
||||
|
||||
/**
|
||||
* Pre-defined recommended MCP server shown to the user in onboarding/picker UIs.
|
||||
*/
|
||||
export interface RecommendedMCPServer extends MCPServerDisplayInfo {
|
||||
description: string;
|
||||
enabled: boolean;
|
||||
requestTimeoutSeconds: number;
|
||||
}
|
||||
|
||||
export interface MCPHostManagerConfig {
|
||||
servers: MCPClientConfig['servers'];
|
||||
clientInfo?: Implementation;
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
import { onMount } from 'svelte';
|
||||
|
||||
import { SidebarNavigation, DialogConversationTitleUpdate } from '$lib/components/app';
|
||||
import { DialogMcpServerRecommendations } from '$lib/components/app/dialogs';
|
||||
import { PwaMetaTags, PwaRefreshAlert } from '$lib/components/pwa';
|
||||
import { pwaAssetsHead } from 'virtual:pwa-assets/head';
|
||||
|
||||
@@ -26,6 +27,7 @@
|
||||
import { FAVICON_PATHS, FAVICON_SELECTORS } from '$lib/constants/pwa';
|
||||
import { useKeyboardShortcuts } from '$lib/hooks/use-keyboard-shortcuts.svelte';
|
||||
import { usePwa } from '$lib/hooks/use-pwa.svelte';
|
||||
import { useMcpRecommendations } from '$lib/hooks/use-mcp-recommendations.svelte';
|
||||
import { conversations } from '$lib/stores/conversations.svelte';
|
||||
import { isMobile } from '$lib/stores/viewport.svelte';
|
||||
import { theme } from '$lib/stores/theme.svelte';
|
||||
@@ -37,6 +39,8 @@
|
||||
let innerHeight = $state<number | undefined>();
|
||||
let innerWidth = $state(browser ? window.innerWidth : 0);
|
||||
|
||||
const mcpRecommendations = useMcpRecommendations();
|
||||
|
||||
let chatSidebar:
|
||||
| {
|
||||
activateSearchMode?: () => void;
|
||||
@@ -321,6 +325,11 @@
|
||||
onConfirm={handleTitleUpdateConfirm}
|
||||
onCancel={handleTitleUpdateCancel}
|
||||
/>
|
||||
|
||||
<DialogMcpServerRecommendations
|
||||
open={mcpRecommendations.open}
|
||||
onOpenChange={mcpRecommendations.handleOpenChange}
|
||||
/>
|
||||
</Tooltip.Provider>
|
||||
|
||||
<!-- PWA update prompt + version -->
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
<script lang="ts">
|
||||
import { untrack } from 'svelte';
|
||||
import McpServerForm from '$lib/components/app/mcp/McpServerForm.svelte';
|
||||
|
||||
interface Props {
|
||||
headers?: string;
|
||||
}
|
||||
|
||||
let { headers = '' }: Props = $props();
|
||||
|
||||
let headersState = $state(untrack(() => headers));
|
||||
let lastCapturedHeaders = $state(untrack(() => headers));
|
||||
|
||||
$effect(() => {
|
||||
if (headers !== lastCapturedHeaders) {
|
||||
headersState = headers;
|
||||
lastCapturedHeaders = headers;
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<!--
|
||||
Drives McpServerForm with a controlled `headers` string and exposes the
|
||||
latest captured value through `data-captured-headers` so the client test
|
||||
can read it back without a custom binding API.
|
||||
-->
|
||||
<McpServerForm
|
||||
url="https://example.test/mcp"
|
||||
headers={headersState}
|
||||
onUrlChange={() => {}}
|
||||
onHeadersChange={(value) => {
|
||||
headersState = value;
|
||||
}}
|
||||
id="mcp-server-form-test"
|
||||
/>
|
||||
|
||||
<div data-testid="captured-headers" data-captured-headers={headersState} hidden></div>
|
||||
@@ -0,0 +1,133 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
import { render } from 'vitest-browser-svelte';
|
||||
import McpServerFormWrapper from './components/McpServerFormWrapper.svelte';
|
||||
|
||||
const AUTHORIZATION_HEADER = 'Authorization';
|
||||
const BEARER_PREFIX = 'Bearer ';
|
||||
const BEARER_PLACEHOLDER = 'Paste token here';
|
||||
|
||||
/**
|
||||
* Client-side tests for the McpServerForm bearer UI.
|
||||
*
|
||||
* The dedicated UI only "owns" Authorization headers that already carry a
|
||||
* Bearer scheme (heuristic check on the value). Other Authorization values
|
||||
* stay in the KV section so the user can still edit them verbatim. Storage
|
||||
* always goes through the same custom-headers slot, so a round-trip via this
|
||||
* UI produces exactly one `Authorization: Bearer <token>` entry.
|
||||
*
|
||||
* Equivalent parser coverage lives in `tests/unit/headers.test.ts`.
|
||||
*/
|
||||
describe('McpServerForm - Authorization / bearer UI', () => {
|
||||
function bearerInput(screen: Awaited<ReturnType<typeof render>>) {
|
||||
return screen.locator.getByPlaceholder(BEARER_PLACEHOLDER);
|
||||
}
|
||||
|
||||
function capturedHeaders(screen: Awaited<ReturnType<typeof render>>) {
|
||||
return screen.getByTestId('captured-headers');
|
||||
}
|
||||
|
||||
it('mounts with the bearer input hidden when no auth header is present', async () => {
|
||||
const screen = await render(McpServerFormWrapper, { headers: '' });
|
||||
|
||||
await expect.element(screen.getByRole('textbox', { name: /server url/i })).toBeVisible();
|
||||
|
||||
await expect.element(bearerInput(screen)).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('toggling Authorization shows the bearer input', async () => {
|
||||
const screen = await render(McpServerFormWrapper, { headers: '' });
|
||||
|
||||
await screen.getByRole('switch', { name: /authorization/i }).click();
|
||||
|
||||
await expect.element(bearerInput(screen)).toBeVisible();
|
||||
});
|
||||
|
||||
it('typing a token writes the Authorization row with the Bearer prefix prepended', async () => {
|
||||
const screen = await render(McpServerFormWrapper, { headers: '' });
|
||||
|
||||
await screen.getByRole('switch', { name: /authorization/i }).click();
|
||||
|
||||
const token = 'super-secret';
|
||||
await bearerInput(screen).fill(token);
|
||||
|
||||
const expected = JSON.stringify({ [AUTHORIZATION_HEADER]: `${BEARER_PREFIX}${token}` });
|
||||
await expect
|
||||
.element(capturedHeaders(screen))
|
||||
.toHaveAttribute('data-captured-headers', expected);
|
||||
});
|
||||
|
||||
it('pre-existing Bearer header pre-fills the bearer input with the token stripped', async () => {
|
||||
const existing = JSON.stringify({
|
||||
'X-Trace-Id': 'abc',
|
||||
[AUTHORIZATION_HEADER]: `${BEARER_PREFIX}preexisting`
|
||||
});
|
||||
|
||||
const screen = await render(McpServerFormWrapper, { headers: existing });
|
||||
|
||||
await expect.element(bearerInput(screen)).toBeVisible();
|
||||
await expect.element(bearerInput(screen)).toHaveValue('preexisting');
|
||||
});
|
||||
|
||||
it('non-Bearer Authorization is ignored by the dedicated UI and stays in the KV section', async () => {
|
||||
const existing = JSON.stringify({ [AUTHORIZATION_HEADER]: 'Basic czNjcjpwYXNz' });
|
||||
|
||||
const screen = await render(McpServerFormWrapper, { headers: existing });
|
||||
|
||||
await expect.element(bearerInput(screen)).not.toBeInTheDocument();
|
||||
|
||||
const headerKeyInput = screen.getByPlaceholder('Header name');
|
||||
await expect.element(headerKeyInput).toBeVisible();
|
||||
});
|
||||
|
||||
it('engaging the token UI replaces a non-Bearer Authorization with the Bearer scheme', async () => {
|
||||
const existing = JSON.stringify({ [AUTHORIZATION_HEADER]: 'Basic old' });
|
||||
|
||||
const screen = await render(McpServerFormWrapper, { headers: existing });
|
||||
|
||||
await screen.getByRole('switch', { name: /authorization/i }).click();
|
||||
await bearerInput(screen).fill('new');
|
||||
|
||||
const expected = JSON.stringify({ [AUTHORIZATION_HEADER]: `${BEARER_PREFIX}new` });
|
||||
await expect
|
||||
.element(capturedHeaders(screen))
|
||||
.toHaveAttribute('data-captured-headers', expected);
|
||||
});
|
||||
|
||||
it('toggling Authorization off with no token drops the Bearer row but keeps non-Bearer schemes', async () => {
|
||||
const existing = JSON.stringify({ [AUTHORIZATION_HEADER]: `${BEARER_PREFIX}xyz` });
|
||||
const screen = await render(McpServerFormWrapper, { headers: existing });
|
||||
|
||||
await screen.getByRole('switch', { name: /authorization/i }).click();
|
||||
|
||||
await expect.element(capturedHeaders(screen)).toHaveAttribute('data-captured-headers', '');
|
||||
});
|
||||
|
||||
it('toggling Authorization off when no Bearer row is present leaves headers untouched', async () => {
|
||||
const existing = JSON.stringify({ [AUTHORIZATION_HEADER]: 'Basic czNjcjpwYXNz' });
|
||||
const screen = await render(McpServerFormWrapper, { headers: existing });
|
||||
|
||||
await screen.getByRole('switch', { name: /authorization/i }).click();
|
||||
await screen.getByRole('switch', { name: /authorization/i }).click();
|
||||
|
||||
await expect
|
||||
.element(capturedHeaders(screen))
|
||||
.toHaveAttribute('data-captured-headers', existing);
|
||||
});
|
||||
|
||||
it('clearing the bearer input drops the Authorization row', async () => {
|
||||
const existing = JSON.stringify({ [AUTHORIZATION_HEADER]: `${BEARER_PREFIX}xyz` });
|
||||
const screen = await render(McpServerFormWrapper, { headers: existing });
|
||||
|
||||
await bearerInput(screen).fill('');
|
||||
|
||||
await expect.element(capturedHeaders(screen)).toHaveAttribute('data-captured-headers', '');
|
||||
});
|
||||
|
||||
it('does not surface Bearer Authorization in the KV section even when pre-existing', async () => {
|
||||
const existing = JSON.stringify({ [AUTHORIZATION_HEADER]: `${BEARER_PREFIX}xyz` });
|
||||
const screen = await render(McpServerFormWrapper, { headers: existing });
|
||||
|
||||
const headerKeyInput = screen.getByPlaceholder('Header name');
|
||||
await expect.element(headerKeyInput).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,126 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
import { parseHeadersToArray, serializeHeaders } from '$lib/utils/headers';
|
||||
|
||||
/**
|
||||
* Tests for the header serialization helpers used by the MCP server form
|
||||
* (custom header rows) and the new Authorization/Bearer-token flow.
|
||||
*/
|
||||
describe('parseHeadersToArray', () => {
|
||||
it('returns an empty array for empty or whitespace-only input', () => {
|
||||
expect(parseHeadersToArray('')).toEqual([]);
|
||||
expect(parseHeadersToArray(' ')).toEqual([]);
|
||||
expect(parseHeadersToArray(undefined as unknown as string)).toEqual([]);
|
||||
});
|
||||
|
||||
it('returns an empty array for invalid JSON input', () => {
|
||||
expect(parseHeadersToArray('{not-json')).toEqual([]);
|
||||
expect(parseHeadersToArray('[]')).toEqual([]);
|
||||
expect(parseHeadersToArray('"plain-string"')).toEqual([]);
|
||||
});
|
||||
|
||||
it('converts an object into ordered key/value pairs', () => {
|
||||
expect(parseHeadersToArray('{"X-Foo":"bar","Authorization":"Bearer abc"}')).toEqual([
|
||||
{ key: 'X-Foo', value: 'bar' },
|
||||
{ key: 'Authorization', value: 'Bearer abc' }
|
||||
]);
|
||||
});
|
||||
|
||||
it('stringifies non-string values', () => {
|
||||
expect(parseHeadersToArray('{"count":"42","flag":"true"}')).toEqual([
|
||||
{ key: 'count', value: '42' },
|
||||
{ key: 'flag', value: 'true' }
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('serializeHeaders', () => {
|
||||
it('returns an empty string when there are no valid pairs', () => {
|
||||
expect(serializeHeaders([])).toBe('');
|
||||
expect(serializeHeaders([{ key: '', value: 'value' }])).toBe('');
|
||||
expect(serializeHeaders([{ key: ' ', value: 'value' }])).toBe('');
|
||||
});
|
||||
|
||||
it('returns an empty string when every pair has a blank key', () => {
|
||||
expect(
|
||||
serializeHeaders([
|
||||
{ key: '', value: 'drop-me' },
|
||||
{ key: ' ', value: 'drop-me-too' },
|
||||
{ key: '\t', value: 'tab-key' }
|
||||
])
|
||||
).toBe('');
|
||||
});
|
||||
|
||||
it('drops pairs with empty keys but keeps the rest', () => {
|
||||
expect(
|
||||
serializeHeaders([
|
||||
{ key: '', value: 'drop-me' },
|
||||
{ key: 'X-Keep', value: 'ok' }
|
||||
])
|
||||
).toBe('{"X-Keep":"ok"}');
|
||||
});
|
||||
|
||||
it('trims keys before serializing', () => {
|
||||
expect(serializeHeaders([{ key: ' X-Space ', value: 'ok' }])).toBe('{"X-Space":"ok"}');
|
||||
});
|
||||
|
||||
it('preserves the input order of surviving pairs', () => {
|
||||
const serialized = serializeHeaders([
|
||||
{ key: 'X-C', value: '3' },
|
||||
{ key: 'X-A', value: '1' },
|
||||
{ key: 'X-B', value: '2' }
|
||||
]);
|
||||
|
||||
// Object key order follows insertion order in modern JS engines, so
|
||||
// the serialized JSON writes keys in our input order.
|
||||
expect(JSON.parse(serialized)).toEqual({ 'X-C': '3', 'X-A': '1', 'X-B': '2' });
|
||||
});
|
||||
});
|
||||
|
||||
describe('parseHeadersToArray / serializeHeaders roundtrip', () => {
|
||||
it('serializes back to an equal header object after a parse', () => {
|
||||
const original = JSON.stringify({
|
||||
'Content-Type': 'application/json',
|
||||
'X-Trace-Id': 'abc-123'
|
||||
});
|
||||
|
||||
const roundtrip = serializeHeaders(parseHeadersToArray(original));
|
||||
|
||||
expect(JSON.parse(roundtrip)).toEqual(JSON.parse(original));
|
||||
});
|
||||
|
||||
it('drops rows whose keys are blank after trimming during serialization', () => {
|
||||
const pairs = parseHeadersToArray('{"X-Keep":"ok","":"drop-me"}');
|
||||
|
||||
// parseHeadersToArray keeps raw key strings (the consumer is expected to
|
||||
// filter blanks, not the parser); serialization must strip them.
|
||||
expect(pairs).toEqual([
|
||||
{ key: 'X-Keep', value: 'ok' },
|
||||
{ key: '', value: 'drop-me' }
|
||||
]);
|
||||
expect(serializeHeaders(pairs)).toBe('{"X-Keep":"ok"}');
|
||||
});
|
||||
|
||||
it('preserves upstream keys untouched (does not lowercase them)', () => {
|
||||
const upperCased = '{"Authorization":"Bearer xyz"}';
|
||||
|
||||
const parsed = parseHeadersToArray(upperCased);
|
||||
|
||||
expect(parsed).toEqual([{ key: 'Authorization', value: 'Bearer xyz' }]);
|
||||
});
|
||||
|
||||
it('bearer-token write survives a re-parse when paired with regular custom headers', () => {
|
||||
// The McpServerForm bearer UI writes {Authorization: `Bearer <token>`}
|
||||
// into the same headers string as the custom KV section. The round
|
||||
// trip below mirrors the exact shape the form produces so a future
|
||||
// refactor of either code path cannot silently change the on-disk key.
|
||||
const pairs = [
|
||||
{ key: 'X-Trace-Id', value: 'abc-123' },
|
||||
{ key: 'Authorization', value: 'Bearer super-secret' }
|
||||
];
|
||||
|
||||
const serialized = serializeHeaders(pairs);
|
||||
|
||||
expect(serialized).toBe('{"X-Trace-Id":"abc-123","Authorization":"Bearer super-secret"}');
|
||||
expect(parseHeadersToArray(serialized)).toEqual(pairs);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,144 @@
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
import { parseMcpServerSettings } from '$lib/utils/mcp';
|
||||
import { DEFAULT_MCP_CONFIG, MCP_SERVER_ID_PREFIX } from '$lib/constants/mcp';
|
||||
|
||||
/**
|
||||
* Tests for the mcpServers settings parser.
|
||||
*
|
||||
* The branch seeds the MCP servers setting with a default value of
|
||||
* `JSON.stringify(RECOMMENDED_MCP_SERVERS)`, so the parser has to be
|
||||
* resilient to anything that may live in the user's localStorage: malformed
|
||||
* JSON, wrong shapes, missing fields, falsy-but-not-zero numbers, and entry
|
||||
* arrays that have been mutated by the user via the settings form.
|
||||
*/
|
||||
describe('parseMcpServerSettings', () => {
|
||||
it('returns an empty array for falsy or whitespace-only input', () => {
|
||||
expect(parseMcpServerSettings(null)).toEqual([]);
|
||||
expect(parseMcpServerSettings(undefined)).toEqual([]);
|
||||
expect(parseMcpServerSettings('')).toEqual([]);
|
||||
expect(parseMcpServerSettings(' ')).toEqual([]);
|
||||
});
|
||||
|
||||
it('returns an empty array and logs a warning for invalid JSON strings', () => {
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {});
|
||||
|
||||
expect(parseMcpServerSettings('{not-json')).toEqual([]);
|
||||
expect(warn).toHaveBeenCalled();
|
||||
|
||||
warn.mockRestore();
|
||||
});
|
||||
|
||||
it('returns an empty array for valid JSON that is not an array', () => {
|
||||
expect(parseMcpServerSettings('"plain-string"')).toEqual([]);
|
||||
expect(parseMcpServerSettings('{"id":"foo"}')).toEqual([]);
|
||||
expect(parseMcpServerSettings('42')).toEqual([]);
|
||||
expect(parseMcpServerSettings('null')).toEqual([]);
|
||||
});
|
||||
|
||||
it('drops entries with no parseable id and substitutes a stable fallback', () => {
|
||||
const parsed = parseMcpServerSettings(
|
||||
JSON.stringify([{ url: 'https://a.test', enabled: true }, { url: 'https://b.test' }])
|
||||
);
|
||||
|
||||
expect(parsed).toHaveLength(2);
|
||||
expect(parsed[0]?.id).toBe(`${MCP_SERVER_ID_PREFIX}-1`);
|
||||
expect(parsed[1]?.id).toBe(`${MCP_SERVER_ID_PREFIX}-2`);
|
||||
});
|
||||
|
||||
it('reuses the first id when it is present and falls back only for missing ones', () => {
|
||||
const parsed = parseMcpServerSettings(
|
||||
JSON.stringify([
|
||||
{ id: 'custom-1', url: 'https://a.test' },
|
||||
{ url: 'https://b.test' },
|
||||
{ id: 'custom-3', url: 'https://c.test' }
|
||||
])
|
||||
);
|
||||
|
||||
expect(parsed[0]?.id).toBe('custom-1');
|
||||
expect(parsed[1]?.id).toBe(`${MCP_SERVER_ID_PREFIX}-2`);
|
||||
expect(parsed[2]?.id).toBe('custom-3');
|
||||
});
|
||||
|
||||
it('falls back to the configured default requestTimeoutSeconds only for nullish values', () => {
|
||||
const fallback = DEFAULT_MCP_CONFIG.requestTimeoutSeconds;
|
||||
|
||||
const parsed = parseMcpServerSettings(
|
||||
JSON.stringify([
|
||||
{ id: 'a', url: 'https://a.test' },
|
||||
{ id: 'b', url: 'https://b.test', requestTimeoutSeconds: undefined },
|
||||
{ id: 'c', url: 'https://c.test', requestTimeoutSeconds: 0 },
|
||||
{ id: 'd', url: 'https://d.test', requestTimeoutSeconds: 45 }
|
||||
])
|
||||
);
|
||||
|
||||
// The parser uses ?? for timeout fallback, which only triggers on
|
||||
// null/undefined. Explicit 0 is preserved at face value.
|
||||
expect(parsed[0]?.requestTimeoutSeconds).toBe(fallback);
|
||||
expect(parsed[1]?.requestTimeoutSeconds).toBe(fallback);
|
||||
expect(parsed[2]?.requestTimeoutSeconds).toBe(0);
|
||||
expect(parsed[3]?.requestTimeoutSeconds).toBe(45);
|
||||
});
|
||||
|
||||
it('treats whitespace-only headers strings as undefined', () => {
|
||||
const parsed = parseMcpServerSettings(
|
||||
JSON.stringify([
|
||||
{ id: 'a', url: 'https://a.test', headers: ' ' },
|
||||
{ id: 'b', url: 'https://b.test', headers: '{"X-Foo":"bar"}' }
|
||||
])
|
||||
);
|
||||
|
||||
// The parser trims headers and coerces empty/whitespace to undefined.
|
||||
expect(parsed[0]?.headers).toBeUndefined();
|
||||
expect(parsed[1]?.headers).toBe('{"X-Foo":"bar"}');
|
||||
});
|
||||
|
||||
it('defaults coercion for booleans (undefined -> false, true -> true)', () => {
|
||||
const parsed = parseMcpServerSettings(
|
||||
JSON.stringify([
|
||||
{ id: 'a', url: 'https://a.test' },
|
||||
{ id: 'b', url: 'https://b.test', enabled: true },
|
||||
{ id: 'c', url: 'https://c.test', enabled: false },
|
||||
{ id: 'd', url: 'https://d.test', useProxy: true }
|
||||
])
|
||||
);
|
||||
|
||||
expect(parsed[0]?.enabled).toBe(false);
|
||||
expect(parsed[1]?.enabled).toBe(true);
|
||||
expect(parsed[2]?.enabled).toBe(false);
|
||||
expect(parsed[0]?.useProxy).toBe(false);
|
||||
expect(parsed[3]?.useProxy).toBe(true);
|
||||
});
|
||||
|
||||
it('preserves input order when mapping entries', () => {
|
||||
const source = [
|
||||
{ id: 'gamma', url: 'https://c.test' },
|
||||
{ id: 'alpha', url: 'https://a.test' },
|
||||
{ id: 'beta', url: 'https://b.test' }
|
||||
];
|
||||
|
||||
const parsed = parseMcpServerSettings(JSON.stringify(source));
|
||||
|
||||
expect(parsed.map((entry) => entry.id)).toEqual(['gamma', 'alpha', 'beta']);
|
||||
});
|
||||
|
||||
it('passes non-string raw input through the JSON-equality path', () => {
|
||||
const parsed = parseMcpServerSettings([
|
||||
{ id: 'a', url: 'https://a.test' },
|
||||
{ id: 'b', url: 'https://b.test', enabled: true }
|
||||
]);
|
||||
|
||||
expect(parsed).toHaveLength(2);
|
||||
expect(parsed[0]?.id).toBe('a');
|
||||
expect(parsed[1]?.enabled).toBe(true);
|
||||
});
|
||||
|
||||
it('coerces non-string url values to an empty string rather than throwing', () => {
|
||||
const parsed = parseMcpServerSettings(
|
||||
JSON.stringify([{ id: 'a', url: 42 }, { id: 'b' }, { id: 'c', url: 'https://c.test' }])
|
||||
);
|
||||
|
||||
expect(parsed[0]?.url).toBe('');
|
||||
expect(parsed[1]?.url).toBe('');
|
||||
expect(parsed[2]?.url).toBe('https://c.test');
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,90 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
import {
|
||||
RECOMMENDED_MCP_SERVER_IDS,
|
||||
RECOMMENDED_MCP_SERVERS
|
||||
} from '$lib/constants/recommended-mcp-servers';
|
||||
import { parseMcpServerSettings } from '$lib/utils/mcp';
|
||||
import { DEFAULT_MCP_CONFIG, MCP_SERVER_ID_PREFIX } from '$lib/constants/mcp';
|
||||
|
||||
/**
|
||||
* Tests for the predefined recommended MCP servers.
|
||||
*
|
||||
* These are surfaced to first-time users via
|
||||
* DialogMcpServerRecommendations and used as the default value of the MCP
|
||||
* servers setting, so a regression that breaks the round-trip through the
|
||||
* settings parser would silently break onboarding for new users.
|
||||
*/
|
||||
describe('RECOMMENDED_MCP_SERVERS', () => {
|
||||
it('lists at least one entry and uses stable, unique ids', () => {
|
||||
expect(RECOMMENDED_MCP_SERVERS.length).toBeGreaterThan(0);
|
||||
|
||||
const ids = RECOMMENDED_MCP_SERVERS.map((server) => server.id);
|
||||
expect(new Set(ids).size).toBe(ids.length);
|
||||
|
||||
for (const id of ids) {
|
||||
expect(id).toMatch(/^[a-z0-9-]+$/);
|
||||
expect(id.toLowerCase()).not.toContain(MCP_SERVER_ID_PREFIX.toLowerCase());
|
||||
}
|
||||
});
|
||||
|
||||
it('requires a name, description and url for every entry', () => {
|
||||
for (const server of RECOMMENDED_MCP_SERVERS) {
|
||||
expect(server.name?.trim().length ?? 0).toBeGreaterThan(0);
|
||||
expect(server.description.trim().length).toBeGreaterThan(0);
|
||||
expect(server.url.trim().length).toBeGreaterThan(0);
|
||||
expect(() => new URL(server.url)).not.toThrow();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('RECOMMENDED_MCP_SERVER_IDS', () => {
|
||||
it('matches the ids declared in RECOMMENDED_MCP_SERVERS', () => {
|
||||
expect(RECOMMENDED_MCP_SERVER_IDS.size).toBe(RECOMMENDED_MCP_SERVERS.length);
|
||||
|
||||
for (const server of RECOMMENDED_MCP_SERVERS) {
|
||||
expect(RECOMMENDED_MCP_SERVER_IDS.has(server.id)).toBe(true);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('recommended-mcp-servers default value', () => {
|
||||
it('round-trips cleanly through parseMcpServerSettings', () => {
|
||||
const serialized = JSON.stringify(RECOMMENDED_MCP_SERVERS);
|
||||
const parsed = parseMcpServerSettings(serialized);
|
||||
|
||||
expect(parsed).toHaveLength(RECOMMENDED_MCP_SERVERS.length);
|
||||
|
||||
for (let index = 0; index < RECOMMENDED_MCP_SERVERS.length; index++) {
|
||||
const source = RECOMMENDED_MCP_SERVERS[index];
|
||||
const entry = parsed[index];
|
||||
|
||||
expect(entry).toBeDefined();
|
||||
expect(entry?.id).toBe(source.id);
|
||||
expect(entry?.url).toBe(source.url);
|
||||
expect(entry?.enabled).toBe(source.enabled);
|
||||
expect(entry?.requestTimeoutSeconds).toBe(source.requestTimeoutSeconds);
|
||||
expect(entry?.name).toBe(source.name);
|
||||
|
||||
// Headers and useProxy are not set on recommended servers; the
|
||||
// parser must fall back to the inactive defaults rather than
|
||||
// surfacing undefined-boundary states.
|
||||
expect(entry?.headers).toBeUndefined();
|
||||
expect(entry?.useProxy).toBe(false);
|
||||
}
|
||||
});
|
||||
|
||||
it('uses the global default timeout when one is not specified on an entry', () => {
|
||||
const sourceOnlyRequired = {
|
||||
id: 'roundtrip-only',
|
||||
name: 'Only required fields',
|
||||
url: 'https://example.test/mcp',
|
||||
description: 'Smoke entry for parser roundtrip with default timeout.',
|
||||
enabled: true
|
||||
};
|
||||
|
||||
const parsed = parseMcpServerSettings(JSON.stringify([sourceOnlyRequired]));
|
||||
const entry = parsed[0];
|
||||
|
||||
expect(entry?.requestTimeoutSeconds).toBe(DEFAULT_MCP_CONFIG.requestTimeoutSeconds);
|
||||
});
|
||||
});
|
||||
Vendored
+127
-35
@@ -478,7 +478,7 @@ bool set_socket_opt_time(socket_t sock, int level, int optname,
|
||||
}
|
||||
|
||||
bool is_hex(char c, int &v) {
|
||||
if (isdigit(static_cast<unsigned char>(c))) {
|
||||
if (is_ascii_digit(c)) {
|
||||
v = c - '0';
|
||||
return true;
|
||||
} else if ('A' <= c && c <= 'F') {
|
||||
@@ -695,7 +695,11 @@ std::string base64_encode(const std::string &in) {
|
||||
std::string out;
|
||||
out.reserve(in.size());
|
||||
|
||||
auto val = 0;
|
||||
// Unsigned: the accumulator is never masked, so with a signed int the
|
||||
// `val << 8` below overflows once enough bytes are folded in (undefined
|
||||
// behaviour before C++20). Only the low bits are ever emitted, so the
|
||||
// wrap-around of an unsigned accumulator does not affect the output.
|
||||
uint32_t val = 0;
|
||||
auto valb = -6;
|
||||
|
||||
for (auto c : in) {
|
||||
@@ -3887,8 +3891,7 @@ bool parse_range_header(const std::string &s, Ranges &ranges) {
|
||||
bool parse_range_header(const std::string &s, Ranges &ranges) try {
|
||||
#endif
|
||||
auto is_valid = [](const std::string &str) {
|
||||
return std::all_of(str.cbegin(), str.cend(),
|
||||
[](unsigned char c) { return std::isdigit(c); });
|
||||
return std::all_of(str.cbegin(), str.cend(), is_ascii_digit);
|
||||
};
|
||||
|
||||
if (s.size() > 7 && s.compare(0, 6, "bytes=") == 0) {
|
||||
@@ -4336,7 +4339,7 @@ bool is_multipart_boundary_chars_valid(const std::string &boundary) {
|
||||
auto valid = true;
|
||||
for (size_t i = 0; i < boundary.size(); i++) {
|
||||
auto c = boundary[i];
|
||||
if (!std::isalnum(static_cast<unsigned char>(c)) && c != '-' && c != '_') {
|
||||
if (!is_ascii_alnum(c) && c != '-' && c != '_') {
|
||||
valid = false;
|
||||
break;
|
||||
}
|
||||
@@ -4344,18 +4347,47 @@ bool is_multipart_boundary_chars_valid(const std::string &boundary) {
|
||||
return valid;
|
||||
}
|
||||
|
||||
// Escape a multipart field name/filename following the WHATWG HTML standard
|
||||
// ("escape a multipart form-data name"), which is what browsers send:
|
||||
// '"' -> %22, CR -> %0D, LF -> %0A
|
||||
// With escape_quote = false, only CR and LF are escaped; this is for header
|
||||
// values outside a quoted-string (e.g. Content-Type), where '"' is legal.
|
||||
std::string escape_multipart_field(const std::string &s,
|
||||
bool escape_quote = true) {
|
||||
std::string result;
|
||||
result.reserve(s.size());
|
||||
for (auto c : s) {
|
||||
switch (c) {
|
||||
case '"':
|
||||
if (escape_quote) {
|
||||
result += "%22";
|
||||
} else {
|
||||
result += c;
|
||||
}
|
||||
break;
|
||||
case '\r': result += "%0D"; break;
|
||||
case '\n': result += "%0A"; break;
|
||||
default: result += c; break;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::string
|
||||
serialize_multipart_formdata_item_begin(const T &item,
|
||||
const std::string &boundary) {
|
||||
std::string body = "--" + boundary + "\r\n";
|
||||
body += "Content-Disposition: form-data; name=\"" + item.name + "\"";
|
||||
body += "Content-Disposition: form-data; name=\"" +
|
||||
escape_multipart_field(item.name) + "\"";
|
||||
if (!item.filename.empty()) {
|
||||
body += "; filename=\"" + item.filename + "\"";
|
||||
body += "; filename=\"" + escape_multipart_field(item.filename) + "\"";
|
||||
}
|
||||
body += "\r\n";
|
||||
if (!item.content_type.empty()) {
|
||||
body += "Content-Type: " + item.content_type + "\r\n";
|
||||
body +=
|
||||
"Content-Type: " + escape_multipart_field(item.content_type, false) +
|
||||
"\r\n";
|
||||
}
|
||||
body += "\r\n";
|
||||
|
||||
@@ -4821,10 +4853,9 @@ private:
|
||||
namespace fields {
|
||||
|
||||
bool is_token_char(char c) {
|
||||
return std::isalnum(static_cast<unsigned char>(c)) || c == '!' || c == '#' ||
|
||||
c == '$' || c == '%' || c == '&' || c == '\'' || c == '*' ||
|
||||
c == '+' || c == '-' || c == '.' || c == '^' || c == '_' || c == '`' ||
|
||||
c == '|' || c == '~';
|
||||
return is_ascii_alnum(c) || c == '!' || c == '#' || c == '$' || c == '%' ||
|
||||
c == '&' || c == '\'' || c == '*' || c == '+' || c == '-' ||
|
||||
c == '.' || c == '^' || c == '_' || c == '`' || c == '|' || c == '~';
|
||||
}
|
||||
|
||||
bool is_token(const std::string &s) {
|
||||
@@ -4873,7 +4904,8 @@ bool is_field_value(const std::string &s) { return is_field_content(s); }
|
||||
} // namespace fields
|
||||
|
||||
bool perform_websocket_handshake(Stream &strm, const std::string &host,
|
||||
int port, const std::string &path,
|
||||
int port, bool is_ssl,
|
||||
const std::string &path,
|
||||
const Headers &headers,
|
||||
std::string &selected_subprotocol) {
|
||||
// Validate path and host
|
||||
@@ -4899,7 +4931,7 @@ bool perform_websocket_handshake(Stream &strm, const std::string &host,
|
||||
|
||||
// Build upgrade request
|
||||
std::string req_str = "GET " + path + " HTTP/1.1\r\n";
|
||||
req_str += "Host: " + host + ":" + std::to_string(port) + "\r\n";
|
||||
req_str += "Host: " + make_host_and_port_string(host, port, is_ssl) + "\r\n";
|
||||
req_str += "Upgrade: websocket\r\n";
|
||||
req_str += "Connection: Upgrade\r\n";
|
||||
req_str += "Sec-WebSocket-Key: " + client_key + "\r\n";
|
||||
@@ -5599,9 +5631,8 @@ std::string encode_uri_component(const std::string &value) {
|
||||
escaped << std::hex;
|
||||
|
||||
for (auto c : value) {
|
||||
if (std::isalnum(static_cast<uint8_t>(c)) || c == '-' || c == '_' ||
|
||||
c == '.' || c == '!' || c == '~' || c == '*' || c == '\'' || c == '(' ||
|
||||
c == ')') {
|
||||
if (detail::is_ascii_alnum(c) || c == '-' || c == '_' || c == '.' ||
|
||||
c == '!' || c == '~' || c == '*' || c == '\'' || c == '(' || c == ')') {
|
||||
escaped << c;
|
||||
} else {
|
||||
escaped << std::uppercase;
|
||||
@@ -5620,10 +5651,10 @@ std::string encode_uri(const std::string &value) {
|
||||
escaped << std::hex;
|
||||
|
||||
for (auto c : value) {
|
||||
if (std::isalnum(static_cast<uint8_t>(c)) || c == '-' || c == '_' ||
|
||||
c == '.' || c == '!' || c == '~' || c == '*' || c == '\'' || c == '(' ||
|
||||
c == ')' || c == ';' || c == '/' || c == '?' || c == ':' || c == '@' ||
|
||||
c == '&' || c == '=' || c == '+' || c == '$' || c == ',' || c == '#') {
|
||||
if (detail::is_ascii_alnum(c) || c == '-' || c == '_' || c == '.' ||
|
||||
c == '!' || c == '~' || c == '*' || c == '\'' || c == '(' || c == ')' ||
|
||||
c == ';' || c == '/' || c == '?' || c == ':' || c == '@' || c == '&' ||
|
||||
c == '=' || c == '+' || c == '$' || c == ',' || c == '#') {
|
||||
escaped << c;
|
||||
} else {
|
||||
escaped << std::uppercase;
|
||||
@@ -5684,7 +5715,8 @@ std::string encode_path_component(const std::string &component) {
|
||||
auto c = static_cast<unsigned char>(component[i]);
|
||||
|
||||
// Unreserved characters per RFC 3986: ALPHA / DIGIT / "-" / "." / "_" / "~"
|
||||
if (std::isalnum(c) || c == '-' || c == '.' || c == '_' || c == '~') {
|
||||
if (detail::is_ascii_alnum(static_cast<char>(c)) || c == '-' || c == '.' ||
|
||||
c == '_' || c == '~') {
|
||||
result += static_cast<char>(c);
|
||||
}
|
||||
// Path-safe sub-delimiters: "!" / "$" / "&" / "'" / "(" / ")" / "*" / "+" /
|
||||
@@ -5757,7 +5789,8 @@ std::string encode_query_component(const std::string &component,
|
||||
auto c = static_cast<unsigned char>(component[i]);
|
||||
|
||||
// Unreserved characters per RFC 3986
|
||||
if (std::isalnum(c) || c == '-' || c == '.' || c == '_' || c == '~') {
|
||||
if (detail::is_ascii_alnum(static_cast<char>(c)) || c == '-' || c == '.' ||
|
||||
c == '_' || c == '~') {
|
||||
result += static_cast<char>(c);
|
||||
}
|
||||
// Space handling
|
||||
@@ -6010,6 +6043,48 @@ size_t MultipartFormData::get_file_count(const std::string &key) const {
|
||||
return static_cast<size_t>(std::distance(r.first, r.second));
|
||||
}
|
||||
|
||||
// Multipart FormData writer implementation
|
||||
bool is_valid_multipart_boundary(const std::string &boundary) {
|
||||
return detail::is_multipart_boundary_chars_valid(boundary);
|
||||
}
|
||||
|
||||
MultipartFormDataWriter::MultipartFormDataWriter()
|
||||
: boundary_(detail::make_multipart_data_boundary()) {}
|
||||
|
||||
MultipartFormDataWriter::MultipartFormDataWriter(std::string boundary)
|
||||
: boundary_(std::move(boundary)) {}
|
||||
|
||||
const std::string &MultipartFormDataWriter::boundary() const {
|
||||
return boundary_;
|
||||
}
|
||||
|
||||
std::string MultipartFormDataWriter::content_type() const {
|
||||
return detail::serialize_multipart_formdata_get_content_type(boundary_);
|
||||
}
|
||||
|
||||
std::string
|
||||
MultipartFormDataWriter::serialize(const UploadFormDataItems &items) const {
|
||||
return detail::serialize_multipart_formdata(items, boundary_);
|
||||
}
|
||||
|
||||
size_t MultipartFormDataWriter::content_length(
|
||||
const UploadFormDataItems &items) const {
|
||||
return detail::get_multipart_content_length(items, boundary_);
|
||||
}
|
||||
|
||||
std::string
|
||||
MultipartFormDataWriter::item_begin(const UploadFormData &item) const {
|
||||
return detail::serialize_multipart_formdata_item_begin(item, boundary_);
|
||||
}
|
||||
|
||||
std::string MultipartFormDataWriter::item_end() {
|
||||
return detail::serialize_multipart_formdata_item_end();
|
||||
}
|
||||
|
||||
std::string MultipartFormDataWriter::finish() const {
|
||||
return detail::serialize_multipart_formdata_finish(boundary_);
|
||||
}
|
||||
|
||||
// Response implementation
|
||||
size_t Response::get_header_value_u64(const std::string &key, size_t def,
|
||||
size_t id) const {
|
||||
@@ -6229,8 +6304,10 @@ ssize_t detail::BodyReader::read(char *buf, size_t len) {
|
||||
}
|
||||
|
||||
// ThreadPool implementation
|
||||
ThreadPool::ThreadPool(size_t n, size_t max_n, size_t mqr)
|
||||
: base_thread_count_(n), max_queued_requests_(mqr), idle_thread_count_(0),
|
||||
ThreadPool::ThreadPool(size_t n, size_t max_n, size_t mqr,
|
||||
time_t idle_timeout_sec)
|
||||
: base_thread_count_(n), max_queued_requests_(mqr),
|
||||
idle_timeout_sec_(idle_timeout_sec), idle_thread_count_(0),
|
||||
shutdown_(false) {
|
||||
#ifndef CPPHTTPLIB_NO_EXCEPTIONS
|
||||
if (max_n != 0 && max_n < n) {
|
||||
@@ -6340,9 +6417,9 @@ void ThreadPool::worker(bool is_dynamic) {
|
||||
idle_thread_count_++;
|
||||
|
||||
if (is_dynamic) {
|
||||
auto has_work = cond_.wait_for(
|
||||
lock, std::chrono::seconds(CPPHTTPLIB_THREAD_POOL_IDLE_TIMEOUT),
|
||||
[&] { return !jobs_.empty() || shutdown_; });
|
||||
auto has_work =
|
||||
cond_.wait_for(lock, std::chrono::seconds(idle_timeout_sec_),
|
||||
[&] { return !jobs_.empty() || shutdown_; });
|
||||
if (!has_work) {
|
||||
// Timed out with no work - exit this dynamic thread
|
||||
idle_thread_count_--;
|
||||
@@ -9687,9 +9764,18 @@ bool ClientImpl::write_request(Stream &strm, Request &req,
|
||||
|
||||
if (!query_part.empty()) {
|
||||
// Normalize the query string (decode then re-encode) while preserving
|
||||
// the original parameter order.
|
||||
auto normalized = detail::normalize_query_string(query_part);
|
||||
if (!normalized.empty()) { path_with_query += '?' + normalized; }
|
||||
// the original parameter order. When path encoding is disabled the
|
||||
// caller has supplied an already-encoded target and expects the exact
|
||||
// bytes to be sent on the wire, so skip normalization for the query
|
||||
// too. Normalizing here would decode-then-re-encode the query and
|
||||
// corrupt pre-encoded binary payloads (e.g. turning `%20` into `+`,
|
||||
// which a strict RFC 3986 server decodes back as `+`, not a space).
|
||||
if (path_encode_) {
|
||||
auto normalized = detail::normalize_query_string(query_part);
|
||||
if (!normalized.empty()) { path_with_query += '?' + normalized; }
|
||||
} else {
|
||||
path_with_query += '?' + query_part;
|
||||
}
|
||||
|
||||
// Still populate req.params for handlers/users who read them.
|
||||
detail::parse_query_text(query_part, req.params);
|
||||
@@ -12518,7 +12604,7 @@ bool is_ipv4_address(const std::string &str) {
|
||||
for (char c : str) {
|
||||
if (c == '.') {
|
||||
dots++;
|
||||
} else if (!isdigit(static_cast<unsigned char>(c))) {
|
||||
} else if (!detail::is_ascii_digit(c)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -12535,7 +12621,7 @@ bool parse_ipv4(const std::string &str, unsigned char *out) {
|
||||
}
|
||||
int val = 0;
|
||||
int digits = 0;
|
||||
while (*p >= '0' && *p <= '9') {
|
||||
while (detail::is_ascii_digit(*p)) {
|
||||
val = val * 10 + (*p - '0');
|
||||
if (val > 255) { return false; }
|
||||
p++;
|
||||
@@ -16487,9 +16573,15 @@ bool WebSocketClient::connect() {
|
||||
return false;
|
||||
}
|
||||
|
||||
#ifdef CPPHTTPLIB_SSL_ENABLED
|
||||
auto is_ssl = is_ssl_;
|
||||
#else
|
||||
auto is_ssl = false;
|
||||
#endif
|
||||
|
||||
std::string selected_subprotocol;
|
||||
if (!detail::perform_websocket_handshake(*strm, host_, port_, path_, headers_,
|
||||
selected_subprotocol)) {
|
||||
if (!detail::perform_websocket_handshake(*strm, host_, port_, is_ssl, path_,
|
||||
headers_, selected_subprotocol)) {
|
||||
shutdown_and_close();
|
||||
return false;
|
||||
}
|
||||
|
||||
Vendored
+56
-12
@@ -8,8 +8,8 @@
|
||||
#ifndef CPPHTTPLIB_HTTPLIB_H
|
||||
#define CPPHTTPLIB_HTTPLIB_H
|
||||
|
||||
#define CPPHTTPLIB_VERSION "0.48.0"
|
||||
#define CPPHTTPLIB_VERSION_NUM "0x003000"
|
||||
#define CPPHTTPLIB_VERSION "0.49.0"
|
||||
#define CPPHTTPLIB_VERSION_NUM "0x003100"
|
||||
|
||||
#ifdef _WIN32
|
||||
#if defined(_WIN32_WINNT) && _WIN32_WINNT < 0x0A00
|
||||
@@ -309,7 +309,6 @@ using socket_t = int;
|
||||
#include <array>
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
#include <cctype>
|
||||
#include <chrono>
|
||||
#include <climits>
|
||||
#include <condition_variable>
|
||||
@@ -540,6 +539,21 @@ make_unique(std::size_t n) {
|
||||
return std::unique_ptr<T>(new RT[n]);
|
||||
}
|
||||
|
||||
// Locale-independent ASCII character classification. The <cctype>
|
||||
// counterparts (std::isalnum, std::isdigit, ...) consult the global C locale,
|
||||
// so e.g. std::isalnum(0xC5) can return true once an embedder calls
|
||||
// setlocale(). HTTP grammars are defined over ASCII, so raw bytes must be
|
||||
// classified without regard to the locale.
|
||||
inline bool is_ascii_digit(char c) { return '0' <= c && c <= '9'; }
|
||||
|
||||
inline bool is_ascii_alpha(char c) {
|
||||
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z');
|
||||
}
|
||||
|
||||
inline bool is_ascii_alnum(char c) {
|
||||
return is_ascii_digit(c) || is_ascii_alpha(c);
|
||||
}
|
||||
|
||||
namespace case_ignore {
|
||||
|
||||
inline unsigned char to_lower(int c) {
|
||||
@@ -661,7 +675,7 @@ inline from_chars_result<T> from_chars(const char *first, const char *last,
|
||||
for (; p != last; ++p) {
|
||||
char c = *p;
|
||||
int digit = -1;
|
||||
if ('0' <= c && c <= '9') {
|
||||
if (is_ascii_digit(c)) {
|
||||
digit = c - '0';
|
||||
} else if ('a' <= c && c <= 'z') {
|
||||
digit = c - 'a' + 10;
|
||||
@@ -733,14 +747,14 @@ inline from_chars_result<double> from_chars(const char *first, const char *last,
|
||||
return false;
|
||||
};
|
||||
|
||||
for (; p != last && '0' <= *p && *p <= '9'; ++p) {
|
||||
for (; p != last && is_ascii_digit(*p); ++p) {
|
||||
seen_digit = true;
|
||||
accumulate(*p);
|
||||
}
|
||||
|
||||
if (p != last && *p == '.') {
|
||||
++p;
|
||||
for (; p != last && '0' <= *p && *p <= '9'; ++p) {
|
||||
for (; p != last && is_ascii_digit(*p); ++p) {
|
||||
seen_digit = true;
|
||||
if (frac_digits < max_frac_digits && accumulate(*p)) { ++frac_digits; }
|
||||
}
|
||||
@@ -803,8 +817,8 @@ inline bool parse_url(const std::string &url, UrlComponents &uc) {
|
||||
// IPv6 host must be [a-fA-F0-9:]+ only
|
||||
if (uc.host.empty()) { return false; }
|
||||
for (auto c : uc.host) {
|
||||
if (!((c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F') ||
|
||||
(c >= '0' && c <= '9') || c == ':')) {
|
||||
if (!(is_ascii_digit(c) || (c >= 'a' && c <= 'f') ||
|
||||
(c >= 'A' && c <= 'F') || c == ':')) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -1541,7 +1555,9 @@ public:
|
||||
|
||||
class ThreadPool final : public TaskQueue {
|
||||
public:
|
||||
explicit ThreadPool(size_t n, size_t max_n = 0, size_t mqr = 0);
|
||||
explicit ThreadPool(
|
||||
size_t n, size_t max_n = 0, size_t mqr = 0,
|
||||
time_t idle_timeout_sec = CPPHTTPLIB_THREAD_POOL_IDLE_TIMEOUT);
|
||||
ThreadPool(const ThreadPool &) = delete;
|
||||
~ThreadPool() override = default;
|
||||
|
||||
@@ -1556,6 +1572,7 @@ private:
|
||||
size_t base_thread_count_;
|
||||
size_t max_thread_count_;
|
||||
size_t max_queued_requests_;
|
||||
time_t idle_timeout_sec_;
|
||||
size_t idle_thread_count_;
|
||||
|
||||
bool shutdown_;
|
||||
@@ -1680,6 +1697,35 @@ make_multipart_content_provider(const UploadFormDataItems &items,
|
||||
|
||||
} // namespace detail
|
||||
|
||||
bool is_valid_multipart_boundary(const std::string &boundary);
|
||||
|
||||
// Serializer for multipart/form-data request bodies. The boundary is owned
|
||||
// by the writer so that per-part framing and the final terminator always
|
||||
// agree. Field names and filenames are escaped following the WHATWG HTML
|
||||
// standard ('"' -> %22, CR -> %0D, LF -> %0A); CR and LF are also escaped
|
||||
// in content types.
|
||||
class MultipartFormDataWriter {
|
||||
public:
|
||||
MultipartFormDataWriter();
|
||||
// precondition: is_valid_multipart_boundary(boundary)
|
||||
explicit MultipartFormDataWriter(std::string boundary);
|
||||
|
||||
const std::string &boundary() const;
|
||||
std::string content_type() const;
|
||||
|
||||
// In-memory items -> whole body (known length)
|
||||
std::string serialize(const UploadFormDataItems &items) const;
|
||||
size_t content_length(const UploadFormDataItems &items) const;
|
||||
|
||||
// Per-part framing for streaming via a content provider
|
||||
std::string item_begin(const UploadFormData &item) const;
|
||||
static std::string item_end();
|
||||
std::string finish() const;
|
||||
|
||||
private:
|
||||
std::string boundary_;
|
||||
};
|
||||
|
||||
class Server {
|
||||
public:
|
||||
using Handler = std::function<void(const Request &, Response &)>;
|
||||
@@ -2897,9 +2943,7 @@ template <size_t N> inline constexpr size_t str_len(const char (&)[N]) {
|
||||
}
|
||||
|
||||
inline bool is_numeric(const std::string &str) {
|
||||
return !str.empty() &&
|
||||
std::all_of(str.cbegin(), str.cend(),
|
||||
[](unsigned char c) { return std::isdigit(c); });
|
||||
return !str.empty() && std::all_of(str.cbegin(), str.cend(), is_ascii_digit);
|
||||
}
|
||||
|
||||
inline size_t get_header_value_u64(const Headers &headers,
|
||||
|
||||
Reference in New Issue
Block a user