Compare commits

...

5 Commits
b9110 ... b9115

Author SHA1 Message Date
Jesus Talavera
78fbbc2c07 convert : add split() to LoraTorchTensor in LoRA converter (#22832)
* convert : add split() method to LoraTorchTensor

* Fix python type-check

* Fix flake8 Lint

* fix: handle positional dim arg in torch.split dispatch

* Fix type-check again

* Fix type-checks

* Remove unit test per reviewers feedback

* work around ty deficiency

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-05-12 08:17:04 +03:00
guyfischman
da44953329 metal : promote mul_mv/mul_mm batch divisors to function constants (#22711)
* metal : promote mul_mv/mul_mm batch divisors to function constants

* metal : take op directly in get_pipeline_mul_mv_ext
2026-05-12 08:15:02 +03:00
Shawn Gu
1ec7ba0c14 opencl: add q4_1 MoE for Adreno (#22856)
* Q4_1 MoE CLC pass sanity check

* remove unnecessary code

* opencl: remove unnecessary asserts and reformat

* opencl: fix supports_op for q4_1 moe

* q4_1 moe is supported by Adreno with certain shapes

---------

Co-authored-by: Li He <lih@qti.qualcomm.com>
2026-05-11 11:57:26 -07:00
CrispStrobe
8e1f9d0834 CUDA: handle OW > 65535 in im2col (2D and 3D) (#22944)
`im2col_cuda` and `im2col_3d_cuda` both dispatch with
`block_nums.y = OW`. CUDA caps grid Y at 65535. Conv1d encoders on
raw 16 kHz audio with T > 65535 (~ 4 s) trip the limit -- e.g. SEANet
at 11 s lands at OW = 176000 -- and the launch returns
`invalid configuration argument`.

Clamp `block_nums.y` to `MIN(OW, MAX_GRIDDIM_Y)` and loop inside the
kernel with stride `MAX_GRIDDIM_Y`. Same in-kernel stride pattern
already used for the z axis (`MAX_GRIDDIM_Z`). Both 2D `im2col_kernel`
and 3D `im2col_3d_kernel` need the same fix. Bit-identical for
OW <= 65535 (single iteration of the new outer loop).

Tested on T4 / Jetson Orin with a SEANet encoder running on 11 s /
16 kHz audio (im2col reaching OW ~ 176000); pre-fix launch returns
`invalid configuration argument`, post-fix runs to completion.
Existing test-backend-ops im2col cases unchanged.
2026-05-11 19:48:29 +02:00
Pascal
e936660760 Ggml/cuda snake fusion hardening (#22912)
* cuda: tighten snake fusion type checks for all operands (defensive, sync vulkan)

* cuda: reject snake fusion when ne[2] or ne[3] > 1 (mirror vulkan PR review)

* cuda: merge type_ok and types_ok into a single types_ok (address am17an review)

* cuda: filter ADD/SUB/MUL/DIV in supports_op to F32/F16

bin_bcast only dispatches F32/F16 type triplets, mirror the
vulkan filter so unsupported types fall back through cpy
instead of aborting.

* test-backend-ops: extend snake_fuse to rank-4 with ne[2]/ne[3] > 1 cases
2026-05-11 18:42:08 +02:00
13 changed files with 1019 additions and 167 deletions

View File

@@ -188,6 +188,24 @@ class LoraTorchTensor:
def swapaxes(self, axis0: int, axis1: int) -> LoraTorchTensor:
return self.transpose(axis0, axis1)
def split(self, split_size: int | Sequence[int], dim: int = 0) -> tuple[LoraTorchTensor, ...]:
shape = self.shape
ndim = len(shape)
if dim < 0:
dim += ndim
if dim == ndim - 1:
A_chunks = self._lora_A.split(split_size, dim=-1)
return tuple(LoraTorchTensor(a, self._lora_B) for a in A_chunks)
elif dim == ndim - 2:
B_chunks = self._lora_B.split(split_size, dim=-2)
return tuple(LoraTorchTensor(self._lora_A, b) for b in B_chunks)
else:
B_chunks = self._lora_B.split(split_size, dim=dim)
if self._lora_A.shape[dim] == 1:
return tuple(LoraTorchTensor(self._lora_A, b) for b in B_chunks)
A_chunks = self._lora_A.split(split_size, dim=dim)
return tuple(LoraTorchTensor(a, b) for a, b in zip(A_chunks, B_chunks))
def to(self, *args, **kwargs):
return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs))
@@ -230,6 +248,11 @@ class LoraTorchTensor:
)
else:
raise NotImplementedError
elif func is torch.split:
assert len(args) and len(args) >= 2
tensor, split_size = args[0], args[1]
dim = args[2] if len(args) > 2 else kwargs.get("dim", 0)
return tensor.split(split_size, dim=dim)
else:
raise NotImplementedError

View File

@@ -3929,10 +3929,25 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph
// closure check: the trailing add must read the same x as the leading mul
const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0];
const bool type_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16);
// Kernel iterates over total = T * C, so x and add must be 2D and
// a / inv_b must collapse to [1, C, 1, 1]. Higher dims are not handled.
const bool dim_ok = (x->ne[2] == 1 && x->ne[3] == 1) &&
(add->ne[2] == 1 && add->ne[3] == 1) &&
(a->ne[2] == 1 && a->ne[3] == 1);
const bool shape_ok = ggml_are_same_shape(a, inv_b) && a->ne[0] == 1 && a->ne[1] == x->ne[1];
if (type_ok && shape_ok && x_in_add == x && add->type == x->type) {
// x must be in the supported whitelist and every operand / intermediate
// result must share x's type, since launch_snake casts a / inv_b as
// float and templates the kernel on a single T. Mixed precision chains
// fall back to the naive path.
const ggml_tensor * sin1 = cgraph->nodes[i + 1];
const bool types_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16) &&
(a->type == x->type) && (inv_b->type == x->type) &&
(mul0->type == x->type) && (sin1->type == x->type) &&
(sqr->type == x->type) && (mul1->type == x->type) &&
(add->type == x->type);
if (types_ok && shape_ok && dim_ok && x_in_add == x) {
ggml_cuda_op_snake_fused(*cuda_ctx, x, a, inv_b, add);
return 4;
}
@@ -5291,12 +5306,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_ADD:
case GGML_OP_ADD_ID:
case GGML_OP_ADD1:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_SCALE:
case GGML_OP_SQR:
case GGML_OP_SQRT:
@@ -5305,6 +5316,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_CLAMP:
case GGML_OP_LOG:
return true;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
case GGML_OP_SSM_SCAN: {
if (op->src[3]->ne[0] == 1) {
// Mamba2

View File

@@ -1,5 +1,6 @@
#include "im2col.cuh"
#define MAX_GRIDDIM_Y 65535
#define MAX_GRIDDIM_Z 65535
template <typename T>
@@ -18,22 +19,23 @@ static __global__ void im2col_kernel(
const int64_t ikh = rem / KW;
const int64_t ikw = rem - ikh * KW;
const int64_t iow = blockIdx.y;
for (int64_t iz = blockIdx.z; iz < N_OH; iz+=MAX_GRIDDIM_Z) {
const int64_t in = iz / OH;
const int64_t ioh = iz - in * OH;
for (int64_t iow = blockIdx.y; iow < OW; iow += MAX_GRIDDIM_Y) {
for (int64_t iz = blockIdx.z; iz < N_OH; iz += MAX_GRIDDIM_Z) {
const int64_t in = iz / OH;
const int64_t ioh = iz - in * OH;
const int64_t iiw = iow * s0 + ikw * d0 - p0;
const int64_t iih = ioh * s1 + ikh * d1 - p1;
const int64_t iiw = iow * s0 + ikw * d0 - p0;
const int64_t iih = ioh * s1 + ikh * d1 - p1;
const int64_t offset_dst =
((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw;
const int64_t offset_dst =
((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw;
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = iic * IC_IH_IW + in * IH_IW;
dst[offset_dst] = x[offset_src + iih * IW + iiw];
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = iic * IC_IH_IW + in * IH_IW;
dst[offset_dst] = x[offset_src + iih * IW + iiw];
}
}
}
@@ -51,7 +53,7 @@ static void im2col_cuda(const float * x, T* dst,
const int64_t num_blocks = (IC_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
const int64_t N_OH = N * OH;
const int64_t KH_KW = KW*KH;
dim3 block_nums(num_blocks, OW, MIN(N_OH, MAX_GRIDDIM_Z));
dim3 block_nums(num_blocks, MIN(OW, MAX_GRIDDIM_Y), MIN(N_OH, MAX_GRIDDIM_Z));
im2col_kernel<<<block_nums, MIN(IC_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(x, dst, IC, IW, IH, OH, OW, KW, KH,
IC_IH_IW, IH_IW, N_OH, KH_KW, IC_KH_KW,
s0, s1, p0, p1, d0, d1);
@@ -136,23 +138,24 @@ static __global__ void im2col_3d_kernel(
const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW;
const int64_t ikw = i % KW;
const int64_t iow = blockIdx.y;
for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz+=MAX_GRIDDIM_Z) {
const int64_t in = iz / OD_OH;
const int64_t iod = (iz - in*OD_OH) / OH;
const int64_t ioh = iz % OH;
for (int64_t iow = blockIdx.y; iow < OW; iow += MAX_GRIDDIM_Y) {
for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz += MAX_GRIDDIM_Z) {
const int64_t in = iz / OD_OH;
const int64_t iod = (iz - in*OD_OH) / OH;
const int64_t ioh = iz % OH;
const int64_t iiw = iow * s0 + ikw * d0 - p0;
const int64_t iih = ioh * s1 + ikh * d1 - p1;
const int64_t iid = iod * s2 + ikd * d2 - p2;
const int64_t iiw = iow * s0 + ikw * d0 - p0;
const int64_t iih = ioh * s1 + ikh * d1 - p1;
const int64_t iid = iod * s2 + ikd * d2 - p2;
const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
dst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x);
dst[offset_dst] = src[offset_src];
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
dst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x);
dst[offset_dst] = src[offset_src];
}
}
}
}
@@ -178,7 +181,7 @@ static void im2col_3d_cuda(const float * src, T* dst,
const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;
const int64_t num_blocks = (IC_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z));
dim3 block_nums(num_blocks, MIN(OW, MAX_GRIDDIM_Y), MIN(N_OD_OH, MAX_GRIDDIM_Z));
im2col_3d_kernel<<<block_nums, MIN(IC_KD_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW,
IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW,

View File

@@ -647,19 +647,30 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_m
return res;
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, const ggml_tensor * op, int nsg, int nxpsg, int r1ptg) {
char base[256];
char name[256];
const ggml_type tsrc0 = op->src[0]->type;
const ggml_type tsrc1 = op->src[1]->type;
const int ne12 = op->src[1]->ne[2];
const int r2 = ne12 / op->src[0]->ne[2];
const int r3 = op->src[1]->ne[3] / op->src[0]->ne[3];
GGML_ASSERT(ne12 <= INT16_MAX && r2 <= INT16_MAX && r3 <= INT16_MAX);
snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg);
snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg);
snprintf(name, 256, "%s_nsg=%d_nxpsg=%d_ne12=%d_r2=%d_r3=%d", base, nsg, nxpsg, ne12, r2, r3);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
ggml_metal_cv_set_int16(cv, (int16_t) ne12, FC_MUL_MV + 2);
ggml_metal_cv_set_int16(cv, (int16_t) r2, FC_MUL_MV + 3);
ggml_metal_cv_set_int16(cv, (int16_t) r3, FC_MUL_MV + 4);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
@@ -687,8 +698,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta
? (op->ne[0] % NRA != 0 || op->ne[1] % NRB != 0)
: (op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0);
GGML_ASSERT(op->src[1]->ne[2] <= INT16_MAX && op->src[1]->ne[3] <= INT16_MAX);
const int16_t ne12 = (int16_t) op->src[1]->ne[2];
const int16_t ne13 = (int16_t) op->src[1]->ne[3];
const int16_t r2 = (int16_t) (ne12 / op->src[0]->ne[2]);
const int16_t r3 = (int16_t) (ne13 / op->src[0]->ne[3]);
snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out);
snprintf(name, 256, "%s_bci=%d_bco=%d_ne12=%d_ne13=%d_r2=%d_r3=%d",
base, bc_inp, bc_out, ne12, ne13, r2, r3);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
@@ -696,6 +714,10 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta
ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1);
ggml_metal_cv_set_int16(cv, ne12, FC_MUL_MM + 2);
ggml_metal_cv_set_int16(cv, ne13, FC_MUL_MM + 3);
ggml_metal_cv_set_int16(cv, r2, FC_MUL_MM + 4);
ggml_metal_cv_set_int16(cv, r3, FC_MUL_MM + 5);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
@@ -877,14 +899,21 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta
}
};
GGML_ASSERT(ne12 <= INT16_MAX && ne13 <= INT16_MAX);
const int16_t r2 = (int16_t) (ne12 / ne02);
const int16_t r3 = (int16_t) (ne13 / ne03);
snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
snprintf(name, 256, "%s_nsg=%d", base, nsg);
snprintf(name, 256, "%s_nsg=%d_ne12=%d_r2=%d_r3=%d", base, nsg, ne12, r2, r3);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
ggml_metal_cv_set_int16(cv, (int16_t) ne12, FC_MUL_MV + 2);
ggml_metal_cv_set_int16(cv, r2, FC_MUL_MV + 3);
ggml_metal_cv_set_int16(cv, r3, FC_MUL_MV + 4);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
@@ -1102,6 +1131,9 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 2);
ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 3);
ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 4);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);

View File

@@ -129,7 +129,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, const struct ggml_tensor * op, int nsg, int nxpsg, int r1ptg);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);

View File

@@ -2120,7 +2120,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
GGML_ABORT("unsupported ne11");
};
auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op, nsg, nxpsg, r1ptg);
ggml_metal_kargs_mul_mv_ext args = {
/*.ne00 =*/ ne00,

View File

@@ -3353,6 +3353,9 @@ static inline void helper_mv_reduce_and_write(
constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]];
constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]];
constant short FC_mul_mv_ne12 [[function_constant(FC_MUL_MV + 2)]];
constant short FC_mul_mv_r2 [[function_constant(FC_MUL_MV + 3)]];
constant short FC_mul_mv_r3 [[function_constant(FC_MUL_MV + 4)]];
template<typename block_q_type, short NR0, typename args_t>
void mul_vec_q_n_f32_impl(
@@ -3376,10 +3379,10 @@ void mul_vec_q_n_f32_impl(
const int r1 = tgpig.y;
const int im = tgpig.z;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
//const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
//const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
//device const block_q_type * x = (device const block_q_type *) (src0 + offset0);
@@ -3388,7 +3391,7 @@ void mul_vec_q_n_f32_impl(
// pointers to src0 rows
device const block_q_type * ax[NR0];
FOR_UNROLL (int row = 0; row < NR0; ++row) {
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
}
@@ -3462,8 +3465,8 @@ void kernel_mul_mv_q1_0_f32_impl(
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13;
@@ -3471,7 +3474,7 @@ void kernel_mul_mv_q1_0_f32_impl(
device const block_q1_0 * ax[nr0];
for (int row = 0; row < nr0; ++row) {
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0);
}
@@ -3590,10 +3593,10 @@ void kernel_mul_mv_q8_0_f32_impl(
const int r1 = tgpig.y;
const int im = tgpig.z;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
//const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
//const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
//device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0);
@@ -3602,7 +3605,7 @@ void kernel_mul_mv_q8_0_f32_impl(
// pointers to src0 rows
device const block_q8_0 * ax[NR0];
FOR_UNROLL (short row = 0; row < NR0; ++row) {
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
}
@@ -3682,10 +3685,10 @@ void kernel_mul_mv_ext_q4_f32_impl(
const int i11 = tgpig.y*r1ptg;
const int i1m = tgpig.z;
const int i12 = i1m%args.ne12;
const int i13 = i1m/args.ne12;
const int i12 = i1m%FC_mul_mv_ne12;
const int i13 = i1m/FC_mul_mv_ne12;
const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = i01*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
@@ -3785,10 +3788,10 @@ void kernel_mul_mv_ext_q4x4_f32_impl(
const int i11 = tgpig.y*r1ptg;
const int i1m = tgpig.z;
const int i12 = i1m%args.ne12;
const int i13 = i1m/args.ne12;
const int i12 = i1m%FC_mul_mv_ne12;
const int i13 = i1m/FC_mul_mv_ne12;
const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = i01*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
@@ -4000,10 +4003,10 @@ void kernel_mul_mv_t_t_impl(
const int r1 = tgpig.y;
const int im = tgpig.z;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
//const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
//const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
//device const T0 * x = (device const T0 *) (src0 + offset0);
@@ -4012,7 +4015,7 @@ void kernel_mul_mv_t_t_impl(
// pointers to src0 rows
device const T0 * ax [NR0];
FOR_UNROLL (short row = 0; row < NR0; ++row) {
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
ax[row] = (device const T0 *) ((device char *) src0 + offset0);
}
@@ -4122,10 +4125,10 @@ void kernel_mul_mv_t_t_4_impl(
const int r1 = tgpig.y;
const int im = tgpig.z;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
//const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
//const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const T1 * y = (device const T1 *) (src1 + offset1);
@@ -4135,7 +4138,7 @@ void kernel_mul_mv_t_t_4_impl(
device const T0 * ax [NR0];
device const T04 * ax4[NR0];
FOR_UNROLL (short row = 0; row < NR0; ++row) {
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
ax [row] = (device const T0 *) ((device char *) src0 + offset0);
ax4[row] = (device const T04 *) ((device char *) src0 + offset0);
@@ -4239,10 +4242,10 @@ void kernel_mul_mv_t_t_short_impl(
return;
}
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
device const T0 * x = (device const T0 *) (src0 + offset0);
@@ -7462,10 +7465,10 @@ void kernel_mul_mv_q2_K_f32_impl(
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0);
@@ -7567,10 +7570,10 @@ void kernel_mul_mv_q3_K_f32_impl(
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0);
@@ -7741,10 +7744,10 @@ void kernel_mul_mv_q4_K_f32_impl(
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0);
@@ -7853,10 +7856,10 @@ void kernel_mul_mv_q5_K_f32_impl(
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
@@ -7989,10 +7992,10 @@ void kernel_mul_mv_q6_K_f32_impl(
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
@@ -8094,10 +8097,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0);
@@ -8202,10 +8205,10 @@ void kernel_mul_mv_iq2_xs_f32_impl(
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0);
@@ -8321,10 +8324,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0);
@@ -8433,10 +8436,10 @@ void kernel_mul_mv_iq3_s_f32_impl(
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0);
@@ -8545,10 +8548,10 @@ void kernel_mul_mv_iq2_s_f32_impl(
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0);
@@ -8658,10 +8661,10 @@ void kernel_mul_mv_iq1_s_f32_impl(
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0);
@@ -8757,10 +8760,10 @@ void kernel_mul_mv_iq1_m_f32_impl(
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0);
@@ -8866,10 +8869,10 @@ void kernel_mul_mv_iq4_nl_f32_impl(
const int first_row = (r0 * NSG + sgitg) * NR0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
@@ -8975,10 +8978,10 @@ void kernel_mul_mv_iq4_xs_f32_impl(
const int im = tgpig.z;
const int first_row = (r0 * NSG + sgitg) * NR0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
@@ -9086,10 +9089,10 @@ void kernel_mul_mv_mxfp4_f32_impl(
const int first_row = (r0 * NSG + sgitg) * NR0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
@@ -9304,6 +9307,10 @@ kernel void kernel_diag_f32(
constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
constant short FC_mul_mm_ne12 [[function_constant(FC_MUL_MM + 2)]];
constant short FC_mul_mm_ne13 [[function_constant(FC_MUL_MM + 3)]];
constant short FC_mul_mm_r2 [[function_constant(FC_MUL_MM + 4)]];
constant short FC_mul_mm_r3 [[function_constant(FC_MUL_MM + 5)]];
// each block_q contains 16*nl weights
#ifdef GGML_METAL_HAS_TENSOR
@@ -9330,11 +9337,11 @@ kernel void kernel_mul_mm(
// Batch dimension handling
const int im = tgpig.z;
const int i12 = im % args.ne12;
const int i13 = im / args.ne12;
const int i12 = im % FC_mul_mm_ne12;
const int i13 = im / FC_mul_mm_ne12;
// Batch offsets for srcA and srcB
const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03;
// Tile dimensions
constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X;
@@ -9473,10 +9480,10 @@ kernel void kernel_mul_mm(
short il = il0;
const int i12 = im%args.ne12;
const int i13 = im/args.ne12;
const int i12 = im % FC_mul_mm_ne12;
const int i13 = im / FC_mul_mm_ne12;
const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03;
const short offset1 = il0/nl;
device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;

View File

@@ -104,6 +104,8 @@ set(GGML_OPENCL_KERNELS
mul_mv_id_mxfp4_f32_flat
gemm_moe_q4_0_f32_ns
gemv_moe_q4_0_f32_ns
gemm_moe_q4_1_f32_ns
gemv_moe_q4_1_f32_ns
gemm_moe_mxfp4_f32
gemv_moe_mxfp4_f32
gemm_moe_mxfp4_f32_ns

View File

@@ -544,6 +544,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
cl_kernel kernel_convert_block_q4_0_trans4_ns, kernel_restore_block_q4_0_trans4_ns;
cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1;
cl_kernel kernel_convert_block_q4_1_trans4_ns, kernel_restore_block_q4_1_trans4_ns;
cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans;
cl_kernel kernel_convert_block_mxfp4_trans4_ns, kernel_restore_block_mxfp4_trans4_ns;
cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans;
@@ -602,6 +603,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4;
cl_kernel kernel_timestep_embedding;
cl_kernel kernel_gemv_moe_q4_0_f32_ns, kernel_gemm_moe_q4_0_f32_ns;
cl_kernel kernel_gemv_moe_q4_1_f32_ns, kernel_gemm_moe_q4_1_f32_ns;
cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32;
cl_kernel kernel_gemv_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns;
cl_kernel kernel_moe_reorder_b;
@@ -958,6 +960,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
CL_CHECK((backend_ctx->kernel_restore_block_q4_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1_noshuffle", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_q4_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1_trans4_ns", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_q4_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1_trans4_ns", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans4_ns", &err), err));
@@ -2856,6 +2860,38 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
" -cl-mad-enable "
" -cl-fast-relaxed-math";
// gemv_moe_q4_1_f32_ns
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "gemv_moe_q4_1_f32_ns.cl.h"
};
#else
const std::string kernel_src = read_file("gemv_moe_q4_1_f32_ns.cl");
#endif
cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
CL_CHECK((backend_ctx->kernel_gemv_moe_q4_1_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q4_1_f32_ns", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// gemm_moe_q4_1_f32_ns
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "gemm_moe_q4_1_f32_ns.cl.h"
};
#else
const std::string kernel_src = read_file("gemm_moe_q4_1_f32_ns.cl");
#endif
cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
CL_CHECK((backend_ctx->kernel_gemm_moe_q4_1_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q4_1_f32_ns", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// gemv_moe_mxfp4_f32
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -3749,11 +3785,14 @@ struct ggml_tensor_extra_cl_q4_1 {
CL_CHECK(clReleaseMemObject(m));
m = nullptr;
}
if (q_img != nullptr) {
CL_CHECK(clReleaseMemObject(q_img));
q_img = nullptr;
}
// Currently, q_img and d_img are only initialized when SMALL_ALLOC is
// enabled. They point to the images in ggml_backend_opencl_buffer_context.
// So, there is no need to release them here.
// TODO: initialize them for non SMALL_PATH path, or remove them.
q_img = nullptr;
d_img = nullptr;
m_img = nullptr;
size_q = 0;
@@ -4189,6 +4228,35 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm
return GGML_STATUS_SUCCESS;
}
// The optimized gemm and gemv kernels are used for large matrices without batch.
// tensor is the quantized weights matrix.
inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {
int64_t threshold_ne0 = 512;
int64_t threshold_ne1 = 512;
if (!backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) &&
backend_ctx->adreno_cl_compiler_version.type != DX) {
threshold_ne0 = 128;
threshold_ne1 = 128;
}
return tensor->ne[0] >= threshold_ne0 && tensor->ne[1] >= threshold_ne1 &&
tensor->ne[2] == 1 && tensor->ne[3] == 1;
}
inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {
GGML_UNUSED(backend_ctx);
int ne01 = tensor->ne[1];
return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0);
}
inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {
bool adreno_kernel = use_adreno_kernels(backend_ctx, tensor);
size_t elem_num = tensor->ne[0] * tensor->ne[1] * tensor->ne[2] * tensor->ne[3];
return ((elem_num < 128 * 1024 * 1024) && adreno_kernel); // max element num: 2**27
}
static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *)dev->context;
ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx;
@@ -4385,6 +4453,18 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
}
}
// q4_0, q8_0 and mxfp4 have general MUL_MAT_ID support,
// the quantizations here currently do not - they are only supported by Adreno with certain shapes
if (op->src[0]->type == GGML_TYPE_Q4_1) {
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
if (op->src[1]->type == GGML_TYPE_F32) {
return use_adreno_moe_kernels(backend_ctx, op->src[0])
&& ggml_is_contiguous(op->src[0])
&& ggml_is_contiguous(op->src[1]);
}
#endif
return false;
}
return false;
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
@@ -4555,6 +4635,12 @@ struct ggml_backend_opencl_buffer_context {
for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0_in_use) {
delete e;
}
for (ggml_tensor_extra_cl_q4_1 * e : temp_tensor_extras_q4_1) {
delete e;
}
for (ggml_tensor_extra_cl_q4_1 * e : temp_tensor_extras_q4_1_in_use) {
delete e;
}
for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4) {
delete e;
}
@@ -4868,35 +4954,6 @@ static enum ggml_status ggml_backend_opencl_buffer_init_tensor(ggml_backend_buff
return GGML_STATUS_SUCCESS;
}
// The optimized gemm and gemv kernels are used for large matrices without batch.
// tensor is the quantized weights matrix.
inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {
int64_t threshold_ne0 = 512;
int64_t threshold_ne1 = 512;
if (!backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) &&
backend_ctx->adreno_cl_compiler_version.type != DX) {
threshold_ne0 = 128;
threshold_ne1 = 128;
}
return tensor->ne[0] >= threshold_ne0 && tensor->ne[1] >= threshold_ne1 &&
tensor->ne[2] == 1 && tensor->ne[3] == 1;
}
inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {
GGML_UNUSED(backend_ctx);
int ne01 = tensor->ne[1];
return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0);
}
inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {
bool adreno_kernel = use_adreno_kernels(backend_ctx, tensor);
size_t elem_num = tensor->ne[0] * tensor->ne[1] * tensor->ne[2] * tensor->ne[3];
return ((elem_num < 128 * 1024 * 1024) && adreno_kernel); // max element num: 2**27
}
static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device);
@@ -5097,15 +5154,54 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
// Adreno moe q4_1 kernel needs special transpose and unshuffling
if (use_adreno_moe_kernels(backend_ctx, tensor)) {
cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1_trans4_ns;
int ne00 = tensor->ne[0];
int ne01 = tensor->ne[1];
int ne02 = tensor->ne[2];
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->m));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)};
size_t local_work_size[3] = {64, 2, 1};
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clWaitForEvents(1, &evt));
CL_CHECK(clReleaseMemObject(data_device));
// Create image for Q
cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32};
cl_image_desc img_desc_q = {
CL_MEM_OBJECT_IMAGE1D_BUFFER,
static_cast<size_t>(ggml_nelements(tensor) / 8),
0, 0, 0, 0, 0, 0, 0,
{ extra->q }
};
extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err);
tensor->extra = extra;
return;
}
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
// normal q4_1 repack
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1;
if (use_adreno_kernels(backend_ctx, tensor)) {
kernel = backend_ctx->kernel_convert_block_q4_1_noshuffle;
}
#else
#else
cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1;
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d));
@@ -5862,6 +5958,36 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
ggml_tensor_extra_cl_q4_1 * extra = (ggml_tensor_extra_cl_q4_1 *)tensor->extra;
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
if (use_adreno_moe_kernels(backend_ctx, tensor)) {
cl_int err;
cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
ggml_nbytes(tensor), NULL, &err);
CL_CHECK(err);
cl_kernel kernel = backend_ctx->kernel_restore_block_q4_1_trans4_ns;
int ne00 = tensor->ne[0];
int ne01 = tensor->ne[1];
int ne02 = tensor->ne[2];
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->m));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &data_device));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne01));
size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)};
size_t local_work_size[3] = {64, 2, 1};
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clWaitForEvents(1, &evt));
CL_CHECK(clEnqueueReadBuffer(
queue, data_device, CL_TRUE, offset,
size, data, 0, NULL, NULL));
CL_CHECK(clReleaseMemObject(data_device));
return;
}
if (use_adreno_kernels(backend_ctx, tensor)) {
static ggml_cl_buffer buf_trans_q;
static ggml_cl_buffer buf_trans_m;
@@ -12862,6 +12988,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
#ifdef GGML_OPENCL_SOA_Q
ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra;
ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra;
ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra;
ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra;
#endif
@@ -13131,6 +13258,179 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
break;
}
case GGML_TYPE_Q4_1: {
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
if (use_adreno_moe_kernels(backend_ctx, src0)) {
cl_int status;
size_t local_size[3] = {64, 2, 1};
size_t global_size[3] = {64, 2, 1};
if (ne12 == 1) { // for gemv
kernel = backend_ctx->kernel_gemv_moe_q4_1_f32_ns;
cl_mem src1_sub_buffer, buf_src1_image, buf_src2;
// create a sub_buffer for src2
cl_buffer_region region;
region.origin = offset2;
region.size = ne20 * ne21 * sizeof(int);
buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
CL_CHECK(status);
// set thread grid
global_size[0] = static_cast<size_t>(ne01);
global_size[1] = 4;
global_size[2] = static_cast<size_t>(ne20);
local_size[1] = 4;
// create a sub_buffer for src1
region.origin = offset1;
region.size = ne10 * ne11 * ne12 * sizeof(float);
src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
CL_CHECK(status);
// create image for src1
cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}};
buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
CL_CHECK(status);
// Set kernel args
int arg_idx = 0;
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->q));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->d));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->m));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11));
// launch kernel
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst);
// deallocate sub buffers and images
CL_CHECK(clReleaseMemObject(src1_sub_buffer));
CL_CHECK(clReleaseMemObject(buf_src1_image));
CL_CHECK(clReleaseMemObject(buf_src2));
} else { // for gemm
kernel = backend_ctx->kernel_gemm_moe_q4_1_f32_ns;
if (strstr(src0->name, "as") != NULL) {
moe_router_reoerder(backend, src2, ne20);
}
cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image;
cl_mem buf_src2, buf_src2_emap;
cl_buffer_region region;
region.origin = 0;
region.size = sizeof(int) * max_post_router_tile * n_tile_size;
buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
CL_CHECK(status);
region.origin = 0;
region.size = sizeof(short) * max_post_router_tile;
buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
CL_CHECK(status);
// Reorder activations
// create a sub_buffer for src1
region.origin = offset1;
region.size = ne10 * ne11 * ne12 * sizeof(float);
sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
CL_CHECK(status);
// Create image for reordered src1
// Use pre-allocated placeholder
region.origin = 0;
region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float);
backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size);
buf_src1_reordered = clCreateSubBuffer(
backend_ctx->prealloc_act_trans.buffer,
0,
CL_BUFFER_CREATE_TYPE_REGION,
&region,
&status);
CL_CHECK(status);
cl_image_format image_format_buf_src1;
cl_image_desc image_desc_buf_src1;
image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}};
image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
CL_CHECK(status);
unsigned short map_ratio = ne20 / ne11;
GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n");
CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre));
CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2));
CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered));
CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer)));
CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00));
CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio));
CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size));
size_t reorder_b_local_size[3] = {256, 1, 1};
size_t reorder_b_global_size[3] = {static_cast<size_t>(((ne00 / 4) + 255) / 256 * 256), static_cast<size_t>(max_post_router_tile * n_tile_size), 1};
// Dispatch reorder kernel
backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst);
// MoE kernel prepare
// Create sub buffer for dst
region.origin = offsetd;
region.size = ne0 * ne1 * ne2 * sizeof(float);
sub_buf_dst = clCreateSubBuffer(
extrad->data_device,
0,
CL_BUFFER_CREATE_TYPE_REGION,
&region,
&status);
CL_CHECK(status);
// Create image for dst
cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT};
cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}};
buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status);
CL_CHECK(status);
// Set kernel args
int arg_idx = 0;
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->q_img));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->d));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->m));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer)));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01));
// set thread grid
global_size[1] = static_cast<size_t>((ne01 + 63) / 64);
global_size[2] = static_cast<size_t>(max_post_router_tile);
local_size[1] = 1;
local_size[2] = 1;
// Dispatch kernel
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst);
clReleaseMemObject(sub_buf_src1_pre);
clReleaseMemObject(buf_src1_reordered);
clReleaseMemObject(image_src1_reordered);
clReleaseMemObject(buf_src2);
clReleaseMemObject(buf_src2_emap);
clReleaseMemObject(sub_buf_dst);
clReleaseMemObject(buf_dst_image);
}
return;
}
#endif //GGML_OPENCL_USE_ADRENO_KERNELS
}
case GGML_TYPE_Q8_0: {
#ifdef GGML_OPENCL_SOA_Q
kernel = backend_ctx->kernel_mul_mv_id_q8_0_f32_flat;

View File

@@ -370,6 +370,96 @@ kernel void kernel_restore_block_q4_1_noshuffle(
}
}
kernel void kernel_convert_block_q4_1_trans4_ns(
__global struct block_q4_1 * src0,
__global uint * dst_q,
__global half * dst_d,
__global half * dst_m,
uint ne00,
uint ne01
) {
uint i00 = get_global_id(1);
uint i01 = get_global_id(0);
uint i02 = get_global_id(2);
uint ne00_blk = ne00 / QK4_1;
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
global struct block_q4_1 * b = src0 + src_blk_offset;
dst_d[dst_blk_offset] = b->d;
dst_m[dst_blk_offset] = b->m;
// extract quantization and unshuffle
ushort8 pre_block = ((global ushort8 *)(&(b->qs[0])))[0];
ushort8 post_block = (ushort8)(0);
uchar * pre_block_ptr = (uchar *)(&pre_block);
uchar * post_block_ptr = (uchar *)(&post_block);
for (int i = 0; i < QK4_1 / 4; ++i) {
uchar x0 = pre_block_ptr[2*i + 0];
uchar x1 = pre_block_ptr[2*i + 1];
post_block_ptr[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);
post_block_ptr[i + QK4_1 / 4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);
}
uint4 q_block = as_uint4(post_block);
uint offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01;
dst_q[offset] = q_block.x;
dst_q[offset + ne01] = q_block.y;
dst_q[offset + ne01 * 2] = q_block.z;
dst_q[offset + ne01 * 3] = q_block.w;
}
kernel void kernel_restore_block_q4_1_trans4_ns(
__global uint * src_q,
__global half * src_d,
__global half * src_m,
__global struct block_q4_1 * dst0,
uint ne00,
uint ne01
) {
int i00 = get_global_id(1);
uint i01 = get_global_id(0);
uint i02 = get_global_id(2);
uint ne00_blk = ne00 / QK4_1;
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
uint src_dm_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
__global struct block_q4_1 * b = dst0 + dst_blk_offset;
b->d = src_d[src_dm_offset];
b->m = src_m[src_dm_offset];
// collect transposed quantization parts for a block
uint src_q_offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01;
uint4 q_block;
q_block.x = src_q[src_q_offset];
q_block.y = src_q[src_q_offset + ne01];
q_block.z = src_q[src_q_offset + ne01 * 2];
q_block.w = src_q[src_q_offset + ne01 * 3];
ushort8 post_block = as_ushort8(q_block);
ushort8 pre_block = (ushort8)(0);
uchar * pre_block_ptr = (uchar *)(&pre_block);
uchar * post_block_ptr = (uchar *)(&post_block);
for (int i = 0; i < QK4_0 / 4; ++i) {
uchar x0 = post_block_ptr[i + 0];
uchar x1 = post_block_ptr[i + QK4_0 / 4];
pre_block_ptr[2 * i + 0] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);
pre_block_ptr[2 * i + 1] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);
}
((__global ushort8 *)(&(b->qs[0])))[0] = pre_block;
}
//------------------------------------------------------------------------------
// block_mxfp4
//------------------------------------------------------------------------------

View File

@@ -0,0 +1,254 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable
#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable
#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable
#define TILESIZE_K 16
#define TILESIZE_M 64
#define TILESIZE_N 32
#define dequantize_q4_1(q4, a_f16, scale, m) \
a_f16.s0 = (half)(q4.s0 & 0x000F) * scale + m; \
a_f16.s1 = (half)((q4.s0 & 0x00F0) >> 4) * scale + m; \
a_f16.s2 = (half)((q4.s0 & 0x0F00) >> 8) * scale + m; \
a_f16.s3 = (half)((q4.s0 & 0xF000) >> 12) * scale + m; \
a_f16.s4 = (half)(q4.s1 & 0x000F) * scale + m; \
a_f16.s5 = (half)((q4.s1 & 0x00F0) >> 4) * scale + m; \
a_f16.s6 = (half)((q4.s1 & 0x0F00) >> 8) * scale + m; \
a_f16.s7 = (half)((q4.s1 & 0xF000) >> 12) * scale + m; \
a_f16.s8 = (half)(q4.s2 & 0x000F) * scale + m; \
a_f16.s9 = (half)((q4.s2 & 0x00F0) >> 4) * scale + m; \
a_f16.sa = (half)((q4.s2 & 0x0F00) >> 8) * scale + m; \
a_f16.sb = (half)((q4.s2 & 0xF000) >> 12) * scale + m; \
a_f16.sc = (half)(q4.s3 & 0x000F) * scale + m; \
a_f16.sd = (half)((q4.s3 & 0x00F0) >> 4) * scale + m; \
a_f16.se = (half)((q4.s3 & 0x0F00) >> 8) * scale + m; \
a_f16.sf = (half)((q4.s3 & 0xF000) >> 12) * scale + m; \
#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \
acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \
acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \
acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \
acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \
acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \
acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \
acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \
acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \
acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \
acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \
acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \
acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \
acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \
acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \
acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \
acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \
acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \
acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \
acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \
acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \
acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \
acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \
acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \
acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \
acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \
acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \
acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \
acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \
acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \
acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \
acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \
acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \
c_reg.lo += convert_float8(acc.lo); \
c_reg.hi += convert_float8(acc.hi); \
acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \
acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \
acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \
acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \
acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \
acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \
acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \
acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \
acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \
acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \
acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \
acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \
acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \
acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \
acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \
acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \
acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \
acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \
acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \
acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \
acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \
acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \
acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \
acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \
acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \
acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \
acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \
acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \
acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \
acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \
acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \
acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \
c_reg.lo += convert_float8(acc.lo); \
c_reg.hi += convert_float8(acc.hi); \
__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair
kernel void kernel_gemm_moe_q4_1_f32_ns(
__read_only image1d_buffer_t src0_q,
__global half * src0_d,
__global half * src0_m,
__read_only image1d_buffer_t src1,
__global uint * src2,
__global ushort * src2_emap,
__write_only image1d_buffer_t dst,
__global int * total_tiles,
uint ne00,
uint ne01
) {
uint block_id_m = get_global_id(1); // m_tile
uint block_id_n = get_global_id(2); // n_tile
// Boundary check
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
return;
}
__private half16 reg_a;
__private float32 reg_c = (float32)(0);
__local half4 shared_b[128];
const ushort expert_id = src2_emap[block_id_n];
const uint row = block_id_m * TILESIZE_M;
const uint col = block_id_n * TILESIZE_N;
uint sub_block_id_m = get_local_id(0);
uint2 b_global_offset;
b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00;
b_global_offset.y = b_global_offset.x + (16 * ne00);
uint2 b_local_offset;
b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2);
b_local_offset.y = b_local_offset.x + 16;
// Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks
for (uint step = 0; step < ne00; step += TILESIZE_K * 2) {
// First sub-block
uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5);
uint b_sub_offset = col * ne00 + step;
// Load scale and m for current Q4_1 block
uint sm_offset = s_sub_offset + get_global_id(0);
half s = src0_d[sm_offset];
half m = src0_m[sm_offset];
// Load 16 q (64-bits) in transposed layout
uint2 q4x16;
q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x;
q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x;
// Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements
float8 bx8_f32;
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
// Convert to half and store to LM to share within the subgroup
half8 bx8_f16 = convert_half8(bx8_f32);
shared_b[b_local_offset.x] = bx8_f16.lo;
shared_b[b_local_offset.y] = bx8_f16.hi;
// Dequantization
dequantize_q4_1(as_ushort4(q4x16), reg_a, s, m);
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
// 32 16x16 fp16 dot product with 8 elements reduction for better precision
half16 acc;
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
// Repeat for second sub-block
uint half_step = step + TILESIZE_K;
q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
b_sub_offset = col * ne00 + half_step;
// Load next 16 q (64-bits) in transposed layout
q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x;
q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x;
// Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
// Convert to half and store to LM to share within the subgroup
bx8_f16 = convert_half8(bx8_f32);
shared_b[b_local_offset.x] = bx8_f16.lo;
shared_b[b_local_offset.y] = bx8_f16.hi;
// Dequantization
dequantize_q4_1(as_ushort4(q4x16), reg_a, s, m);
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
// 32 16x16 fp16 dot product with 3-levels reduction for better precision
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
}
// Load poster router and share in LM
__local uint out_idx[TILESIZE_N];
if (get_local_id(0) < TILESIZE_N) {
uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)];
if (idx == 0xFFFFFFFF) {
idx = src2[block_id_n * TILESIZE_N + 0];
}
out_idx[get_local_id(0)] = idx * ne01;
}
barrier(CLK_LOCAL_MEM_FENCE);
// Scatter results back to original position in output grid
uint m_offset = row + get_local_id(0);
write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1));
write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2));
write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3));
write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4));
write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5));
write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6));
write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7));
write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8));
write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9));
write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa));
write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb));
write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc));
write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd));
write_imagef(dst, out_idx[14] + m_offset, (reg_c.se));
write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf));
write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg));
write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh));
write_imagef(dst, out_idx[18] + m_offset, (reg_c.si));
write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj));
write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk));
write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl));
write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm));
write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn));
write_imagef(dst, out_idx[24] + m_offset, (reg_c.so));
write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp));
write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq));
write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr));
write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss));
write_imagef(dst, out_idx[29] + m_offset, (reg_c.st));
write_imagef(dst, out_idx[30] + m_offset, (reg_c.su));
write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv));
// Store zero padding parts to the index of first output in tile, override correct result in the end
barrier(CLK_GLOBAL_MEM_FENCE);
write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0));
}

View File

@@ -0,0 +1,119 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define QK_Q4_1 32
#define N_SIMDGROUP 4
#define SIMDGROUP_WIDTH 64
static inline float8 q4_1_to_fp32_packed8(ushort2 q4x8, half s, half m) {
float8 fp32x8;
fp32x8.s0 = (float)((q4x8.s0 & 0x000F) * s + m);
fp32x8.s1 = (float)(((q4x8.s0 & 0x00F0) >> 4) * s + m);
fp32x8.s2 = (float)(((q4x8.s0 & 0x0F00) >> 8) * s + m);
fp32x8.s3 = (float)(((q4x8.s0 & 0xF000) >> 12) * s + m);
fp32x8.s4 = (float)((q4x8.s1 & 0x000F) * s + m);
fp32x8.s5 = (float)(((q4x8.s1 & 0x00F0) >> 4) * s + m);
fp32x8.s6 = (float)(((q4x8.s1 & 0x0F00) >> 8) * s + m);
fp32x8.s7 = (float)(((q4x8.s1 & 0xF000) >> 12) * s + m);
return fp32x8;
}
__attribute__((qcom_reqd_sub_group_size("half")))
__kernel void kernel_gemv_moe_q4_1_f32_ns(
__global uint * src0_q,
__global half * src0_d,
__global half * src0_m,
__read_only image1d_buffer_t src1,
__global uint * src2,
__global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne11
) {
uint i01 = get_global_id(0);
uint i20 = get_global_id(2);
uint sgid = get_local_id(1);
uint slid = get_sub_group_local_id();
uint i11 = i20 % ne11;
uint expert_id = src2[i20];
uint expert_offset = expert_id * ne00 * ne01 / 32;
__private float sum = 0.0f; // each thread calculate partial sum of one output
// loop along ne00 in block granularity, skip 4 blocks every iter
for (uint ib00 = sgid; ib00 < (ne00 / QK_Q4_1); ib00 += N_SIMDGROUP) {
// load one block of q
uint4 regQ;
uint block_offset = expert_offset * 4 + ib00 * ne01 * 4 + i01;
regQ.s0 = src0_q[block_offset];
regQ.s1 = src0_q[block_offset + ne01];
regQ.s2 = src0_q[block_offset + ne01 * 2];
regQ.s3 = src0_q[block_offset + ne01 * 3];
uint offset = i11 * ne00 / 4 + ib00 * 8;
half regM = src0_m[ib00 * ne01 + i01 + expert_offset];
half regS = src0_d[ib00 * ne01 + i01 + expert_offset];
float8 fp32x8 = q4_1_to_fp32_packed8(as_ushort2(regQ.s0), regS, regM);
float4 shared_y4;
shared_y4 = read_imagef(src1, (offset + 0));
float4 acc = shared_y4 * fp32x8.lo;
shared_y4 = read_imagef(src1, (offset + 1));
acc += shared_y4 * fp32x8.hi;
fp32x8 = q4_1_to_fp32_packed8(as_ushort2(regQ.s1), regS, regM);
shared_y4 = read_imagef(src1, (offset + 2));
acc += shared_y4 * fp32x8.lo;
shared_y4 = read_imagef(src1, (offset + 3));
acc += shared_y4 * fp32x8.hi;
fp32x8 = q4_1_to_fp32_packed8(as_ushort2(regQ.s2), regS, regM);
shared_y4 = read_imagef(src1, (offset + 4));
acc += shared_y4 * fp32x8.lo;
shared_y4 = read_imagef(src1, (offset + 5));
acc += shared_y4 * fp32x8.hi;
fp32x8 = q4_1_to_fp32_packed8(as_ushort2(regQ.s3), regS, regM);
shared_y4 = read_imagef(src1, (offset + 6));
acc += shared_y4 * fp32x8.lo;
shared_y4 = read_imagef(src1, (offset + 7));
acc += shared_y4 * fp32x8.hi;
sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));
}
// reduction in local memory, assumes #subgroups=4
__local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];
if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;
if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;
if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;
barrier(CLK_LOCAL_MEM_FENCE);
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
// 1 outputs per thread in subgroup 0
if (sgid == 0) {
dst = dst + (offsetd >> 2);
dst[i01 + i20 * ne01] = sum;
}
}

View File

@@ -3561,7 +3561,7 @@ struct test_relu_sqr : public test_case {
// and dispatches a single fused kernel.
struct test_snake_fuse : public test_case {
const ggml_type type;
const std::array<int64_t, 2> ne; // [T, C]
const std::array<int64_t, 4> ne; // [T, C, D2, D3]
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
@@ -3586,11 +3586,11 @@ struct test_snake_fuse : public test_case {
}
test_snake_fuse(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 2> ne = {256, 192})
std::array<int64_t, 4> ne = {256, 192, 1, 1})
: type(type), ne(ne) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * x = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]);
ggml_tensor * x = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
ggml_set_name(x, "x");
ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, ne[1]);
@@ -7558,11 +7558,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
// SNAKE activation fusion: x + sin(a*x)^2 * inv_b
for (ggml_type type : { GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16 }) {
test_cases.emplace_back(new test_snake_fuse(type, { 5, 7})); // primes sub-block
test_cases.emplace_back(new test_snake_fuse(type, { 33, 32})); // boundary
test_cases.emplace_back(new test_snake_fuse(type, {1025, 13})); // large prime, grid-stride
test_cases.emplace_back(new test_snake_fuse(type, { 128, 16})); // power-of-two
test_cases.emplace_back(new test_snake_fuse(type, { 256, 192})); // BigVGAN-ish
test_cases.emplace_back(new test_snake_fuse(type, { 5, 7, 1, 1})); // primes sub-block
test_cases.emplace_back(new test_snake_fuse(type, { 33, 32, 1, 1})); // boundary
test_cases.emplace_back(new test_snake_fuse(type, {1025, 13, 1, 1})); // large prime, grid-stride
test_cases.emplace_back(new test_snake_fuse(type, { 128, 16, 1, 1})); // power-of-two
test_cases.emplace_back(new test_snake_fuse(type, { 256, 192, 1, 1})); // BigVGAN-ish
// higher-rank shapes: matcher must reject fusion, fallback to naive chain
test_cases.emplace_back(new test_snake_fuse(type, { 64, 32, 2, 1})); // ne[2] > 1
test_cases.emplace_back(new test_snake_fuse(type, { 64, 32, 1, 2})); // ne[3] > 1
test_cases.emplace_back(new test_snake_fuse(type, { 64, 32, 2, 3})); // ne[2] > 1 and ne[3] > 1
}
// glu ops
@@ -9093,9 +9097,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
// SNAKE activation fusion at BigVGAN scale (T=7680 = 24 kHz x 320 ms, C=192)
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F32, {7680, 192}));
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F16, {7680, 192}));
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_BF16, {7680, 192}));
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F32, {7680, 192, 1, 1}));
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F16, {7680, 192, 1, 1}));
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_BF16, {7680, 192, 1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416));