Compare commits

...

15 Commits

Author SHA1 Message Date
lhez 4fc4ec5541 opencl: allow loading precompiled binary kernels from library (#23042)
* opencl: allow loading binary kernel

* opencl: add libdl.h

* ggml-backend-dl is in ggml, which depends backend libs, thus
  ggml-opencl cannot depend on ggml-backend-dl
* add libdl.h to break cyclic dep

* opencl: allow loading bin kernel lib

* opencl: load `gemm_moe_mxfp4_f32_ns` from kernel lib if available

* opencl: load q8_0 gemm from kernel lib

* opencl: load q4_0 moe gemm from kernel lib

* opencl: load q4_1 moe gemm from kernel lib

* opencl: load q4_k moe gemm from kernel lib

* opencl: always declare `get_adreno_bin_kernel_func_t`

* opencl: rephrase message

* opencl: fix for rebase

* opencl: update doc
2026-07-01 10:29:22 -07:00
Adrien Gallouët a6647b1a32 common : use hf primary split as model path (#25194)
Fixes #25181
2026-07-01 18:33:00 +02:00
Max Krasnyansky 13e673863b hexagon: flash attention rework (optimizations, accuracy improvements, etc) (#25085)
* hex-mm: fold mm quant tasks into the main matmul threads

* hex-mm: minor formatting fixes

* hex-mm: cleanup is_quant checks in dma dispatch

* hex-mm: fix dst-spad alignment

* hex-mm: move fp kernels in the hvx-mm-kernels header

* hex-mm: fuse with ADD

* hex-fa: factor out ukernels into separate headers and unify the rest

* hex-fa: move kernel-params compute into the host

* hex-fa: refactor vtcm alloc for consistency

* hex-fa: add support for FA_SELECT

* hex-fa: update tracing insrumentation to cover all functions

* hex-fa: update hvx fallback thresholds to recover t/g regressions

* hex-fa: update tracing instrumentation

* hex-fa: improved tracing with additional events

* hex-fa: optimize mask processing (fastdiv, etc)

* hex-fa: improve mask dma caching

* hmx-fa: change loop order to maximize mask cache hits

* hex-fa: remove over instrumentation

* hex-fa: breakdown QKV prep trace events

* hmx-fa: further mask proc optimizations

* hex-fa: mask broadcast is the common case, optimize for that

* hex-fa: use aligned loads where possible

* hex-fa: update loops to use uint32_t indices

* hmx-fa: fold vtcm init into q prep task

* hex-fa: update rest of the hmx funcs to use uint32_t

* hmx-fa: fold build_d into the main softmax loop

* hmx-fa: start kv dmas earlier

* hmx-fa: start mask dma a bit earlier

* hex-fa: precompute rows per task to avoid divs

* hmx-fa: specialize fa_o_store for f16 and f32

* hmx-fa: prelim support for Sinks

* hmx-fa: keep softmax accumulators in fp32

* hex-fa: add tanh_f16 and exp2_f16 and use that in FA

* hex-fa: use fp16 math in the hvx kernel

* hex-fa: avoid expensive float -> __fp16 cast for slopes and softcap

* hex-fa: replace most vec_exp_f32 with vec_exp2_f16

* hmx-fa: vectorize sinks update

* hex-fa: minor formatting

* hmx-fa: fold softcap loop into the tile load

* hmx-fa: use vectoralias to populate sinks

* hex-fa: remove redudant check

* hex-fa: fix vtcm size compute to use fp32 for accumulators

* hex-mm: fix trailing spaces

* hmx-fa: dont use -inf to init mask to avoid conversion overflows

* hex-fa: no need to explicitly guard -inf in the f16->f32 converter now

* hmx-fa: cleanup fa sinks handling

* hex-mm: fixed src2 stride handling when mm is fused with add

* hex-fa: make lto happy
2026-07-01 06:59:19 -07:00
Johannes Gäßler b820cc8e6f CUDA: consistent use of __restrict__ + PDL for FA (#25185) 2026-07-01 10:55:14 +02:00
ragz4125 6dbc1174b8 ggml-cpu: add AVX2 optimization for nvfp4 dot product and use UE4M3 LUT (#23961) 2026-07-01 15:31:20 +08:00
Aleksander Grygier 9d88e7cedd ui Prevent tool messages from incorrectly appending to other conversations (#25177)
* fix: Prevent tool messages from incorrectly appending to other conversations

* ui: prevent agentic loop from poisoning another conv's currNode

* ui: make editedContent a  so background recompute does not wipe in-progress edits

---------

Co-authored-by: Pascal <admin@serveurperso.com>
2026-07-01 09:25:18 +02:00
Aleksander Grygier 7af4279f45 ui: Remove PWA navigate fallback to prevent caching API endpoint requests (#25174) 2026-07-01 07:32:55 +02:00
lhez fd1a05791d opencl: initial q1_0 support (#25160)
* opencl: general q1_0 support

* opencl: add Adreno GEMM/GEMV for q1_0
2026-06-30 21:43:20 -07:00
fairydreaming 0eca4d490e cuda : prevent integer truncation and overflow errors when using KQ mask strides in flash_attn_mask_to_KV_max kernel (#24945)
Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
2026-06-30 20:47:05 +02:00
Jürgen Schmied 4f31eedb0c model : register t_layer_inp for qwen3next (#25141)
* Fix input assignment in layer processing loop

Fix DFLASH for qwen-coder-next

* add line break

Added tensor for attention normalization in Qwen3 model.
2026-06-30 17:57:14 +02:00
Pascal 799fcc04a5 common,server: handle bracketed IPv6 literals in URL authority (#25140)
* common,server: handle bracketed IPv6 literals in URL authority

Parse the [host]:port form (RFC 3986) and bracket IPv6 hosts when
formatting a URL authority: listening log, proxy Host header, proxy
log, client rebuild. The per-request remote_addr stays bare.

* common: restore unsupported scheme throw in url parser

Address @ngxson review: keep the explicit reject in port resolution so
the block stays self-contained. Non-http(s) schemes still throw (also
gated at the top of common_http_parse_url).
2026-06-30 16:16:44 +02:00
Matt Jallo 931eb37f8c CUDA: fix get_rows_back for tables with more than 65535 rows (grid-y clamp + stride) (#25103) 2026-06-30 14:16:24 +02:00
Johannes Gäßler e495d1e748 CUDA: fix Gemma E4B MTP FlashAttention (#25148)
* CUDA: fix Gemma E4B MTP FlashAttention

* remove unused template declaration
2026-06-30 14:06:54 +02:00
Kevin Liu f708a5b2ca vulkan: roll bk loop in matmul for asahi linux (#24663)
* vulkan: roll bk loop in matmul for asahi linux

* vulkan: fix inline comment

* vulkan: revert BK-loop unroll change

* vulkan: edit spirv directly for asahi roll bk loop

* vulkan: remove trailing whitespace at the end of comments
2026-06-30 12:27:38 +02:00
zduford d9df11006f HIP: use hipBLAS for dense prefill on gfx900, keep MMQ for MoE (#24588)
* HIP: keep MMQ for gfx900 MoE and Q8_0, use hipBLAS for dense K-quants

Assisted-by: GitHub Copilot CLI

* HIP: tighten conditional block to be explicitly for gfx900

* HIP: Further simplified gfx900 conditional block

* removed unnecessary comment
2026-06-30 11:51:38 +02:00
63 changed files with 6003 additions and 3143 deletions
+10 -8
View File
@@ -496,13 +496,15 @@ void common_models_handler_apply(common_models_handler & handler, common_params
}
// handle hf_plan tasks
auto add_tasks = [&opts, &tasks](const hf_cache::hf_files & model_files, common_params_model & model) {
auto add_tasks = [&opts, &tasks](const hf_cache::hf_files & model_files,
const hf_cache::hf_file & primary,
common_params_model & model) {
for (size_t i = 0; i < model_files.size(); ++i) {
auto & model_file = model_files[i];
bool is_first = (i == 0);
tasks.emplace_back(model_file, opts, [&, is_first]() {
if (is_first) {
// only use first part as model path
bool is_primary = (model_file.path == primary.path);
tasks.emplace_back(model_file, opts, [&, is_primary]() {
if (is_primary) {
// the primary file is the first split (00001-of), use it as model path
model.path = hf_cache::finalize_file(model_file);
} else {
hf_cache::finalize_file(model_file);
@@ -511,7 +513,7 @@ void common_models_handler_apply(common_models_handler & handler, common_params
}
};
if (!plan.model_files.empty()) {
add_tasks(plan.model_files, params.model);
add_tasks(plan.model_files, plan.primary, params.model);
}
if (!plan.mmproj.local_path.empty()) {
tasks.emplace_back(plan.mmproj, opts, [&]() {
@@ -539,12 +541,12 @@ void common_models_handler_apply(common_models_handler & handler, common_params
// handle plan_spec (e.g. --spec-draft-hf)
if (!plan_spec.model_files.empty()) {
add_tasks(plan_spec.model_files, params.speculative.draft.mparams);
add_tasks(plan_spec.model_files, plan_spec.primary, params.speculative.draft.mparams);
}
// handle vocoder plan (e.g. --hf-repo-v)
if (!plan_voc.model_files.empty()) {
add_tasks(plan_voc.model_files, params.vocoder.model);
add_tasks(plan_voc.model_files, plan_voc.primary, params.vocoder.model);
}
// run all tasks in parallel
+28 -6
View File
@@ -11,6 +11,11 @@ struct common_http_url {
std::string path;
};
// bracket an IPv6 literal host for a URL authority (RFC 3986)
static std::string common_http_format_host(const std::string & host) {
return host.find(':') != std::string::npos ? "[" + host + "]" : host;
}
static common_http_url common_http_parse_url(const std::string & url) {
common_http_url parts;
auto scheme_end = url.find("://");
@@ -49,11 +54,28 @@ static common_http_url common_http_parse_url(const std::string & url) {
parts.path = "/";
}
auto colon_pos = parts.host.find(':');
// split the authority into host and optional port, a bracketed IPv6 literal keeps its inner colons (RFC 3986)
std::string port_str;
if (!parts.host.empty() && parts.host.front() == '[') {
auto close = parts.host.find(']');
if (close == std::string::npos) {
throw std::runtime_error("invalid IPv6 URL authority: " + parts.host);
}
auto after = parts.host.substr(close + 1);
if (!after.empty() && after.front() == ':') {
port_str = after.substr(1);
}
parts.host = parts.host.substr(1, close - 1);
} else {
auto colon_pos = parts.host.find(':');
if (colon_pos != std::string::npos) {
port_str = parts.host.substr(colon_pos + 1);
parts.host = parts.host.substr(0, colon_pos);
}
}
if (colon_pos != std::string::npos) {
parts.port = std::stoi(parts.host.substr(colon_pos + 1));
parts.host = parts.host.substr(0, colon_pos);
if (!port_str.empty()) {
parts.port = std::stoi(port_str);
} else if (parts.scheme == "http") {
parts.port = 80;
} else if (parts.scheme == "https") {
@@ -83,7 +105,7 @@ static std::pair<httplib::Client, common_http_url> common_http_client(const std:
}
#endif
httplib::Client cli(parts.scheme + "://" + parts.host + ":" + std::to_string(parts.port));
httplib::Client cli(parts.scheme + "://" + common_http_format_host(parts.host) + ":" + std::to_string(parts.port));
if (!parts.user.empty()) {
cli.set_basic_auth(parts.user, parts.password);
@@ -95,5 +117,5 @@ static std::pair<httplib::Client, common_http_url> common_http_client(const std:
}
static std::string common_http_show_masked_url(const common_http_url & parts) {
return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + parts.host + parts.path;
return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + common_http_format_host(parts.host) + parts.path;
}
+51 -39
View File
@@ -1,16 +1,26 @@
# llama.cpp for OpenCL
- [Background](#background)
- [OS](#os)
- [Hardware](#hardware)
- [DataType Supports](#datatype-supports)
- [Model Preparation](#model-preparation)
- [CMake Options](#cmake-options)
- [Android](#android)
- [Windows 11 Arm64](#windows-11-arm64)
- [Linux](#Linux)
- [Known Issue](#known-issues)
- [TODO](#todo)
- [llama.cpp for OpenCL](#llamacpp-for-opencl)
- [Background](#background)
- [Llama.cpp + OpenCL](#llamacpp--opencl)
- [OS](#os)
- [Hardware](#hardware)
- [Adreno GPU](#adreno-gpu)
- [DataType Supports](#datatype-supports)
- [Model Preparation](#model-preparation)
- [Binary Kernel Library](#binary-kernel-library)
- [CMake Options](#cmake-options)
- [Android](#android)
- [I. Setup Environment](#i-setup-environment)
- [II. Build llama.cpp](#ii-build-llamacpp)
- [Windows 11 Arm64](#windows-11-arm64)
- [I. Setup Environment](#i-setup-environment-1)
- [II. Build llama.cpp](#ii-build-llamacpp-1)
- [Linux](#linux)
- [I. Setup Environment](#i-setup-environment-2)
- [II. Build llama.cpp](#ii-build-llamacpp-2)
- [Known Issues](#known-issues)
- [TODO](#todo)
## Background
@@ -34,11 +44,13 @@ The llama.cpp OpenCL backend is designed to enable llama.cpp on **Qualcomm Adren
**Verified devices**
| Adreno GPU | Status |
|:------------------------------------:|:-------:|
| Adreno 750 (Snapdragon 8 Gen 3) | Support |
| Adreno 830 (Snapdragon 8 Elite) | Support |
| Adreno X85 (Snapdragon X Elite) | Support |
| Adreno GPU | Status |
|:-------------------------------------:|:-------:|
| Adreno 750 (Snapdragon 8 Gen 3) | Support |
| Adreno 830 (Snapdragon 8 Elite) | Support |
| Adreno 840 (Snapdragon 8 Elite Gen 5) | Support |
| Adreno X1-85 (Snapdragon X Elite) | Support |
| Adreno X2-90 (Snapdragon X2 Elite) | Support |
> A6x GPUs with a recent driver and compiler are supported; they are usually found in IoT platforms.
However, A6x GPUs in phones are likely not supported due to the outdated driver and compiler.
@@ -47,42 +59,43 @@ However, A6x GPUs in phones are likely not supported due to the outdated driver
| DataType | Status |
|:----------------------:|:--------------------------:|
| Q1_0 | Support |
| Q4_0 | Support |
| Q6_K | Support, but not optimized |
| Q4_1 | Support |
| Q5_0 | Support |
| Q5_1 | Support |
| Q8_0 | Support |
| Q4_K | Support |
| Q5_K | Support |
| Q6_K | Support |
| MXFP4 | Support |
| IQ4_NL | Support |
## Model Preparation
You can refer to the general [llama-quantize tool](/tools/quantize/README.md) for steps to convert a model in Hugging Face safetensor format to GGUF with quantization.
Since common quantizations are supported now, it is recommanded to download GGUF models directly from Huggingface.
Currently we support `Q4_0` quantization and have optimized for it. To achieve best performance on Adreno GPU, add `--pure` to `llama-quantize` (i.e., make all weights in `Q4_0`). For example,
## Binary Kernel Library
```sh
./llama-quantize --pure ggml-model-qwen2.5-3b-f16.gguf ggml-model-qwen-3b-Q4_0.gguf Q4_0
```
A prebuilt binary kernel library has been introduced for Adreno GPUs.
It currently targets X2 GPUs (X2-90, X2-85 and X2-45) found in Snapdragon X2 SoC.
The library currently contains kernels for MUL_MAT_ID with Q4_0, Q4_1, Q4_K, MXFP4.
The library must be manually downloaded from https://softwarecenter.qualcomm.com/catalog/item/Adreno_Kernel_Library_GGML.
Since `Q6_K` is also supported, `Q4_0` quantization without `--pure` will also work. However, the performance will be worse compared to pure `Q4_0` quantization.
To allow using the kernel library, add `-DGGML_OPENCL_USE_ADRENO_BIN_KERNELS=ON` when configuring with CMake.
Then, extract `adreno-opencl-kernels.dll` from the zip file downloaded from the above URL and put it alongside the executables.
If kernels compatible with the current GPU are found in the library, they will be loaded and used.
### `MXFP4` MoE Models
OpenAI gpt-oss models are MoE models in `MXFP4`. The quantized model will be in `MXFP4_MOE`, a mixture of `MXFP4` and `Q8_0`.
For this quantization, there is no need to specify `--pure`.
For gpt-oss-20b model, you can directly [download](https://huggingface.co/ggml-org/gpt-oss-20b-GGUF) the quantized GGUF file in `MXFP4_MOE` from Hugging Face.
Although it is possible to quantize gpt-oss-20b model in pure `Q4_0` (all weights in `Q4_0`), it is not recommended since `MXFP4` has been optimized for MoE while `Q4_0` is not. In addition, accuracy should degrade with such pure `Q4_0` quantization.
Hence, using the default `MXFP4_MOE` quantization (see the link above) is recommended for this model.
> Note that the `Q4_0` model found [here](https://huggingface.co/unsloth/gpt-oss-20b-GGUF/blob/main/gpt-oss-20b-Q4_0.gguf) is a mixture of `Q4_0`, `Q8_0` and `MXFP4` and gives better performance than `MXFP4_MOE` quantization.
## CMake Options
The OpenCL backend has the following CMake options that control the behavior of the backend.
| CMake options | Default value | Description |
|:---------------------------------:|:--------------:|:------------------------------------------|
| `GGML_OPENCL_EMBED_KERNELS` | `ON` | Embed OpenCL kernels into the executable. |
| `GGML_OPENCL_USE_ADRENO_KERNELS` | `ON` | Use kernels optimized for Adreno. |
| CMake options | Default value | Description |
|:------------------------------------:|:--------------:|:------------------------------------------|
| `GGML_OPENCL_EMBED_KERNELS` | `ON` | Embed OpenCL kernels into the executable. |
| `GGML_OPENCL_USE_ADRENO_KERNELS` | `ON` | Use kernels optimized for Adreno. |
| `GGML_OPENCL_USE_ADRENO_BIN_KERNELS` | `OFF` | Allow using binary kernel lib for Adreno. |
## Android
@@ -277,6 +290,5 @@ ninja
## TODO
- Optimization for Q6_K
- Support and optimization for Q4_K
- Improve flash attention
- Improve OpenCL C kernels performance
+3 -2
View File
@@ -1111,11 +1111,12 @@ GGML_TABLE_BEGIN(int8_t, kvalues_iq4nl, 16)
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
GGML_TABLE_END()
// e2m1 values (doubled)
// e2m1 values (doubled), shared by MXFP4 and NVFP4
// ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
GGML_TABLE_BEGIN(int8_t, kvalues_mxfp4, 16)
GGML_TABLE_BEGIN(int8_t, kvalues_fp4, 16)
0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12,
GGML_TABLE_END()
#define kvalues_mxfp4 kvalues_fp4
#define NGRID_IQ1S 2048
#define IQ1S_DELTA 0.125f
-1
View File
@@ -82,7 +82,6 @@
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
// quants.c
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
+142 -4
View File
@@ -934,7 +934,7 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
#if defined __AVX2__
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_fp4);
const __m128i m4b = _mm_set1_epi8(0x0f);
const __m256i mone = _mm256_set1_epi16(1);
@@ -963,7 +963,7 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
#elif defined __AVX__
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_fp4);
const __m128i m4b = _mm_set1_epi8(0x0f);
__m256 accum = _mm256_setzero_ps();
@@ -993,14 +993,152 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
int sumi1 = 0;
int sumi2 = 0;
for (int j = 0; j < QK_MXFP4/2; ++j) {
sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
sumi1 += y[ib].qs[j + 0] * kvalues_fp4[x[ib].qs[j] & 0xf];
sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_fp4[x[ib].qs[j] >> 4];
}
sumf += d * (sumi1 + sumi2);
}
*s = sumf;
}
void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
assert(n % QK_NVFP4 == 0);
const block_nvfp4 * GGML_RESTRICT x = vx;
const block_q8_0 * GGML_RESTRICT y = vy;
const int nb = n / QK_NVFP4;
int ib = 0;
float sumf = 0;
#if defined(__AVX2__)
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_fp4);
const __m128i m4b = _mm_set1_epi8(0x0f);
const __m256i mone = _mm256_set1_epi16(1);
__m256 accum = _mm256_setzero_ps();
for(; ib < nb; ib++){
const __m128i q4bits_01 = _mm_loadu_si128((const __m128i *)(x[ib].qs + 0));
const __m128i q4bits_23 = _mm_loadu_si128((const __m128i *)(x[ib].qs + 16));
const __m256i q8_01 = _mm256_loadu_si256((const __m256i *)y[2*ib + 0].qs);
const __m256i q8_23 = _mm256_loadu_si256((const __m256i *)y[2*ib + 1].qs);
const __m128i q4_01_lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_01, m4b));
const __m128i q4_01_hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_01, 4), m4b));
const __m128i q4_23_lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_23, m4b));
const __m128i q4_23_hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_23, 4), m4b));
//reordering
const __m256i q4_01 = MM256_SET_M128I(_mm_unpackhi_epi64(q4_01_lo,q4_01_hi), _mm_unpacklo_epi64(q4_01_lo,q4_01_hi));
const __m256i q4_23 = MM256_SET_M128I(_mm_unpackhi_epi64(q4_23_lo,q4_23_hi),_mm_unpacklo_epi64(q4_23_lo,q4_23_hi));
const __m256i p01 = mul_add_epi8(q4_01,q8_01);
const __m256i p_1 = _mm256_madd_epi16(p01, mone);
const __m256i p23 = mul_add_epi8(q4_23,q8_23);
const __m256i p_2 = _mm256_madd_epi16(p23, mone);
const float dy0 = GGML_CPU_FP16_TO_FP32(y[2*ib].d);
const float dy1 = GGML_CPU_FP16_TO_FP32(y[2*ib+1].d);
const float s0 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[0]) * dy0;
const float s1 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[1]) * dy0;
const float s2 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[2]) * dy1;
const float s3 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[3]) * dy1;
const __m256 scales01 = _mm256_set_m128(_mm_set1_ps(s1), _mm_set1_ps(s0));
const __m256 scales23 = _mm256_set_m128(_mm_set1_ps(s3), _mm_set1_ps(s2));
accum = _mm256_fmadd_ps(scales01, _mm256_cvtepi32_ps(p_1), accum);
accum = _mm256_fmadd_ps(scales23, _mm256_cvtepi32_ps(p_2), accum);
}
sumf = hsum_float_8(accum);
#elif defined(__AVX__)
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_fp4);
const __m128i m4b = _mm_set1_epi8(0x0f);
__m256 accum = _mm256_setzero_ps();
for(; ib < nb; ib++){
const __m128i q4bits_01 = _mm_loadu_si128((const __m128i *)(x[ib].qs + 0));
const __m128i q4bits_23 = _mm_loadu_si128((const __m128i *)(x[ib].qs + 16));
const __m128i q8_0 = _mm_loadu_si128((const __m128i *)(y[2*ib + 0].qs + 0));
const __m128i q8_1 = _mm_loadu_si128((const __m128i *)(y[2*ib + 0].qs + 16));
const __m128i q8_2 = _mm_loadu_si128((const __m128i *)(y[2*ib + 1].qs + 0));
const __m128i q8_3 = _mm_loadu_si128((const __m128i *)(y[2*ib + 1].qs + 16));
const __m128i q4_01_lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_01, m4b));
const __m128i q4_01_hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_01, 4), m4b));
const __m128i q4_23_lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_23, m4b));
const __m128i q4_23_hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_23, 4), m4b));
const __m128i q4_0 = _mm_unpacklo_epi64(q4_01_lo, q4_01_hi);
const __m128i q4_1 = _mm_unpackhi_epi64(q4_01_lo, q4_01_hi);
const __m128i q4_2 = _mm_unpacklo_epi64(q4_23_lo, q4_23_hi);
const __m128i q4_3 = _mm_unpackhi_epi64(q4_23_lo, q4_23_hi);
const __m128i p0_i32 = mul_sum_i8_pairs(q4_0, q8_0);
const __m128i p1_i32 = mul_sum_i8_pairs(q4_1, q8_1);
const __m128i p2_i32 = mul_sum_i8_pairs(q4_2, q8_2);
const __m128i p3_i32 = mul_sum_i8_pairs(q4_3, q8_3);
const __m128 p0 = _mm_cvtepi32_ps(p0_i32);
const __m128 p1 = _mm_cvtepi32_ps(p1_i32);
const __m128 p2 = _mm_cvtepi32_ps(p2_i32);
const __m128 p3 = _mm_cvtepi32_ps(p3_i32);
const __m256 p01 = _mm256_set_m128(p1, p0);
const __m256 p23 = _mm256_set_m128(p3, p2);
const float dy0 = GGML_CPU_FP16_TO_FP32(y[2*ib].d);
const float dy1 = GGML_CPU_FP16_TO_FP32(y[2*ib+1].d);
const float s0 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[0]) * dy0;
const float s1 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[1]) * dy0;
const float s2 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[2]) * dy1;
const float s3 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[3]) * dy1;
const __m256 scales01 = _mm256_set_m128(_mm_set1_ps(s1), _mm_set1_ps(s0));
const __m256 scales23 = _mm256_set_m128(_mm_set1_ps(s3), _mm_set1_ps(s2));
accum = _mm256_add_ps(accum, _mm256_mul_ps(p01, scales01));
accum = _mm256_add_ps(accum, _mm256_mul_ps(p23, scales23));
}
sumf = hsum_float_8(accum);
#endif
for (;ib < nb; ++ib) {
for (int s_idx = 0; s_idx < 4; ++s_idx) {
const float d = GGML_CPU_UE4M3_TO_FP32(x[ib].d[s_idx]);
const int q8_block = s_idx / 2;
const int q8_off = (s_idx % 2) * QK_NVFP4_SUB;
const float dy = GGML_CPU_FP16_TO_FP32(y[2*ib + q8_block].d);
int sumi_lo = 0, sumi_hi = 0;
for (int j = 0; j < QK_NVFP4_SUB/2; ++j) {
const uint8_t qv = x[ib].qs[s_idx*(QK_NVFP4_SUB/2) + j];
sumi_lo += y[2*ib + q8_block].qs[q8_off + j + 0] * kvalues_fp4[qv & 0xf];
sumi_hi += y[2*ib + q8_block].qs[q8_off + j + QK_NVFP4_SUB/2] * kvalues_fp4[qv >> 4];
}
sumf += dy * d * (sumi_lo + sumi_hi);
}
}
*s = sumf;
}
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
+8
View File
@@ -82,6 +82,9 @@ float ggml_table_f32_f16[1 << 16];
// precomputed f32 table for e8m0 half (1 KB) (simd-mappings.h)
float ggml_table_f32_e8m0_half[1 << 8];
// precomputed f32 table for ue4m3 (1 KB) (simd-mappings.h)
float ggml_table_f32_ue4m3[1 << 8];
#if defined(__ARM_ARCH)
struct ggml_arm_arch_features_type {
int sve_cnt;
@@ -3798,6 +3801,11 @@ void ggml_cpu_init(void) {
ggml_table_f32_e8m0_half[i] = GGML_E8M0_TO_FP32_HALF(i);
}
// initialize UE4M3 table (256 entries)
for (int i = 0; i < (1 << 8); ++i) {
ggml_table_f32_ue4m3[i] = ggml_ue4m3_to_fp32(i);
}
const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0);
+11
View File
@@ -120,6 +120,10 @@ extern float ggml_table_f32_f16[1 << 16];
// defined in ggml-cpu.c, initialized in ggml_cpu_init()
extern float ggml_table_f32_e8m0_half[1 << 8];
// precomputed f32 table for ue4m3 (1 KB)
// defined in ggml-cpu.c, initialized in ggml_cpu_init()
extern float ggml_table_f32_ue4m3[1 << 8];
// Use lookup table for E8M0 on x86 (faster than bit manipulation)
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
#define GGML_CPU_E8M0_TO_FP32_HALF(x) ggml_table_f32_e8m0_half[(uint8_t)(x)]
@@ -127,6 +131,13 @@ extern float ggml_table_f32_e8m0_half[1 << 8];
#define GGML_CPU_E8M0_TO_FP32_HALF(x) GGML_E8M0_TO_FP32_HALF(x)
#endif
// Use lookup table for UE4M3 on x86 (faster than bit manipulation)
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
#define GGML_CPU_UE4M3_TO_FP32(x) ggml_table_f32_ue4m3[(uint8_t)(x)]
#else
#define GGML_CPU_UE4M3_TO_FP32(x) ggml_ue4m3_to_fp32(x)
#endif
// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
// so we define GGML_CPU_FP16_TO_FP32 and GGML_CPU_FP32_TO_FP16 elsewhere for NEON.
// This is also true for POWER9.
+9 -5
View File
@@ -664,7 +664,10 @@ constexpr __device__ dequantize_V_t get_dequantize_V() {
template <int ncols1>
__launch_bounds__(FATTN_KQ_STRIDE/2, 1)
static __global__ void flash_attn_mask_to_KV_max(
const half2 * __restrict__ mask, int * __restrict__ KV_max, const int ne30, const int s31, const int s33) {
const half2 * mask_ptr, int * KV_max_ptr, const int ne30, const int64_t s31, const int64_t s33) {
const half2 * GGML_CUDA_RESTRICT mask = mask_ptr;
int * GGML_CUDA_RESTRICT KV_max = KV_max_ptr;
const int ne31 = gridDim.x;
const int tid = threadIdx.x;
const int sequence = blockIdx.y;
@@ -1089,8 +1092,8 @@ void launch_fattn(
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
// multiple sequences of possibly different lengths.
if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
const int s31 = mask->nb[1] / sizeof(half2);
const int s33 = mask->nb[3] / sizeof(half2);
const int64_t s31 = mask->nb[1] / sizeof(half2);
const int64_t s33 = mask->nb[3] / sizeof(half2);
const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1);
@@ -1099,8 +1102,9 @@ void launch_fattn(
const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
KV_max.alloc(ne_KV_max);
flash_attn_mask_to_KV_max<ncols1><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
((const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33);
ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num_KV_max, block_dim_KV_max, 0, main_stream);
ggml_cuda_kernel_launch(flash_attn_mask_to_KV_max<ncols1>, launch_params,
(const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33);
CUDA_CHECK(cudaGetLastError());
}
+4
View File
@@ -2003,6 +2003,10 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 2);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 2);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 16, 2);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 32, 2);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4);
+9 -5
View File
@@ -76,6 +76,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 2, 64, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
@@ -144,6 +145,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 16, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 2, 64, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 32, 64)
@@ -219,6 +221,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 512, 1, 128, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 2, 64, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
@@ -296,6 +299,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 128, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 2, 64, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 4, 64, 64)
@@ -1308,12 +1312,12 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
return;
}
if constexpr (DV <= 256) {
if (use_gqa_opt && gqa_ratio % 2 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio % 2 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
return;
}
if constexpr (DV <= 256) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst);
return;
}
+5 -5
View File
@@ -99,12 +99,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
return;
}
if constexpr (DKQ <= 256) {
if (use_gqa_opt && gqa_ratio > 1) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio > 1) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
return;
}
if constexpr (DKQ <= 256) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
} else {
GGML_ABORT("fatal error");
+17 -14
View File
@@ -78,26 +78,29 @@ static __global__ void k_get_rows_float(
template<typename grad_t, typename dst_t>
static __global__ void k_get_rows_back_float(
const grad_t * __restrict__ grad, const int32_t * __restrict__ rows, dst_t * __restrict__ dst, const int64_t ncols, const int64_t nrows_grad) {
const grad_t * __restrict__ grad, const int32_t * __restrict__ rows, dst_t * __restrict__ dst,
const int64_t ncols, const int64_t nrows_grad, const int64_t nrows_dst) {
const int col = blockIdx.x*blockDim.x + threadIdx.x;
if (col >= ncols) {
return;
}
const int dst_row = blockIdx.y*blockDim.y + threadIdx.y;
float sum = 0.0f;
ggml_cuda_pdl_sync();
for (int64_t i = 0; i < nrows_grad; ++i) {
if (rows[i] != dst_row) {
continue;
}
sum += grad[i*ncols + col];
}
dst[dst_row*ncols + col] = sum;
// grid.y is clamped to the CUDA grid limit, so stride over the destination rows
for (int64_t dst_row = blockIdx.y; dst_row < nrows_dst; dst_row += gridDim.y) {
float sum = 0.0f;
for (int64_t i = 0; i < nrows_grad; ++i) {
if (rows[i] != dst_row) {
continue;
}
sum += grad[i*ncols + col];
}
dst[dst_row*ncols + col] = sum;
}
}
template<int qk, int qr, dequantize_kernel_t dq, typename dst_t>
@@ -302,7 +305,7 @@ void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * d
const dim3 block_dims(CUDA_GET_ROWS_BACK_BLOCK_SIZE, 1, 1);
const int block_num_x = (ne00 + CUDA_GET_ROWS_BACK_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BACK_BLOCK_SIZE;
const dim3 block_nums(block_num_x, ne1, 1);
const dim3 block_nums(block_num_x, MIN(ne1, (int64_t)UINT16_MAX), 1);
k_get_rows_back_float<<<block_nums, block_dims, 0, stream>>>(src0_d, src1_d, dst_d, ne00, ne10);
k_get_rows_back_float<<<block_nums, block_dims, 0, stream>>>(src0_d, src1_d, dst_d, ne00, ne10, ne1);
}
+7
View File
@@ -368,5 +368,12 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
return true;
}
// gfx900 (Vega 10) lacks native dp4a, loses to dequant + hipBLAS
// for dense matrices; keep MMQ only for MoE, where the
// hipBLAS path is much slower.
if (cc == GGML_CUDA_CC_VEGA) {
return n_experts > 0;
}
return (!GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 2);
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 2);
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 2);
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 2);
DECL_FATTN_MMA_F16_CASE(512, 512, 16, 2);
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 32, 2);
DECL_FATTN_MMA_F16_CASE(112, 112, 32, 2);
DECL_FATTN_MMA_F16_CASE(128, 128, 32, 2);
DECL_FATTN_MMA_F16_CASE(256, 256, 32, 2);
DECL_FATTN_MMA_F16_CASE(512, 512, 32, 2);
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 2);
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 2);
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 2);
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 2);
DECL_FATTN_MMA_F16_CASE(512, 512, 4, 2);
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 2);
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 2);
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 2);
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 2);
DECL_FATTN_MMA_F16_CASE(512, 512, 8, 2);
@@ -92,7 +92,7 @@ for ncols in [8, 16, 32, 64]:
continue
if head_size_kq == 320 and ncols2 != 32: # Mistral Small 4
continue
if head_size_kq == 512 and ncols2 not in (4, 8): # Gemma 4
if head_size_kq == 512 and ncols2 not in (2, 4, 8): # Gemma 4 (+ MTP)
continue
if head_size_kq == 576 and ncols2 not in (4, 16, 32): # Deepseek, GLM 4.7 Flash
continue
-1
View File
@@ -23,7 +23,6 @@ include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
include(ExternalProject)
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
option(GGML_HEXAGON_FA_EXP2_HF "ggml-hexagon: use FP16 exp2 polynomial in FA softmax instead of F32 exp round-trip" OFF)
set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate")
add_library(htp_iface OBJECT
+221 -26
View File
@@ -43,6 +43,7 @@
#include "htp-opnode.h"
#include "htp-ops.h"
#include "htp/matmul-ops.h"
#include "htp/flash-attn-ops.h"
#include "htp_iface.h"
#include "htp-drv.h"
@@ -62,6 +63,7 @@ static int opt_profile = 0; // profiling mode (0-disabled, 1-basic, 2-pmu)
static int opt_hostbuf = 1; // hostbuf ON by default
static int opt_mm_select = 3; // 3 = HMX -> Tiled -> Flat -> CPU, 2 = Tiled -> Flat -> CPU, 1 = Flat -> CPU
static int opt_fa_select = 2; // 2 = HMX -> HVX -> CPU, 1 = HVX -> CPU, 0 = CPU (unsupported)
// Default PMU events, if profiling with PMU (mode=2) is enabled
// See https://docs.qualcomm.com/doc/80-N2040-60/topic/pmu-events.html
@@ -125,6 +127,11 @@ static const char * htp_event_name(uint16_t id) {
case HTP_TRACE_EVT_HVX_W_DEQUANT: return "HVX_W_DEQUANT";
case HTP_TRACE_EVT_HVX_W_PREP: return "HVX_W_PREP";
case HTP_TRACE_EVT_HVX_O_PROC: return "HVX_O_PROC";
case HTP_TRACE_EVT_HVX_FA_QK: return "HVX_QK_FA";
case HTP_TRACE_EVT_HVX_FA_SFM: return "HVX_SFM_FA";
case HTP_TRACE_EVT_HVX_FA_Q_PREP: return "HVX_Q_PREP";
case HTP_TRACE_EVT_HVX_FA_K_PREP: return "HVX_K_PREP";
case HTP_TRACE_EVT_HVX_FA_V_PREP: return "HVX_V_PREP";
case HTP_TRACE_EVT_HMX_COMP: return "HMX_COMP";
default: return "UNKNOWN";
}
@@ -1879,6 +1886,162 @@ ggml_hexagon_session::~ggml_hexagon_session() noexcept(true) {
// ** backend interface
static bool ggml_hexagon_flash_attn_is_hmx_eligible(
const struct ggml_hexagon_session * sess,
const struct ggml_tensor * q,
const struct ggml_tensor * k,
const struct ggml_tensor * v,
const struct ggml_tensor * sinks
) {
if (sess->n_hmx == 0) {
return false;
}
if (opt_fa_select < 2) {
return false;
}
if (k->type != GGML_TYPE_F16 || v->type != GGML_TYPE_F16) {
return false;
}
const uint32_t DK = q->ne[0];
const uint32_t DV = v->ne[0];
if (DK % 64 != 0 || DV % 64 != 0) {
return false;
}
// Fall back to HVX for small token counts if head dimension is small (DK <= 128)
const uint32_t neq1 = q->ne[1];
if (DK <= 128 && neq1 < 5) {
return false;
}
return true;
}
static bool ggml_hexagon_precompute_flash_attn_params(
const struct ggml_hexagon_session * sess,
const struct ggml_tensor * op,
struct htp_fa_kernel_params * kparams
) {
if (opt_fa_select < 1) {
return false;
}
memset(kparams, 0, sizeof(*kparams));
const struct ggml_tensor * q = op->src[0];
const struct ggml_tensor * k = op->src[1];
const struct ggml_tensor * v = op->src[2];
const struct ggml_tensor * mask = op->src[3];
const struct ggml_tensor * dst = op;
const uint32_t neq0 = q->ne[0]; // head_dim (DK)
const uint32_t neq1 = q->ne[1]; // n_tokens
const uint32_t neq2 = q->ne[2]; // n_heads
const uint32_t nek1 = k->ne[1]; // kv_len
const uint32_t nev0 = v->ne[0]; // head_dim (DV)
const uint32_t DK = neq0;
const uint32_t DV = nev0;
const uint32_t n_kv_heads = k->ne[2];
const uint32_t G = neq2 / n_kv_heads;
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
memcpy(&scale, &op->op_params[0], sizeof(float));
memcpy(&max_bias, &op->op_params[1], sizeof(float));
memcpy(&logit_softcap, &op->op_params[2], sizeof(float));
if (logit_softcap != 0.0f) {
scale /= logit_softcap;
}
kparams->scale = scale;
kparams->max_bias = max_bias;
kparams->logit_softcap = logit_softcap;
kparams->is_q_fp32 = (q->type == GGML_TYPE_F32) ? 1 : 0;
kparams->is_dst_fp32 = (dst->type == GGML_TYPE_F32) ? 1 : 0;
kparams->G = G;
const uint32_t n_head = q->ne[2];
kparams->n_head_log2 = 1u << (uint32_t) std::floor(std::log2(n_head));
kparams->m0 = std::pow(2.0f, -(max_bias) / kparams->n_head_log2);
kparams->m1 = std::pow(2.0f, -(max_bias / 2.0f) / kparams->n_head_log2);
// Check HMX eligibility
const struct ggml_tensor * sinks = op->src[4];
if (ggml_hexagon_flash_attn_is_hmx_eligible(sess, q, k, v, sinks)) {
size_t Br = 0, Bc = 0;
int ret = hmx_fa_find_chunk_size(&Br, &Bc, G, DK, DV, neq1, nek1, sess->vtcm_size, sess->n_threads);
if (ret == 0) {
kparams->kernel_type = HTP_FA_KERNEL_HMX;
kparams->Br = Br;
kparams->Bc = Bc;
kparams->n_kv_blocks = (nek1 + Bc - 1) / Bc;
kparams->n_threads = (kparams->n_kv_blocks >= 3 && sess->n_threads >= 2) ? sess->n_threads : 1;
kparams->u.hmx.g_br = hex_align_up(G * Br, 32);
kparams->u.hmx.pipeline = (kparams->n_kv_blocks >= 3 && sess->n_threads >= 2) ? 1 : 0;
kparams->vtcm_size = hmx_fa_compute_vtcm_usage(G, DK, DV, Br, Bc, kparams->n_threads, kparams->u.hmx.pipeline != 0);
const size_t row_vec_bytes = hex_align_up(Bc * sizeof(uint16_t), 256);
kparams->u.hmx.row_buf_stride = row_vec_bytes / 128; // HVX vector is 128 bytes
const size_t m_line_bytes = hex_align_up(Bc * sizeof(uint16_t), 128);
kparams->u.hmx.mask_buf_row_stride = m_line_bytes / sizeof(uint16_t);
kparams->u.hmx.mask_broadcast = (mask != nullptr && mask->ne[2] == 1) ? 1 : 0;
kparams->u.hmx.div_G = init_fastdiv_values(G);
if (mask) {
kparams->src3_div2 = init_fastdiv_values(mask->ne[2]);
kparams->src3_div3 = init_fastdiv_values(mask->ne[3]);
}
kparams->qrows = 0;
kparams->qrows_per_thread = 0;
return true;
}
}
// Fallback to HVX
kparams->kernel_type = HTP_FA_KERNEL_HVX;
kparams->Br = 1;
kparams->Bc = 64; // FLASH_ATTN_BLOCK_SIZE
kparams->n_kv_blocks = (k->ne[1] + 64 - 1) / 64;
kparams->n_threads = sess->n_threads;
const size_t size_q_row_padded = hex_round_up(q->ne[0] * (kparams->is_q_fp32 ? 4 : 2), 128);
const size_t size_k_row_padded = hex_round_up(k->ne[0] * 2, 128);
const size_t size_v_row_padded = hex_round_up(v->ne[0] * 2, 128);
kparams->vtcm_size = hvx_fa_compute_vtcm_usage(DK, DV, kparams->is_q_fp32 != 0, mask != nullptr, sess->n_threads);
kparams->u.hvx.size_q_row_padded = size_q_row_padded;
kparams->u.hvx.size_k_row_padded = size_k_row_padded;
kparams->u.hvx.size_v_row_padded = size_v_row_padded;
kparams->u.hvx.src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
kparams->u.hvx.src0_div1 = init_fastdiv_values(q->ne[1]);
kparams->u.hvx.broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
kparams->u.hvx.broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
kparams->u.hvx.broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
kparams->u.hvx.broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
if (mask) {
kparams->src3_div2 = init_fastdiv_values(mask->ne[2]);
kparams->src3_div3 = init_fastdiv_values(mask->ne[3]);
}
kparams->qrows = q->ne[1] * q->ne[2] * q->ne[3];
kparams->qrows_per_thread = (kparams->qrows + sess->n_threads - 1) / sess->n_threads;
return true;
}
static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const struct ggml_tensor * src0 = op->src[0];
@@ -1912,6 +2075,17 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess
return false;
}
struct htp_fa_kernel_params kparams;
if (!ggml_hexagon_precompute_flash_attn_params(sess, op, &kparams)) {
return false;
}
if ((size_t) kparams.vtcm_size > sess->vtcm_size) {
HEX_VERBOSE("ggml-hex: skip flash_attn_ext because VTCM needed (%d) > budget (%zu)\n",
kparams.vtcm_size, sess->vtcm_size);
return false;
}
return true;
}
@@ -2211,14 +2385,14 @@ static void ggml_hexagon_precompute_hvx_mm_params(
kparams->kernel_type = (src1_nrows < (int) sess->n_threads) ? HTP_MM_KERNEL_HVX_QUANT_BLOCK : HTP_MM_KERNEL_HVX_QUANT_ROW;
kparams->src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10);
size_t vtcm_src0_size = 0, vtcm_src1_size = 0;
size_t vtcm_src0_size = 0, vtcm_src1_size = 0, vtcm_dst_size = 0;
uint32_t max_prefetch = (src1_nrows > HTP_MM_HMX_MIN_NROWS) ? 2 : 16;
uint32_t best_n_prefetch = 2;
size_t total_size = 0;
for (uint32_t d = max_prefetch; d >= 2; d /= 2) {
total_size = htp_mm_hvx_id_get_vtcm_sizes(
wtype, ne10, src1_nrows, sess->n_threads, src0->nb[1], d,
&vtcm_src0_size, &vtcm_src1_size
&vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size
);
if (total_size <= vtcm_budget) {
best_n_prefetch = d;
@@ -2228,14 +2402,14 @@ static void ggml_hexagon_precompute_hvx_mm_params(
if (best_n_prefetch == 2 && total_size > vtcm_budget) {
total_size = htp_mm_hvx_id_get_vtcm_sizes(
wtype, ne10, src1_nrows, sess->n_threads, src0->nb[1], 2,
&vtcm_src0_size, &vtcm_src1_size
&vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size
);
}
kparams->n_prefetch = best_n_prefetch;
kparams->vtcm_size = total_size;
kparams->vtcm_src0_size = vtcm_src0_size;
kparams->vtcm_src1_size = vtcm_src1_size;
kparams->vtcm_dst_size = 0;
kparams->vtcm_dst_size = vtcm_dst_size;
} else {
bool try_tiled = (k_align && opt_mm_select >= 2);
if (try_tiled) {
@@ -2441,11 +2615,12 @@ static void ggml_hexagon_precompute_fused_qkv_params(
size_t src3_sz_per_thread = 0;
uint32_t best_n_prefetch = 16;
size_t quant_scratch_size = hex_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float)) * sess->n_threads;
if (is_repack) {
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
uint32_t n_k_tiles = hex_round_up(ne10, 32) / 32;
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float));
size_t src1_sz_per_thread = hex_round_up(src1_row_size * src1_nrows, 128);
size_t src1_sz = src1_sz_per_thread;
@@ -2453,13 +2628,10 @@ static void ggml_hexagon_precompute_fused_qkv_params(
best_n_prefetch = 2;
for (uint32_t d = max_prefetch; d >= 2; d /= 2) {
size_t repacked_vtcm_size = hex_round_up(d * tile_row_size, 128);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
size_t src0_sz = repacked_vtcm_size * sess->n_threads;
size_t src2_sz = hex_round_up(d * tile_row_size, 128) * sess->n_threads;
size_t src3_sz = hex_round_up(d * tile_row_size, 128) * sess->n_threads;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz + quant_scratch_size;
if (tiled_vtcm_size <= sess->vtcm_size) {
best_n_prefetch = d;
@@ -2471,9 +2643,6 @@ static void ggml_hexagon_precompute_fused_qkv_params(
}
if (best_n_prefetch == 2 && src0_sz_per_thread == 0) {
size_t repacked_vtcm_size = hex_round_up(2 * tile_row_size, 128);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
src0_sz_per_thread = repacked_vtcm_size;
src2_sz_per_thread = hex_round_up(2 * tile_row_size, 128);
src3_sz_per_thread = hex_round_up(2 * tile_row_size, 128);
@@ -2492,7 +2661,7 @@ static void ggml_hexagon_precompute_fused_qkv_params(
size_t src2_sz = src2_sz_per_thread * sess->n_threads;
size_t src3_sz = src3_sz_per_thread * sess->n_threads;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz + quant_scratch_size;
bool try_tiled = (opt_mm_select >= 2);
if (try_tiled && tiled_vtcm_size <= sess->vtcm_size) {
kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW;
@@ -2500,6 +2669,7 @@ static void ggml_hexagon_precompute_fused_qkv_params(
kparams->vtcm_src1_size = src1_sz;
kparams->vtcm_src2_size = src2_sz;
kparams->vtcm_src3_size = src3_sz;
kparams->vtcm_dst_size = quant_scratch_size;
kparams->vtcm_size = tiled_vtcm_size;
kparams->n_prefetch = best_n_prefetch;
} else {
@@ -2510,7 +2680,8 @@ static void ggml_hexagon_precompute_fused_qkv_params(
kparams->vtcm_src1_size = flat_src1_sz;
kparams->vtcm_src2_size = src2_sz;
kparams->vtcm_src3_size = src3_sz;
kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz + src3_sz;
kparams->vtcm_dst_size = quant_scratch_size;
kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz + src3_sz + quant_scratch_size;
kparams->n_prefetch = best_n_prefetch;
}
}
@@ -2536,11 +2707,12 @@ static void ggml_hexagon_precompute_fused_ffn_params(
size_t src2_sz_per_thread = 0;
uint32_t best_n_prefetch = 16;
size_t quant_scratch_size = hex_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float)) * sess->n_threads;
if (is_repack) {
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
uint32_t n_k_tiles = hex_round_up(ne10, 32) / 32;
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float));
size_t src1_sz_per_thread = hex_round_up(src1_row_size * src1_nrows, 128);
size_t src1_sz = src1_sz_per_thread;
@@ -2548,12 +2720,9 @@ static void ggml_hexagon_precompute_fused_ffn_params(
best_n_prefetch = 2;
for (uint32_t d = max_prefetch; d >= 2; d /= 2) {
size_t repacked_vtcm_size = hex_round_up(d * tile_row_size, 128);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
size_t src0_sz = repacked_vtcm_size * sess->n_threads;
size_t src2_sz = hex_round_up(d * tile_row_size, 128) * sess->n_threads;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + quant_scratch_size;
if (tiled_vtcm_size <= sess->vtcm_size) {
best_n_prefetch = d;
@@ -2564,9 +2733,6 @@ static void ggml_hexagon_precompute_fused_ffn_params(
}
if (best_n_prefetch == 2 && src0_sz_per_thread == 0) {
size_t repacked_vtcm_size = hex_round_up(2 * tile_row_size, 128);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
src0_sz_per_thread = repacked_vtcm_size;
src2_sz_per_thread = hex_round_up(2 * tile_row_size, 128);
}
@@ -2582,13 +2748,14 @@ static void ggml_hexagon_precompute_fused_ffn_params(
size_t src1_sz = src1_sz_per_thread;
size_t src2_sz = src2_sz_per_thread * sess->n_threads;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + quant_scratch_size;
bool try_tiled = (opt_mm_select >= 2);
if (try_tiled && tiled_vtcm_size <= sess->vtcm_size) {
kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW;
kparams->vtcm_src0_size = src0_sz;
kparams->vtcm_src1_size = src1_sz;
kparams->vtcm_src2_size = src2_sz;
kparams->vtcm_dst_size = quant_scratch_size;
kparams->vtcm_size = tiled_vtcm_size;
kparams->n_prefetch = best_n_prefetch;
} else {
@@ -2598,7 +2765,8 @@ static void ggml_hexagon_precompute_fused_ffn_params(
kparams->vtcm_src0_size = src0_sz;
kparams->vtcm_src1_size = flat_src1_sz;
kparams->vtcm_src2_size = src2_sz;
kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz;
kparams->vtcm_dst_size = quant_scratch_size;
kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz + quant_scratch_size;
kparams->n_prefetch = best_n_prefetch;
}
}
@@ -3243,7 +3411,7 @@ static inline bool op_is_compute(ggml_tensor *node)
return !ggml_op_is_empty(node->op) && !ggml_is_empty(node) && (node->flags & GGML_TENSOR_FLAG_COMPUTE);
}
static bool is_hmx_eligible(const ggml_tensor * t) {
static bool mm_is_hmx_eligible(const ggml_tensor * t) {
if (opt_nhmx == 0) { return false; }
const ggml_tensor * src0 = t->src[0];
@@ -3262,7 +3430,7 @@ static bool is_hmx_eligible(const ggml_tensor * t) {
static bool is_mergeable_mul_mat(const ggml_tensor * t) {
if (!t || t->op != GGML_OP_MUL_MAT) return false;
if (t->src[1]->type != GGML_TYPE_F32) return false;
return ggml_is_quantized(t->src[0]->type) && !is_hmx_eligible(t);
return ggml_is_quantized(t->src[0]->type) && !mm_is_hmx_eligible(t);
}
static bool is_mergeable_mul_mat_pair(const ggml_tensor * n1, const ggml_tensor * n2) {
@@ -3357,6 +3525,26 @@ static bool try_fuse_node(const ggml_hexagon_session * sess, const ggml_cgraph *
}
}
if (n->op == GGML_OP_MUL_MAT && next_node) {
if (next_node->op == GGML_OP_ADD && op_is_compute(next_node) && ggml_can_fuse(graph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) {
if (next_node->src[0] == n || next_node->src[1] == n) {
struct htp_mm_kernel_params kparams;
ggml_hexagon_precompute_matmul_params(sess, n->src[0], n->src[1], next_node, &kparams);
if ((size_t)kparams.vtcm_size <= sess->vtcm_size) {
htp_opnode node(n, {}, HTP_OP_MUL_MAT_ADD);
node.add_fused(next_node);
memcpy(node.kernel_params, &kparams, sizeof(kparams));
nodes.push_back(std::move(node));
i += 1;
return true;
} else {
HEX_VERBOSE("ggml-hex: skip MUL_MAT_ADD fusion because VTCM needed (%d) > budget (%zu)\n",
kparams.vtcm_size, sess->vtcm_size);
}
}
}
}
return false;
}
@@ -3393,6 +3581,11 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
node.node->src[0], node.node->src[1], node.node,
(struct htp_mm_kernel_params *)node.kernel_params
);
} else if (node.opcode == HTP_OP_FLASH_ATTN_EXT) {
ggml_hexagon_precompute_flash_attn_params(sess,
node.node,
(struct htp_fa_kernel_params *)node.kernel_params
);
}
computed_nodes.push_back(std::move(node));
}
@@ -4079,6 +4272,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX");
const char * str_nhmx = getenv("GGML_HEXAGON_NHMX");
const char * str_mm_select = getenv("GGML_HEXAGON_MM_SELECT");
const char * str_fa_select = getenv("GGML_HEXAGON_FA_SELECT");
const char * str_ndev = getenv("GGML_HEXAGON_NDEV");
const char * str_arch = getenv("GGML_HEXAGON_ARCH");
const char * str_vmem = getenv("GGML_HEXAGON_VMEM");
@@ -4120,6 +4314,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;
opt_nhmx = str_nhmx ? atoi(str_nhmx) : (str_use_hmx ? atoi(str_use_hmx) : opt_nhmx);
opt_mm_select = str_mm_select ? atoi(str_mm_select) : opt_mm_select;
opt_fa_select = str_fa_select ? atoi(str_fa_select) : opt_fa_select;
opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev;
opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf;
opt_mbuf = str_mbuf ? strtoul(str_mbuf, NULL, 0) * MiB : opt_mbuf;
+13 -1
View File
@@ -11,6 +11,7 @@
#include <stdio.h>
#include "htp-ops.h"
#include "htp/matmul-ops.h"
#include "htp/flash-attn-ops.h"
struct htp_opnode {
ggml_tensor * node = nullptr;
@@ -335,7 +336,8 @@ struct htp_opformat {
}
void format_kernel_params(char * str, size_t max_size, const htp_opnode & node) {
if (node.opcode == HTP_OP_MUL_MAT || node.opcode == HTP_OP_MUL_MAT_ID ||
node.opcode == HTP_OP_MUL_MAT_QKV || node.opcode == HTP_OP_MUL_MAT_FFN) {
node.opcode == HTP_OP_MUL_MAT_QKV || node.opcode == HTP_OP_MUL_MAT_FFN ||
node.opcode == HTP_OP_MUL_MAT_ADD) {
const auto * kparams = (const struct htp_mm_kernel_params *) node.kernel_params;
const char * path = "unknown";
int32_t type = kparams->kernel_type;
@@ -350,6 +352,16 @@ struct htp_opformat {
path = "hvx-flat";
}
snprintf(str, max_size, "%s vtcm %d", path, (int) kparams->vtcm_size);
} else if (node.opcode == HTP_OP_FLASH_ATTN_EXT) {
const auto * kparams = (const struct htp_fa_kernel_params *) node.kernel_params;
const char * path = "unknown";
int32_t type = kparams->kernel_type;
if (type == HTP_FA_KERNEL_HMX) {
path = kparams->u.hmx.pipeline ? "hmx-pipe" : "hmx-seq";
} else if (type == HTP_FA_KERNEL_HVX) {
path = "hvx";
}
snprintf(str, max_size, "%s vtcm %d", path, (int) kparams->vtcm_size);
} else {
snprintf(str, max_size, "----");
}
+2 -7
View File
@@ -20,9 +20,6 @@ add_library(${HTP_LIB} SHARED
worker-pool.c
hex-dma.c
hmx-queue.c
flash-attn-ops.c
hmx-flash-attn-ops.c
matmul-ops.c
binary-ops.c
unary-ops.c
sum-rows-ops.c
@@ -42,16 +39,14 @@ add_library(${HTP_LIB} SHARED
solve-tri-ops.c
gated-delta-net-ops.c
pad-ops.c
matmul-ops.c
flash-attn-ops.c
)
target_compile_definitions(${HTP_LIB} PRIVATE
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>)
if (GGML_HEXAGON_FA_EXP2_HF)
message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)")
target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1)
endif()
build_idl(htp_iface.idl ${HTP_LIB})
File diff suppressed because it is too large Load Diff
+253
View File
@@ -0,0 +1,253 @@
#ifndef HTP_FLASH_ATTN_OPS_H
#define HTP_FLASH_ATTN_OPS_H
#include <stdint.h>
#include <stddef.h>
#include <stdbool.h>
#include "hex-fastdiv.h"
#include "hex-common.h"
#ifdef __cplusplus
extern "C" {
#endif
// Tile constants (mirrored from hmx-utils.h for use on host side if needed)
#define HMX_FP16_TILE_N_ROWS 32
#define HMX_FP16_TILE_N_COLS 32
#define HMX_FP16_TILE_N_ELMS 1024
#define HMX_FP16_TILE_SIZE 2048
#define HVX_FA_DMA_CACHE_SIZE 128
#define HMX_FA_DMA_CACHE_SIZE 4
#define HTP_FA_M_INITIAL_VAL -10000.0f
enum htp_fa_kernel_type {
HTP_FA_KERNEL_UNSUPPORTED = 0,
HTP_FA_KERNEL_HVX,
HTP_FA_KERNEL_HMX
};
struct htp_fa_kernel_params {
uint8_t kernel_type; // enum htp_fa_kernel_type
uint8_t is_q_fp32; // 1 = Q type is F32, 0 = F16
uint8_t is_dst_fp32; // 1 = dst type is F32, 0 = F16
uint8_t n_threads; // Number of threads to run
// Common parameters
uint16_t Br;
uint16_t Bc;
uint16_t n_kv_blocks; // also HVX's n_blocks
uint16_t G; // GQA factor (n_heads / n_kv_heads)
float scale;
float max_bias;
float logit_softcap;
uint32_t vtcm_size;
uint32_t qrows;
uint32_t qrows_per_thread;
float m0;
float m1;
uint32_t n_head_log2;
struct fastdiv_values src3_div2;
struct fastdiv_values src3_div3;
union {
struct {
uint32_t g_br;
uint32_t row_buf_stride;
uint32_t mask_buf_row_stride;
int32_t mask_broadcast;
int32_t pipeline;
struct fastdiv_values div_G;
} hmx;
struct {
uint32_t size_q_row_padded;
uint32_t size_k_row_padded;
uint32_t size_v_row_padded;
struct fastdiv_values src0_div21;
struct fastdiv_values src0_div1;
struct fastdiv_values broadcast_rk2;
struct fastdiv_values broadcast_rk3;
struct fastdiv_values broadcast_rv2;
struct fastdiv_values broadcast_rv3;
} hvx;
} u;
};
#if defined(__cplusplus)
static_assert(sizeof(struct htp_fa_kernel_params) <= 128, "htp_fa_kernel_params is too large for kernel_params blob");
#endif
// Exact VTCM usage for a given (gqa_factor, DK, DV, Br, Bc) configuration.
// g_br = hex_align_up(gqa_factor * Br, 32) replaces Br for all Q/O/S/P/D dimensions.
// Layout: Q + O_ping + O_pong + K_dma*2 + V_dma*2 + K_tile + V_tile + S + P + D + vectors + scales
// Mask is DMA'd into a VTCM buffer (Br rows per KV block) to avoid DDR reads in softmax.
static inline size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool pipeline) {
const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS);
const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK]
const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong
const size_t k_dma_size = hex_align_up(Bc * hex_round_up(DK * sizeof(__fp16), 128), 4096); // K DMA: [Bc, DK] x2 double-buf
const size_t v_dma_size = hex_align_up(Bc * hex_round_up(DV * sizeof(__fp16), 128), 4096); // V DMA: [Bc, DV] x2 double-buf
const size_t k_tile_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); // K tiles: [Bc, DK] interleaved
const size_t v_tile_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); // V tiles: [Bc, DV] interleaved
const size_t s_tile_size = hex_align_up(g_br * Bc * sizeof(__fp16), 4096); // S/P:[g_br, Bc]
const size_t d_tile_size = hex_align_up(g_br * g_br * sizeof(__fp16), 4096); // D: [g_br, g_br]
const size_t col_vec_size = hex_align_up(g_br * sizeof(float), 256); // m, l, etc.
const size_t row_vec_size = hex_align_up(Bc * sizeof(__fp16), 256);
const size_t m_line_size = hex_align_up(Bc * sizeof(__fp16), 128);
const size_t m_buf_size = hex_align_up(Br * m_line_size, 4096) * HMX_FA_DMA_CACHE_SIZE;
const size_t slopes_size = hex_align_up(g_br * sizeof(__fp16), 128);
return q_tile_size * 1 // Q tiles
+ o_tile_size * 2 // O ping-pong
+ k_dma_size * 2 // K DMA x2
+ v_dma_size * 2 // V DMA x2
+ k_tile_size * 1 // K tiles
+ v_tile_size * (pipeline ? 2 : 1) // V tiles (double-buffered if pipelining)
+ s_tile_size * 2 // S + P
+ d_tile_size * 1 // D (diagonal matrix)
+ col_vec_size * 4 // m_vec, l_vec, s_rowmax, p_rowsum
+ row_vec_size * 2 * n_threads // per-thread softmax row scratch
+ m_buf_size * 1 // mask VTCM buffer [Br rows]
+ slopes_size // Slopes
+ 256 * 2; // HMX scales (id + qk)
}
#define FA_HVX_BLOCK_SIZE 64
static inline size_t hvx_fa_compute_vtcm_usage(size_t DK, size_t DV, bool is_q_fp32, bool has_mask, size_t n_threads) {
const size_t size_q_row_padded = hex_round_up(DK * (is_q_fp32 ? 4 : 2), 128);
const size_t size_k_row_padded = hex_round_up(DK * sizeof(__fp16), 128);
const size_t size_v_row_padded = hex_round_up(DV * sizeof(__fp16), 128);
const size_t size_q_block = size_q_row_padded * 1;
const size_t size_k_block = size_k_row_padded * FA_HVX_BLOCK_SIZE;
const size_t size_v_block = size_v_row_padded * FA_HVX_BLOCK_SIZE;
const size_t size_m_block = hex_round_up(FA_HVX_BLOCK_SIZE * sizeof(__fp16), 128);
const size_t size_vkq_acc = hex_round_up(DV * sizeof(float), 128);
const size_t size_per_thread = size_q_block * 1
+ size_k_block * 2
+ size_v_block * 2
+ (has_mask ? size_m_block * HVX_FA_DMA_CACHE_SIZE : 0)
+ size_vkq_acc;
return size_per_thread * n_threads;
}
#define FA_MIN_KV_BLOCKS 3
// Cost-based (Br, Bc) search for flash attention with pipeline constraint.
static inline int hmx_fa_find_chunk_size(size_t * Br_out,
size_t * Bc_out,
size_t gqa_factor,
size_t DK,
size_t DV,
size_t qo_len,
size_t kv_len,
size_t vtcm_budget,
size_t n_threads) {
const size_t T = HMX_FP16_TILE_N_ROWS; // 32
const size_t br_unit = hmx_ceil_div(T, gqa_factor);
const size_t bc_unit = HMX_FP16_TILE_N_COLS * 2; // 64
const size_t fp16 = sizeof(__fp16);
const bool can_pipeline = (kv_len >= FA_MIN_KV_BLOCKS * bc_unit && n_threads >= 2);
// Approximate per-unit VTCM costs (without per-buffer alignment padding).
const size_t per_gbr = (DK + 2 * DV) * fp16 + 4 * sizeof(float); // Q + O*2 + 4 col vectors
const size_t per_gbr2 = fp16; // D diagonal matrix
const size_t per_bc =
3 * DK * fp16 + (can_pipeline ? 4 : 3) * DV * fp16 + 2 * n_threads * fp16; // K/V DMA x2 + tiles + row bufs
const size_t per_gbr_bc = 2 * fp16; // S + P
const size_t overhead = 256 * 2 + 13 * 4096;
if (vtcm_budget <= overhead) {
return -1;
}
const size_t usable = vtcm_budget - overhead;
// Br_max: largest Br aligned to br_unit that does not exceed qo_len.
const size_t Br_max = qo_len >= br_unit ? hex_align_down(qo_len, br_unit) : br_unit;
// Pipeline constraint: cap Bc so n_kv_blocks >= FA_MIN_KV_BLOCKS.
// Only relax when kv_len is too short to form enough blocks.
const size_t Bc_limit = can_pipeline ? hex_align_down(kv_len / FA_MIN_KV_BLOCKS, bc_unit) :
(kv_len >= bc_unit ? hex_align_down(kv_len, bc_unit) : bc_unit);
// Cost coefficients calibrated from profiling
const size_t c_q_fixed = 1400; // per-Q-block: q_load + epilogue o_update + o_norm + o_store
const size_t c_iter_fixed = 200; // per-KV-iter: HMX queue push/pop + DMA pop + barriers
size_t best_cost = SIZE_MAX, best_mn = 0;
size_t best_Br = 0, best_Bc = 0;
for (size_t Br = Br_max; Br >= br_unit; Br -= br_unit) {
const size_t g_br = hex_align_up(gqa_factor * Br, T);
// g_br-dependent VTCM cost: g_br * per_gbr + g_br*g_br * per_gbr2
const size_t gbr_cost = g_br * per_gbr + g_br * g_br * per_gbr2;
if (gbr_cost >= usable) {
if (Br == br_unit) {
break;
}
continue;
}
// Analytically solve for max Bc:
// remain >= Bc * (per_bc + g_br * per_gbr_bc + Br * fp16 * HMX_FA_DMA_CACHE_SIZE)
// The Br * fp16 term accounts for the VTCM mask buffer [Br * Bc].
const size_t remain = usable - gbr_cost;
const size_t bc_denom = per_bc + g_br * per_gbr_bc + Br * fp16 * HMX_FA_DMA_CACHE_SIZE;
size_t Bc = hex_smin(hex_align_down(remain / bc_denom, bc_unit), Bc_limit);
if (Bc < bc_unit) {
if (Br == br_unit) {
break;
}
continue;
}
// Exact VTCM verification (alignment padding may push over budget)
while (Bc >= bc_unit && hmx_fa_compute_vtcm_usage(gqa_factor, DK, DV, Br, Bc, n_threads, can_pipeline) > vtcm_budget) {
Bc -= bc_unit;
}
if (Bc < bc_unit) {
if (Br == br_unit) {
break;
}
continue;
}
const size_t q_blocks = (qo_len + Br - 1) / Br;
const size_t kv_blocks = (kv_len + Bc - 1) / Bc;
const size_t cost = q_blocks * (c_q_fixed + kv_blocks * c_iter_fixed);
const size_t mn = Br * Bc;
if (cost < best_cost || (cost == best_cost && mn > best_mn)) {
best_cost = cost;
best_mn = mn;
best_Br = Br;
best_Bc = Bc;
}
if (Br == br_unit) {
break;
}
}
if (best_Br == 0) {
return -1;
}
*Br_out = best_Br;
*Bc_out = best_Bc;
return 0;
}
#ifdef __cplusplus
}
#endif
#endif /* HTP_FLASH_ATTN_OPS_H */
+15 -15
View File
@@ -138,27 +138,28 @@ static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t
}
dma_descriptor_1d * desc = (dma_descriptor_1d *) &q->desc[q->push_idx];
desc->next = NULL;
desc->desc_size = 0; // 1D mode
desc->src_bypass = dma_src_l2_bypass_on;
desc->dst_bypass = dma_dst_l2_bypass_on;
desc->order = 0;
desc->done = 0;
desc->src = (void *) dptr.src;
desc->dst = (void *) dptr.dst;
desc->size = size;
desc->src = (void *) dptr.src;
desc->dst = (void *) dptr.dst;
desc->size = size;
q->dptr[q->push_idx] = dptr;
if (size) {
desc->next = NULL;
desc->desc_size = 0; // 1D mode
desc->src_bypass = dma_src_l2_bypass_on;
desc->dst_bypass = dma_dst_l2_bypass_on;
desc->order = 0;
desc->done = 0;
htp_trace_event_start(q->trace, HTP_TRACE_EVT_DMA, q->push_idx);
dmlink(q->tail, desc);
q->tail = (dma_descriptor_2d *) desc;
} else {
desc->done = 1;
desc->desc_size = 0;
desc->done = 1;
}
// FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src);
q->push_idx = (q->push_idx + 1) & q->idx_mask;
return true;
}
@@ -320,7 +321,7 @@ static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q, dma_ptr dptr, size_
return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows);
}
#define DMA_CACHE_MAX_SIZE 64U
#define DMA_CACHE_MAX_SIZE 256U
typedef struct {
uint8_t *base;
@@ -352,20 +353,19 @@ static inline bool dma_cache_push(dma_queue *q, dma_cache *c, const uint8_t * sr
if (c->src[i] == (uint32_t) src) {
c->age[i] = 0;
dst = c->base + (i * c->line_size); nrows = 0; // dummy dma
// FARF(ERROR, "dma-cache: found %p", src);
} else {
c->age[i]++;
if (c->age[i] > o_age) { o_age = c->age[i]; o_idx = i; }
}
}
if (!dst) {
// FARF(ERROR, "dma-cache: replacing #%u : age %u %p -> %p", o_idx, c->age[o_idx], (void *) c->src[o_idx], src);
c->age[o_idx] = 0;
c->src[o_idx] = (uint32_t) src;
dst = c->base + o_idx * c->line_size; // normal nrows dma
return dma_queue_push(q, dma_make_ptr(dst, src), dst_stride, src_stride, row_size, nrows);
}
return dma_queue_push(q, dma_make_ptr(dst, src), dst_stride, src_stride, row_size, nrows);
return dma_queue_push_single_1d(q, dma_make_ptr(dst, src), 0);
}
#ifdef __cplusplus
@@ -0,0 +1,96 @@
#ifndef HMX_FA_KERNELS_H
#define HMX_FA_KERNELS_H
#include <stdint.h>
#include <stddef.h>
#include <stdbool.h>
#include "hvx-utils.h"
#include "hmx-utils.h"
// HMX-specific parameters, offsets and inner kernels for Flash Attention
// Scatter offsets for diagonal tile: entry[2i] = i*136, entry[2i+1] = i*136+6
// 136 = 4 * 32 + 8 = byte offset to diagonal in a 32x32 fp16 interleaved tile
static const int16_t d_tile_scatter_offsets[64] __attribute__((aligned(128))) = {
0 * 136, 0 * 136 + 6,
1 * 136, 1 * 136 + 6,
2 * 136, 2 * 136 + 6,
3 * 136, 3 * 136 + 6,
4 * 136, 4 * 136 + 6,
5 * 136, 5 * 136 + 6,
6 * 136, 6 * 136 + 6,
7 * 136, 7 * 136 + 6,
8 * 136, 8 * 136 + 6,
9 * 136, 9 * 136 + 6,
10 * 136, 10 * 136 + 6,
11 * 136, 11 * 136 + 6,
12 * 136, 12 * 136 + 6,
13 * 136, 13 * 136 + 6,
14 * 136, 14 * 136 + 6,
15 * 136, 15 * 136 + 6,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
};
// Inner HMX tile computation kernels
static inline void hmx_fa_qk_dot_tile(
const __fp16 * row_tiles,
const __fp16 * col_tiles,
__fp16 * out_tile,
size_t n_dot_tiles
) {
for (size_t k = 0; k < n_dot_tiles; ++k) {
Q6_activation_hf_mxmem_RR((unsigned int) row_tiles, 2047);
Q6_weight_hf_mxmem_RR((unsigned int) col_tiles, 2047);
row_tiles += HMX_FP16_TILE_N_ELMS;
col_tiles += HMX_FP16_TILE_N_ELMS;
}
Q6_mxmem_AR_after_hf(out_tile, 0);
}
static inline void hmx_fa_o_update_tile(
const __fp16 * d_diag,
const __fp16 * o_rc,
const __fp16 * p_tile_in,
const __fp16 * v_tile_in,
__fp16 * o_tile_out,
size_t n_col_tiles
) {
Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047);
Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047);
for (size_t k = 0; k < n_col_tiles; ++k) {
Q6_activation_hf_mxmem_RR((unsigned int) p_tile_in, 2047);
Q6_weight_hf_mxmem_RR((unsigned int) v_tile_in, 2047);
p_tile_in += HMX_FP16_TILE_N_ELMS;
v_tile_in += HMX_FP16_TILE_N_ELMS;
}
Q6_mxmem_AR_after_hf(o_tile_out, 0);
}
static inline void hmx_fa_o_norm_tile(
const __fp16 * d_diag,
const __fp16 * o_rc,
__fp16 * o_out
) {
Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047);
Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047);
Q6_mxmem_AR_after_hf(o_out, 0);
}
#endif /* HMX_FA_KERNELS_H */
File diff suppressed because it is too large Load Diff
@@ -712,7 +712,17 @@ static inline void hmx_matmul_job_init(hmx_matmul_job_t * job,
// output : fp16 -> f32p
static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, uint32_t start_row, uint32_t n_rows, uint32_t n_cols, uint32_t dst_stride, uint32_t dst_cols) {
static void transfer_output_chunk_fp16_to_fp32(
float *restrict dst,
const float *restrict src2,
const __fp16 *restrict vtcm_src,
uint32_t start_row,
uint32_t n_rows,
uint32_t n_cols,
uint32_t dst_stride,
uint32_t src2_stride,
uint32_t dst_cols
) {
assert(n_cols % HTP_MM_HMX_TILE_N_COLS == 0);
const size_t tile_row_stride = (n_cols / HTP_MM_HMX_TILE_N_COLS) * HTP_MM_HMX_TILE_N_ELMS;
@@ -727,6 +737,7 @@ static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16
const size_t r1 = (r_idx0 % HTP_MM_HMX_TILE_N_ROWS) / 2; // index of the row pair within the tile
const __fp16 *row_base = vtcm_src + r0 * tile_row_stride;
float *output_row_base = dst + r * dst_stride; // global memory row base for row r (and r+1)
const float *src2_row_base = src2 ? (src2 + r * src2_stride) : NULL;
#pragma unroll(4)
for (size_t c = 0; c < limit_c_aligned; c += HTP_MM_HMX_TILE_N_COLS) {
@@ -738,9 +749,20 @@ static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16
HVX_Vector *pv_out0 = (HVX_Vector *) (output_row_base + c + 0);
HVX_Vector *pv_out1 = (HVX_Vector *) (output_row_base + c + dst_stride);
*pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp));
HVX_Vector v_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp));
if (src2_row_base) {
HVX_Vector v_src2_0 = hvx_vmemu(src2_row_base + c + 0);
v_out0 = hvx_vec_add_f32_f32(v_out0, v_src2_0);
}
*pv_out0 = v_out0;
if (r + 1 < n_rows) {
*pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp));
HVX_Vector v_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp));
if (src2_row_base) {
HVX_Vector v_src2_1 = hvx_vmemu(src2_row_base + c + src2_stride);
v_out1 = hvx_vec_add_f32_f32(v_out1, v_src2_1);
}
*pv_out1 = v_out1;
}
}
@@ -752,9 +774,20 @@ static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16
HVX_Vector v = ((const HVX_Vector *) tile)[r1];
HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one);
hvx_vec_store_u(output_row_base + c, valid_c * sizeof(float), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)));
HVX_Vector v_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp));
if (src2_row_base) {
HVX_Vector v_src2_0 = hvx_vmemu(src2_row_base + c + 0);
v_out0 = hvx_vec_add_f32_f32(v_out0, v_src2_0);
}
hvx_vec_store_u(output_row_base + c, valid_c * sizeof(float), v_out0);
if (r + 1 < n_rows) {
hvx_vec_store_u(output_row_base + c + dst_stride, valid_c * sizeof(float), Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp)));
HVX_Vector v_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp));
if (src2_row_base) {
HVX_Vector v_src2_1 = hvx_vmemu(src2_row_base + c + src2_stride);
v_out1 = hvx_vec_add_f32_f32(v_out1, v_src2_1);
}
hvx_vec_store_u(output_row_base + c + dst_stride, valid_c * sizeof(float), v_out1);
}
}
}
@@ -763,11 +796,13 @@ static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16
typedef struct {
const __fp16 *vtcm_src;
float *dst;
const float *src2;
uint32_t n_tasks;
uint32_t n_tot_chunks;
uint32_t n_chunks_per_task;
uint32_t n_cols;
uint32_t dst_stride; // DDR row stride
uint32_t src2_stride; // DDR row stride for residual
uint32_t dst_cols; // Actual output columns
struct htp_thread_trace * traces;
} output_transfer_task_state_t;
+35 -35
View File
@@ -42,14 +42,14 @@ static const int32_t hmx_transpose_scatter_offsets[32] __attribute__((aligned(VL
// Full range: start_row=0, end_row=n_cols.
static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
const __fp16 * restrict vtcm_src,
int n_cols,
int k,
int src_stride,
int start_row,
int end_row) {
uint32_t n_cols,
uint32_t k,
size_t src_stride,
uint32_t start_row,
uint32_t end_row) {
assert(k % HMX_FP16_TILE_N_COLS == 0);
const int n_k_tiles = k / HMX_FP16_TILE_N_COLS;
const uint32_t n_k_tiles = k / HMX_FP16_TILE_N_COLS;
const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets);
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4);
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64);
@@ -65,14 +65,14 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
if (pair_scatter) {
// Step c by 64 fp16 (two K-tiles per scatter), advance dst by 2 tiles per iter.
const int c_step = 2 * HMX_FP16_TILE_N_COLS;
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
const size_t dst_step = 2 * (size_t) HMX_FP16_TILE_N_ELMS;
const int n_c_iters = k / c_step;
const uint32_t c_step = 2 * HMX_FP16_TILE_N_COLS;
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
const size_t dst_step = 2 * (size_t) HMX_FP16_TILE_N_ELMS;
const uint32_t n_c_iters = k / c_step;
for (int r = start_row; r < end_row; r += 2) {
const int ct = r / HMX_FP16_TILE_N_ROWS;
const int local_r = r % HMX_FP16_TILE_N_ROWS;
for (uint32_t r = start_row; r < end_row; r += 2) {
const uint32_t ct = r / HMX_FP16_TILE_N_ROWS;
const uint32_t local_r = r % HMX_FP16_TILE_N_ROWS;
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols;
const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
@@ -86,7 +86,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
assert(c_byte_step % 128 == 0);
if (p1) {
for (int i = 0; i < n_c_iters; ++i) {
for (uint32_t i = 0; i < n_c_iters; ++i) {
HVX_Vector v0 = hvx_vmem(p0); p0 += c_byte_step;
HVX_Vector v1 = hvx_vmem(p1); p1 += c_byte_step;
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0);
@@ -95,7 +95,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
}
} else {
const HVX_Vector vzero = Q6_V_vzero();
for (int i = 0; i < n_c_iters; ++i) {
for (uint32_t i = 0; i < n_c_iters; ++i) {
HVX_Vector v0 = hvx_vmem(p0); p0 += c_byte_step;
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0);
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, vzero);
@@ -105,14 +105,14 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
}
} else {
// Fallback: scatter one K-tile per call (region 2047, masked).
const int c_step = HMX_FP16_TILE_N_COLS;
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
const size_t dst_step = (size_t) HMX_FP16_TILE_N_ELMS;
const int n_c_iters = k / c_step;
const uint32_t c_step = HMX_FP16_TILE_N_COLS;
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
const size_t dst_step = (size_t) HMX_FP16_TILE_N_ELMS;
const uint32_t n_c_iters = k / c_step;
for (int r = start_row; r < end_row; r += 2) {
const int ct = r / HMX_FP16_TILE_N_ROWS;
const int local_r = r % HMX_FP16_TILE_N_ROWS;
for (uint32_t r = start_row; r < end_row; r += 2) {
const uint32_t ct = r / HMX_FP16_TILE_N_ROWS;
const uint32_t local_r = r % HMX_FP16_TILE_N_ROWS;
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols;
const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
@@ -122,7 +122,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL;
if (p1) {
for (int i = 0; i < n_c_iters; ++i) {
for (uint32_t i = 0; i < n_c_iters; ++i) {
HVX_Vector v0 = hvx_vmemu(p0); p0 += c_byte_step;
HVX_Vector v1 = hvx_vmemu(p1); p1 += c_byte_step;
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0);
@@ -131,7 +131,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
}
} else {
const HVX_Vector vzero = Q6_V_vzero();
for (int i = 0; i < n_c_iters; ++i) {
for (uint32_t i = 0; i < n_c_iters; ++i) {
HVX_Vector v0 = hvx_vmemu(p0); p0 += c_byte_step;
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0);
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, vzero);
@@ -148,24 +148,24 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
// Full range: start_row=0, end_row=n_rows.
static inline void hmx_interleave_cols_to_tiles(__fp16 * restrict tiles_out,
const __fp16 * restrict src,
int n_rows,
int head_dim,
int src_stride,
int n_row_tiles,
int start_row,
int end_row) {
uint32_t n_rows,
uint32_t head_dim,
size_t src_stride,
uint32_t n_row_tiles,
uint32_t start_row,
uint32_t end_row) {
__builtin_assume(head_dim > 0);
const size_t tile_stride_elms = (size_t) n_row_tiles * HMX_FP16_TILE_N_ELMS;
for (int r = start_row; r < end_row; r += 2) {
for (uint32_t r = start_row; r < end_row; r += 2) {
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_rows;
const HVX_Vector * pv_in0 = (const HVX_Vector *) (src + r * src_stride);
const HVX_Vector * pv_in1 = next_row_valid ? (const HVX_Vector *) (src + (r + 1) * src_stride) : NULL;
// Row-pair invariants hoisted out of the c loop.
const int r0 = r / HMX_FP16_TILE_N_ROWS;
const int r1_half = (r % HMX_FP16_TILE_N_ROWS) / 2;
const uint32_t r0 = r / HMX_FP16_TILE_N_ROWS;
const uint32_t r1_half = (r % HMX_FP16_TILE_N_ROWS) / 2;
// tb0 starts at tile (c0=0, r0); tb1 at the adjacent dim-tile (c0=1, r0).
// Each c step (+= 64) advances both by 2 dim-tiles worth of fp16.
@@ -174,7 +174,7 @@ static inline void hmx_interleave_cols_to_tiles(__fp16 * restrict tiles_out,
const size_t tb_step = 2 * tile_stride_elms;
if (pv_in1) {
for (int c = 0; c < head_dim; c += 64) {
for (uint32_t c = 0; c < head_dim; c += 64) {
HVX_Vector v0 = *pv_in0++;
HVX_Vector v1 = *pv_in1++;
HVX_VectorPair vp = Q6_W_vshuff_VVR(v1, v0, -2);
@@ -185,7 +185,7 @@ static inline void hmx_interleave_cols_to_tiles(__fp16 * restrict tiles_out,
}
} else {
const HVX_Vector vzero = Q6_V_vzero();
for (int c = 0; c < head_dim; c += 64) {
for (uint32_t c = 0; c < head_dim; c += 64) {
HVX_Vector v0 = *pv_in0++;
HVX_VectorPair vp = Q6_W_vshuff_VVR(vzero, v0, -2);
((HVX_Vector *) tb0)[r1_half] = Q6_V_lo_W(vp);
+6
View File
@@ -60,6 +60,7 @@ enum htp_op_code {
HTP_OP_MUL_MAT_ID,
HTP_OP_MUL_MAT_QKV,
HTP_OP_MUL_MAT_FFN,
HTP_OP_MUL_MAT_ADD,
HTP_OP_RMS_NORM,
HTP_OP_RMS_NORM_MUL,
HTP_OP_UNARY_SILU,
@@ -175,6 +176,11 @@ enum htp_trace_event_id {
HTP_TRACE_EVT_HVX_W_DEQUANT = 23,
HTP_TRACE_EVT_HVX_W_PREP = 24,
HTP_TRACE_EVT_HVX_O_PROC = 25,
HTP_TRACE_EVT_HVX_FA_QK = 26,
HTP_TRACE_EVT_HVX_FA_SFM = 27,
HTP_TRACE_EVT_HVX_FA_Q_PREP = 28,
HTP_TRACE_EVT_HVX_FA_K_PREP = 29,
HTP_TRACE_EVT_HVX_FA_V_PREP = 30,
HTP_TRACE_EVT_HMX_COMP = 40,
};
+1 -12
View File
@@ -134,16 +134,7 @@ static inline HVX_Vector hvx_vec_f32_to_f16_shuff(HVX_Vector v0, HVX_Vector v1)
}
static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) {
HVX_Vector v = Q6_Vh_vdeal_Vh(hvx_vec_f32_to_f16_shuff(v0, v1));
#if __HVX_ARCH__ < 79
// replace NaNs with -INF, older arches produce NaNs for (-INF + 0.0)
const HVX_Vector neg_inf = hvx_vec_splat_f16(-INFINITY);
HVX_VectorPred nan = hvx_vec_is_nan_f16(v);
v = Q6_V_vmux_QVV(nan, neg_inf, v);
#endif
return v;
return Q6_Vh_vdeal_Vh(hvx_vec_f32_to_f16_shuff(v0, v1));
}
#if __HVX_ARCH__ >= 79
@@ -170,8 +161,6 @@ static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) {
}
#endif
static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) {
// This looks complicated.
// Ideally should just be Q6_Vh_equals_Vhf(vin)
+39
View File
@@ -16,6 +16,7 @@
#define EXP_COEFF_0 (0x3F000000) // 0.5 = 1/(2!)
#define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805
#define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408
#define EXP_LOG2E_F 1.44269504f
#define EXP_ONE (0x3f800000) // 1.0
#define EXP_RANGE_R (0x42B17218) // ln(FLT_MAX) approx = 88.7228
#define EXP_RANGE_L (0xC2B00000) // -88.0 (approx log(FLT_MIN))
@@ -213,4 +214,42 @@ static inline void hvx_exp_f32(uint8_t * restrict dst, const uint8_t * restrict
}
}
static inline HVX_Vector hvx_vec_exp2_f16(HVX_Vector x_v) {
const HVX_Vector zero_v = Q6_V_vzero();
const HVX_Vector half_hf_v = Q6_Vh_vsplat_R(0x3800); // fp16 0.5
// Clamp input to prevent integer underflow in FP16-to-INT16 conversion
const HVX_Vector v_clamp_min = hvx_vec_splat_f16(-24.0f);
x_v = Q6_Vhf_vmax_VhfVhf(v_clamp_min, x_v);
// k = round_toward_neg_inf(x); f = (float)k; frac = x - f
HVX_Vector x_minus_half = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vsub_VhfVhf(x_v, half_hf_v));
HVX_Vector k_v = Q6_Vh_equals_Vhf(x_minus_half); // truncate to int16
HVX_Vector f_v = Q6_Vhf_equals_Vh(k_v); // back to fp16
HVX_Vector x_qf16 = Q6_Vqf16_vsub_VhfVhf(x_v, f_v); // fractional part in qf16
// Horner: y = ((((E5*x + E4)*x + E3)*x + E2)*x + E1)*x + E0
HVX_Vector y = Q6_Vqf16_vmpy_Vqf16Vqf16(Q6_Vh_vsplat_R(0x5082), x_qf16); // E5*x
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x157d)); // + E4
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16);
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x20ed)); // + E3
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16);
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x2b1b)); // + E2
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16);
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x33b0)); // + E1
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16);
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x398c)); // + E0
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); // y = y * x
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x3c00)); // + 1.0
// Combine polynomial (mantissa) with integer part (exponent): result = y * 2^k
y = Q6_Vhf_equals_Vqf16(y);
HVX_Vector y_exp = Q6_Vuh_vlsr_VuhR(Q6_Vh_vasl_VhR(y, 1), 11);
y_exp = Q6_Vh_vadd_VhVh(k_v, y_exp);
HVX_VectorPred q_underflow = Q6_Q_vcmp_gt_VhVh(zero_v, y_exp);
y = Q6_Vh_vaslacc_VhVhR(y, k_v, 10);
return Q6_V_vmux_QVV(q_underflow, zero_v, y);
}
#endif /* HVX_EXP_H */
+232
View File
@@ -0,0 +1,232 @@
#ifndef HVX_FA_KERNELS_H
#define HVX_FA_KERNELS_H
#include <assert.h>
#include <math.h>
#include "hvx-utils.h"
// Little inner kernels for HVX
#if __HVX_ARCH__ < 79
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
#define HVX_OP_SUB_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b))
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
#else
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
#define HVX_OP_SUB_F32(a, b) Q6_Vsf_vsub_VsfVsf(a, b)
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#endif
// This is a bit of a hack because the compiler is struggling to properly inline
// the default hvx_vec_f32_to_f16 with output into the local array.
static __attribute__((unused)) __attribute__((noinline)) void hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1)
{
*(HVX_Vector *) ptr = hvx_vec_f32_to_f16(v0, v1);
}
// Dot product of two F16 vectors, accumulating to float
static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_VectorPair rsum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, vx[i], vy[i]);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf);
}
HVX_Vector rsum = HVX_OP_ADD_F32(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p));
rsum = HVX_OP_MUL_F32(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
hvx_vec_store_u(r, 4, rsum);
}
static inline HVX_Vector hvx_dot_f16_f16_aa_rx4(const void * restrict y,
const uint8_t * restrict x,
const size_t stride_x,
const size_t nvec,
const size_t nloe) {
const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x; // fp16
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) (x + stride_x); // fp16
const HVX_Vector * restrict vx2 = (const HVX_Vector * restrict) (x + stride_x * 2); // fp16
const HVX_Vector * restrict vx3 = (const HVX_Vector * restrict) (x + stride_x * 3); // fp16
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
HVX_VectorPair rsum0_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
HVX_VectorPair rsum1_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
HVX_VectorPair rsum2_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
HVX_VectorPair rsum3_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
uint32_t i = 0;
for (i = 0; i < nvec; i++) {
HVX_Vector y_hf = vy[i];
HVX_Vector x0_hf = vx0[i];
HVX_Vector x1_hf = vx1[i];
HVX_Vector x2_hf = vx2[i];
HVX_Vector x3_hf = vx3[i];
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf);
rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf);
}
if (nloe) {
// Load x (fp16) and zero-out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
HVX_Vector x2_hf = Q6_V_vand_QV(bmask, vx2[i]);
HVX_Vector x3_hf = Q6_V_vand_QV(bmask, vx3[i]);
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf);
rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf);
}
HVX_Vector rsum0 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p));
HVX_Vector rsum1 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p));
HVX_Vector rsum2 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p));
HVX_Vector rsum3 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p));
HVX_Vector_x4 rsum0123 = { .v = { rsum0, rsum1, rsum2, rsum3 } };
return hvx_vec_reduce_sum_f32x4(rsum0123);
}
static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y,
const uint8_t * restrict x,
const size_t stride_x,
const size_t n,
float s) {
const size_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
const size_t nloe = n % VLEN_FP16; // leftover elements
HVX_Vector sums = Q6_V_vzero();
const size_t stride_x_4 = stride_x * 4;
for (uint32_t j = 0; j < VLEN_FP32; j += 4) {
HVX_Vector sums_x4 = hvx_dot_f16_f16_aa_rx4(y, x, stride_x, nvec, nloe);
HVX_VectorPred pred = Q6_Q_vsetq_R(j * SIZEOF_FP32);
sums = Q6_V_vmux_QVV(pred, sums, sums_x4);
x += stride_x_4;
}
return HVX_OP_MUL_F32(hvx_vec_splat_f32(s), sums);
}
// MAD: y (F32) += x (F16) * s (F16)
static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, const __fp16 * restrict s, uint32_t n) {
const HVX_Vector * restrict vx0 = (const HVX_Vector *) x;
HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y;
HVX_Vector * restrict vy = (HVX_Vector *) y;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_Vector S0 = hvx_vec_splat_f16(*s);
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; ++i) {
vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0);
}
if (nloe) {
HVX_VectorPair xy_p = vy_p[i];
xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0);
HVX_Vector xy = Q6_V_lo_W(xy_p);
i = 2 * i; // index for vy
if (nloe >= VLEN_FP32) {
vy[i] = xy;
nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p);
}
if (nloe) {
hvx_vec_store_a(&vy[i], nloe * 4, xy);
}
}
}
// MAD: y (F32) += x0 (F16) * s0 (F16) + x1 (F16) * s1 (F16)
static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y, const void * restrict x0, const void * restrict x1,
const __fp16 * restrict s0, const __fp16 * restrict s1, uint32_t n) {
const HVX_Vector * restrict vx0 = (const HVX_Vector *) x0;
const HVX_Vector * restrict vx1 = (const HVX_Vector *) x1;
HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y;
HVX_Vector * restrict vy = (HVX_Vector *) y;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_Vector S0 = hvx_vec_splat_f16(*s0);
HVX_Vector S1 = hvx_vec_splat_f16(*s1);
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; ++i) {
vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0);
vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx1[i]), S1);
}
if (nloe) {
HVX_VectorPair xy_p = vy_p[i];
xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0);
xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx1[i]), S1);
HVX_Vector xy = Q6_V_lo_W(xy_p);
i = 2 * i; // index for vy
if (nloe >= VLEN_FP32) {
vy[i] = xy;
nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p);
}
if (nloe) {
hvx_vec_store_a(&vy[i], nloe * 4, xy);
}
}
}
static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t n, HVX_Vector vs) {
assert((size_t) dst % 128 == 0);
assert((size_t) src % 128 == 0);
const HVX_Vector * restrict vsrc = (const HVX_Vector * restrict) src;
HVX_Vector * restrict vdst = (HVX_Vector * restrict) dst;
const uint32_t nvec = n / VLEN_FP32;
const uint32_t nloe = n % VLEN_FP32;
uint32_t i = 0;
#pragma unroll(4)
for (; i < nvec; ++i) {
vdst[i] = HVX_OP_MUL_F32(vsrc[i], vs);
}
if (nloe) {
hvx_vec_store_a(&vdst[i], nloe * sizeof(float), HVX_OP_MUL_F32(vsrc[i], vs));
}
}
#endif /* HVX_FA_KERNELS_H */
+512 -25
View File
@@ -256,7 +256,7 @@ static inline void quantize_f16_f16_flat_kernel(
// Dot kernels that consume flat (non-tiled) activations
static void flat_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void flat_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -312,10 +312,14 @@ static void flat_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const v
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void flat_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void flat_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -397,11 +401,19 @@ static void flat_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void flat_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void flat_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -464,10 +476,14 @@ static void flat_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const v
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void flat_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void flat_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -561,11 +577,19 @@ static void flat_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void flat_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void flat_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -620,10 +644,14 @@ static void flat_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const v
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void flat_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void flat_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -704,11 +732,19 @@ static void flat_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void flat_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void flat_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -765,10 +801,14 @@ static void flat_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void flat_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void flat_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -851,11 +891,19 @@ static void flat_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void flat_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void flat_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -921,10 +969,14 @@ static void flat_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const
v_sum_float = hvx_vec_mul_f32_f32(v_sum_float, hvx_vec_splat_f32(0.5f));
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void flat_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void flat_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -1019,6 +1071,441 @@ static void flat_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c0 = hvx_vec_mul_f32_f32(v_sum_float_c0, hvx_vec_splat_f32(0.5f));
v_sum_float_c1 = hvx_vec_mul_f32_f32(v_sum_float_c1, hvx_vec_splat_f32(0.5f));
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
#if __HVX_ARCH__ < 79
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
#else
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#endif
static inline void vec_dot_f32_f32_aa_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const HVX_Vector * restrict x = (const HVX_Vector *) vx;
const HVX_Vector * restrict y = (const HVX_Vector *) vy;
uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors
uint32_t nloe = n % VLEN_FP32; // leftover elements
HVX_Vector rsum = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
HVX_Vector prod = HVX_OP_MUL_F32(x[i], y[i]);
rsum = HVX_OP_ADD_F32(rsum, prod);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
HVX_Vector x_sf = Q6_V_vand_QV(bmask, x[i]);
HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]);
HVX_Vector prod = HVX_OP_MUL_F32(x_sf, y_sf);
rsum = HVX_OP_ADD_F32(rsum, prod);
}
*s = hvx_vec_get_f32(hvx_vec_reduce_sum_f32(rsum));
}
static inline void vec_dot_f32_f32_aa_2x1(const uint32_t n, float * restrict s0,
const void * restrict vx0, const void * restrict vx1,
const void * restrict vy0) {
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
const HVX_Vector * restrict y = (const HVX_Vector *) vy0;
uint32_t nvec = n / VLEN_FP32;
uint32_t nloe = n % VLEN_FP32;
HVX_Vector rsum0 = Q6_V_vzero();
HVX_Vector rsum1 = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector y_sf = y[i];
HVX_Vector prod0 = HVX_OP_MUL_F32(x0[i], y_sf);
HVX_Vector prod1 = HVX_OP_MUL_F32(x1[i], y_sf);
rsum0 = HVX_OP_ADD_F32(rsum0, prod0);
rsum1 = HVX_OP_ADD_F32(rsum1, prod1);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]);
HVX_Vector x0_sf = Q6_V_vand_QV(bmask, x0[i]);
HVX_Vector x1_sf = Q6_V_vand_QV(bmask, x1[i]);
HVX_Vector prod0 = HVX_OP_MUL_F32(x0_sf, y_sf);
HVX_Vector prod1 = HVX_OP_MUL_F32(x1_sf, y_sf);
rsum0 = HVX_OP_ADD_F32(rsum0, prod0);
rsum1 = HVX_OP_ADD_F32(rsum1, prod1);
}
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1);
hvx_vec_store_u(s0, 8, rsum);
}
static inline void vec_dot_f32_f32_aa_2x2(const uint32_t n, float * restrict s0, float * restrict s1,
const void * restrict vx0, const void * restrict vx1,
const void * restrict vy0, const void * restrict vy1) {
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0;
const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1;
uint32_t nvec = n / VLEN_FP32;
uint32_t nloe = n % VLEN_FP32;
HVX_Vector r0_c0_sum = Q6_V_vzero();
HVX_Vector r0_c1_sum = Q6_V_vzero();
HVX_Vector r1_c0_sum = Q6_V_vzero();
HVX_Vector r1_c1_sum = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector r0_sf = x0[i];
HVX_Vector r1_sf = x1[i];
HVX_Vector c0_sf = y0[i];
HVX_Vector c1_sf = y1[i];
r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf));
r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf));
r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf));
r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf));
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
HVX_Vector r0_sf = Q6_V_vand_QV(bmask, x0[i]);
HVX_Vector r1_sf = Q6_V_vand_QV(bmask, x1[i]);
HVX_Vector c0_sf = Q6_V_vand_QV(bmask, y0[i]);
HVX_Vector c1_sf = Q6_V_vand_QV(bmask, y1[i]);
r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf));
r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf));
r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf));
r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf));
}
// Reduce and store results
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
hvx_vec_store_u(s0, 8, r0_r1_c0_sum);
hvx_vec_store_u(s1, 8, r0_r1_c1_sum);
}
static inline void vec_dot_f32_f32_uu_1x1(const uint32_t n, float * restrict s, const void * restrict x, const void * restrict y) {
const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors
uint32_t nloe = n % VLEN_FP32; // leftover elements
HVX_Vector rsum = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector x_sf = vx[i];
HVX_Vector y_sf = vy[i];
rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf));
}
if (nloe) {
HVX_Vector x_sf = vx[i];
HVX_Vector y_sf = vy[i];
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
x_sf = Q6_V_vand_QV(bmask, x_sf);
y_sf = Q6_V_vand_QV(bmask, y_sf);
rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf));
}
rsum = hvx_vec_reduce_sum_f32(rsum);
hvx_vec_store_u(&s[0], 4, rsum);
}
#undef HVX_OP_ADD_F32
#undef HVX_OP_MUL_F32
static inline void vec_dot_f16_f16_aa_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const HVX_Vector * restrict x = (const HVX_Vector *) vx;
const HVX_Vector * restrict y = (const HVX_Vector *) vy;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_VectorPair rsum_p = Q6_W_vzero();
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x[i], y[i]);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf);
}
HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p)));
hvx_vec_store_u(s, 4, hvx_vec_reduce_sum_f32(rsum));
}
static inline void vec_dot_f16_f16_aa_2x1(const uint32_t n, float * restrict s0,
const void * restrict vx0, const void * restrict vx1,
const void * restrict vy0) {
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
const HVX_Vector * restrict y = (const HVX_Vector *) vy0;
uint32_t nvec = n / VLEN_FP16;
uint32_t nloe = n % VLEN_FP16;
HVX_VectorPair rsum0_p = Q6_W_vzero();
HVX_VectorPair rsum1_p = Q6_W_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector y_hf = y[i];
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0[i], y_hf);
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1[i], y_hf);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, x0[i]);
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, x1[i]);
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
}
HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p)));
HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p)));
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1);
hvx_vec_store_u(s0, 8, rsum);
}
static inline void vec_dot_f16_f16_aa_2x2(const uint32_t n, float * restrict s0, float * restrict s1,
const void * restrict vx0, const void * restrict vx1,
const void * restrict vy0, const void * restrict vy1) {
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0;
const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1;
uint32_t nvec = n / VLEN_FP16;
uint32_t nloe = n % VLEN_FP16;
// Row sums (sf) - 4 accumulators for 2x2 tile
HVX_VectorPair r0_c0_sum_p = Q6_W_vzero();
HVX_VectorPair r0_c1_sum_p = Q6_W_vzero();
HVX_VectorPair r1_c0_sum_p = Q6_W_vzero();
HVX_VectorPair r1_c1_sum_p = Q6_W_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector r0_hf = x0[i];
HVX_Vector r1_hf = x1[i];
HVX_Vector c0_hf = y0[i];
HVX_Vector c1_hf = y1[i];
// Compute 4 dot products: r0xc0, r0xc1, r1xc0, r1xc1
r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf);
r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf);
r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf);
r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector r0_hf = Q6_V_vand_QV(bmask, x0[i]);
HVX_Vector r1_hf = Q6_V_vand_QV(bmask, x1[i]);
HVX_Vector c0_hf = Q6_V_vand_QV(bmask, y0[i]);
HVX_Vector c1_hf = Q6_V_vand_QV(bmask, y1[i]);
r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf);
r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf);
r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf);
r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf);
}
HVX_Vector r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c0_sum_p), Q6_V_hi_W(r0_c0_sum_p)));
HVX_Vector r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c1_sum_p), Q6_V_hi_W(r0_c1_sum_p)));
HVX_Vector r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c0_sum_p), Q6_V_hi_W(r1_c0_sum_p)));
HVX_Vector r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c1_sum_p), Q6_V_hi_W(r1_c1_sum_p)));
// Reduce and store results
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0
hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
}
static inline void vec_dot_f16_f16_uu_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const HVX_UVector * restrict x = (const HVX_UVector *) vx;
const HVX_UVector * restrict y = (const HVX_UVector *) vy;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_Vector rsum = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
hvx_vec_store_u(&s[0], 4, rsum);
}
static inline void vec_dot_f16_f32_uu_1x1(const uint32_t n, float * restrict s, const void * restrict x, const void * restrict y) {
const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vzero();
HVX_Vector rsum = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
// Load y (fp32) and convert into fp16
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
// Load x (fp16)
HVX_Vector x_hf = vx[i];
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
if (nloe) {
// Load y (fp32) and convert into fp16
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
// Load x (fp16)
HVX_Vector x_hf = vx[i];
// Zero-out unused elements
// Note that we need to clear both x and y because they may contain NANs
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
x_hf = Q6_V_vand_QV(bmask, x_hf);
y_hf = Q6_V_vand_QV(bmask, y_hf);
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
// Convert into fp32 and reduce
rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
hvx_vec_store_u(&s[0], 4, rsum);
}
static inline void hvx_tensor_add_f32_grid(
const struct htp_tensor * restrict dst,
const struct htp_tensor * restrict src2,
uint32_t start_row,
uint32_t end_row,
uint32_t start_col,
uint32_t end_col,
const struct fastdiv_values * div_ne11_12,
const struct fastdiv_values * div_ne11
) {
if (start_row >= end_row || start_col >= end_col) return;
const uint32_t nb1 = dst->nb[1]; // row stride in bytes
const uint32_t ne11 = dst->ne[1];
const uint32_t ne12 = dst->ne[2];
const uint32_t ne11_12 = ne11 * ne12;
const bool is_broadcast1 = (src2->ne[1] == 1);
const bool is_broadcast2 = (src2->ne[2] == 1);
const bool is_broadcast3 = (src2->ne[3] == 1);
for (uint32_t r = start_row; r < end_row; r++) {
float * dst_row = (float *) ((uint8_t *) dst->data + r * nb1);
uint32_t i13 = fastdiv(r, div_ne11_12);
uint32_t i12 = fastdiv(r - i13 * ne11_12, div_ne11);
uint32_t i11 = r - i13 * ne11_12 - i12 * ne11;
uint32_t i23 = is_broadcast3 ? 0 : i13;
uint32_t i22 = is_broadcast2 ? 0 : i12;
uint32_t i21 = is_broadcast1 ? 0 : i11;
const float * src2_row = (const float *) ((const uint8_t *) src2->data +
i21 * src2->nb[1] + i22 * src2->nb[2] + i23 * src2->nb[3]);
float * dst_ptr = &dst_row[start_col];
const float * src2_ptr = &src2_row[start_col];
int remaining = end_col - start_col;
while (remaining >= 32) {
HVX_Vector v_out = hvx_vmemu(dst_ptr);
HVX_Vector v_z = hvx_vmemu(src2_ptr);
hvx_vmemu(dst_ptr) = hvx_vec_add_f32_f32(v_out, v_z);
dst_ptr += 32;
src2_ptr += 32;
remaining -= 32;
}
if (remaining > 0) {
HVX_Vector v_out = hvx_vmemu(dst_ptr);
HVX_Vector v_z = hvx_vmemu(src2_ptr);
hvx_vec_store_u(dst_ptr, remaining * sizeof(float), hvx_vec_add_f32_f32(v_out, v_z));
}
}
}
@@ -378,7 +378,7 @@ static inline HVX_VectorPair accum_q8_0_32x2(
return Q6_W_vcombine_VV(v_sum1, v_sum0);
}
static void tiled_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void tiled_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -401,10 +401,14 @@ static void tiled_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void tiled_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void tiled_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -484,11 +488,19 @@ static void tiled_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void tiled_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void tiled_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -519,10 +531,14 @@ static void tiled_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void tiled_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void tiled_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -637,11 +653,19 @@ static void tiled_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void tiled_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void tiled_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -663,10 +687,14 @@ static void tiled_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void tiled_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void tiled_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -745,11 +773,19 @@ static void tiled_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void tiled_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void tiled_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -773,10 +809,14 @@ static void tiled_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void tiled_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void tiled_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -857,11 +897,19 @@ static void tiled_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, floa
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void tiled_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void tiled_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -896,10 +944,14 @@ static void tiled_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const
v_sum_float = hvx_vec_mul_f32_f32(v_sum_float, hvx_vec_splat_f32(0.5f));
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void tiled_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void tiled_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -1013,8 +1065,16 @@ static void tiled_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, floa
v_sum_float_c0 = hvx_vec_mul_f32_f32(v_sum_float_c0, hvx_vec_splat_f32(0.5f));
v_sum_float_c1 = hvx_vec_mul_f32_f32(v_sum_float_c1, hvx_vec_splat_f32(0.5f));
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static inline void quantize_f32_q8_0_tiled_kernel(
+39
View File
@@ -3,6 +3,7 @@
#include "hvx-base.h"
#include "hvx-inverse.h"
#include "hvx-exp.h"
#define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022
#define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777
@@ -139,4 +140,42 @@ static inline void hvx_tanh_f32_aa(uint8_t * restrict dst, const uint8_t * restr
hvx_tanh_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
static inline HVX_Vector hvx_vec_fast_sigmoid_f16(HVX_Vector x_v) {
const HVX_Vector v_one = hvx_vec_splat_f16(1.0f);
const HVX_Vector v_neg_log2e = hvx_vec_splat_f16(-EXP_LOG2E_F);
const HVX_Vector em_mask = Q6_Vh_vsplat_R(0x7FFF);
// Compute absolute value of x_v
HVX_Vector abs_x = Q6_V_vand_VV(x_v, em_mask);
// Compute u = -abs_x * log2(e) <= 0.
HVX_Vector u = hvx_vec_mul_f16_f16(abs_x, v_neg_log2e);
// Clamp input to prevent underflow in exp2
const HVX_Vector v_clamp_min = hvx_vec_splat_f16(-24.0f);
u = Q6_Vhf_vmax_VhfVhf(v_clamp_min, u);
HVX_Vector exp_val = hvx_vec_exp2_f16(u);
HVX_Vector denom = hvx_vec_add_f16_f16(v_one, exp_val);
HVX_Vector sig_abs = hvx_vec_inverse_f16(denom);
// check if x_v < 0 (using integer comparison on absolute value)
HVX_VectorPred is_neg = Q6_Q_vcmp_gt_VhVh(abs_x, x_v);
// If x_v < 0, return 1.0f - sig_abs
HVX_Vector sig_neg = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vsub_VhfVhf(v_one, sig_abs));
return Q6_V_vmux_QVV(is_neg, sig_neg, sig_abs);
}
static inline HVX_Vector hvx_vec_tanh_f16(HVX_Vector x) {
// tanh(x) = 2 * sigmoid(2x) - 1
const HVX_Vector v_two = hvx_vec_splat_f16(2.0f);
HVX_Vector x2 = hvx_vec_mul_f16_f16(x, v_two);
HVX_Vector sig2x = hvx_vec_fast_sigmoid_f16(x2);
const HVX_Vector v_neg_one = hvx_vec_splat_f16(-1.0f);
return hvx_vec_add_f16_f16(hvx_vec_mul_f16_f16(sig2x, v_two), v_neg_one);
}
#endif /* HVX_SIGMOID_H */
+1
View File
@@ -575,6 +575,7 @@ static inline void profile_stop(uint32_t mode, struct profile_data * d) {
static int execute_op(struct htp_ops_context * octx) {
switch (octx->op) {
case HTP_OP_MUL_MAT:
case HTP_OP_MUL_MAT_ADD:
return op_matmul(octx);
case HTP_OP_MUL_MAT_ID:
File diff suppressed because it is too large Load Diff
+19 -32
View File
@@ -392,56 +392,49 @@ static inline size_t htp_mm_hvx_get_vtcm_sizes(
case HTP_MM_KERNEL_HVX_QUANT_ROW: {
size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10);
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256);
// src0 spad is also used in dynamic quantizer to store padded src1 rows
size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, QK_Q8_0_TILED * sizeof(float));
if (vtcm_src0_size < src1_row_size_padded) {
vtcm_src0_size = src1_row_size_padded;
}
vtcm_src0_size = vtcm_src0_size * n_threads;
vtcm_dst_size = vtcm_dst_size * n_threads;
if (is_repack) {
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
uint32_t n_k_tiles = ne10 / 32;
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
vtcm_src0_size = repacked_vtcm_size * n_threads;
}
size_t quant_scratch_size_per_thread = htp_mm_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float));
size_t dst_size_per_thread = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
if (dst_size_per_thread < quant_scratch_size_per_thread) {
dst_size_per_thread = quant_scratch_size_per_thread;
}
vtcm_dst_size = dst_size_per_thread * n_threads;
break;
}
case HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT: {
size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(ne10) : htp_mm_q8_0_flat_row_size(ne10);
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256);
size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, 256);
if (vtcm_src0_size < src1_row_size_padded) {
vtcm_src0_size = src1_row_size_padded;
}
vtcm_src0_size = vtcm_src0_size * n_threads;
vtcm_dst_size = vtcm_dst_size * n_threads;
if (is_repack) {
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
uint32_t n_k_tiles = ne10 / 32;
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
vtcm_src0_size = repacked_vtcm_size * n_threads;
}
size_t quant_scratch_size_per_thread = htp_mm_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float));
size_t dst_size_per_thread = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
if (dst_size_per_thread < quant_scratch_size_per_thread) {
dst_size_per_thread = quant_scratch_size_per_thread;
}
vtcm_dst_size = dst_size_per_thread * n_threads;
break;
}
default:
@@ -463,7 +456,8 @@ static inline size_t htp_mm_hvx_id_get_vtcm_sizes(
size_t src0_row_size, // nb01
uint32_t n_prefetch,
size_t * vtcm_src0_size_out,
size_t * vtcm_src1_size_out
size_t * vtcm_src1_size_out,
size_t * vtcm_dst_size_out
) {
const bool is_repack = (wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 ||
wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL ||
@@ -476,29 +470,22 @@ static inline size_t htp_mm_hvx_id_get_vtcm_sizes(
size_t src0_sz_per_thread = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
size_t src1_sz = htp_mm_round_up(src1_row_size * src1_nrows, 256);
// src0 spad also holds temporary transposed src1 columns during dynamic quantization.
const size_t src1_row_size_padded = htp_mm_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float));
if (src0_sz_per_thread < src1_row_size_padded) {
src0_sz_per_thread = src1_row_size_padded;
}
if (is_repack) {
const uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
const uint32_t n_k_tiles = ne10 / 32;
const uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
src0_sz_per_thread = repacked_vtcm_size;
}
const size_t vtcm_src0_size = src0_sz_per_thread * n_threads;
const size_t vtcm_dst_size = htp_mm_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float)) * n_threads;
*vtcm_src0_size_out = vtcm_src0_size;
*vtcm_src1_size_out = src1_sz;
*vtcm_dst_size_out = vtcm_dst_size;
return vtcm_src0_size + src1_sz;
return vtcm_src0_size + src1_sz + vtcm_dst_size;
}
#ifdef __cplusplus
+10
View File
@@ -31,6 +31,11 @@ if (GGML_OPENCL_EMBED_KERNELS)
target_include_directories(${TARGET_NAME} PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/autogenerated")
endif ()
if (GGML_OPENCL_USE_ADRENO_BIN_KERNELS)
message(STATUS "OpenCL will use precompiled binary kernels for Adreno (improved performance on some platforms)")
add_compile_definitions(GGML_OPENCL_USE_ADRENO_BIN_KERNELS)
endif ()
function(ggml_opencl_add_kernel KNAME)
set(KERN_HDR ${CMAKE_CURRENT_BINARY_DIR}/autogenerated/${KNAME}.cl.h)
set(KERN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/kernels/${KNAME}.cl)
@@ -78,6 +83,8 @@ set(GGML_OPENCL_KERNELS
mul_mv_f16_f32_l4
mul_mv_f16_f32
mul_mv_f32_f32
mul_mv_q1_0_f32
mul_mv_q1_0_f32_flat
mul_mv_q4_0_f32
mul_mv_q4_0_f32_v
mul_mv_q4_0_f32_8x_flat
@@ -128,6 +135,7 @@ set(GGML_OPENCL_KERNELS
moe_sort_by_expert
mul_mm_f32_f32_l4_lm
mul_mm_f16_f32_l4_lm
mul_mm_q1_0_f32_l4_lm
mul_mm_q4_0_f32_l4_lm
mul_mm_q4_1_f32_l4_lm
mul_mm_q5_0_f32_l4_lm
@@ -137,6 +145,8 @@ set(GGML_OPENCL_KERNELS
mul_mm_q4_k_f32_l4_lm
mul_mm_q5_k_f32_l4_lm
mul_mm_q6_k_f32_l4_lm
gemv_noshuffle_q1_0_f32
gemm_noshuffle_q1_0_f32
gemv_noshuffle_q4_0_f32
gemv_noshuffle_q4_0_f32_spec
gemm_noshuffle_q4_0_f32
File diff suppressed because it is too large Load Diff
+46
View File
@@ -27,6 +27,8 @@
#define QR5_1 2
#define QK8_0 32
#define QR8_0 1
#define QK1_0 128
#define QR1_0 1
#define QK_K 256
#define K_SCALE_SIZE (3 * QK_K / 64)
#define K_QUANTS_PER_ITERATION 2
@@ -38,6 +40,14 @@ typedef ushort uint16_t;
typedef int int32_t;
typedef uint uint32_t;
//------------------------------------------------------------------------------
// block_q1_0
//------------------------------------------------------------------------------
typedef struct {
half d; // delta
uchar qs[QK1_0/8]; // 1-bit signs (16 bytes)
} block_q1_0;
//------------------------------------------------------------------------------
// block_q4_0
//------------------------------------------------------------------------------
@@ -159,6 +169,42 @@ kernel void kernel_convert_f16_to_bf16(
}
}
//------------------------------------------------------------------------------
// kernel_convert_block_q1_0
// Convert block_q1_0 (AOS) to 2 separate arrays (SOA): quant bytes + scales.
// q1_0 bits are stored in natural order (bit j of byte i -> weight 8*i + j)
//------------------------------------------------------------------------------
kernel void kernel_convert_block_q1_0(
global block_q1_0 * src0,
global uchar * dst_q,
global half * dst_d
) {
global block_q1_0 * b = (global block_q1_0 *) src0 + get_global_id(0);
global uchar * q = (global uchar *) dst_q + (QK1_0/8)*get_global_id(0);
global half * d = (global half *) dst_d + get_global_id(0);
*d = b->d;
for (int i = 0; i < QK1_0/8; ++i) {
q[i] = b->qs[i];
}
}
kernel void kernel_restore_block_q1_0(
global uchar * src_q,
global half * src_d,
global block_q1_0 * dst
) {
global block_q1_0 * b = (global block_q1_0 *) dst + get_global_id(0);
global uchar * q = (global uchar *) src_q + (QK1_0/8)*get_global_id(0);
global half * d = (global half *) src_d + get_global_id(0);
b->d = *d;
for (int i = 0; i < QK1_0/8; ++i) {
b->qs[i] = q[i];
}
}
//------------------------------------------------------------------------------
// kernel_convert_block_q4_0
// Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA).
@@ -0,0 +1,94 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#ifdef cl_qcom_reqd_sub_group_size
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
// each work-item computes a 4 (rows of A / m) x 8 (cols of B / n) output tile.
#ifdef ADRENO_GPU
REQD_SUBGROUP_SIZE_128
#endif
kernel void kernel_gemm_noshuffle_q1_0_f32(
global const uint * src0_q,
global const half * src0_d,
read_only image1d_buffer_t src1,
global float * dst,
int k,
int m,
int n,
int n_no_padding,
ulong offsetd
) {
int n_4 = n >> 2;
int gy = get_global_id(0);
int gx = get_global_id(1);
int gx_2 = gx << 2;
dst = (global float *)((global char*)dst + offsetd);
half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0;
half8 B;
global const uint* wptr = src0_q + gx_2;
global const half* sptr = src0_d + gx_2;
// 32 weights per uint32, 128 weights (one block / one scale) per 4 uint32.
for (int i = 0; i < k; i += 32) {
uint4 pack4 = vload4(0, wptr + (i / 32) * m); // 4 rows, 32 K-values each
half4 scale = vload4(0, sptr + (i / 128) * m); // 4 rows, one scale per 128
for (int j = 0; j < 32; ++j) {
B.s0123 = read_imageh(src1, gy * 2 + (i + j) * n_4);
B.s4567 = read_imageh(src1, gy * 2 + (i + j) * n_4 + 1);
// sign bit -> +-1 (half arithmetic avoids unsigned underflow)
half4 wj = (half4)(
2.0h * (half)((pack4.s0 >> j) & 1u) - 1.0h,
2.0h * (half)((pack4.s1 >> j) & 1u) - 1.0h,
2.0h * (half)((pack4.s2 >> j) & 1u) - 1.0h,
2.0h * (half)((pack4.s3 >> j) & 1u) - 1.0h) * scale;
c0 += B * wj.s0;
c1 += B * wj.s1;
c2 += B * wj.s2;
c3 += B * wj.s3;
}
}
int idx = (gy << 3) * m + (gx << 2);
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx);
}
}
@@ -0,0 +1,121 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#ifdef cl_qcom_reqd_sub_group_size
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#endif
#define QK1_0 128
#define N_SIMDGROUP 4
#define dequantizeBlockAccum_q1(total, bits, scale, regB, lb) \
total += (2.0f*(float)((bits >> 0) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s0, lb+0); \
total += (2.0f*(float)((bits >> 1) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s1, lb+0); \
total += (2.0f*(float)((bits >> 2) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s2, lb+0); \
total += (2.0f*(float)((bits >> 3) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s3, lb+0); \
total += (2.0f*(float)((bits >> 4) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s4, lb+0); \
total += (2.0f*(float)((bits >> 5) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s5, lb+0); \
total += (2.0f*(float)((bits >> 6) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s6, lb+0); \
total += (2.0f*(float)((bits >> 7) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s7, lb+0); \
total += (2.0f*(float)((bits >> 8) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s0, lb+1); \
total += (2.0f*(float)((bits >> 9) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s1, lb+1); \
total += (2.0f*(float)((bits >> 10) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s2, lb+1); \
total += (2.0f*(float)((bits >> 11) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s3, lb+1); \
total += (2.0f*(float)((bits >> 12) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s4, lb+1); \
total += (2.0f*(float)((bits >> 13) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s5, lb+1); \
total += (2.0f*(float)((bits >> 14) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s6, lb+1); \
total += (2.0f*(float)((bits >> 15) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s7, lb+1); \
total += (2.0f*(float)((bits >> 16) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s0, lb+2); \
total += (2.0f*(float)((bits >> 17) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s1, lb+2); \
total += (2.0f*(float)((bits >> 18) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s2, lb+2); \
total += (2.0f*(float)((bits >> 19) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s3, lb+2); \
total += (2.0f*(float)((bits >> 20) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s4, lb+2); \
total += (2.0f*(float)((bits >> 21) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s5, lb+2); \
total += (2.0f*(float)((bits >> 22) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s6, lb+2); \
total += (2.0f*(float)((bits >> 23) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s7, lb+2); \
total += (2.0f*(float)((bits >> 24) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s0, lb+3); \
total += (2.0f*(float)((bits >> 25) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s1, lb+3); \
total += (2.0f*(float)((bits >> 26) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s2, lb+3); \
total += (2.0f*(float)((bits >> 27) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s3, lb+3); \
total += (2.0f*(float)((bits >> 28) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s4, lb+3); \
total += (2.0f*(float)((bits >> 29) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s5, lb+3); \
total += (2.0f*(float)((bits >> 30) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s6, lb+3); \
total += (2.0f*(float)((bits >> 31) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s7, lb+3);
#ifdef ADRENO_GPU
REQD_SUBGROUP_SIZE_64
#endif
__kernel void kernel_gemv_noshuffle_q1_0_f32(
read_only image1d_buffer_t src0_q,
global half * src0_d,
read_only image1d_buffer_t src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3)
{
uint groupId = get_local_id(1);
uint gid = get_global_id(0);
ushort slid = get_sub_group_local_id();
uint K = ne00;
uint M = ne01;
uint LINE_STRIDE_A = M;
uint BLOCK_STRIDE_A = 4 * M;
uint4 regA;
half regS;
float8 regB;
float totalSum = 0.0f;
#pragma unroll 1
for (uint kb = groupId; kb < (K / QK1_0); kb += N_SIMDGROUP) {
regS = src0_d[gid + kb * LINE_STRIDE_A]; // each fiber loads its row's scale
// first 16 fibers load 8 B values each -> 128 activations for this block
if (slid < 16) {
regB.s0123 = read_imagef(src1, (slid * 2 + kb * 32));
regB.s4567 = read_imagef(src1, (1 + slid * 2 + kb * 32));
}
// load this row's 4 uint32 (128 sign bits)
regA.s0 = read_imageui(src0_q, (gid + kb * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;
regA.s1 = read_imageui(src0_q, (gid + kb * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;
regA.s2 = read_imageui(src0_q, (gid + kb * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;
regA.s3 = read_imageui(src0_q, (gid + kb * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;
float scale = (float)regS;
dequantizeBlockAccum_q1(totalSum, regA.s0, scale, regB, 0);
dequantizeBlockAccum_q1(totalSum, regA.s1, scale, regB, 4);
dequantizeBlockAccum_q1(totalSum, regA.s2, scale, regB, 8);
dequantizeBlockAccum_q1(totalSum, regA.s3, scale, regB, 12);
}
// reduction in local memory, assumes #wave = N_SIMDGROUP = 4
local float reduceLM[SIMDGROUP_WIDTH * 3];
if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum;
if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum;
if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum;
barrier(CLK_LOCAL_MEM_FENCE);
if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
if (groupId == 0) {
dst = (global float*)((global char*)dst + offsetd);
dst[gid] = totalSum;
}
}
@@ -0,0 +1,156 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
// LOAD_VEC_A is 8 because one q1_0 quant byte expands to 8 weights along K.
#define LOAD_VEC_A 8
#define LOAD_VEC_B 4
#define BM 64
#define BN 64
#define BK 32
#define TM 4
#define TN 8
kernel void kernel_mul_mm_q1_0_f32_l4_lm(
global uchar * src0_q,
global half * src0_d,
global float4 * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne11,
int ne12,
int stride_a,
int stride_b,
int stride_d,
int batch_stride_a,
int batch_stride_b,
int batch_stride_d,
int r2,
int r3
) {
src1 = (global float4*)((global char*)src1 + offset1);
dst = (global float *)((global char*)dst + offsetd);
local float buf_a[BM * BK];
local float buf_b[BN * BK];
const int batch_idx = get_global_id(2);
const int i13 = batch_idx / ne12;
const int i12 = batch_idx % ne12;
const int i03 = i13 / r3;
const int i02 = i12 / r2;
const int batch_idx_a = i03 * ne02 + i02;
const int ir = get_group_id(0);
const int ic = get_group_id(1);
const int tid = get_local_id(0);
const int th_r = tid % (BM / TM);
const int th_c = tid / (BM / TM);
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
float sums[TM * TN];
float cache_a[TM];
float cache_b[TN];
for (int i = 0; i < TM * TN; i++) {
sums[i] = 0.0f;
}
for (int block = 0; block < ne00; block += BK) {
for (int l = 0; l < BM; l += loadstride_a) {
if (ir*BM + loadc_a + l < ne01) {
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
int ib = idx / 16; // 16 quant bytes per q1_0 block
float d = (float)src0_d[ib];
uint bits = src0_q[idx];
// use float to avoid unsigned underflow of (2*0 - 1).
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 0) & 1) - 1.0f);
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 1) & 1) - 1.0f);
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 2) & 1) - 1.0f);
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 3) & 1) - 1.0f);
buf_a[(loadr_a * LOAD_VEC_A + 4) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 4) & 1) - 1.0f);
buf_a[(loadr_a * LOAD_VEC_A + 5) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 5) & 1) - 1.0f);
buf_a[(loadr_a * LOAD_VEC_A + 6) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 6) & 1) - 1.0f);
buf_a[(loadr_a * LOAD_VEC_A + 7) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 7) & 1) - 1.0f);
} else {
for (int b = 0; b < LOAD_VEC_A; ++b) {
buf_a[(loadr_a * LOAD_VEC_A + b) * BM + loadc_a + l] = 0.0f;
}
}
}
for (int l = 0; l < BN; l += loadstride_b) {
if (ic*BN + loadc_b + l < ne11) {
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
} else {
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B;
for (int i = 0; i < BK; i++) {
for (int j = 0; j < TM; j++) {
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
}
for (int j = 0; j < TN; j++) {
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
}
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
const int sums_idx = cc*TM + cr;
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
}
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
const int dr = ir * BM + th_r * TM;
const int dc = ic * BN + th_c * TN;
const int offsets = batch_idx * batch_stride_d;
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
if (dr + cr < ne01 && dc + cc < ne11) {
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
}
}
}
}
@@ -0,0 +1,141 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define QK1_0 128
typedef struct {
half d;
uchar qs[QK1_0/8];
} block_q1_0;
#define NB_Q1_0 16
#ifdef INTEL_GPU
#define N_R0_Q1_0 4 // number of rows each subgroup works on
#define N_SG_Q1_0 2 // number of subgroups in a work group
#define N_SIMDWIDTH 16 // subgroup size
#elif defined (ADRENO_GPU)
#define N_R0_Q1_0 4
#define N_SG_Q1_0 2
#define N_SIMDWIDTH 64
#endif
inline float block_q_1_0_dot_y(global block_q1_0 * qb, float sumy, float yl[NB_Q1_0], short il) {
global uchar * qs = qb->qs + il*2;
uint b0 = qs[0];
uint b1 = qs[1];
float acc = 0.f;
acc += yl[ 0]*(float)((b0 >> 0) & 1) + yl[ 1]*(float)((b0 >> 1) & 1);
acc += yl[ 2]*(float)((b0 >> 2) & 1) + yl[ 3]*(float)((b0 >> 3) & 1);
acc += yl[ 4]*(float)((b0 >> 4) & 1) + yl[ 5]*(float)((b0 >> 5) & 1);
acc += yl[ 6]*(float)((b0 >> 6) & 1) + yl[ 7]*(float)((b0 >> 7) & 1);
acc += yl[ 8]*(float)((b1 >> 0) & 1) + yl[ 9]*(float)((b1 >> 1) & 1);
acc += yl[10]*(float)((b1 >> 2) & 1) + yl[11]*(float)((b1 >> 3) & 1);
acc += yl[12]*(float)((b1 >> 4) & 1) + yl[13]*(float)((b1 >> 5) & 1);
acc += yl[14]*(float)((b1 >> 6) & 1) + yl[15]*(float)((b1 >> 7) & 1);
return qb->d * (2.0f*acc - sumy);
}
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mv_q1_0_f32(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
int ne12,
ulong nb11,
ulong nb12,
ulong nb13,
int ne0,
int ne1,
int r2,
int r3
) {
src0 = (global char*)((global char*)src0 + offset0);
src1 = (global char*)((global char*)src1 + offset1);
dst = (global char*)((global char*)dst + offsetd);
int nb = ne00/QK1_0;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
int first_row = (r0*N_SG_Q1_0 + get_sub_group_id()) * N_R0_Q1_0;
uint i12 = im%ne12;
uint i13 = im/ne12;
ulong offset_src1 = r1*nb11 + i12*nb12 + i13*nb13;
global float * y = (global float *) (src1 + offset_src1);
// pointers to src0 rows
global block_q1_0 * ax[N_R0_Q1_0];
for (int row = 0; row < N_R0_Q1_0; ++row) {
ulong offset_src0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
ax[row] = (global block_q1_0 *) ((global char *) src0 + offset_src0);
}
float yl[NB_Q1_0];
float sumf[N_R0_Q1_0] = { 0.f };
const short ix = get_sub_group_local_id()/8;
const short il = get_sub_group_local_id()%8;
global float * yb = y + ix*QK1_0 + il*NB_Q1_0;
// each thread handles NB_Q1_0 quants at a time
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) {
float sumy = 0.f;
for (short i = 0; i < NB_Q1_0; ++i) {
yl[i] = yb[i];
sumy += yb[i];
}
for (short row = 0; row < N_R0_Q1_0; row++) {
sumf[row] += block_q_1_0_dot_y(ax[row] + ib, sumy, yl, il);
}
yb += N_SIMDWIDTH*NB_Q1_0;
}
global float * dst_f32 = (global float *) dst + (ulong)im*ne0*ne1 + (ulong)r1*ne0;
for (int row = 0; row < N_R0_Q1_0; ++row) {
float tot = sub_group_reduce_add(sumf[row]);
if (get_sub_group_local_id() == 0 && first_row + row < ne01) {
dst_f32[first_row + row] = tot;
}
}
}
@@ -0,0 +1,190 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define QK1_0 128
#define QK1_0_BYTES (QK1_0/8) // 16 quant bytes per block
#define QK1_0_BLK_BYTES (QK1_0_BYTES + 2) // d + qs in original tensor = 18
#define NB_Q1_0 16 // quants handled per thread (two qs bytes)
#ifdef INTEL_GPU
#define N_R0_Q1_0 4 // number of rows each subgroup works on
#define N_SG_Q1_0 2 // number of subgroups in a work group
#define N_SIMDWIDTH 16 // subgroup size
#elif defined (ADRENO_GPU)
#define N_R0_Q1_0 4
#define N_SG_Q1_0 2
#define N_SIMDWIDTH 64
#endif
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mv_q1_0_f32_flat(
global char * src0_q,
global half * src0_d,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
int ne12,
ulong nb11,
ulong nb12,
ulong nb13,
int ne0,
int ne1,
int r2,
int r3
) {
src1 = (global char*)((global char*)src1 + offset1);
dst = (global char*)((global char*)dst + offsetd);
int nb = ne00/QK1_0;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
int first_row = (r0*N_SG_Q1_0 + get_sub_group_id()) * N_R0_Q1_0;
uint i12 = im%ne12;
uint i13 = im/ne12;
ulong offset_src1 = r1*nb11 + i12*nb12 + i13*nb13;
global float * y = (global float *) (src1 + offset_src1);
// pointers to src0 rows (flat: q bytes + scales)
uint offset_src0_base = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
global uchar * ax0, * ax1, * ax2, * ax3;
global half * ad0, * ad1, * ad2, * ad3;
uint offset_src0;
offset_src0 = (offset_src0_base + 0*nb01) / QK1_0_BLK_BYTES;
ax0 = (global uchar *) ((global char *) src0_q + offset_src0*QK1_0_BYTES);
ad0 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
offset_src0 = (offset_src0_base + 1*nb01) / QK1_0_BLK_BYTES;
ax1 = (global uchar *) ((global char *) src0_q + offset_src0*QK1_0_BYTES);
ad1 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
offset_src0 = (offset_src0_base + 2*nb01) / QK1_0_BLK_BYTES;
ax2 = (global uchar *) ((global char *) src0_q + offset_src0*QK1_0_BYTES);
ad2 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
offset_src0 = (offset_src0_base + 3*nb01) / QK1_0_BLK_BYTES;
ax3 = (global uchar *) ((global char *) src0_q + offset_src0*QK1_0_BYTES);
ad3 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
const short ix = get_sub_group_local_id()/8;
const short il = get_sub_group_local_id()%8;
global float * yb = y + ix*QK1_0 + il*NB_Q1_0;
float8 yl_lo;
float8 yl_hi;
float4 sumf = 0.f;
// each thread handles NB_Q1_0 = 16 quants (two qs bytes) at a time
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) {
yl_lo = vload8(0, yb);
yl_hi = vload8(0, yb + 8);
float sumy = yl_lo.s0 + yl_lo.s1 + yl_lo.s2 + yl_lo.s3
+ yl_lo.s4 + yl_lo.s5 + yl_lo.s6 + yl_lo.s7
+ yl_hi.s0 + yl_hi.s1 + yl_hi.s2 + yl_hi.s3
+ yl_hi.s4 + yl_hi.s5 + yl_hi.s6 + yl_hi.s7;
uint b0, b1;
float acc;
b0 = ax0[ib*QK1_0_BYTES + il*2 + 0];
b1 = ax0[ib*QK1_0_BYTES + il*2 + 1];
acc = yl_lo.s0*(float)((b0 >> 0) & 1) + yl_lo.s1*(float)((b0 >> 1) & 1)
+ yl_lo.s2*(float)((b0 >> 2) & 1) + yl_lo.s3*(float)((b0 >> 3) & 1)
+ yl_lo.s4*(float)((b0 >> 4) & 1) + yl_lo.s5*(float)((b0 >> 5) & 1)
+ yl_lo.s6*(float)((b0 >> 6) & 1) + yl_lo.s7*(float)((b0 >> 7) & 1)
+ yl_hi.s0*(float)((b1 >> 0) & 1) + yl_hi.s1*(float)((b1 >> 1) & 1)
+ yl_hi.s2*(float)((b1 >> 2) & 1) + yl_hi.s3*(float)((b1 >> 3) & 1)
+ yl_hi.s4*(float)((b1 >> 4) & 1) + yl_hi.s5*(float)((b1 >> 5) & 1)
+ yl_hi.s6*(float)((b1 >> 6) & 1) + yl_hi.s7*(float)((b1 >> 7) & 1);
sumf.s0 += (float)ad0[ib] * (2.0f*acc - sumy);
b0 = ax1[ib*QK1_0_BYTES + il*2 + 0];
b1 = ax1[ib*QK1_0_BYTES + il*2 + 1];
acc = yl_lo.s0*(float)((b0 >> 0) & 1) + yl_lo.s1*(float)((b0 >> 1) & 1)
+ yl_lo.s2*(float)((b0 >> 2) & 1) + yl_lo.s3*(float)((b0 >> 3) & 1)
+ yl_lo.s4*(float)((b0 >> 4) & 1) + yl_lo.s5*(float)((b0 >> 5) & 1)
+ yl_lo.s6*(float)((b0 >> 6) & 1) + yl_lo.s7*(float)((b0 >> 7) & 1)
+ yl_hi.s0*(float)((b1 >> 0) & 1) + yl_hi.s1*(float)((b1 >> 1) & 1)
+ yl_hi.s2*(float)((b1 >> 2) & 1) + yl_hi.s3*(float)((b1 >> 3) & 1)
+ yl_hi.s4*(float)((b1 >> 4) & 1) + yl_hi.s5*(float)((b1 >> 5) & 1)
+ yl_hi.s6*(float)((b1 >> 6) & 1) + yl_hi.s7*(float)((b1 >> 7) & 1);
sumf.s1 += (float)ad1[ib] * (2.0f*acc - sumy);
b0 = ax2[ib*QK1_0_BYTES + il*2 + 0];
b1 = ax2[ib*QK1_0_BYTES + il*2 + 1];
acc = yl_lo.s0*(float)((b0 >> 0) & 1) + yl_lo.s1*(float)((b0 >> 1) & 1)
+ yl_lo.s2*(float)((b0 >> 2) & 1) + yl_lo.s3*(float)((b0 >> 3) & 1)
+ yl_lo.s4*(float)((b0 >> 4) & 1) + yl_lo.s5*(float)((b0 >> 5) & 1)
+ yl_lo.s6*(float)((b0 >> 6) & 1) + yl_lo.s7*(float)((b0 >> 7) & 1)
+ yl_hi.s0*(float)((b1 >> 0) & 1) + yl_hi.s1*(float)((b1 >> 1) & 1)
+ yl_hi.s2*(float)((b1 >> 2) & 1) + yl_hi.s3*(float)((b1 >> 3) & 1)
+ yl_hi.s4*(float)((b1 >> 4) & 1) + yl_hi.s5*(float)((b1 >> 5) & 1)
+ yl_hi.s6*(float)((b1 >> 6) & 1) + yl_hi.s7*(float)((b1 >> 7) & 1);
sumf.s2 += (float)ad2[ib] * (2.0f*acc - sumy);
b0 = ax3[ib*QK1_0_BYTES + il*2 + 0];
b1 = ax3[ib*QK1_0_BYTES + il*2 + 1];
acc = yl_lo.s0*(float)((b0 >> 0) & 1) + yl_lo.s1*(float)((b0 >> 1) & 1)
+ yl_lo.s2*(float)((b0 >> 2) & 1) + yl_lo.s3*(float)((b0 >> 3) & 1)
+ yl_lo.s4*(float)((b0 >> 4) & 1) + yl_lo.s5*(float)((b0 >> 5) & 1)
+ yl_lo.s6*(float)((b0 >> 6) & 1) + yl_lo.s7*(float)((b0 >> 7) & 1)
+ yl_hi.s0*(float)((b1 >> 0) & 1) + yl_hi.s1*(float)((b1 >> 1) & 1)
+ yl_hi.s2*(float)((b1 >> 2) & 1) + yl_hi.s3*(float)((b1 >> 3) & 1)
+ yl_hi.s4*(float)((b1 >> 4) & 1) + yl_hi.s5*(float)((b1 >> 5) & 1)
+ yl_hi.s6*(float)((b1 >> 6) & 1) + yl_hi.s7*(float)((b1 >> 7) & 1);
sumf.s3 += (float)ad3[ib] * (2.0f*acc - sumy);
yb += N_SIMDWIDTH*NB_Q1_0;
}
global float * dst_f32 = (global float *) dst + (ulong)im*ne0*ne1 + (ulong)r1*ne0;
float4 tot = (float4)(
sub_group_reduce_add(sumf.s0),
sub_group_reduce_add(sumf.s1),
sub_group_reduce_add(sumf.s2),
sub_group_reduce_add(sumf.s3)
);
if (get_sub_group_local_id() == 0) {
if (first_row + 0 < ne01) dst_f32[first_row + 0] = tot.s0;
if (first_row + 1 < ne01) dst_f32[first_row + 1] = tot.s1;
if (first_row + 2 < ne01) dst_f32[first_row + 2] = tot.s2;
if (first_row + 3 < ne01) dst_f32[first_row + 3] = tot.s3;
}
}
+79
View File
@@ -0,0 +1,79 @@
#pragma once
#ifdef _WIN32
# define WIN32_LEAN_AND_MEAN
# ifndef NOMINMAX
# define NOMINMAX
# endif
# include <windows.h>
# include <winevt.h>
#else
# include <dlfcn.h>
# include <unistd.h>
#endif
#include <filesystem>
namespace fs = std::filesystem;
#ifdef _WIN32
using dl_handle = std::remove_pointer_t<HMODULE>;
struct dl_handle_deleter {
void operator()(HMODULE handle) {
FreeLibrary(handle);
}
};
static inline dl_handle * dl_load_library(const fs::path & path) {
// suppress error dialogs for missing DLLs
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
HMODULE handle = LoadLibraryW(path.wstring().c_str());
SetErrorMode(old_mode);
return handle;
}
static inline void * dl_get_sym(dl_handle * handle, const char * name) {
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
void * p = (void *) GetProcAddress(handle, name);
SetErrorMode(old_mode);
return p;
}
static inline const char * dl_error() {
return "";
}
#else
using dl_handle = void;
struct dl_handle_deleter {
void operator()(void * handle) {
dlclose(handle);
}
};
static inline dl_handle * dl_load_library(const fs::path & path) {
dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
return handle;
}
static inline void * dl_get_sym(dl_handle * handle, const char * name) {
return dlsym(handle, name);
}
static inline const char * dl_error() {
const char *rslt = dlerror();
return rslt != nullptr ? rslt : "";
}
#endif
+95
View File
@@ -2475,6 +2475,85 @@ static bool ggml_vk_strip_decode_vector(const uint32_t * code, size_t word_count
return true;
}
// Remove the loop unrolling hint of the matmul shader's BK loop
// and replace it with the dont_unroll hint for better performance on
// hardware like Apple M1/M2.
// Assumes 1. code comes from mul_mm.comp 2. the K-tile loop has no loop
// control hint and 3. the BK loop is the last loop nested directly inside
// the K-tile loop.
// Returns true when the input was modified; returns false otherwise
// without touching `out`.
static bool ggml_vk_roll_bk_loop(const uint32_t * code, size_t word_count, std::vector<uint32_t> & out) {
if (word_count < 5) {
return false;
}
struct vk_spv_loop {
size_t header;
size_t end;
uint32_t control;
};
std::vector<vk_spv_loop> loops;
// Collect a list of all loops in the module.
for (size_t pos = 5; pos < word_count; ) {
const uint32_t wc = code[pos] >> spv::WordCountShift;
const uint32_t op = code[pos] & spv::OpCodeMask;
if (wc == 0 || pos + wc > word_count) {
return false;
}
if (op == spv::OpLoopMerge && wc >= 4) { loops.push_back({ pos, 0, code[pos + 3] }); }
if (op == spv::OpLabel && wc >= 2) {
for (auto & l : loops) {
if (l.end == 0 && code[l.header + 1] == code[pos + 1]) { l.end = pos; }
}
}
pos += wc;
}
auto encloses = [](const vk_spv_loop & a, const vk_spv_loop & b) {
return a.header < b.header && b.header < a.end;
};
// Find the BK loop.
const vk_spv_loop * bk = nullptr;
for (const auto & h : loops) {
if (h.control != spv::LoopControlUnrollMask) {
continue;
}
const vk_spv_loop * parent = nullptr;
bool has_child = false;
for (const auto & g : loops) {
if (encloses(g, h) && (!parent || g.header > parent->header)) {
parent = &g;
}
if (encloses(h, g)) {
has_child = true;
}
}
// BK loop should be the last loop nested inside the loop with no hint
// and have at least one child loop.
if (parent &&
parent->control == spv::LoopControlMaskNone &&
has_child &&
(!bk || h.header > bk->header)) {
bk = &h;
}
}
if (!bk) {
return false;
}
// set DontUnroll instead of Unroll
out.assign(code, code + word_count);
out[bk->header + 3] = spv::LoopControlDontUnrollMask;
return true;
}
static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, size_t spv_size, const void* spv_data, const std::string entrypoint,
uint32_t parameter_count, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants,
bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) {
@@ -2558,6 +2637,22 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
}
#endif
#if VK_HEADER_VERSION >= 287
// Roll the mul_mm BK loop on Asahi Linux. Skip bf16 and the mul_mmq pipelines.
if (device->driver_id == vk::DriverId::eMesaHoneykrisp &&
pipeline->name.rfind("matmul", 0) == 0 &&
pipeline->name.find("bf16") == std::string::npos &&
pipeline->name.find("q8_1") == std::string::npos) {
const uint32_t * src = spirv.empty() ? reinterpret_cast<const uint32_t *>(spv_data) : spirv.data();
size_t src_n = spirv.empty() ? spv_size / sizeof(uint32_t) : spirv.size();
std::vector<uint32_t> rolled;
if (ggml_vk_roll_bk_loop(src, src_n, rolled)) {
spirv = std::move(rolled);
shader_module_create_info = vk::ShaderModuleCreateInfo({}, spirv.size() * sizeof(uint32_t), spirv.data());
}
}
#endif
pipeline->shader_module = device->device.createShaderModule(shader_module_create_info);
vk::PushConstantRange pcr(
+4 -1
View File
@@ -69,13 +69,16 @@ mbuf=
mmsel=
[ "$MM" != "" ] && mmsel="GGML_HEXAGON_MM_SELECT=$MM"
fasel=
[ "$FA" != "" ] && fasel="GGML_HEXAGON_FA_SELECT=$FA"
set -x
adb $adbserial $adbhost shell " \
cd $basedir; ulimit -c unlimited; \
LD_LIBRARY_PATH=$basedir/$branch/lib \
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opflt $opfuse $vmem $mbuf $mmsel \
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opflt $opfuse $vmem $mbuf $mmsel $fasel \
./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
--ctx-size 8192 --ubatch-size 1024 -fa on \
+4 -1
View File
@@ -57,6 +57,9 @@ opfuse=
mmsel=
[ "$MM" != "" ] && mmsel="GGML_HEXAGON_MM_SELECT=$MM"
fasel=
[ "$FA" != "" ] && fasel="GGML_HEXAGON_FA_SELECT=$FA"
set -x
tool=$1; shift
@@ -65,5 +68,5 @@ adb $adbserial $adbhost shell " \
cd $basedir; ulimit -c unlimited; \
LD_LIBRARY_PATH=$basedir/$branch/lib \
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opfuse $mmsel ./$branch/bin/$tool $@ \
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opfuse $mmsel $fasel ./$branch/bin/$tool $@ \
"
@@ -230,6 +230,12 @@ def print_ascii_timeline(op_name, dims, types, usec, cycles, events, evt_val=Non
char = 'Q'
elif norm_evt == 'A-PREP':
char = 'A'
elif norm_evt == 'Q-PREP':
char = 'q'
elif norm_evt == 'K-PREP':
char = 'k'
elif norm_evt == 'V-PREP':
char = 'v'
elif norm_evt == 'W-DEQUANT':
char = 'D'
elif norm_evt == 'O-PROC':
+2
View File
@@ -121,6 +121,8 @@ llama_model_qwen3next::graph::graph(const llama_model & model, const llm_graph_p
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
res->t_layer_inp[il] = inpL;
ggml_tensor * inpSA = inpL;
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
+1
View File
@@ -7759,6 +7759,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_F32, 1, 8, 2, 1, false));
test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_F32, 1, 70000, 4, 1, false)); // row count > CUDA grid-y limit (65535)
for (ggml_type type : all_types) {
for (bool v : {false, true}) {
test_cases.emplace_back(new test_get_rows_back(type, 256, 5, 4, 1, v));
+1 -1
View File
@@ -39,7 +39,7 @@ static server_http_res_ptr proxy_request(const server_http_req & req, std::strin
throw std::runtime_error("unsupported URL scheme in target URL: " + parsed_url.scheme);
}
SRV_INF("proxying %s request to %s://%s:%i%s\n", method.c_str(), parsed_url.scheme.c_str(), parsed_url.host.c_str(), parsed_url.port, parsed_url.path.c_str());
SRV_INF("proxying %s request to %s://%s:%i%s\n", method.c_str(), parsed_url.scheme.c_str(), common_http_format_host(parsed_url.host).c_str(), parsed_url.port, parsed_url.path.c_str());
std::map<std::string, std::string> headers;
const std::string proxy_header_prefix = "x-llama-server-proxy-header-";
+2 -1
View File
@@ -1,4 +1,5 @@
#include "common.h"
#include "http.h"
#include "server-http.h"
#include "server-stream.h"
#include "server-common.h"
@@ -441,7 +442,7 @@ bool server_http_context::start() {
srv->wait_until_ready();
listening_address = is_sock ? string_format("unix://%s", hostname.c_str())
: string_format("%s://%s:%d", is_ssl ? "https" : "http", hostname.c_str(), port);
: string_format("%s://%s:%d", is_ssl ? "https" : "http", common_http_format_host(hostname).c_str(), port);
return true;
}
+3 -1
View File
@@ -1,4 +1,5 @@
#include "server-common.h"
#include "http.h"
#include "server-models.h"
#include "server-context.h"
#include "server-stream.h"
@@ -2263,7 +2264,8 @@ server_http_proxy::server_http_proxy(
}
if (lowered == "host") {
bool is_default_port = (scheme == "https" && port == 443) || (scheme == "http" && port == 80);
req.set_header(key, is_default_port ? host : host + ":" + std::to_string(port));
const std::string url_host = common_http_format_host(host);
req.set_header(key, is_default_port ? url_host : url_host + ":" + std::to_string(port));
} else {
req.set_header(key, value);
}
@@ -43,7 +43,7 @@
assistantMessages: number;
messageTypes: string[];
} | null>(null);
let editedContent = $derived(message.content);
let editedContent = $state(message.content);
let rawEditContent = $derived.by(() => {
if (message.role !== MessageRole.ASSISTANT) return undefined;
+10 -7
View File
@@ -288,9 +288,7 @@ export const API_CACHING_PATTERNS = {
} as const;
// SvelteKit PWA plugin options
export const PWA_KIT_OPTIONS = {
NAVIGATE_FALLBACK: './'
} as const;
export const PWA_KIT_OPTIONS = {} as const;
export const APPLE_META_TAGS = {
MOBILE_WEB_APP_CAPABLE: { name: 'apple-mobile-web-app-capable', content: 'yes' },
@@ -322,6 +320,14 @@ export const SVELTEKIT_PWA_OPTIONS: SvelteKitPWAOptions = {
globIgnores: GLOB_IGNORES,
maximumFileSizeToCacheInBytes: CACHE_SETTINGS.MAX_FILE_SIZE_BYTES,
// Prevent @vite-pwa/sveltekit from auto-adding a NavigationRoute by
// setting navigateFallback to empty string. This keeps the service
// worker from intercepting direct browser navigation to server API
// endpoints (e.g. /slots, /models, /v1/models) which should return
// JSON, not the SPA HTML shell. The server's own static-file fallback
// handles non-API navigation to index.html for the SPA router.
navigateFallback: '',
// Runtime caching for API calls - use NetworkFirst so APIs are always fresh
runtimeCaching: [
{
@@ -351,10 +357,7 @@ export const SVELTEKIT_PWA_OPTIONS: SvelteKitPWAOptions = {
devOptions: {
enabled: true,
suppressWarnings: true,
// Use PWA_KIT_OPTIONS.NAVIGATE_FALLBACK to match production SW behaviour
// (navigateFallback defaults to the configured base path, which is '/' for this SPA).
navigateFallback: PWA_KIT_OPTIONS.NAVIGATE_FALLBACK
suppressWarnings: true
},
// SvelteKit-specific options
+29 -8
View File
@@ -1083,6 +1083,11 @@ class ChatStore {
let resolvedModel: string | null = null;
let modelPersisted = false;
const convId = assistantMessage.convId;
// Tracks the last message created in this flow. Used as the parent for the next
// turn's assistant message so createAssistantMessage does not have to read
// conversationsStore.activeMessages, which may belong to a different conversation
// after the user navigates while the loop is still running.
let lastCreatedInFlow = currentMessageId;
// freeze the POST identity from t0 so a stop cancels with the exact session key,
// never a stale or empty model resolved later
this.setChatStreaming(convId, streamedContent, currentMessageId, effectiveModel);
@@ -1208,8 +1213,15 @@ class ChatStore {
};
if (timings) uiUpdate.timings = timings;
if (resolvedModel) uiUpdate.model = resolvedModel;
conversationsStore.updateMessageAtIndex(idx, uiUpdate);
await conversationsStore.updateCurrentNode(currentMessageId);
// touch the active ui array and node pointer only when this conversation
// is displayed; otherwise persist the node move straight to the db so a
// foreign conv's currNode stays untouched
if (conversationsStore.activeConversation?.id === convId) {
conversationsStore.updateMessageAtIndex(idx, uiUpdate);
await conversationsStore.updateCurrentNode(currentMessageId);
} else {
await DatabaseService.updateCurrentNode(convId, currentMessageId);
}
},
createToolResultMessage: async (
toolCallId: string,
@@ -1230,8 +1242,16 @@ class ChatStore {
},
currentMessageId
);
conversationsStore.addMessageToActive(msg);
await conversationsStore.updateCurrentNode(msg.id);
// mirror into the active store and move the node pointer only when this
// conversation is displayed; otherwise persist the node move straight to
// the db for the owning conv so a foreign conv's currNode stays untouched
if (conversationsStore.activeConversation?.id === convId) {
conversationsStore.addMessageToActive(msg);
await conversationsStore.updateCurrentNode(msg.id);
} else {
await DatabaseService.updateCurrentNode(convId, msg.id);
}
lastCreatedInFlow = msg.id;
return msg;
},
createAssistantMessage: async () => {
@@ -1239,8 +1259,6 @@ class ChatStore {
streamedContent = '';
streamedReasoningContent = '';
const lastMsg =
conversationsStore.activeMessages[conversationsStore.activeMessages.length - 1];
const msg = await DatabaseService.createMessageBranch(
{
convId,
@@ -1252,10 +1270,13 @@ class ChatStore {
children: [],
model: resolvedModel
},
lastMsg.id
lastCreatedInFlow
);
conversationsStore.addMessageToActive(msg);
if (conversationsStore.activeConversation?.id === convId) {
conversationsStore.addMessageToActive(msg);
}
currentMessageId = msg.id;
lastCreatedInFlow = msg.id;
return msg;
},
onFlowComplete: (finalTimings?: ChatMessageTimings) => {
+4 -1
View File
@@ -43,7 +43,10 @@ test.describe('PWA Service Worker', () => {
expect(swContent).toMatch(/"_app\/immutable\/assets\/bundle\.[a-zA-Z0-9_-]+\.css"/);
expect(swContent).toMatch(/"manifest\.webmanifest"/);
expect(swContent).toMatch(/"_app\/version\.json"/);
expect(swContent).toMatch(/NavigationRoute/);
// NavigationRoute is intentionally absent — server API endpoints
// (e.g. /slots, /models) must not be intercepted by the PWA and
// should return JSON directly from the server.
expect(swContent).not.toMatch(/NavigationRoute/);
expect(swContent).toMatch(/api-cache/);
});
+4 -2
View File
@@ -108,9 +108,11 @@ describe('PWA Build Output', () => {
expect(swContent).toMatch(/"manifest\.webmanifest"/);
});
it('has navigation route registered', () => {
it('no navigation route — API endpoints bypass PWA', () => {
expect(swContent).toBeTruthy();
expect(swContent).toMatch(/NavigationRoute/);
// NavigationRoute is intentionally absent so direct browser
// navigation to server API endpoints returns JSON, not HTML.
expect(swContent).not.toMatch(/NavigationRoute/);
});
it('has runtime caching for API routes', () => {