Compare commits

...

8 Commits

Author SHA1 Message Date
Hongqiang Wang ebd048fc5e opencl: flash attention improvement (#25069)
* opencl: rework FA kernel for f16 and f32

* opencl: flash-attention prefill prepass kernels

- flash_attn_kv_pad_f16    pads the tail KV tile to a BLOCK_N multiple
- flash_attn_mask_pad_f16  pads the matching mask tile
- flash_attn_blk_f16       classifies each KV tile per query block as
                           fully masked / mixed / fully unmasked, so
                           the main kernel can skip fully-masked tiles
                           and the mask lookup for fully-unmasked ones

* opencl: FA kernels for q4_0 and q8_0

* opencl: `set_rows` for f32 to q8_0/q4_0

* opencl: dequant kernels for q4_0 and q8_0

* opencl: add FA tile tuning table with override

* opencl: wire host side for FA

* opencl: q4_0 MoE tensors are also SOA'ed

* opencl: cosmetic fix

* opencl: refactor, also clarify some code paths in comments

* opencl: fix inifity for `-cl-finite-math-only`

---------

Co-authored-by: Li He <lih@qti.qualcomm.com>
2026-06-27 15:36:06 -07:00
Gaurav Garg 0ed235ea2c [CUDA] Added a cudaMemcpy2DAsync fast path to ggml_cuda_cpy (#25057)
* [CUDA] Added a cudaMemcpy2DAsync fast path to ggml_cuda_cpy

Add a CUDA ggml_cpy fast path for same-type, same-shape strided copies that are just 2D pitched block copies.
When tensors are not fully contiguous but each row is contiguous, it now uses cudaMemcpy2DAsync instead of the slow element-wise scalar copy kernel.

This fixes the GDN recurrent snapshot update with -np 4, where rollback slots are separated by cache stride gaps.

* Add new tests that execute the new optimized strided copy path

* Return unsupported for strided copy in OpenVINO, as new tests are failing
2026-06-27 17:46:21 +05:30
Neo Zhang 9bebfcb4bc sycl : fix failed ut cases of norm (#25044) 2026-06-27 12:13:43 +03:00
Ruben Ortlam 0b6529d818 vulkan: fix step operator for 0 input (#25036) 2026-06-27 10:57:31 +02:00
Christian Kastner c299a92c38 binaries : Improve rpc-server and export-graph-ops names. (#25045)
Tests are generally prefixed with -test, so rename export-graph-ops
accordingly.

rpc-server is probably too generic a name for /usr/bin. Because it
should work with any ggml application, it is renamed to ggml-rpc-server.
2026-06-27 10:31:29 +03:00
Sigbjørn Skjæret 0275c0f800 ci : add windows-openvino to check-release (#25022) 2026-06-27 10:30:56 +03:00
Sigbjørn Skjæret 83d385b429 tests : fix test-chat-template --no-common option (#25075) 2026-06-27 10:30:19 +03:00
Adrien Gallouët 050ee92d04 app : allow --version, --licenses & --help (#25054)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-06-26 23:18:11 +02:00
24 changed files with 5834 additions and 503 deletions
+3
View File
@@ -552,6 +552,9 @@ jobs:
name: llama-bin-ubuntu-openvino-${{ env.OPENVINO_VERSION_MAJOR }}-x64.tar.gz
windows-openvino:
needs: [check-release]
if: ${{ needs.check-release.outputs.should_release == 'true' }}
runs-on: windows-2022
outputs:
+1 -1
View File
@@ -80,7 +80,7 @@ To protect sensitive data from potential leaks or unauthorized access, it is cru
### Untrusted environments or networks
If you can't run your models in a secure and isolated environment or if it must be exposed to an untrusted network, make sure to take the following security precautions:
* Do not use the RPC backend, [rpc-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) and [llama-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/server) functionality (see https://github.com/ggml-org/llama.cpp/pull/13061).
* Do not use the RPC backend, [ggml-rpc-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) and [llama-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/server) functionality (see https://github.com/ggml-org/llama.cpp/pull/13061).
* Confirm the hash of any downloaded artifact (e.g. pre-trained model weights) matches a known-good value.
* Encrypt your data if sending it over the network.
+8 -4
View File
@@ -50,6 +50,7 @@ struct command {
std::vector<std::string> aliases;
bool hidden;
int (*func)(int, char **);
bool flags = false; // allow --name
};
#ifdef LLAMA_INSTALL_BUILD
@@ -69,9 +70,9 @@ static const command cmds[] = {
{"fit-params", "Compute parameters to fit a model in device memory", {}, true, llama_fit_params },
{"quantize", "Quantize a model", {}, true, llama_quantize },
{"perplexity", "Compute model perplexity and KL divergence", {}, true, llama_perplexity },
{"version", "Show version", {}, false, version },
{"licenses", "Show third-party licenses", {"credits"}, false, licenses },
{"help", "Show available commands", {}, false, help },
{"version", "Show version", {}, false, version, true },
{"licenses", "Show third-party licenses", {"credits"}, false, licenses, true },
{"help", "Show available commands", {}, false, help, true },
};
#undef UPDATE_HIDDEN
@@ -108,7 +109,10 @@ static int help(int argc, char ** argv) {
return 0;
}
static bool matches(const std::string & arg, const command & cmd) {
static bool matches(std::string arg, const command & cmd) {
if (cmd.flags && arg.size() > 2 && arg[0] == '-' && arg[1] == '-') {
arg.erase(0, 2);
}
if (arg == cmd.name) {
return true;
}
+45
View File
@@ -386,6 +386,46 @@ static void ggml_cpy_f32_iq4_nl_cuda(
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
// check if a same-type copy reduces to a 2D strided copy (height rows of width
// contiguous bytes), so it can use cudaMemcpy2DAsync instead of the scalar kernel
static bool ggml_cuda_cpy_as_memcpy_2d(const ggml_tensor * src0, const ggml_tensor * src1,
size_t & width, size_t & height, size_t & spitch, size_t & dpitch) {
// require matching shape: a reshaped copy maps elements by flat order, which the
// prefix walk below does not handle
if (src0->type != src1->type || !ggml_are_same_shape(src0, src1)) {
return false;
}
// grow the contiguous prefix block shared by both tensors
size_t block_nb = ggml_element_size(src0);
int d = 0;
for (; d < GGML_MAX_DIMS; ++d) {
if (src0->nb[d] != block_nb || src1->nb[d] != block_nb) {
break;
}
block_nb *= src0->ne[d];
}
// d == 0: nothing contiguous; d == GGML_MAX_DIMS: fully contiguous (handled by memcpy)
if (d == 0 || d == GGML_MAX_DIMS) {
return false;
}
// dim d carries the rows; everything above it must be a single element
for (int i = d + 1; i < GGML_MAX_DIMS; ++i) {
if (src0->ne[i] != 1) {
return false;
}
}
width = block_nb;
height = src0->ne[d];
spitch = src0->nb[d];
dpitch = src1->nb[d];
return spitch >= width && dpitch >= width;
}
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1));
@@ -421,6 +461,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) &&
src0->ne[3] == 1 && nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0);
size_t mc_width = 0, mc_height = 0, mc_spitch = 0, mc_dpitch = 0;
if (src0->type == src1->type && contiguous_srcs) {
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
@@ -431,6 +473,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
{
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
}
} else if (ggml_cuda_cpy_as_memcpy_2d(src0, src1, mc_width, mc_height, mc_spitch, mc_dpitch)) {
CUDA_CHECK(cudaMemcpy2DAsync(src1_ddc, mc_dpitch, src0_ddc, mc_spitch,
mc_width, mc_height, cudaMemcpyDeviceToDevice, main_stream));
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
if (can_be_transposed) {
ggml_cpy_scalar_cuda<float, float, true>
+3
View File
@@ -192,7 +192,10 @@ set(GGML_OPENCL_KERNELS
mul_mm_f16_f32_kq_kqv
conv2d
conv2d_f16_f32
flash_attn_pre_f16
flash_attn_f32_f16
flash_attn_f32_q8_0
flash_attn_f32_q4_0
flash_attn_f16
flash_attn_f32
)
+91
View File
@@ -0,0 +1,91 @@
#pragma once
// Flash-attention per-(dk,dv) tile tuning for the Adreno OpenCL backend.
// Isolated from ggml-opencl.cpp so the tuning numbers are easy to find and
// edit; the FA dispatch and kernel-compile logic stay in the main file.
// This header is a file section — it is #included exactly once, at the point
// in ggml-opencl.cpp where the ggml logging macros are already in scope.
// Per-(dk, dv) FA config; shared by dispatch and supports_op.
struct ggml_opencl_fa_dim {
int dk; int dv; int bm; int bn; int n_split; int nkv_split_threshold;
};
// Split variant fires when n_kv >= threshold (threshold=0 -> always split).
// Default tuning covers Adreno 7xx/8xx mobile and X1-series laptop GPUs.
static const ggml_opencl_fa_dim g_fa_dims_adreno_default[] = {
{ 40, 40, 64, 32, 1, 0}, { 64, 64, 64, 32, 2, 64},
{ 80, 80, 64, 32, 2, 64}, { 96, 96, 64, 32, 2, 64},
{112, 112, 64, 32, 2, 64}, {128, 128, 64, 32, 2, 64},
{192, 128, 16, 16, 1, 0},
{192, 192, 16, 16, 1, 0},
{256, 256, 16, 16, 16, 0},
};
struct ggml_opencl_fa_dim_table {
const ggml_opencl_fa_dim * data;
size_t count;
const ggml_opencl_fa_dim * begin() const { return data; }
const ggml_opencl_fa_dim * end() const { return data + count; }
};
// Mutable copy of the active table; GGML_OPENCL_FA_TUNE patches entries here
// at backend init without touching the const source table.
static ggml_opencl_fa_dim g_fa_dims_runtime[
sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0])];
static ggml_opencl_fa_dim_table g_opencl_fa_dims = {
g_fa_dims_adreno_default,
sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0]),
};
// GGML_OPENCL_FA_TUNE=dk:dv:bm:bn:nsplit:thr[,…] — patches matching entries
// in the active table at backend init, before the first FA kernel compiles.
// Unmatched (dk,dv) pairs are warned and ignored.
static void ggml_opencl_fa_apply_env_overrides() {
const char * e = std::getenv("GGML_OPENCL_FA_TUNE");
if (!e || !e[0]) {
return;
}
std::string s = e;
size_t pos = 0;
while (pos < s.size()) {
size_t comma = s.find(',', pos);
std::string entry = s.substr(pos, comma == std::string::npos ? std::string::npos : comma - pos);
int dk, dv, bm, bn, nsplit, thr;
if (std::sscanf(entry.c_str(), "%d:%d:%d:%d:%d:%d", &dk, &dv, &bm, &bn, &nsplit, &thr) == 6) {
bool patched = false;
for (size_t i = 0; i < g_opencl_fa_dims.count; ++i) {
ggml_opencl_fa_dim & d = g_fa_dims_runtime[i];
if (d.dk == dk && d.dv == dv) {
d.bm = bm; d.bn = bn; d.n_split = nsplit; d.nkv_split_threshold = thr;
GGML_LOG_INFO("ggml_opencl: FA tune override DK=%d DV=%d -> bm=%d bn=%d n_split=%d thr=%d\n",
dk, dv, bm, bn, nsplit, thr);
patched = true;
break;
}
}
if (!patched) {
GGML_LOG_WARN("ggml_opencl: FA tune override DK=%d DV=%d ignored (no matching dim)\n", dk, dv);
}
} else {
GGML_LOG_WARN("ggml_opencl: FA tune override entry malformed: '%s'\n", entry.c_str());
}
if (comma == std::string::npos) break;
pos = comma + 1;
}
}
// Copy the default table into the mutable runtime buffer and apply any
// GGML_OPENCL_FA_TUNE overrides. A per-generation table can be added here
// once it has been tuned on hardware.
static void ggml_cl_init_fa_dims_table() {
const size_t count = sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0]);
for (size_t i = 0; i < count; ++i) {
g_fa_dims_runtime[i] = g_fa_dims_adreno_default[i];
}
g_opencl_fa_dims = { g_fa_dims_runtime, count };
ggml_opencl_fa_apply_env_overrides();
}
File diff suppressed because it is too large Load Diff
+152
View File
@@ -1582,6 +1582,158 @@ kernel void kernel_restore_block_q8_0(
}
}
// View-aware AoS q8_0 -> f32 dequant (f32/f32 FA path).
kernel void kernel_dequant_q8_0_f32_view_aos(
global char * src,
ulong src_offset,
ulong src_nb1,
ulong src_nb2,
ulong src_nb3,
int nblk0,
int ne1,
int ne2,
int ne3,
global float * dst
) {
int blk_i0 = get_global_id(0);
int i1 = get_global_id(1);
int batch = get_global_id(2);
if (blk_i0 >= nblk0) return;
if (i1 >= ne1) return;
int i2 = batch % ne2;
int i3 = batch / ne2;
if (i3 >= ne3) return;
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK8_0);
float d = vload_half(0, (global half *)block);
global char * qs = block + 2;
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
global float * out = dst + (dst_row_base + blk_i0) * QK8_0;
for (int i = 0; i < QK8_0; ++i) {
out[i] = d * (float)qs[i];
}
}
// View-aware AoS q8_0 -> f16 dequant. Rows tight, batch strides may be gapped.
kernel void kernel_dequant_q8_0_f16_view_aos(
global char * src,
ulong src_offset,
ulong src_nb1,
ulong src_nb2,
ulong src_nb3,
int nblk0,
int ne1,
int ne2,
int ne3,
global half * dst
) {
int blk_i0 = get_global_id(0);
int i1 = get_global_id(1);
int batch = get_global_id(2);
if (blk_i0 >= nblk0) return;
if (i1 >= ne1) return;
int i2 = batch % ne2;
int i3 = batch / ne2;
if (i3 >= ne3) return;
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK8_0);
float d = vload_half(0, (global half *)block);
global char * qs = block + 2;
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
global half * out = dst + (dst_row_base + blk_i0) * QK8_0;
for (int i = 0; i < QK8_0; ++i) {
out[i] = (half)(d * (float)qs[i]);
}
}
// View-aware AoS q4_0 -> f32 dequant (mirrors the q8_0 view variant).
kernel void kernel_dequant_q4_0_f32_view_aos(
global char * src,
ulong src_offset,
ulong src_nb1,
ulong src_nb2,
ulong src_nb3,
int nblk0,
int ne1,
int ne2,
int ne3,
global float * dst
) {
int blk_i0 = get_global_id(0);
int i1 = get_global_id(1);
int batch = get_global_id(2);
if (blk_i0 >= nblk0) return;
if (i1 >= ne1) return;
int i2 = batch % ne2;
int i3 = batch / ne2;
if (i3 >= ne3) return;
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK4_0/2);
float d = vload_half(0, (global half *)block);
global uchar * qs = (global uchar *)(block + 2);
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
global float * out = dst + (dst_row_base + blk_i0) * QK4_0;
for (int i = 0; i < QK4_0/2; ++i) {
uchar byte = qs[i];
int q0 = (int)(byte & 0x0F) - 8;
int q1 = (int)(byte >> 4) - 8;
out[i] = d * (float)q0;
out[i + QK4_0/2] = d * (float)q1;
}
}
// View-aware AoS q4_0 -> f16 dequant (mirrors the q8_0 view variant).
kernel void kernel_dequant_q4_0_f16_view_aos(
global char * src,
ulong src_offset,
ulong src_nb1,
ulong src_nb2,
ulong src_nb3,
int nblk0,
int ne1,
int ne2,
int ne3,
global half * dst
) {
int blk_i0 = get_global_id(0);
int i1 = get_global_id(1);
int batch = get_global_id(2);
if (blk_i0 >= nblk0) return;
if (i1 >= ne1) return;
int i2 = batch % ne2;
int i3 = batch / ne2;
if (i3 >= ne3) return;
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK4_0/2);
float d = vload_half(0, (global half *)block);
global uchar * qs = (global uchar *)(block + 2);
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
global half * out = dst + (dst_row_base + blk_i0) * QK4_0;
for (int i = 0; i < QK4_0/2; ++i) {
uchar byte = qs[i];
int q0 = (int)(byte & 0x0F) - 8;
int q1 = (int)(byte >> 4) - 8;
out[i] = (half)(d * (float)q0);
out[i + QK4_0/2] = (half)(d * (float)q1);
}
}
kernel void kernel_restore_block_q8_0_trans(
global uchar * src_q,
global half * src_d,
+75 -40
View File
@@ -4,14 +4,26 @@
#define ACC_TYPE4 float4
#define DATA_TYPE half
#define DATA_TYPE4 half4
#define CONVERT_ACC4(x) convert_float4(x)
#define CONVERT_DATA4(x) convert_half4(x)
#define CONVERT_ACC4(x) ((float4)((float)(x).s0, (float)(x).s1, (float)(x).s2, (float)(x).s3))
#define CONVERT_DATA4(x) ((half4)((half)(x).s0, (half)(x).s1, (half)(x).s2, (half)(x).s3))
#define DK_VEC (DK/4)
#define DV_VEC (DV/4)
#define WG_SIZE (BLOCK_M)
#define Q1_WG_SIZE 64
// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs,
// infinite operand can cause undefined behavior and miscompilation for exp.
// Therefore, a large negative value is used instead.
#define FA_M_INIT (-3.0e38f)
// Drop full unroll at DK>=192 Adreno compiler host-memory budget.
#if DK >= 192
#define FA_UNROLL
#else
#define FA_UNROLL _Pragma("unroll")
#endif
inline float get_alibi_slope(
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
) {
@@ -81,18 +93,18 @@ __kernel void flash_attn_f16(
if (my_query_row < n_q) {
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
}
}
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = (ACC_TYPE4)(0.0f);
}
ACC_TYPE m_i = -INFINITY;
ACC_TYPE m_i = FA_M_INIT;
ACC_TYPE l_i = 0.0f;
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
@@ -125,49 +137,72 @@ __kernel void flash_attn_f16(
continue;
}
for (int j = 0; j < BLOCK_N; j += 2) {
for (int j = 0; j < BLOCK_N; j += 4) {
const int k_row0 = k_start + j;
const int k_row1 = k_start + j + 1;
const int k_row2 = k_start + j + 2;
const int k_row3 = k_start + j + 3;
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
#pragma unroll
ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f);
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
const ACC_TYPE4 qk = q_priv[k];
dot_acc0 = mad(qk, CONVERT_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(qk, CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
dot_acc2 = mad(qk, CONVERT_ACC4(l_k[j+2][k]), dot_acc2);
dot_acc3 = mad(qk, CONVERT_ACC4(l_k[j+3][k]), dot_acc3);
}
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
ACC_TYPE s0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE s1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
ACC_TYPE s2 = (dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3) * scale;
ACC_TYPE s3 = (dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3) * scale;
if (is_causal) {
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
const int causal_limit = n_kv - n_q + my_query_row;
if (k_row0 > causal_limit) s0 = FA_M_INIT;
if (k_row1 > causal_limit) s1 = FA_M_INIT;
if (k_row2 > causal_limit) s2 = FA_M_INIT;
if (k_row3 > causal_limit) s3 = FA_M_INIT;
}
if (k_row0 >= n_kv) score0 = -INFINITY;
if (k_row1 >= n_kv) score1 = -INFINITY;
if (k_row0 >= n_kv) s0 = FA_M_INIT;
if (k_row1 >= n_kv) s1 = FA_M_INIT;
if (k_row2 >= n_kv) s2 = FA_M_INIT;
if (k_row3 >= n_kv) s3 = FA_M_INIT;
if (mask_base != NULL) {
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
if (k_row0 < n_kv) s0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) s1 += slope * (ACC_TYPE)mask_ptr[k_row1];
if (k_row2 < n_kv) s2 += slope * (ACC_TYPE)mask_ptr[k_row2];
if (k_row3 < n_kv) s3 += slope * (ACC_TYPE)mask_ptr[k_row3];
}
if (logit_softcap > 0.0f) {
score0 = logit_softcap * tanh(score0 / logit_softcap);
score1 = logit_softcap * tanh(score1 / logit_softcap);
s0 = logit_softcap * tanh(s0 / logit_softcap);
s1 = logit_softcap * tanh(s1 / logit_softcap);
s2 = logit_softcap * tanh(s2 / logit_softcap);
s3 = logit_softcap * tanh(s3 / logit_softcap);
}
const ACC_TYPE m_new = max(m_i, max(score0, score1));
const ACC_TYPE p0 = exp(score0 - m_new);
const ACC_TYPE p1 = exp(score1 - m_new);
const ACC_TYPE scale_prev = exp(m_i - m_new);
const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3)));
const ACC_TYPE scale_prev = native_exp(m_i - m_new);
const ACC_TYPE p0 = native_exp(s0 - m_new);
const ACC_TYPE p1 = native_exp(s1 - m_new);
const ACC_TYPE p2 = native_exp(s2 - m_new);
const ACC_TYPE p3 = native_exp(s3 - m_new);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]);
o_acc[i] = mad(p3, CONVERT_ACC4(l_v[j+3][i]),
mad(p2, CONVERT_ACC4(l_v[j+2][i]),
mad(p1, CONVERT_ACC4(l_v[j+1][i]),
mad(p0, CONVERT_ACC4(l_v[j][i]),
o_acc[i] * scale_prev))));
}
l_i = l_i * scale_prev + p0 + p1;
l_i = l_i * scale_prev + p0 + p1 + p2 + p3;
m_i = m_new;
}
}
@@ -179,7 +214,7 @@ __kernel void flash_attn_f16(
const ACC_TYPE m_final = max(m_i, m_sink);
const ACC_TYPE scale_o = exp(m_i - m_final);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] *= scale_o;
}
@@ -191,12 +226,12 @@ __kernel void flash_attn_f16(
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_i;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv);
}
} else {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = (DATA_TYPE4)(0.0f);
}
@@ -258,7 +293,7 @@ __kernel void flash_attn_f16_q1(
ACC_TYPE4 q_priv[DK_VEC];
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
}
@@ -270,12 +305,12 @@ __kernel void flash_attn_f16_q1(
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
}
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
}
@@ -293,7 +328,7 @@ __kernel void flash_attn_f16_q1(
__local ACC_TYPE local_m[Q1_WG_SIZE];
local_m[tid] = m_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
barrier(CLK_LOCAL_MEM_FENCE);
@@ -301,7 +336,7 @@ __kernel void flash_attn_f16_q1(
const ACC_TYPE m_final = local_m[0];
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
ACC_TYPE l_i = 0.0f;
@@ -311,7 +346,7 @@ __kernel void flash_attn_f16_q1(
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
}
@@ -325,7 +360,7 @@ __kernel void flash_attn_f16_q1(
}
const ACC_TYPE p = exp(score - m_final);
l_i += p;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; i++) {
o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
}
@@ -335,7 +370,7 @@ __kernel void flash_attn_f16_q1(
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
local_l[tid] = l_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_l[tid] += local_l[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -354,7 +389,7 @@ __kernel void flash_attn_f16_q1(
for (int i = 0; i < DV_VEC; i++) {
local_o_comp[tid] = o_acc[i];
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -364,7 +399,7 @@ __kernel void flash_attn_f16_q1(
}
}
} else if (tid == 0) {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f);
}
}
+73 -38
View File
@@ -13,6 +13,18 @@
#define WG_SIZE (BLOCK_M)
#define Q1_WG_SIZE 64
// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs,
// infinite operand can cause undefined behavior and miscompilation for exp.
// Therefore, a large negative value is used instead.
#define FA_M_INIT (-3.0e38f)
// Drop full unroll at DK>=192 Adreno compiler host-memory budget.
#if DK >= 192
#define FA_UNROLL
#else
#define FA_UNROLL _Pragma("unroll")
#endif
inline float get_alibi_slope(
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
) {
@@ -82,18 +94,18 @@ __kernel void flash_attn_f32(
if (my_query_row < n_q) {
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
}
}
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = (ACC_TYPE4)(0.0f);
}
ACC_TYPE m_i = -INFINITY;
ACC_TYPE m_i = FA_M_INIT;
ACC_TYPE l_i = 0.0f;
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
@@ -126,49 +138,72 @@ __kernel void flash_attn_f32(
continue;
}
for (int j = 0; j < BLOCK_N; j += 2) {
for (int j = 0; j < BLOCK_N; j += 4) {
const int k_row0 = k_start + j;
const int k_row1 = k_start + j + 1;
const int k_row2 = k_start + j + 2;
const int k_row3 = k_start + j + 3;
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
#pragma unroll
ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f);
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
const ACC_TYPE4 qk = q_priv[k];
dot_acc0 = mad(qk, CONVERT_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(qk, CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
dot_acc2 = mad(qk, CONVERT_ACC4(l_k[j+2][k]), dot_acc2);
dot_acc3 = mad(qk, CONVERT_ACC4(l_k[j+3][k]), dot_acc3);
}
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
ACC_TYPE s0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE s1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
ACC_TYPE s2 = (dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3) * scale;
ACC_TYPE s3 = (dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3) * scale;
if (is_causal) {
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
const int causal_limit = n_kv - n_q + my_query_row;
if (k_row0 > causal_limit) s0 = FA_M_INIT;
if (k_row1 > causal_limit) s1 = FA_M_INIT;
if (k_row2 > causal_limit) s2 = FA_M_INIT;
if (k_row3 > causal_limit) s3 = FA_M_INIT;
}
if (k_row0 >= n_kv) score0 = -INFINITY;
if (k_row1 >= n_kv) score1 = -INFINITY;
if (k_row0 >= n_kv) s0 = FA_M_INIT;
if (k_row1 >= n_kv) s1 = FA_M_INIT;
if (k_row2 >= n_kv) s2 = FA_M_INIT;
if (k_row3 >= n_kv) s3 = FA_M_INIT;
if (mask_base != NULL) {
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
if (k_row0 < n_kv) s0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) s1 += slope * (ACC_TYPE)mask_ptr[k_row1];
if (k_row2 < n_kv) s2 += slope * (ACC_TYPE)mask_ptr[k_row2];
if (k_row3 < n_kv) s3 += slope * (ACC_TYPE)mask_ptr[k_row3];
}
if (logit_softcap > 0.0f) {
score0 = logit_softcap * tanh(score0 / logit_softcap);
score1 = logit_softcap * tanh(score1 / logit_softcap);
s0 = logit_softcap * tanh(s0 / logit_softcap);
s1 = logit_softcap * tanh(s1 / logit_softcap);
s2 = logit_softcap * tanh(s2 / logit_softcap);
s3 = logit_softcap * tanh(s3 / logit_softcap);
}
const ACC_TYPE m_new = max(m_i, max(score0, score1));
const ACC_TYPE p0 = exp(score0 - m_new);
const ACC_TYPE p1 = exp(score1 - m_new);
const ACC_TYPE scale_prev = exp(m_i - m_new);
const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3)));
const ACC_TYPE scale_prev = native_exp(m_i - m_new);
const ACC_TYPE p0 = native_exp(s0 - m_new);
const ACC_TYPE p1 = native_exp(s1 - m_new);
const ACC_TYPE p2 = native_exp(s2 - m_new);
const ACC_TYPE p3 = native_exp(s3 - m_new);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]);
o_acc[i] = mad(p3, CONVERT_ACC4(l_v[j+3][i]),
mad(p2, CONVERT_ACC4(l_v[j+2][i]),
mad(p1, CONVERT_ACC4(l_v[j+1][i]),
mad(p0, CONVERT_ACC4(l_v[j][i]),
o_acc[i] * scale_prev))));
}
l_i = l_i * scale_prev + p0 + p1;
l_i = l_i * scale_prev + p0 + p1 + p2 + p3;
m_i = m_new;
}
}
@@ -180,7 +215,7 @@ __kernel void flash_attn_f32(
const ACC_TYPE m_final = max(m_i, m_sink);
const ACC_TYPE scale_o = exp(m_i - m_final);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] *= scale_o;
}
@@ -192,12 +227,12 @@ __kernel void flash_attn_f32(
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_i;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv);
}
} else {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = (DATA_TYPE4)(0.0f);
}
@@ -259,7 +294,7 @@ __kernel void flash_attn_f32_q1(
ACC_TYPE4 q_priv[DK_VEC];
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
}
@@ -271,12 +306,12 @@ __kernel void flash_attn_f32_q1(
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
}
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
}
@@ -294,7 +329,7 @@ __kernel void flash_attn_f32_q1(
__local ACC_TYPE local_m[Q1_WG_SIZE];
local_m[tid] = m_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
barrier(CLK_LOCAL_MEM_FENCE);
@@ -302,7 +337,7 @@ __kernel void flash_attn_f32_q1(
const ACC_TYPE m_final = local_m[0];
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
ACC_TYPE l_i = 0.0f;
@@ -312,7 +347,7 @@ __kernel void flash_attn_f32_q1(
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
}
@@ -326,7 +361,7 @@ __kernel void flash_attn_f32_q1(
}
const ACC_TYPE p = exp(score - m_final);
l_i += p;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; i++) {
o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
}
@@ -336,7 +371,7 @@ __kernel void flash_attn_f32_q1(
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
local_l[tid] = l_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_l[tid] += local_l[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -355,7 +390,7 @@ __kernel void flash_attn_f32_q1(
for (int i = 0; i < DV_VEC; i++) {
local_o_comp[tid] = o_acc[i];
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -365,7 +400,7 @@ __kernel void flash_attn_f32_q1(
}
}
} else if (tid == 0) {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f);
}
}
@@ -1,5 +1,13 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_khr_subgroup_shuffle
#pragma OPENCL EXTENSION cl_khr_subgroup_shuffle : enable
#define HAS_SUBGROUP_SHUFFLE 1
#elif defined(cl_qcom_subgroup_shuffle)
#pragma OPENCL EXTENSION cl_qcom_subgroup_shuffle : enable
#define HAS_SUBGROUP_SHUFFLE 1
#endif
#define ACC_TYPE float
#define ACC_TYPE4 float4
#define Q_DATA_TYPE4 float4
@@ -12,9 +20,34 @@
#define DK_VEC (DK/4)
#define DV_VEC (DV/4)
#define WG_SIZE (BLOCK_M)
#define Q1_WG_SIZE 64
// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs,
// infinite operand can cause undefined behavior and miscompilation for exp.
// Therefore, a large negative value is used instead.
#define FA_M_INIT (-3.0e38f)
// Drop full unroll at DK>=192 Adreno compiler host-memory budget.
#if DK >= 192
#define FA_UNROLL
#else
#define FA_UNROLL _Pragma("unroll")
#endif
// N_SPLIT>1 splits DK/DV across threads to cut per-thread register use.
#ifndef N_SPLIT
#define N_SPLIT 1
#endif
#define SPLIT_DK_VEC (DK_VEC / N_SPLIT)
#define SPLIT_DV_VEC (DV_VEC / N_SPLIT)
#if N_SPLIT > 1
#define WG_SIZE (BLOCK_M * N_SPLIT)
#else
#define WG_SIZE (BLOCK_M)
#endif
inline float get_alibi_slope(
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
) {
@@ -54,19 +87,38 @@ __kernel void flash_attn_f32_f16(
const int mask_ne2,
const int mask_ne3,
const global void* sinks_void,
const ulong sinks_offset
const ulong sinks_offset,
const global void * k_pad_void,
const global void * v_pad_void,
const global void * mask_pad_void,
const global char * blk,
const int n_kv_blocks,
const ulong mask_pad_nb1,
const ulong mask_pad_nb2,
const ulong mask_pad_nb3
) {
const int tid = get_local_id(0);
const int block_q_idx = get_group_id(0);
const int head_batch_idx = get_global_id(1);
const int my_query_row = block_q_idx * BLOCK_M + tid;
#if N_SPLIT > 1
const int q_lane = tid / N_SPLIT;
const int split_idx = tid % N_SPLIT;
#else
const int q_lane = tid;
const int split_idx = 0;
#endif
const int my_query_row = block_q_idx * BLOCK_M + q_lane;
const int query_valid = my_query_row < n_q;
const int batch_idx = head_batch_idx / n_head;
const int head_idx = head_batch_idx % n_head;
const int gqa_ratio = n_head / n_head_kv;
const int head_kv_idx = head_idx / gqa_ratio;
const int mask_head_idx = mask_void != NULL ? head_idx % mask_ne2 : 0;
const int mask_batch_idx = mask_void != NULL ? batch_idx % mask_ne3 : 0;
const global char* q_base = (const global char*)q_void + q_offset;
const global char* k_base = (const global char*)k_void + k_offset;
@@ -75,27 +127,41 @@ __kernel void flash_attn_f32_f16(
const global char* mask_base = NULL;
if (mask_void != NULL) {
const int mask_head_idx = head_idx % mask_ne2;
const int mask_batch_idx = batch_idx % mask_ne3;
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
}
const global char* mask_pad_base = NULL;
if (mask_pad_void != NULL) {
mask_pad_base = (const global char*)mask_pad_void + mask_batch_idx * mask_pad_nb3 + mask_head_idx * mask_pad_nb2;
}
const global char* blk_base = NULL;
if (blk != NULL) {
const int n_q_blocks = (n_q + BLOCK_M - 1) / BLOCK_M;
blk_base = blk + (((mask_batch_idx * mask_ne2) + mask_head_idx) * n_q_blocks + block_q_idx) * n_kv_blocks;
}
ACC_TYPE4 q_priv[DK_VEC];
if (my_query_row < n_q) {
ACC_TYPE4 q_priv[SPLIT_DK_VEC];
const int dk_off = split_idx * SPLIT_DK_VEC;
if (query_valid) {
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);
FA_UNROLL
for (int i = 0; i < SPLIT_DK_VEC; ++i) {
q_priv[i] = CONVERT_Q_ACC4(q_ptr[dk_off + i]);
}
} else {
FA_UNROLL
for (int i = 0; i < SPLIT_DK_VEC; ++i) {
q_priv[i] = (ACC_TYPE4)(0.0f);
}
}
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
ACC_TYPE4 o_acc[SPLIT_DV_VEC];
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_acc[i] = (ACC_TYPE4)(0.0f);
}
ACC_TYPE m_i = -INFINITY;
ACC_TYPE m_i = FA_M_INIT;
ACC_TYPE l_i = 0.0f;
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
@@ -103,86 +169,369 @@ __kernel void flash_attn_f32_f16(
__local KV_DATA_TYPE4 l_k[BLOCK_N][DK_VEC];
__local KV_DATA_TYPE4 l_v[BLOCK_N][DV_VEC];
#if N_SPLIT > 1 && !defined(HAS_SUBGROUP_SHUFFLE)
__local ACC_TYPE local_partial[BLOCK_N][WG_SIZE];
__local ACC_TYPE local_p[BLOCK_M][BLOCK_N];
__local ACC_TYPE local_softmax_scale[BLOCK_M];
__local ACC_TYPE local_l_inv[BLOCK_M];
#endif
for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) {
char blk_cur = 1;
if (blk_base != NULL) {
blk_cur = blk_base[k_start / BLOCK_N];
if (blk_cur == 0) continue;
}
const int use_kv_pad = k_pad_void != NULL && k_start + BLOCK_N > n_kv;
const int k_tile_start = use_kv_pad ? 0 : k_start;
const ulong k_tile_nb2 = use_kv_pad ? (ulong) BLOCK_N * k_nb1 : k_nb2;
const ulong k_tile_nb3 = use_kv_pad ? (ulong) n_head_kv * k_tile_nb2 : k_nb3;
const ulong v_tile_nb2 = use_kv_pad ? (ulong) BLOCK_N * v_nb1 : v_nb2;
const ulong v_tile_nb3 = use_kv_pad ? (ulong) n_head_kv * v_tile_nb2 : v_nb3;
const global char* k_tile_base = use_kv_pad ? (const global char*) k_pad_void : k_base;
const global char* v_tile_base = use_kv_pad ? (const global char*) v_pad_void : v_base;
for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) {
const int row = i / DK_VEC;
const int col = i % DK_VEC;
const int k_row_idx = k_start + row;
if (k_row_idx < n_kv) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1;
l_k[row][col] = ((__global KV_DATA_TYPE4*)(k_base + k_row_offset))[col];
const int k_row_idx = k_tile_start + row;
if (use_kv_pad || k_row_idx < n_kv) {
const ulong k_row_offset = batch_idx * k_tile_nb3 + head_kv_idx * k_tile_nb2 + k_row_idx * k_nb1;
l_k[row][col] = ((__global KV_DATA_TYPE4*)(k_tile_base + k_row_offset))[col];
} else {
l_k[row][col] = (KV_DATA_TYPE4)(0.0h);
}
}
for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) {
const int row = i / DV_VEC;
const int col = i % DV_VEC;
const int v_row_idx = k_start + row;
if (v_row_idx < n_kv) {
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1;
l_v[row][col] = ((__global KV_DATA_TYPE4*)(v_base + v_row_offset))[col];
const int v_row_idx = k_tile_start + row;
if (use_kv_pad || v_row_idx < n_kv) {
const ulong v_row_offset = batch_idx * v_tile_nb3 + head_kv_idx * v_tile_nb2 + v_row_idx * v_nb1;
l_v[row][col] = ((__global KV_DATA_TYPE4*)(v_tile_base + v_row_offset))[col];
} else {
l_v[row][col] = (KV_DATA_TYPE4)(0.0h);
}
}
barrier(CLK_LOCAL_MEM_FENCE);
if (my_query_row >= n_q) {
continue;
#if N_SPLIT > 1 && defined(HAS_SUBGROUP_SHUFFLE)
{
const int dv_off = split_idx * SPLIT_DV_VEC;
for (int j = 0; j < BLOCK_N; j += 2) {
const int k_row0 = k_start + j;
const int k_row1 = k_start + j + 1;
ACC_TYPE partial0 = 0.0f;
ACC_TYPE partial1 = 0.0f;
FA_UNROLL
for (int k = 0; k < SPLIT_DK_VEC; k++) {
const ACC_TYPE4 qk = q_priv[k];
ACC_TYPE4 dot0 = qk * CONVERT_KV_ACC4(l_k[j ][dk_off + k]);
ACC_TYPE4 dot1 = qk * CONVERT_KV_ACC4(l_k[j+1][dk_off + k]);
partial0 += dot0.s0 + dot0.s1 + dot0.s2 + dot0.s3;
partial1 += dot1.s0 + dot1.s1 + dot1.s2 + dot1.s3;
}
FA_UNROLL
for (int step = 1; step < N_SPLIT; step <<= 1) {
partial0 += sub_group_shuffle_xor(partial0, step);
partial1 += sub_group_shuffle_xor(partial1, step);
}
ACC_TYPE score0 = partial0 * scale;
ACC_TYPE score1 = partial1 * scale;
if (!query_valid) { score0 = FA_M_INIT; score1 = FA_M_INIT; }
if (is_causal) {
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = FA_M_INIT;
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = FA_M_INIT;
}
if (k_row0 >= n_kv) score0 = FA_M_INIT;
if (k_row1 >= n_kv) score1 = FA_M_INIT;
if (query_valid && mask_base != NULL && blk_cur != 2) {
if (use_kv_pad && mask_pad_base != NULL) {
const global MASK_DATA_TYPE* mask_ptr =
(const global MASK_DATA_TYPE*)(mask_pad_base + my_query_row * mask_pad_nb1);
score0 += slope * (ACC_TYPE)mask_ptr[j];
score1 += slope * (ACC_TYPE)mask_ptr[j + 1];
} else {
const global MASK_DATA_TYPE* mask_ptr =
(const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
}
}
if (logit_softcap > 0.0f) {
score0 = logit_softcap * tanh(score0 / logit_softcap);
score1 = logit_softcap * tanh(score1 / logit_softcap);
}
const ACC_TYPE m_new = max(m_i, max(score0, score1));
// Whole tile masked (m_new == FA_M_INIT): force the exp() args
// far negative so the tile contributes 0, not exp(0)=1.
const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new;
const ACC_TYPE sp = native_exp(m_i - m_exp);
const ACC_TYPE p0 = native_exp(score0 - m_exp);
const ACC_TYPE p1 = native_exp(score1 - m_exp);
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_acc[i] = o_acc[i] * sp
+ p0 * CONVERT_KV_ACC4(l_v[j ][dv_off + i])
+ p1 * CONVERT_KV_ACC4(l_v[j+1][dv_off + i]);
}
l_i = l_i * sp + p0 + p1;
m_i = m_new;
}
}
for (int j = 0; j < BLOCK_N; j += 2) {
const int k_row0 = k_start + j;
const int k_row1 = k_start + j + 1;
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
#pragma unroll
for (int k = 0; k < DK_VEC; k++) {
dot_acc0 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1);
#elif N_SPLIT > 1
// N_SPLIT>1 fallback (no shuffle): 3-phase local-memory reduction.
// Phase 1 partial dots for all BLOCK_N tokens.
for (int j = 0; j < BLOCK_N; ++j) {
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
FA_UNROLL
for (int k = 0; k < SPLIT_DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j][dk_off + k]), dot_acc);
}
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
if (is_causal) {
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
}
if (k_row0 >= n_kv) score0 = -INFINITY;
if (k_row1 >= n_kv) score1 = -INFINITY;
if (mask_base != NULL) {
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
}
if (logit_softcap > 0.0f) {
score0 = logit_softcap * tanh(score0 / logit_softcap);
score1 = logit_softcap * tanh(score1 / logit_softcap);
}
const ACC_TYPE m_new = max(m_i, max(score0, score1));
const ACC_TYPE p0 = exp(score0 - m_new);
const ACC_TYPE p1 = exp(score1 - m_new);
const ACC_TYPE scale_prev = exp(m_i - m_new);
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_KV_ACC4(l_v[j][i]) + p1 * CONVERT_KV_ACC4(l_v[j+1][i]);
}
l_i = l_i * scale_prev + p0 + p1;
m_i = m_new;
local_partial[j][tid] =
dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3;
}
barrier(CLK_LOCAL_MEM_FENCE); // 1 barrier: partial dots visible
// Phase 2 split_idx==0 reduces partial sums and computes block softmax.
if (split_idx == 0) {
if (query_valid) {
ACC_TYPE m_new = m_i;
for (int j = 0; j < BLOCK_N; ++j) {
const int k_row = k_start + j;
ACC_TYPE score = 0.0f;
FA_UNROLL
for (int s = 0; s < N_SPLIT; s++) {
score += local_partial[j][q_lane * N_SPLIT + s];
}
score *= scale;
if (is_causal && k_row > (n_kv - n_q + my_query_row)) score = FA_M_INIT;
if (k_row >= n_kv) score = FA_M_INIT;
if (mask_base != NULL && blk_cur != 2) {
if (use_kv_pad && mask_pad_base != NULL) {
const global MASK_DATA_TYPE* mask_ptr =
(const global MASK_DATA_TYPE*)(mask_pad_base + my_query_row * mask_pad_nb1);
score += slope * (ACC_TYPE)mask_ptr[j];
} else {
const global MASK_DATA_TYPE* mask_ptr =
(const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row < n_kv) score += slope * (ACC_TYPE)mask_ptr[k_row];
}
}
if (logit_softcap > 0.0f) {
score = logit_softcap * tanh(score / logit_softcap);
}
m_new = max(m_new, score);
local_p[q_lane][j] = score;
}
const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new;
const ACC_TYPE sp = native_exp(m_i - m_exp);
ACC_TYPE l_new = l_i * sp;
for (int j = 0; j < BLOCK_N; ++j) {
const ACC_TYPE p = native_exp(local_p[q_lane][j] - m_exp);
local_p[q_lane][j] = p;
l_new += p;
}
local_softmax_scale[q_lane] = sp;
l_i = l_new;
m_i = m_new;
} else {
local_softmax_scale[q_lane] = 1.0f;
for (int j = 0; j < BLOCK_N; ++j) local_p[q_lane][j] = 0.0f;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
// Phase 3 V accumulate using broadcast probabilities.
{
const ACC_TYPE sp_block = local_softmax_scale[q_lane];
const int dv_off = split_idx * SPLIT_DV_VEC;
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_acc[i] *= sp_block;
}
for (int j = 0; j < BLOCK_N; ++j) {
const ACC_TYPE p = local_p[q_lane][j];
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_acc[i] = mad(p, CONVERT_KV_ACC4(l_v[j][dv_off + i]), o_acc[i]);
}
}
}
#else
// N_SPLIT==1: j+=4 unroll. Requires BLOCK_N % 4 == 0.
if (query_valid) {
for (int j = 0; j < BLOCK_N; j += 4) {
const int k_row0 = k_start + j;
const int k_row1 = k_start + j + 1;
const int k_row2 = k_start + j + 2;
const int k_row3 = k_start + j + 3;
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f);
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
const ACC_TYPE4 qk = q_priv[k];
dot_acc0 = mad(qk, CONVERT_KV_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(qk, CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1);
dot_acc2 = mad(qk, CONVERT_KV_ACC4(l_k[j+2][k]), dot_acc2);
dot_acc3 = mad(qk, CONVERT_KV_ACC4(l_k[j+3][k]), dot_acc3);
}
ACC_TYPE s0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE s1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
ACC_TYPE s2 = (dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3) * scale;
ACC_TYPE s3 = (dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3) * scale;
if (is_causal) {
const int causal_limit = n_kv - n_q + my_query_row;
if (k_row0 > causal_limit) s0 = FA_M_INIT;
if (k_row1 > causal_limit) s1 = FA_M_INIT;
if (k_row2 > causal_limit) s2 = FA_M_INIT;
if (k_row3 > causal_limit) s3 = FA_M_INIT;
}
if (k_row0 >= n_kv) s0 = FA_M_INIT;
if (k_row1 >= n_kv) s1 = FA_M_INIT;
if (k_row2 >= n_kv) s2 = FA_M_INIT;
if (k_row3 >= n_kv) s3 = FA_M_INIT;
if (mask_base != NULL && blk_cur != 2) {
if (use_kv_pad && mask_pad_base != NULL) {
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_pad_base + my_query_row * mask_pad_nb1);
s0 += slope * (ACC_TYPE)mask_ptr[j];
s1 += slope * (ACC_TYPE)mask_ptr[j + 1];
s2 += slope * (ACC_TYPE)mask_ptr[j + 2];
s3 += slope * (ACC_TYPE)mask_ptr[j + 3];
} else {
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) s0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) s1 += slope * (ACC_TYPE)mask_ptr[k_row1];
if (k_row2 < n_kv) s2 += slope * (ACC_TYPE)mask_ptr[k_row2];
if (k_row3 < n_kv) s3 += slope * (ACC_TYPE)mask_ptr[k_row3];
}
}
if (logit_softcap > 0.0f) {
s0 = logit_softcap * tanh(s0 / logit_softcap);
s1 = logit_softcap * tanh(s1 / logit_softcap);
s2 = logit_softcap * tanh(s2 / logit_softcap);
s3 = logit_softcap * tanh(s3 / logit_softcap);
}
const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3)));
// Whole tile masked (m_new == FA_M_INIT): force the exp() args
// far negative so the tile contributes 0, not exp(0)=1.
const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new;
const ACC_TYPE scale_prev = native_exp(m_i - m_exp);
const ACC_TYPE p0 = native_exp(s0 - m_exp);
const ACC_TYPE p1 = native_exp(s1 - m_exp);
const ACC_TYPE p2 = native_exp(s2 - m_exp);
const ACC_TYPE p3 = native_exp(s3 - m_exp);
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = mad(p3, CONVERT_KV_ACC4(l_v[j+3][i]),
mad(p2, CONVERT_KV_ACC4(l_v[j+2][i]),
mad(p1, CONVERT_KV_ACC4(l_v[j+1][i]),
mad(p0, CONVERT_KV_ACC4(l_v[j][i]),
o_acc[i] * scale_prev))));
}
l_i = l_i * scale_prev + p0 + p1 + p2 + p3;
m_i = m_new;
}
}
#endif
// End of tile: every thread must finish reading l_k/l_v before the
// next iteration's load overwrites them (WAR hazard on local memory).
barrier(CLK_LOCAL_MEM_FENCE);
}
if (my_query_row < n_q) {
// Write output.
#if N_SPLIT > 1 && defined(HAS_SUBGROUP_SHUFFLE)
if (query_valid) {
ACC_TYPE sinks_sp = 1.0f;
if (sinks_void != NULL) {
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
const ACC_TYPE m_sink = sinks_ptr[head_idx];
const ACC_TYPE m_final = max(m_i, m_sink);
sinks_sp = exp(m_i - m_final);
l_i = l_i * sinks_sp + exp(m_sink - m_final);
m_i = m_final;
}
const ACC_TYPE l_inv = (l_i > 0.0f) ? (1.0f / l_i) : 0.0f;
const int dv_off = split_idx * SPLIT_DV_VEC;
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
if (l_inv > 0.0f) {
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_row[dv_off + i] = CONVERT_O_DATA4(o_acc[i] * sinks_sp * l_inv);
}
} else {
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_row[dv_off + i] = (O_DATA_TYPE4)(0.0f);
}
}
}
#elif N_SPLIT > 1
if (split_idx == 0) {
ACC_TYPE sinks_sp = 1.0f;
if (query_valid && sinks_void != NULL) {
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
const ACC_TYPE m_sink = sinks_ptr[head_idx];
const ACC_TYPE m_final = max(m_i, m_sink);
sinks_sp = exp(m_i - m_final);
l_i = l_i * sinks_sp + exp(m_sink - m_final);
m_i = m_final;
}
local_softmax_scale[q_lane] = sinks_sp;
local_l_inv[q_lane] = (query_valid && l_i > 0.0f) ? (1.0f / l_i) : 0.0f;
}
barrier(CLK_LOCAL_MEM_FENCE);
if (query_valid) {
const ACC_TYPE sinks_sp = local_softmax_scale[q_lane];
const ACC_TYPE l_inv = local_l_inv[q_lane];
const int dv_off = split_idx * SPLIT_DV_VEC;
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
if (l_inv > 0.0f) {
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_row[dv_off + i] = CONVERT_O_DATA4(o_acc[i] * sinks_sp * l_inv);
}
} else {
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_row[dv_off + i] = (O_DATA_TYPE4)(0.0f);
}
}
}
#else
if (query_valid) {
if (sinks_void != NULL) {
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
const ACC_TYPE m_sink = sinks_ptr[head_idx];
const ACC_TYPE m_final = max(m_i, m_sink);
const ACC_TYPE scale_o = exp(m_i - m_final);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] *= scale_o;
}
@@ -194,17 +543,18 @@ __kernel void flash_attn_f32_f16(
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_i;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = CONVERT_O_DATA4(o_acc[i] * l_inv);
}
} else {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = (O_DATA_TYPE4)(0.0f);
}
}
}
#endif
}
__kernel void flash_attn_f32_f16_q1(
@@ -258,13 +608,16 @@ __kernel void flash_attn_f32_f16_q1(
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
}
ACC_TYPE4 q_priv[DK_VEC];
// Q is uniform across WG threads (n_q=1). Share via local memory to
// avoid per-thread q_priv[DK_VEC] dynamic-indexed private array that
// spills to DDR on Adreno.
__local ACC_TYPE4 q_shared[DK_VEC];
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);
for (int i = tid; i < DK_VEC; i += Q1_WG_SIZE) {
q_shared[i] = CONVERT_Q_ACC4(q_ptr[i]);
}
barrier(CLK_LOCAL_MEM_FENCE);
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
@@ -273,14 +626,14 @@ __kernel void flash_attn_f32_f16_q1(
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
}
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
@@ -296,7 +649,7 @@ __kernel void flash_attn_f32_f16_q1(
__local ACC_TYPE local_m[Q1_WG_SIZE];
local_m[tid] = m_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
barrier(CLK_LOCAL_MEM_FENCE);
@@ -304,7 +657,7 @@ __kernel void flash_attn_f32_f16_q1(
const ACC_TYPE m_final = local_m[0];
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
ACC_TYPE l_i = 0.0f;
@@ -314,9 +667,9 @@ __kernel void flash_attn_f32_f16_q1(
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
const global KV_DATA_TYPE4* v_ptr = (const global KV_DATA_TYPE4*)(v_base + v_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
@@ -328,7 +681,7 @@ __kernel void flash_attn_f32_f16_q1(
}
const ACC_TYPE p = exp(score - m_final);
l_i += p;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; i++) {
o_acc[i] = mad(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]);
}
@@ -338,7 +691,7 @@ __kernel void flash_attn_f32_f16_q1(
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
local_l[tid] = l_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_l[tid] += local_l[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -357,7 +710,7 @@ __kernel void flash_attn_f32_f16_q1(
for (int i = 0; i < DV_VEC; i++) {
local_o_comp[tid] = o_acc[i];
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -367,7 +720,257 @@ __kernel void flash_attn_f32_f16_q1(
}
}
} else if (tid == 0) {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (O_DATA_TYPE4)(0.0f);
}
}
// Flash-decoding split pass. gid(2) = q_idx * n_splits + split_idx.
// Partial record per split: [m, l, O[DV]]. Merge kernel applies sink + norm.
#define FA_PARTIAL_FLOATS (2 + DV)
__kernel void flash_attn_f32_f16_q1_split(
const global void * q_void, ulong q_offset,
const global void * k_void, ulong k_offset,
const global void * v_void, ulong v_offset,
const float scale,
const int n_q,
const int n_kv,
const int n_head,
const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
const float max_bias,
const float m0,
const float m1,
const int n_head_log2,
const float logit_softcap,
const int n_head_kv,
const global void * mask_void,
const ulong mask_offset,
const ulong mask_nb1,
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
const int mask_ne3,
global float * partial_void,
const int n_splits,
const int kv_per_split
) {
const int tid = get_local_id(0);
const int head_batch_idx = get_global_id(1);
const int split_q_idx = get_global_id(2);
const int split_idx = split_q_idx % n_splits;
const int q_idx = split_q_idx / n_splits;
const int batch_idx = head_batch_idx / n_head;
const int head_idx = head_batch_idx % n_head;
const int gqa_ratio = n_head / n_head_kv;
const int head_kv_idx = head_idx / gqa_ratio;
const int kv_start = split_idx * kv_per_split;
const int kv_end = min(kv_start + kv_per_split, n_kv);
const ulong record_stride = (ulong) FA_PARTIAL_FLOATS;
const ulong record_idx = ((((ulong) batch_idx * n_head + head_idx) * n_q + q_idx)
* n_splits + split_idx);
global float * rec = partial_void + record_idx * record_stride;
global float4 * rec_o = (global float4 *) (rec + 2);
if (kv_start >= kv_end) {
// Empty split: leave sentinel partial for merge.
if (tid == 0) {
rec[0] = FA_M_INIT;
rec[1] = 0.0f;
}
return;
}
const global char * q_base = (const global char *) q_void + q_offset;
const global char * k_base = (const global char *) k_void + k_offset;
const global char * v_base = (const global char *) v_void + v_offset;
const global char * mask_base = NULL;
if (mask_void != NULL) {
const int mask_head_idx = head_idx % mask_ne2;
const int mask_batch_idx = batch_idx % mask_ne3;
mask_base = (const global char *) mask_void + mask_offset +
mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2 +
(ulong) q_idx * mask_nb1;
}
// Share Q via local memory (n_q=1 per split -> uniform across WG).
__local ACC_TYPE4 q_shared[DK_VEC];
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + (ulong) q_idx * q_nb1;
const global Q_DATA_TYPE4 * q_ptr = (const global Q_DATA_TYPE4 *) (q_base + q_row_offset);
for (int i = tid; i < DK_VEC; i += Q1_WG_SIZE) {
q_shared[i] = CONVERT_Q_ACC4(q_ptr[i]);
}
barrier(CLK_LOCAL_MEM_FENCE);
const float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
// Pass 1a split-local max.
ACC_TYPE m_i = FA_M_INIT;
for (int k_idx = kv_start + tid; k_idx < kv_end; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global KV_DATA_TYPE4 * k_ptr = (const global KV_DATA_TYPE4 *) (k_base + k_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
for (int k = 0; k < DK_VEC; ++k) {
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
const global MASK_DATA_TYPE * mask_ptr = (const global MASK_DATA_TYPE *) (mask_base);
score += slope * (ACC_TYPE) mask_ptr[k_idx];
}
if (logit_softcap > 0.0f) {
score = logit_softcap * tanh(score / logit_softcap);
}
m_i = max(m_i, score);
}
__local ACC_TYPE local_m[Q1_WG_SIZE];
local_m[tid] = m_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
barrier(CLK_LOCAL_MEM_FENCE);
}
const ACC_TYPE m_c = local_m[0];
// Pass 1b softmax-weighted V accumulate.
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
ACC_TYPE l_i = 0.0f;
for (int k_idx = kv_start + tid; k_idx < kv_end; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1;
const global KV_DATA_TYPE4 * k_ptr = (const global KV_DATA_TYPE4 *) (k_base + k_row_offset);
const global KV_DATA_TYPE4 * v_ptr = (const global KV_DATA_TYPE4 *) (v_base + v_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
for (int k = 0; k < DK_VEC; ++k) {
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
const global MASK_DATA_TYPE * mask_ptr = (const global MASK_DATA_TYPE *) (mask_base);
score += slope * (ACC_TYPE) mask_ptr[k_idx];
}
if (logit_softcap > 0.0f) {
score = logit_softcap * tanh(score / logit_softcap);
}
const ACC_TYPE p = exp(score - m_c);
l_i += p;
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = mad(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]);
}
}
__local ACC_TYPE local_l[Q1_WG_SIZE];
__local ACC_TYPE4 local_o[Q1_WG_SIZE];
local_l[tid] = l_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_l[tid] += local_l[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
}
const ACC_TYPE l_c = local_l[0];
if (tid == 0) {
rec[0] = (float) m_c;
rec[1] = (float) l_c;
}
for (int i = 0; i < DV_VEC; ++i) {
local_o[tid] = o_acc[i];
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_o[tid] += local_o[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
}
if (tid == 0) {
rec_o[i] = local_o[0];
}
}
}
// FD Pass 2: merge per-split partials into final O. Empty splits drop via exp(-INF)=0.
__kernel void flash_attn_f32_merge(
const global float * partial_void,
global void * o_void,
const ulong o_offset,
const int n_head,
const int n_splits,
const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
const global void * sinks_void,
const ulong sinks_offset,
const int n_q
) {
const int lane = get_local_id(0); // 0..DV_VEC-1
const int head_batch_idx = get_global_id(1);
const int q_idx = get_global_id(2);
const int batch_idx = head_batch_idx / n_head;
const int head_idx = head_batch_idx % n_head;
const ulong record_stride = (ulong) FA_PARTIAL_FLOATS;
const ulong record_idx_0 = (((ulong) batch_idx * n_head + head_idx) * n_q + q_idx) * n_splits;
const global float * rec0 = partial_void + record_idx_0 * record_stride;
__local ACC_TYPE m_final_shared;
__local ACC_TYPE l_final_shared;
if (lane == 0) {
ACC_TYPE m = FA_M_INIT;
for (int c = 0; c < n_splits; ++c) {
const ACC_TYPE m_c = rec0[c * record_stride + 0];
m = max(m, m_c);
}
ACC_TYPE m_sink = 0.0f;
bool has_sink = false;
if (sinks_void != NULL) {
const global ACC_TYPE * sinks_ptr =
(const global ACC_TYPE *) ((const global char *) sinks_void + sinks_offset);
m_sink = sinks_ptr[head_idx];
has_sink = true;
m = max(m, m_sink);
}
ACC_TYPE l = 0.0f;
for (int c = 0; c < n_splits; ++c) {
const ACC_TYPE m_c = rec0[c * record_stride + 0];
const ACC_TYPE l_c = rec0[c * record_stride + 1];
if (m_c > FA_M_INIT) {
l += l_c * exp(m_c - m);
}
}
if (has_sink) {
l += exp(m_sink - m);
}
m_final_shared = m;
l_final_shared = l;
}
barrier(CLK_LOCAL_MEM_FENCE);
const ACC_TYPE m_final = m_final_shared;
const ACC_TYPE l_final = l_final_shared;
const ACC_TYPE l_inv = (l_final > 0.0f) ? (1.0f / l_final) : 0.0f;
ACC_TYPE4 o = (ACC_TYPE4)(0.0f);
for (int c = 0; c < n_splits; ++c) {
const global float * rec_c = rec0 + c * record_stride;
const ACC_TYPE m_c = rec_c[0];
if (m_c <= FA_M_INIT) continue;
const global float4 * rec_oc = (const global float4 *) (rec_c + 2);
const ACC_TYPE scale_c = exp(m_c - m_final);
o = mad((ACC_TYPE4)(scale_c), rec_oc[lane], o);
}
o = o * l_inv;
const ulong o_row_offset = (ulong) batch_idx * o_nb3 + (ulong) q_idx * o_nb2 + (ulong) head_idx * o_nb1;
global O_DATA_TYPE4 * o_row = (global O_DATA_TYPE4 *) ((global char *) o_void + o_offset + o_row_offset);
o_row[lane] = CONVERT_O_DATA4(o);
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,156 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void flash_attn_kv_pad_f16(
const global void * k_void, ulong k_offset,
const global void * v_void, ulong v_offset,
global void * k_pad_void,
global void * v_pad_void,
const int n_kv,
const int n_head_kv,
const int n_batch,
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3
) {
const int row_idx = get_global_id(0);
const int head_kv_idx = get_global_id(1);
const int batch_idx = get_global_id(2);
if (row_idx >= BLOCK_N || head_kv_idx >= n_head_kv || batch_idx >= n_batch) {
return;
}
const int tail_start = n_kv - (n_kv % BLOCK_N);
const int src_row_idx = tail_start + row_idx;
const global char * k_src = (const global char *) k_void + k_offset;
const global char * v_src = (const global char *) v_void + v_offset;
global char * k_pad = (global char *) k_pad_void;
global char * v_pad = (global char *) v_pad_void;
const ulong k_dst_offset = ((ulong) batch_idx * (ulong) n_head_kv + (ulong) head_kv_idx) * ((ulong) BLOCK_N * k_nb1) + (ulong) row_idx * k_nb1;
const ulong v_dst_offset = ((ulong) batch_idx * (ulong) n_head_kv + (ulong) head_kv_idx) * ((ulong) BLOCK_N * v_nb1) + (ulong) row_idx * v_nb1;
if (src_row_idx < n_kv) {
const ulong k_src_offset = (ulong) batch_idx * k_nb3 + (ulong) head_kv_idx * k_nb2 + (ulong) src_row_idx * k_nb1;
const ulong v_src_offset = (ulong) batch_idx * v_nb3 + (ulong) head_kv_idx * v_nb2 + (ulong) src_row_idx * v_nb1;
for (ulong i = 0; i < k_nb1; ++i) {
k_pad[k_dst_offset + i] = k_src[k_src_offset + i];
}
for (ulong i = 0; i < v_nb1; ++i) {
v_pad[v_dst_offset + i] = v_src[v_src_offset + i];
}
} else {
for (ulong i = 0; i < k_nb1; ++i) {
k_pad[k_dst_offset + i] = 0;
}
for (ulong i = 0; i < v_nb1; ++i) {
v_pad[v_dst_offset + i] = 0;
}
}
}
__kernel void flash_attn_mask_pad_f16(
const global void * mask_void, ulong mask_offset,
global void * mask_pad_void,
const int n_q,
const int n_kv,
const ulong mask_nb1,
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
const int mask_ne3
) {
const int col_idx = get_global_id(0);
const int q_row = get_global_id(1);
const int mask_slice = get_global_id(2);
if (col_idx >= BLOCK_N || q_row >= n_q || mask_slice >= mask_ne2 * mask_ne3) {
return;
}
const int tail_start = n_kv - (n_kv % BLOCK_N);
const int src_col_idx = tail_start + col_idx;
const int mask_head_idx = mask_slice % mask_ne2;
const int mask_batch_idx = mask_slice / mask_ne2;
const global char * mask_src_base = (const global char *) mask_void + mask_offset +
(ulong) mask_batch_idx * mask_nb3 +
(ulong) mask_head_idx * mask_nb2 +
(ulong) q_row * mask_nb1;
const global half * mask_src = (const global half *) mask_src_base;
global half * mask_pad = (global half *) mask_pad_void;
const ulong dst_idx =
(((ulong) mask_batch_idx * (ulong) mask_ne2 + (ulong) mask_head_idx) * (ulong) n_q + (ulong) q_row) * (ulong) BLOCK_N +
(ulong) col_idx;
mask_pad[dst_idx] = src_col_idx < n_kv ? mask_src[src_col_idx] : (half) (-INFINITY);
}
// Per-KV-tile mask class. 0=all -inf (skip tile), 1=mixed (apply mask),
// 2=all zero, no -inf (skip mask lookup). Causal diagonal tiles are class 1.
__kernel void flash_attn_blk_f16(
const global void * mask_void, ulong mask_offset,
global char * blk,
const int n_q,
const int n_kv,
const ulong mask_nb1,
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
const int mask_ne3
) {
const int kv_block_idx = get_global_id(0);
const int q_block_idx = get_global_id(1);
const int mask_slice = get_global_id(2);
const int n_q_blocks = (n_q + BLOCK_M - 1) / BLOCK_M;
const int n_kv_blocks = (n_kv + BLOCK_N - 1) / BLOCK_N;
if (kv_block_idx >= n_kv_blocks || q_block_idx >= n_q_blocks || mask_slice >= mask_ne2 * mask_ne3) {
return;
}
const int mask_head_idx = mask_slice % mask_ne2;
const int mask_batch_idx = mask_slice / mask_ne2;
const int q_start = q_block_idx * BLOCK_M;
const int k_start = kv_block_idx * BLOCK_N;
const int q_count = min(BLOCK_M, n_q - q_start);
const int k_count = min(BLOCK_N, n_kv - k_start);
const half neg_max_half = (half) (-65504.0f);
char has_unmasked = 0;
char has_masked = 0;
char has_nonzero = 0;
const global char * mask_base = (const global char *) mask_void + mask_offset +
(ulong) mask_batch_idx * mask_nb3 +
(ulong) mask_head_idx * mask_nb2;
for (int qi = 0; qi < q_count; ++qi) {
const global half * mask_row = (const global half *) (mask_base + (ulong) (q_start + qi) * mask_nb1) + k_start;
for (int ki = 0; ki < k_count; ++ki) {
const half v = mask_row[ki];
if (v <= neg_max_half) {
has_masked = 1;
} else {
has_unmasked = 1;
if (v != (half) 0.0f) {
has_nonzero = 1;
}
}
}
if (has_masked && has_unmasked) break; // mixed tile — short-circuit.
}
char res;
if (has_unmasked == 0) {
res = 0;
} else if (has_masked || has_nonzero) {
res = 1;
} else {
res = 2;
}
blk[((ulong) mask_slice * (ulong) n_q_blocks + (ulong) q_block_idx) * (ulong) n_kv_blocks + (ulong) kv_block_idx] = res;
}
+500
View File
@@ -158,6 +158,239 @@ kernel void kernel_set_rows_f32_i32(
}
}
// f32 -> q8_0 quantize set_rows. Block = half d + char qs[32].
#define QK8_0 32
inline void quantize_q8_0_block(global float * x, global char * qs, global half * d_out) {
float amax = 0.0f;
for (int j = 0; j < QK8_0; j++) {
amax = fmax(amax, fabs(x[j]));
}
float d = amax / 127.0f;
float id = (d != 0.0f) ? 127.0f / amax : 0.0f;
vstore_half(d, 0, d_out);
for (int j = 0; j < QK8_0; j++) {
qs[j] = (char)((int)round(x[j] * id));
}
}
kernel void kernel_set_rows_q8_0_i64(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK8_0;
global char * y = dst_row + blk * (2 + QK8_0);
quantize_q8_0_block(x, y + 2, (global half *)y);
}
}
kernel void kernel_set_rows_q8_0_i32(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK8_0;
global char * y = dst_row + blk * (2 + QK8_0);
quantize_q8_0_block(x, y + 2, (global half *)y);
}
}
// SoA q8_0 variants. dst_q: int8[QK8_0] per block; dst_d: fp16 scale per block.
// Layout matches kernel_convert_block_q8_0; block index follows dst element order.
kernel void kernel_set_rows_q8_0_soa_i64(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst_q,
ulong offset_q,
global char * dst_d,
ulong offset_d,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
int ne1_dst,
int ne2_dst,
int ne3_dst
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst_q = dst_q + offset_q;
dst_d = dst_d + offset_d;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0;
global half * d_row = (global half *)(dst_d) + row_blk_base;
global char * q_row = (global char *)(dst_q) + row_blk_base * QK8_0;
global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK8_0;
global char * q = q_row + blk * QK8_0;
quantize_q8_0_block(x, q, d_row + blk);
}
}
kernel void kernel_set_rows_q8_0_soa_i32(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst_q,
ulong offset_q,
global char * dst_d,
ulong offset_d,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
int ne1_dst,
int ne2_dst,
int ne3_dst
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst_q = dst_q + offset_q;
dst_d = dst_d + offset_d;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0;
global half * d_row = (global half *)(dst_d) + row_blk_base;
global char * q_row = (global char *)(dst_q) + row_blk_base * QK8_0;
global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK8_0;
global char * q = q_row + blk * QK8_0;
quantize_q8_0_block(x, q, d_row + blk);
}
}
kernel void kernel_set_rows_f16_i32(
global char * src0,
ulong offset0,
@@ -206,3 +439,270 @@ kernel void kernel_set_rows_f16_i32(
dst_row[ind] = src_row[ind];
}
}
// f32 -> q4_0 quantize set_rows. Block = half d + uchar qs[16] (shuffled
// nibbles: qs[j] low/high = elem j / j+16).
// Dequant: val[i] = d * (nibble_i - 8)
// nblk0 = number of q4_0 blocks per row = ne00 / 32.
#define QK4_0 32
#define Q4_0_BLOCK_SIZE 18
inline void quantize_q4_0_block(global float * x, global uchar * qs, global half * d_out) {
// Find the signed value with the largest absolute magnitude (matches ggml ref).
float max = 0.0f;
float amax = 0.0f;
for (int j = 0; j < QK4_0; j++) {
float v = x[j];
float a = fabs(v);
if (a > amax) {
amax = a;
max = v;
}
}
float d = max / -8.0f;
float id = (d != 0.0f) ? 1.0f / d : 0.0f;
vstore_half(d, 0, d_out);
for (int j = 0; j < QK4_0/2; j++) {
float x0 = x[j] * id;
float x1 = x[j + QK4_0/2] * id;
int i0 = (int)(x0 + 8.5f);
int i1 = (int)(x1 + 8.5f);
if (i0 < 0) i0 = 0;
if (i0 > 15) i0 = 15;
if (i1 < 0) i1 = 0;
if (i1 > 15) i1 = 15;
qs[j] = (uchar)i0 | ((uchar)i1 << 4);
}
}
kernel void kernel_set_rows_q4_0_i64(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK4_0;
global char * y = dst_row + blk * Q4_0_BLOCK_SIZE;
global half * yd = (global half *)(y);
global uchar * yqs = (global uchar *)(y + 2);
quantize_q4_0_block(x, yqs, yd);
}
}
kernel void kernel_set_rows_q4_0_i32(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK4_0;
global char * y = dst_row + blk * Q4_0_BLOCK_SIZE;
global half * yd = (global half *)(y);
global uchar * yqs = (global uchar *)(y + 2);
quantize_q4_0_block(x, yqs, yd);
}
}
// SoA variants for q4_0 dst. Used when the backend has split block_q4_0 records
// into separate quant (dst_q) and scale (dst_d) sub-buffers same pattern as
// the q8_0 SoA variants above.
//
// Layout (matches kernel_convert_block_q4_0, the "shuffled" variant):
// dst_q: contiguous 16 packed nibbles per block, block i at offset i * 16 bytes.
// dst_d: contiguous fp16 scales, block i at offset i * 2 bytes.
// Nibble layout inside each byte is unchanged from AoS: qs[j] low nibble = element j,
// qs[j] high nibble = element j+16. kernel_restore_block_q4_0 copies bytes as-is.
kernel void kernel_set_rows_q4_0_soa_i64(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst_q,
ulong offset_q,
global char * dst_d,
ulong offset_d,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
int ne1_dst,
int ne2_dst,
int ne3_dst
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst_q = dst_q + offset_q;
dst_d = dst_d + offset_d;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0;
global half * d_row = (global half *)(dst_d) + row_blk_base;
global uchar * q_row = (global uchar *)(dst_q) + row_blk_base * (QK4_0/2);
global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK4_0;
global uchar * qs = q_row + blk * (QK4_0/2);
global half * d_bk = d_row + blk;
quantize_q4_0_block(x, qs, d_bk);
}
}
kernel void kernel_set_rows_q4_0_soa_i32(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst_q,
ulong offset_q,
global char * dst_d,
ulong offset_d,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
int ne1_dst,
int ne2_dst,
int ne3_dst
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst_q = dst_q + offset_q;
dst_d = dst_d + offset_d;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0;
global half * d_row = (global half *)(dst_d) + row_blk_base;
global uchar * q_row = (global uchar *)(dst_q) + row_blk_base * (QK4_0/2);
global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK4_0;
global uchar * qs = q_row + blk * (QK4_0/2);
global half * d_bk = d_row + blk;
quantize_q4_0_block(x, qs, d_bk);
}
}
+4
View File
@@ -1053,6 +1053,10 @@ static bool is_op_unsupported_case(const ggml_tensor * op) {
(op->ne[0] == 2 && op->ne[1] == 4 && op->ne[2] == 3 && op->ne[3] == 2)) {
return true;
}
// CPY into a strided view of a larger buffer (recurrent-state snapshots) not supported
if (op->view_src && ggml_nbytes(op) != ggml_nbytes(op->view_src)) {
return true;
}
break;
}
case GGML_OP_MUL_MAT: {
+102 -48
View File
@@ -2,8 +2,10 @@
#include "ggml-sycl/common.hpp"
#include "ggml-sycl/presets.hpp"
static void norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
static void norm_f32(const float* x, float* dst, const int ncols,
const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, const int64_t src_stride_sample,
const int64_t dst_stride_col, const int64_t dst_stride_row, const int64_t dst_stride_channel, const int64_t dst_stride_sample,
const float eps, const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
const int nrows = item_ct1.get_group_range(2);
const int nchannels = item_ct1.get_group_range(1);
@@ -16,16 +18,16 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
const int tid = item_ct1.get_local_id(2);
const int nwarps = nthreads / WARP_SIZE;
const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row});
const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
const auto src_offset = calculate_offset<3>({src_stride_sample, src_stride_channel, src_stride_row}, {sample, channel, row});
const auto dst_offset = calculate_offset<3>({dst_stride_sample, dst_stride_channel, dst_stride_row}, {sample, channel, row});
x += strided_offset;
dst += packed_offset;
x += src_offset;
dst += dst_offset;
sycl::float2 mean_var = sycl::float2(0.f, 0.f);
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[col];
const float xi = x[col * src_stride_col];
mean_var.x() += xi;
mean_var.y() += xi * xi;
}
@@ -54,7 +56,7 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
const float inv_std = sycl::rsqrt(var + eps);
for (int col = tid; col < ncols; col += block_size) {
dst[col] = (x[col] - mean) * inv_std;
dst[col * dst_stride_col] = (x[col * src_stride_col] - mean) * inv_std;
}
}
@@ -145,8 +147,10 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
}
}
static void rms_norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
static void rms_norm_f32(const float* x, float* dst, const int ncols,
const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, const int64_t src_stride_sample,
const int64_t dst_stride_col, const int64_t dst_stride_row, const int64_t dst_stride_channel, const int64_t dst_stride_sample,
const float eps, const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
const int nrows = item_ct1.get_group_range(2);
const int nchannels = item_ct1.get_group_range(1);
@@ -160,17 +164,17 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
const int tid = item_ct1.get_local_id(2);
const int nwarps = nthreads / WARP_SIZE;
const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row});
const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
const auto src_offset = calculate_offset<3>({src_stride_sample, src_stride_channel, src_stride_row}, {sample, channel, row});
const auto dst_offset = calculate_offset<3>({dst_stride_sample, dst_stride_channel, dst_stride_row}, {sample, channel, row});
x += strided_offset;
dst += packed_offset;
x += src_offset;
dst += dst_offset;
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[col];
const float xi = x[col * src_stride_col];
tmp += xi * xi;
}
@@ -198,14 +202,15 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
const float scale = sycl::rsqrt(mean + eps);
for (int col = tid; col < ncols; col += block_size) {
dst[col] = scale * x[col];
dst[col * dst_stride_col] = scale * x[col * src_stride_col];
}
}
template<int warp_size>
static void l2_norm_f32(const float * x, float * dst, const int ncols,
const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps,
const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel,
const int64_t src_stride_sample, const int64_t dst_stride_col, const int64_t dst_stride_row,
const int64_t dst_stride_channel, const int64_t dst_stride_sample, const float eps,
const sycl::nd_item<3>& item_ct1, float* s_sum, const int block_size) {
const int nrows = item_ct1.get_group_range(2);
const int nchannels = item_ct1.get_group_range(1);
@@ -215,13 +220,13 @@ static void l2_norm_f32(const float * x, float * dst, const int ncols,
const int sample = item_ct1.get_group(0);
const int tid = item_ct1.get_local_id(2);
x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
x += sample*src_stride_sample + channel*src_stride_channel + row*src_stride_row;
dst += sample*dst_stride_sample + channel*dst_stride_channel + row*dst_stride_row;
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[col];
const float xi = x[col * src_stride_col];
tmp += xi * xi;
}
@@ -229,12 +234,13 @@ static void l2_norm_f32(const float * x, float * dst, const int ncols,
const float scale = sycl::rsqrt(sycl::fmax(tmp, eps * eps));
for (int col = tid; col < ncols; col += block_size) {
dst[col] = scale * x[col];
dst[col * dst_stride_col] = scale * x[col * src_stride_col];
}
}
static void norm_f32_sycl(const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, const int64_t src_stride_sample,
const int64_t dst_stride_col, const int64_t dst_stride_row, const int64_t dst_stride_channel, const int64_t dst_stride_sample,
const float eps, queue_ptr stream, int device) {
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
@@ -245,7 +251,10 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
sycl::nd_range<3>(global_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
norm_f32(x, dst, ncols,
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
eps, item_ct1, nullptr, WARP_SIZE);
});
});
}
@@ -265,7 +274,10 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
sycl::nd_range<3>(global_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
norm_f32(x, dst, ncols,
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
});
});
}
@@ -319,7 +331,9 @@ static void group_norm_f32_sycl(const float* x, float* dst,
}
static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) {
const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, const int64_t src_stride_sample,
const int64_t dst_stride_col, const int64_t dst_stride_row, const int64_t dst_stride_channel, const int64_t dst_stride_sample,
const float eps, queue_ptr stream, int device) {
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
@@ -330,7 +344,10 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
sycl::nd_range<3>(global_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
rms_norm_f32(x, dst, ncols,
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
eps, item_ct1, nullptr, WARP_SIZE);
});
});
}
@@ -350,7 +367,10 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
sycl::nd_range<3>(global_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
rms_norm_f32(x, dst, ncols,
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
});
});
}
@@ -363,9 +383,14 @@ static void l2_norm_f32_sycl(const float * x,
const int nrows,
const int nchannels,
const int nsamples,
const int64_t stride_row,
const int64_t stride_channel,
const int64_t stride_sample,
const int64_t src_stride_col,
const int64_t src_stride_row,
const int64_t src_stride_channel,
const int64_t src_stride_sample,
const int64_t dst_stride_col,
const int64_t dst_stride_row,
const int64_t dst_stride_channel,
const int64_t dst_stride_sample,
const float eps,
queue_ptr stream,
int device) {
@@ -379,7 +404,10 @@ static void l2_norm_f32_sycl(const float * x,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(warp_size)]] {
l2_norm_f32<warp_size>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
l2_norm_f32<warp_size>(x, dst, ncols,
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
eps, item_ct1,
nullptr, warp_size);
});
});
@@ -398,7 +426,9 @@ static void l2_norm_f32_sycl(const float * x,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(warp_size)]] {
l2_norm_f32<warp_size>(x, dst, ncols, stride_row, stride_channel, stride_sample,
l2_norm_f32<warp_size>(x, dst, ncols,
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
});
});
@@ -421,12 +451,20 @@ void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps >= 0.0f);
const size_t ts0 = ggml_type_size(src0->type);
GGML_ASSERT(nb00 == ts0);
const int64_t s01 = nb01 / ts0;
const int64_t s02 = nb02 / ts0;
const int64_t s03 = nb03 / ts0;
const size_t tdst = ggml_type_size(dst->type);
GGML_ASSERT(nb00 % ts0 == 0 && nb01 % ts0 == 0 && nb02 % ts0 == 0 && nb03 % ts0 == 0);
GGML_ASSERT(nb0 % tdst == 0 && nb1 % tdst == 0 && nb2 % tdst == 0 && nb3 % tdst == 0);
const int64_t ss0 = nb00 / ts0;
const int64_t ss1 = nb01 / ts0;
const int64_t ss2 = nb02 / ts0;
const int64_t ss3 = nb03 / ts0;
const int64_t ds0 = nb0 / tdst;
const int64_t ds1 = nb1 / tdst;
const int64_t ds2 = nb2 / tdst;
const int64_t ds3 = nb3 / tdst;
norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);
norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03,
ss0, ss1, ss2, ss3, ds0, ds1, ds2, ds3, eps, main_stream, ctx.device);
}
void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
@@ -465,11 +503,19 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_TENSOR_UNARY_OP_LOCALS
const size_t ts0 = ggml_type_size(src0->type);
GGML_ASSERT(nb00 == ts0);
const int64_t s01 = nb01 / ts0;
const int64_t s02 = nb02 / ts0;
const int64_t s03 = nb03 / ts0;
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);
const size_t tdst = ggml_type_size(dst->type);
GGML_ASSERT(nb00 % ts0 == 0 && nb01 % ts0 == 0 && nb02 % ts0 == 0 && nb03 % ts0 == 0);
GGML_ASSERT(nb0 % tdst == 0 && nb1 % tdst == 0 && nb2 % tdst == 0 && nb3 % tdst == 0);
const int64_t ss0 = nb00 / ts0;
const int64_t ss1 = nb01 / ts0;
const int64_t ss2 = nb02 / ts0;
const int64_t ss3 = nb03 / ts0;
const int64_t ds0 = nb0 / tdst;
const int64_t ds1 = nb1 / tdst;
const int64_t ds2 = nb2 / tdst;
const int64_t ds3 = nb3 / tdst;
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03,
ss0, ss1, ss2, ss3, ds0, ds1, ds2, ds3, eps, main_stream, ctx.device);
}
void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
@@ -644,13 +690,21 @@ void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
GGML_ASSERT(eps >= 0.0f);
const size_t ts0 = ggml_type_size(src0->type);
GGML_ASSERT(nb00 == ts0);
const int64_t s01 = nb01 / ts0;
const int64_t s02 = nb02 / ts0;
const int64_t s03 = nb03 / ts0;
const size_t tdst = ggml_type_size(dst->type);
GGML_ASSERT(nb00 % ts0 == 0 && nb01 % ts0 == 0 && nb02 % ts0 == 0 && nb03 % ts0 == 0);
GGML_ASSERT(nb0 % tdst == 0 && nb1 % tdst == 0 && nb2 % tdst == 0 && nb3 % tdst == 0);
const int64_t ss0 = nb00 / ts0;
const int64_t ss1 = nb01 / ts0;
const int64_t ss2 = nb02 / ts0;
const int64_t ss3 = nb03 / ts0;
const int64_t ds0 = nb0 / tdst;
const int64_t ds1 = nb1 / tdst;
const int64_t ds2 = nb2 / tdst;
const int64_t ds3 = nb3 / tdst;
/*support both WARP_SIZE or WARP_32_SIZE in code
choose by hardware for better performance
*/
l2_norm_f32_sycl<WARP_SIZE>(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream, ctx.device);
l2_norm_f32_sycl<WARP_SIZE>(src0_d, dst_d, ne00, ne01, ne02, ne03,
ss0, ss1, ss2, ss3, ds0, ds1, ds2, ds3, eps, stream, ctx.device);
}
@@ -42,7 +42,7 @@ float op_leaky_relu(float x) {
}
float op_step(float x) {
return x >= 0.0f ? 1.0f : 0.0f;
return x > 0.0f ? 1.0f : 0.0f;
}
float op_tanh(float x) {
+4 -4
View File
@@ -302,9 +302,9 @@ target_link_libraries(${TEST_TARGET} PRIVATE llama)
llama_build_and_test(test-alloc.cpp)
target_include_directories(test-alloc PRIVATE ${PROJECT_SOURCE_DIR}/ggml/src)
llama_build(export-graph-ops.cpp)
target_include_directories(export-graph-ops PRIVATE ${PROJECT_SOURCE_DIR}/ggml/src)
llama_build(test-export-graph-ops.cpp)
target_include_directories(test-export-graph-ops PRIVATE ${PROJECT_SOURCE_DIR}/ggml/src)
if (TARGET gguf-model-data)
target_link_libraries(export-graph-ops PRIVATE gguf-model-data)
target_compile_definitions(export-graph-ops PRIVATE LLAMA_HF_FETCH)
target_link_libraries(test-export-graph-ops PRIVATE gguf-model-data)
target_compile_definitions(test-export-graph-ops PRIVATE LLAMA_HF_FETCH)
endif()
+29 -8
View File
@@ -2890,12 +2890,17 @@ struct test_cpy : public test_case {
const std::array<int64_t, 4> ne_dst;
const std::array<int64_t, 4> permute_src;
const std::array<int64_t, 4> permute_dst;
const std::array<int64_t, 4> dst_alloc; // if set, dst is a view into a larger buffer (strided)
bool _src_use_permute;
bool _dst_use_permute;
bool _src_transpose;
bool _use_dst_shape;
bool _use_dst_alloc;
std::string vars() override {
if (_use_dst_alloc) {
return VARS_TO_STR8(type_src, type_dst, ne_src, ne_dst, permute_src, permute_dst, _src_transpose, dst_alloc);
}
if (_use_dst_shape) {
return VARS_TO_STR7(type_src, type_dst, ne_src, ne_dst, permute_src, permute_dst, _src_transpose);
}
@@ -2943,12 +2948,15 @@ struct test_cpy : public test_case {
std::array<int64_t, 4> ne_dst = {-1, -1, -1, -1},
std::array<int64_t, 4> permute_src = {0, 0, 0, 0},
std::array<int64_t, 4> permute_dst = {0, 0, 0, 0},
bool transpose_src = false)
bool transpose_src = false,
std::array<int64_t, 4> dst_alloc = {0, 0, 0, 0})
: type_src(type_src), type_dst(type_dst), ne_src(ne_src), ne_dst(ne_dst), permute_src(permute_src), permute_dst(permute_dst),
dst_alloc(dst_alloc),
_src_use_permute(permute_src[0] + permute_src[1] + permute_src[2] + permute_src[3] > 0),
_dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0),
_src_transpose(transpose_src),
_use_dst_shape(ne_dst[0] >= 0 && ne_dst[1] >= 0 && ne_dst[2] >= 0 && ne_dst[3] >= 0){}
_use_dst_shape(ne_dst[0] >= 0 && ne_dst[1] >= 0 && ne_dst[2] >= 0 && ne_dst[3] >= 0),
_use_dst_alloc(dst_alloc[0] > 0){}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne_src.data());
@@ -2966,12 +2974,23 @@ struct test_cpy : public test_case {
}
std::array<int64_t, 4> dst_ne = _use_dst_shape ? ne_dst : std::array<int64_t, 4>{src->ne[0], src->ne[1], src->ne[2], src->ne[3]};
ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, dst_ne.data());
ggml_set_name(dst, "dst");
ggml_tensor * dst;
if (_dst_use_permute) {
dst = ggml_permute(ctx, dst, permute_dst[0], permute_dst[1], permute_dst[2], permute_dst[3]);
ggml_set_name(dst, "dst_permuted");
if (_use_dst_alloc) {
// view a sub-block of a larger buffer -> strided dst
ggml_tensor * dst_buf = ggml_new_tensor(ctx, type_dst, 4, dst_alloc.data());
ggml_set_name(dst_buf, "dst_buf");
dst = ggml_view_4d(ctx, dst_buf, dst_ne[0], dst_ne[1], dst_ne[2], dst_ne[3],
dst_buf->nb[1], dst_buf->nb[2], dst_buf->nb[3], 0);
ggml_set_name(dst, "dst_view");
} else {
dst = ggml_new_tensor(ctx, type_dst, 4, dst_ne.data());
ggml_set_name(dst, "dst");
if (_dst_use_permute) {
dst = ggml_permute(ctx, dst, permute_dst[0], permute_dst[1], permute_dst[2], permute_dst[3]);
ggml_set_name(dst, "dst_permuted");
}
}
ggml_tensor * out = ggml_cpy(ctx, src, dst);
@@ -8181,6 +8200,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 1, 4, 1}, {-1,-1,-1,-1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {2, 2097121, 1, 1}, {-1,-1,-1,-1}, {1, 0, 2, 3}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {2, 2, 524281, 1}, {-1,-1,-1,-1}, {1, 0, 2, 3}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {128, 2, 3, 1}, {128, 2, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, false, {128, 4, 3, 1})); // strided dst
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {128, 2, 3, 1}, {128, 2, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, false, {128, 4, 3, 1})); // strided dst
// CPY - different src/dst shapes (reshaping via CPY)
// Use permutations of {3, 5, 7, 32}. Total elements: 3*5*7*32 = 3360.
@@ -9943,7 +9964,7 @@ static void usage(char ** argv) {
printf(" --output specifies output format (default: console, options: console, sql, csv)\n");
printf(" --list-ops lists all available GGML operations\n");
printf(" --show-coverage shows test coverage\n");
printf(" --test-file reads test operators from a test file generated by llama-export-graph-ops\n");
printf(" --test-file reads test operators from a test file generated by test-export-graph-ops\n");
printf(" -j <n> runs tests using <n> parallel worker threads (default: 1, test mode only)\n");
}
+1 -1
View File
@@ -135,7 +135,7 @@ int main(int argc, char ** argv) {
output_path = args[i + 1];
i++;
} else if (args[i] == "--no-common") {
use_common = true;
use_common = false;
} else if (tmpl_path.empty()) {
tmpl_path = args[i];
} else {
@@ -185,7 +185,7 @@ int main(int argc, char ** argv) {
return 1;
}
#else
LOG_ERR("export-graph-ops compiled without HF fetch support\n");
LOG_ERR("test-export-graph-ops compiled without HF fetch support\n");
return 1;
#endif
}
+1 -1
View File
@@ -1,4 +1,4 @@
set(TARGET rpc-server)
set(TARGET ggml-rpc-server)
add_executable(${TARGET} rpc-server.cpp)
target_link_libraries(${TARGET} PRIVATE ggml)
target_compile_features(${TARGET} PRIVATE cxx_std_17)
+17 -17
View File
@@ -4,8 +4,8 @@
> This example and the RPC backend are currently in a proof-of-concept development stage. As such, the functionality is fragile and
> insecure. **Never run the RPC server on an open network or in a sensitive environment!**
The `rpc-server` allows exposing `ggml` devices on a remote host.
The RPC backend communicates with one or several instances of `rpc-server` and offloads computations to them.
The `ggml-rpc-server` allows exposing `ggml` devices on a remote host.
The RPC backend communicates with one or several instances of `ggml-rpc-server` and offloads computations to them.
This can be used for distributed LLM inference with `llama.cpp` in the following way:
```mermaid
@@ -14,15 +14,15 @@ flowchart TD
rpcb<-->|TCP|srvb
rpcb<-.->|TCP|srvn
subgraph hostn[Host N]
srvn[rpc-server]<-.->dev4["CUDA0"]
srvn[rpc-server]<-.->dev5["CPU"]
srvn[ggml-rpc-server]<-.->dev4["CUDA0"]
srvn[ggml-rpc-server]<-.->dev5["CPU"]
end
subgraph hostb[Host B]
srvb[rpc-server]<-->dev3["Metal"]
srvb[ggml-rpc-server]<-->dev3["Metal"]
end
subgraph hosta[Host A]
srva[rpc-server]<-->dev["CUDA0"]
srva[rpc-server]<-->dev2["CUDA1"]
srva[ggml-rpc-server]<-->dev["CUDA0"]
srva[ggml-rpc-server]<-->dev2["CUDA1"]
end
subgraph host[Main Host]
local["Local devices"]<-->ggml[llama-cli]
@@ -33,7 +33,7 @@ flowchart TD
class local,dev,dev2,dev3,dev4,dev5 devcls
```
By default, `rpc-server` exposes all available accelerator devices on the host.
By default, `ggml-rpc-server` exposes all available accelerator devices on the host.
If there are no accelerators, it exposes a single `CPU` device.
## Usage
@@ -41,7 +41,7 @@ If there are no accelerators, it exposes a single `CPU` device.
### Remote hosts
On each remote host, build the backends for each accelerator by adding `-DGGML_RPC=ON` to the build options.
For example, to build the `rpc-server` with support for CUDA accelerators:
For example, to build the `ggml-rpc-server` with support for CUDA accelerators:
```bash
mkdir build-rpc-cuda
@@ -50,10 +50,10 @@ cmake .. -DGGML_CUDA=ON -DGGML_RPC=ON
cmake --build . --config Release
```
When started, the `rpc-server` will detect and expose all available `CUDA` devices:
When started, the `ggml-rpc-server` will detect and expose all available `CUDA` devices:
```bash
$ bin/rpc-server
$ bin/ggml-rpc-server
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
@@ -67,14 +67,14 @@ Devices:
You can control the set of exposed CUDA devices with the `CUDA_VISIBLE_DEVICES` environment variable or the `--device` command line option. The following two commands have the same effect:
```bash
$ CUDA_VISIBLE_DEVICES=0 bin/rpc-server -p 50052
$ bin/rpc-server --device CUDA0 -p 50052
$ CUDA_VISIBLE_DEVICES=0 bin/ggml-rpc-server -p 50052
$ bin/ggml-rpc-server --device CUDA0 -p 50052
```
### Main host
On the main host build `llama.cpp` with the backends for the local devices and add `-DGGML_RPC=ON` to the build options.
Finally, when running `llama-cli` or `llama-server`, use the `--rpc` option to specify the host and port of each `rpc-server`:
Finally, when running `llama-cli` or `llama-server`, use the `--rpc` option to specify the host and port of each `ggml-rpc-server`:
```bash
$ llama-cli -hf ggml-org/gemma-3-1b-it-GGUF -ngl 99 --rpc 192.168.88.10:50052,192.168.88.11:50052
@@ -90,7 +90,7 @@ This can speed up model loading significantly, especially when using large model
To enable the cache, use the `-c` option:
```bash
$ bin/rpc-server -c
$ bin/ggml-rpc-server -c
```
By default, the cache is stored in the `$HOME/.cache/llama.cpp/rpc` directory and can be controlled via the `LLAMA_CACHE` environment variable.
@@ -103,8 +103,8 @@ RDMA is enabled by default when `libibverbs` is found at build time.
### Troubleshooting
Use the `GGML_RPC_DEBUG` environment variable to enable debug messages from `rpc-server`:
Use the `GGML_RPC_DEBUG` environment variable to enable debug messages from `ggml-rpc-server`:
```bash
$ GGML_RPC_DEBUG=1 bin/rpc-server
$ GGML_RPC_DEBUG=1 bin/ggml-rpc-server
```