Compare commits

..

6 Commits

Author SHA1 Message Date
lucy
d05fe1d7da fix: CUDA device PCI bus ID de-dupe OOMing (ignoring other 3 gpus entirely) (#22533)
* fix: CUDA device PCI bus ID detection for multi-GPU de-dupe

* HIP, MUSA macros

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2026-05-02 22:19:25 +02:00
Georgi Gerganov
0754b7b6fe server : avoid checkpoint data host copies (#22558)
* server : avoid checkpoint data host copies

* llama : refactor llama_io_read_i
2026-05-02 18:03:25 +03:00
JusteLeo
09294365a9 ggml-virtgpu: fix circular dependency in headers (#22557) 2026-05-02 21:28:50 +08:00
Csaba Kecskemeti
63d93d1733 convert : disable uint types (#18908) 2026-05-02 09:05:59 +03:00
Shawn Gu
c5a3bc39b1 opencl: Adreno optimization for MoE - MxFP4 (#22301)
* MoE Mxfp4 CLC kernel added, router reorder on GPU

* Pass test-backend-ops for MoE mxfp4 Adreno CLC

* remove putenv in llama-model.cpp

* fix indent style and whitespace

* opencl: remove unnecessary headers

* opencl: do not save cl_program objects

* opencl: remove unnecessary assert

* fix precision issue

---------

Co-authored-by: Li He <lih@qti.qualcomm.com>
2026-05-01 23:02:24 -07:00
Johannes Gäßler
9dbb372610 Github: update issue templates (#22594) 2026-05-02 07:56:13 +02:00
26 changed files with 1196 additions and 158 deletions

View File

@@ -12,6 +12,8 @@ body:
after recreating the CMake build directory and with `-DGGML_CCACHE=OFF`.
If the compilation succeeds with ccache disabled you should be able to permanently fix the issue
by clearing `~/.cache/ccache` (on Linux).
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
- type: textarea
id: commit
attributes:

View File

@@ -1,5 +1,5 @@
name: Bug (model use)
description: Something goes wrong when using a model (in general, not specific to a single llama.cpp module).
description: Something goes wrong when running a model (crashes, garbled outputs, etc.).
title: "Eval bug: "
labels: ["bug-unconfirmed", "model evaluation"]
body:
@@ -12,6 +12,8 @@ body:
If you encountered the issue while using an external UI (e.g. ollama),
please reproduce your issue using one of the examples/binaries in this repository.
The `llama-completion` binary can be used for simple and reproducible model inference.
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
- type: textarea
id: version
attributes:

View File

@@ -10,6 +10,8 @@ body:
This issue template is intended for miscellaneous bugs that don't fit into any other category.
If you encountered the issue while using an external UI (e.g. ollama),
please reproduce your issue using one of the examples/binaries in this repository.
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
- type: textarea
id: version
attributes:

View File

@@ -8,6 +8,8 @@ body:
value: |
[Please post your idea first in Discussion if there is not yet a consensus for this enhancement request. This will help to keep this issue tracker focused on enhancements that the community has agreed needs to be implemented.](https://github.com/ggml-org/llama.cpp/discussions/categories/ideas)
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
- type: checkboxes
id: prerequisites
attributes:

View File

@@ -8,6 +8,8 @@ body:
value: |
Don't forget to check for any [duplicate research issue tickets](https://github.com/ggml-org/llama.cpp/issues?q=is%3Aopen+is%3Aissue+label%3A%22research+%F0%9F%94%AC%22)
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
- type: checkboxes
id: research-stage
attributes:

View File

@@ -9,6 +9,8 @@ body:
Don't forget to [check for existing refactor issue tickets](https://github.com/ggml-org/llama.cpp/issues?q=is%3Aopen+is%3Aissue+label%3Arefactoring) in case it's already covered.
Also you may want to check [Pull request refactor label as well](https://github.com/ggml-org/llama.cpp/pulls?q=is%3Aopen+is%3Apr+label%3Arefactoring) for duplicates too.
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
- type: textarea
id: background-description
attributes:

View File

@@ -13232,17 +13232,18 @@ class LazyTorchTensor(gguf.LazyBase):
}
# only used when byteswapping data. Only correct size is needed
# TODO: uncomment uint64, uint32, and uint16, ref: https://github.com/pytorch/pytorch/issues/58734
_dtype_byteswap_map: dict[torch.dtype, type] = {
torch.float64: np.float64,
torch.float32: np.float32,
torch.bfloat16: np.float16,
torch.float16: np.float16,
torch.int64: np.int64,
torch.uint64: np.uint64,
# torch.uint64: np.uint64,
torch.int32: np.int32,
torch.uint32: np.uint32,
# torch.uint32: np.uint32,
torch.int16: np.int16,
torch.uint16: np.uint16,
# torch.uint16: np.uint16,
torch.int8: np.int8,
torch.uint8: np.uint8,
torch.bool: np.uint8,

View File

@@ -5431,8 +5431,8 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
dev_ctx->description = prop.name;
char pci_bus_id[16] = {};
snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
char pci_bus_id[32] = {};
CUDA_CHECK(cudaDeviceGetPCIBusId(pci_bus_id, sizeof(pci_bus_id), i));
dev_ctx->pci_bus_id = pci_bus_id;
dev_ctx->op_offload_min_batch_size = min_batch_size;

View File

@@ -55,6 +55,7 @@
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
#define cudaDeviceGetAttribute hipDeviceGetAttribute
#define cudaDeviceGetPCIBusId hipDeviceGetPCIBusId
#define cudaDeviceProp hipDeviceProp_t
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t

View File

@@ -39,6 +39,7 @@
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
#define cudaDeviceGetPCIBusId musaDeviceGetPCIBusId
#define cudaDeviceProp musaDeviceProp
#define cudaDeviceSynchronize musaDeviceSynchronize
#define cudaError_t musaError_t

View File

@@ -107,6 +107,10 @@ set(GGML_OPENCL_KERNELS
mul_mv_id_mxfp4_f32_flat
gemm_moe_mxfp4_f32
gemv_moe_mxfp4_f32
gemm_moe_mxfp4_f32_ns
gemv_moe_mxfp4_f32_ns
moe_reorder_b
moe_sort_by_expert
mul_mm_f32_f32_l4_lm
mul_mm_f16_f32_l4_lm
mul_mm_q4_0_f32_l4_lm

View File

@@ -416,6 +416,15 @@ struct ggml_backend_opencl_context {
ggml_cl_buffer prealloc_src0;
ggml_cl_buffer prealloc_src1;
// prealloc buffers for MoE router table preprocess
bool toggle_reorder = false;
ggml_cl_buffer prealloc_post_router;
ggml_cl_buffer prealloc_emap;
ggml_cl_buffer prealloc_hist;
ggml_cl_buffer prealloc_tile_offset;
ggml_cl_buffer prealloc_total_tiles;
ggml_cl_buffer prealloc_slot_counter;
cl_program program_add;
cl_program program_add_id;
cl_program program_clamp;
@@ -531,6 +540,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1;
cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans;
cl_kernel kernel_convert_block_mxfp4_trans4_ns, kernel_restore_block_mxfp4_trans4_ns;
cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans;
cl_kernel kernel_convert_block_q6_K_noshuffle, kernel_restore_block_q6_K_noshuffle;
cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
@@ -587,6 +597,9 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4;
cl_kernel kernel_timestep_embedding;
cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32;
cl_kernel kernel_gemv_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns;
cl_kernel kernel_moe_reorder_b;
cl_kernel kernel_moe_histogram, kernel_moe_scan, kernel_moe_fill, kernel_moe_scatter;
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
cl_kernel kernel_mul_mv_id_q8_0_f32, kernel_mul_mv_id_q8_0_f32_flat;
cl_kernel kernel_mul_mv_id_mxfp4_f32;
@@ -945,6 +958,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
CL_CHECK((backend_ctx->kernel_restore_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans4_ns", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans4_ns", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err));
@@ -2864,6 +2879,77 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
// gemv_moe_mxfp4_f32_ns
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "gemv_moe_mxfp4_f32_ns.cl.h"
};
#else
const std::string kernel_src = read_file("gemv_moe_mxfp4_f32_ns.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
CL_CHECK((backend_ctx->kernel_gemv_moe_mxfp4_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_mxfp4_f32_ns", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// gemm_moe_mxfp4_f32_ns
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "gemm_moe_mxfp4_f32_ns.cl.h"
};
#else
const std::string kernel_src = read_file("gemm_moe_mxfp4_f32_ns.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_mxfp4_f32_ns", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// moe_reorder_b
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "moe_reorder_b.cl.h"
};
#else
const std::string kernel_src = read_file("moe_reorder_b.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
CL_CHECK((backend_ctx->kernel_moe_reorder_b = clCreateKernel(prog, "kernel_moe_reorder_b", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// moe_sort_by_expert
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "moe_sort_by_expert.cl.h"
};
#else
const std::string kernel_src = read_file("moe_sort_by_expert.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
CL_CHECK((backend_ctx->kernel_moe_histogram = clCreateKernel(prog, "kernel_moe_histogram", &err), err));
CL_CHECK((backend_ctx->kernel_moe_scan = clCreateKernel(prog, "kernel_moe_scan", &err), err));
CL_CHECK((backend_ctx->kernel_moe_fill = clCreateKernel(prog, "kernel_moe_fill", &err), err));
CL_CHECK((backend_ctx->kernel_moe_scatter = clCreateKernel(prog, "kernel_moe_scatter", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// gemv_noshuffle_q6_k_f32
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -3651,13 +3737,12 @@ struct ggml_tensor_extra_cl_mxfp4 {
CL_CHECK(clReleaseMemObject(e));
e = nullptr;
}
if (q != nullptr) {
if (q_img != nullptr) {
CL_CHECK(clReleaseMemObject(q_img));
q = nullptr;
q_img = nullptr;
}
// Currently, q_img and d_img are not used. They can be image1d_buffer_t
// Currently, e_img is not used. They can be image1d_buffer_t
// that wraps around q and d to utilize image access path.
q_img = nullptr;
e_img = nullptr;
size_q = 0;
size_e = 0;
@@ -4740,7 +4825,7 @@ inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, c
inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {
GGML_UNUSED(backend_ctx);
int ne01 = tensor->ne[1];
return ((strstr(tensor->name, "ffn") != NULL) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0);
return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0);
}
inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {
@@ -5151,8 +5236,9 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
CL_CHECK(err);
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
// Adreno moe mxfp4 kernel needs special transpose and unshuffling
if (use_adreno_moe_kernels(backend_ctx, tensor)) {
cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans;
cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans4_ns;
int ne00 = tensor->ne[0];
int ne01 = tensor->ne[1];
@@ -5172,9 +5258,21 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
CL_CHECK(clReleaseMemObject(data_device));
tensor->extra = extra;
// Create image for Q
cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32};
cl_image_desc img_desc_q = {
CL_MEM_OBJECT_IMAGE1D_BUFFER,
static_cast<size_t>(ggml_nelements(tensor) / 8),
0, 0, 0, 0, 0, 0, 0,
{ extra->q }
};
extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err);
tensor->extra = extra;
return;
}
#endif
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
@@ -5912,7 +6010,7 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
if (use_adreno_moe_kernels(backend_ctx, tensor)) {
cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4_trans;
cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4_trans4_ns;
int ne00 = tensor->ne[0];
int ne01 = tensor->ne[1];
@@ -5936,7 +6034,8 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
CL_CHECK(clReleaseMemObject(data_device));
return;
}
#endif
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e));
@@ -12763,6 +12862,118 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
}
}
static void moe_router_reoerder(ggml_backend_t backend, const ggml_tensor * src, int ne20) {
cl_int err;
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *)src->extra;
cl_ulong offset = extra->offset + src->view_offs;
const int ne21 = src->ne[1];
const int nb21 = src->nb[1];
const int ne02 = nb21 / src->nb[0];
const int n_tile_size = 32;
const int max_post_router_tile = (ne20 * ne21 / n_tile_size) + ne02;
cl_buffer_region region;
region.origin = offset;
region.size = nb21 * ne21;
cl_mem original_router_buf = clCreateSubBuffer(extra->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
backend_ctx->prealloc_post_router.allocate(backend_ctx->context, sizeof(int) * max_post_router_tile * n_tile_size);
region.origin = 0;
region.size = sizeof(int) * max_post_router_tile * n_tile_size;
cl_mem post_router_buf = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
backend_ctx->prealloc_emap.allocate(backend_ctx->context, sizeof(short) * max_post_router_tile);
region.origin = 0;
region.size = sizeof(short) * max_post_router_tile;
cl_mem emap_buf = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
backend_ctx->prealloc_hist.allocate(backend_ctx->context, sizeof(int) * ne02);
region.origin = 0;
region.size = sizeof(int) * ne02;
cl_mem hist_buf = clCreateSubBuffer(backend_ctx->prealloc_hist.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
backend_ctx->prealloc_tile_offset.allocate(backend_ctx->context, sizeof(int) * ne02);
region.origin = 0;
region.size = sizeof(int) * ne02;
cl_mem tile_offset_buf = clCreateSubBuffer(backend_ctx->prealloc_tile_offset.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
backend_ctx->prealloc_slot_counter.allocate(backend_ctx->context, sizeof(int) * ne02);
region.origin = 0;
region.size = sizeof(int) * ne02;
cl_mem slot_counter_buf = clCreateSubBuffer(backend_ctx->prealloc_slot_counter.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
backend_ctx->prealloc_total_tiles.allocate(backend_ctx->context, sizeof(int));
region.origin = 0;
region.size = sizeof(int);
cl_mem total_tiles_buf = clCreateSubBuffer(backend_ctx->prealloc_total_tiles.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
// Histogram
cl_kernel kernel = backend_ctx->kernel_moe_histogram;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &original_router_buf));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &hist_buf));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &ne21));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne20));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne02));
size_t histogram_global_size[] = {(size_t)(((ne21 + 63) / 64) * 64), static_cast<size_t>(ne20), 1};
size_t histogram_local_size[] = {64, static_cast<size_t>(ne20), 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, histogram_global_size, histogram_local_size, src);
// Scan
kernel = backend_ctx->kernel_moe_scan;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &hist_buf));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &tile_offset_buf));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &total_tiles_buf));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &slot_counter_buf));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &n_tile_size));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne02));
size_t scan_global_size[] = {1};
size_t scan_local_size[] = {1};
backend_ctx->enqueue_ndrange_kernel(kernel, 1, scan_global_size, scan_local_size, src);
// Fill
kernel = backend_ctx->kernel_moe_fill;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &post_router_buf));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &total_tiles_buf));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &n_tile_size));
size_t fill_global_size[] = {(size_t)(((max_post_router_tile + 63) / 64) * 64), n_tile_size, 1};
size_t fill_local_size[] = {64, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, fill_global_size, fill_local_size, src);
// Scatter
kernel = backend_ctx->kernel_moe_scatter;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &original_router_buf));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &post_router_buf));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &emap_buf));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &tile_offset_buf));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &slot_counter_buf));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne21));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne20));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne02));
backend_ctx->enqueue_ndrange_kernel(kernel, 3, histogram_global_size, histogram_local_size, src);
CL_CHECK(clReleaseMemObject(original_router_buf));
CL_CHECK(clReleaseMemObject(hist_buf));
CL_CHECK(clReleaseMemObject(tile_offset_buf));
CL_CHECK(clReleaseMemObject(total_tiles_buf));
CL_CHECK(clReleaseMemObject(slot_counter_buf));
CL_CHECK(clReleaseMemObject(post_router_buf));
CL_CHECK(clReleaseMemObject(emap_buf));
}
static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
@@ -12824,6 +13035,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
const int ne0 = dst->ne[0];
const int ne1 = dst->ne[1];
const int ne2 = dst->ne[2];
const int r2 = ne12/ne02;
const int r3 = ne13/ne03;
@@ -12836,6 +13048,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
int nrows = 1; // number of row in src1
int ndst = 4; // number of values produced by each subgroup
const int n_tile_size = 32;
const int max_post_router_tile = (ne20 * ne21 / n_tile_size) + ne02;
cl_kernel kernel;
// subgroup mat vec
@@ -12967,11 +13182,10 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
size_t local_size[3] = {64, 2, 1};
size_t global_size[3] = {64, 2, 1};
cl_mem src1_sub_buffer, buf_src1_image, buf_src2;
int tile_size = 320;
if (ne12 == 1) { // for gemv
kernel = backend_ctx->kernel_gemv_moe_mxfp4_f32;
kernel = backend_ctx->kernel_gemv_moe_mxfp4_f32_ns;
cl_mem src1_sub_buffer, buf_src1_image, buf_src2;
// create a sub_buffer for src2
cl_buffer_region region;
@@ -12985,78 +13199,154 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
global_size[1] = 4;
global_size[2] = static_cast<size_t>(ne20);
local_size[1] = 4;
} else { // for gemm
kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32;
// preprocess router table
int num_tiles_per_expert = (ne01 + tile_size - 1) / tile_size;
void * host_src2_reorder = malloc(ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short));
void * host_src2 = malloc(ne21 * nb21);
CL_CHECK(clEnqueueReadBuffer(backend_ctx->queue, extra2->data_device, CL_TRUE, offset2, ne21 * nb21, host_src2, 0, NULL, NULL));
int total_experts = nb21 / nb20;
int out_idx = 0;
for (int i_expert = 0; i_expert < ne02; i_expert++) {
for (int i_tile = 0; i_tile < num_tiles_per_expert; i_tile++) {
for (int j = 0; j < ne21; j++) {
for (int i = 0; i < ne20; i++) {
int expert = ((int *)host_src2)[j * total_experts + i];
if (i_expert == expert) {
((short *)host_src2_reorder)[out_idx] = static_cast<short>(expert);
((short *)host_src2_reorder)[out_idx + 1] = static_cast<short>(j * ne11 + (i % ne11));
((short *)host_src2_reorder)[out_idx + 2] = static_cast<short>(j * ne20 + i);
((short *)host_src2_reorder)[out_idx + 3] = static_cast<short>(i_tile);
out_idx += 4;
}
}
}
}
}
buf_src2 = clCreateBuffer(backend_ctx->context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short), host_src2_reorder, &status);
// create a sub_buffer for src1
region.origin = offset1;
region.size = ne10 * ne11 * ne12 * sizeof(float);
src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
CL_CHECK(status);
// set thread grid
global_size[0] = static_cast<size_t>(tile_size);
global_size[2] = static_cast<size_t>(ne20 * ne21 * num_tiles_per_expert);
}
// create image for src1
cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}};
buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
CL_CHECK(status);
// create a sub_buffer for src1
cl_buffer_region region;
region.origin = offset1;
region.size = ne10 * ne11 * ne12 * sizeof(float);
src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
CL_CHECK(status);
// create image for src1
cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}};
buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
CL_CHECK(status);
// Set kernel args
int arg_idx = 0;
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->q));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->e));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01));
if (ne12 == 1) {
// Set kernel args
int arg_idx = 0;
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->q));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->e));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11));
} else {
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &tile_size));
// launch kernel
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst);
// deallocate sub buffers and images
CL_CHECK(clReleaseMemObject(src1_sub_buffer));
CL_CHECK(clReleaseMemObject(buf_src1_image));
CL_CHECK(clReleaseMemObject(buf_src2));
} else { // for gemm
kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32_ns;
// Reorder router if called from test-backend-ops or when new router is generated.
// Otherwise reuse the reordered result from previous mul_mat_id call.
if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) {
moe_router_reoerder(backend, src2, ne20);
backend_ctx->toggle_reorder = false;
}
cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image;
cl_mem buf_src2, buf_src2_emap;
cl_buffer_region region;
region.origin = 0;
region.size = sizeof(int) * max_post_router_tile * n_tile_size;
GGML_ASSERT(backend_ctx->prealloc_post_router.buffer);
buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
CL_CHECK(status);
region.origin = 0;
region.size = sizeof(short) * max_post_router_tile;
buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
CL_CHECK(status);
// Reorder activations
// create a sub_buffer for src1
region.origin = offset1;
region.size = ne10 * ne11 * ne12 * sizeof(float);
sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
CL_CHECK(status);
// Create image for reordered src1
// Use pre-allocated placeholder
region.origin = 0;
region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float);
backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size);
buf_src1_reordered = clCreateSubBuffer(
backend_ctx->prealloc_act_trans.buffer,
0,
CL_BUFFER_CREATE_TYPE_REGION,
&region,
&status);
CL_CHECK(status);
cl_image_format image_format_buf_src1;
cl_image_desc image_desc_buf_src1;
image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}};
image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
CL_CHECK(status);
unsigned short map_ratio = ne20 / ne11;
GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n");
CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre));
CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2));
CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered));
CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer)));
CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00));
CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio));
CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size));
size_t reorder_b_local_size[3] = {256, 1, 1};
size_t reorder_b_global_size[3] = {static_cast<size_t>(((ne00 / 4) + 255) / 256 * 256), static_cast<size_t>(max_post_router_tile * n_tile_size), 1};
// Dispatch reorder kernel
backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst);
// MoE kernel prepare
// Create sub buffer for dst
region.origin = offsetd;
region.size = ne0 * ne1 * ne2 * sizeof(float);
sub_buf_dst = clCreateSubBuffer(
extrad->data_device,
0,
CL_BUFFER_CREATE_TYPE_REGION,
&region,
&status);
CL_CHECK(status);
// Create image for dst
cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT};
cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}};
buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status);
CL_CHECK(status);
// Set kernel args
int arg_idx = 0;
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->q_img));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->e));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer)));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01));
// set thread grid
global_size[1] = static_cast<size_t>((ne01 + 63) / 64);
global_size[2] = static_cast<size_t>(max_post_router_tile);
local_size[1] = 1;
local_size[2] = 1;
// Dispatch kernel
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst);
clReleaseMemObject(sub_buf_src1_pre);
clReleaseMemObject(buf_src1_reordered);
clReleaseMemObject(image_src1_reordered);
clReleaseMemObject(buf_src2);
clReleaseMemObject(buf_src2_emap);
clReleaseMemObject(sub_buf_dst);
clReleaseMemObject(buf_dst_image);
}
// launch kernel
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst);
// deallocate sub buffers and images
CL_CHECK(clReleaseMemObject(src1_sub_buffer));
CL_CHECK(clReleaseMemObject(buf_src1_image));
CL_CHECK(clReleaseMemObject(buf_src2));
return;
} // else fallback to generic kernel
} // fallback to generic MoE mxfp4 kernel
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
#ifdef GGML_OPENCL_SOA_Q
@@ -14002,6 +14292,13 @@ static void ggml_cl_argsort(ggml_backend_t backend, const ggml_tensor * src0, co
size_t local_work_size[] = {(size_t)ne00_padded, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
const int ne21 = dst->ne[1];
if ((strstr(src0->name, "_moe") != NULL) && (ne21 != 1)) {
backend_ctx->toggle_reorder = true;
}
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
}
static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {

View File

@@ -371,6 +371,93 @@ kernel void kernel_restore_block_mxfp4_trans(
b->e = src_e[src_blk_offset];
}
kernel void kernel_convert_block_mxfp4_trans4_ns(
global struct block_mxfp4 * src0,
__global uint * dst_q,
__global uchar * dst_e,
uint ne00,
uint ne01
) {
uint i00 = get_global_id(1);
uint i01 = get_global_id(0);
uint i02 = get_global_id(2);
uint ne00_blk = ne00 / QK_MXFP4;
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
global struct block_mxfp4 * b = src0 + src_blk_offset;
dst_e[dst_blk_offset] = b->e;
// extract quantization and unshuffle
ushort8 pre_block = ((global ushort8 *)(&(b->qs[0])))[0];
ushort8 post_block = (ushort8)(0);
uchar * pre_block_ptr = (uchar *)(&pre_block);
uchar * post_block_ptr = (uchar *)(&post_block);
for (int i = 0; i < QK_MXFP4 / 4; ++i) {
uchar x0 = pre_block_ptr[2*i + 0];
uchar x1 = pre_block_ptr[2*i + 1];
post_block_ptr[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);
post_block_ptr[i + QK_MXFP4 / 4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);
}
uint4 q_block = as_uint4(post_block);
uint offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01;
dst_q[offset] = q_block.x;
dst_q[offset + ne01] = q_block.y;
dst_q[offset + ne01 * 2] = q_block.z;
dst_q[offset + ne01 * 3] = q_block.w;
}
kernel void kernel_restore_block_mxfp4_trans4_ns(
__global uint * src_q,
__global uchar * src_e,
__global struct block_mxfp4 * dst0,
uint ne00,
uint ne01
) {
uint i00 = get_global_id(1);
uint i01 = get_global_id(0);
uint i02 = get_global_id(2);
uint ne00_blk = ne00 / QK_MXFP4;
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
uint src_d_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
__global struct block_mxfp4 * b = dst0 + dst_blk_offset;
b->e = src_e[src_d_offset];
// collect transposed quantization parts for a block
uint src_q_offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01;
uint4 q_block;
q_block.x = src_q[src_q_offset];
q_block.y = src_q[src_q_offset + ne01];
q_block.z = src_q[src_q_offset + ne01 * 2];
q_block.w = src_q[src_q_offset + ne01 * 3];
ushort8 post_block = as_ushort8(q_block);
ushort8 pre_block = (ushort8)(0);
uchar * pre_block_ptr = (uchar *)(&pre_block);
uchar * post_block_ptr = (uchar *)(&post_block);
for (int i = 0; i < QK_MXFP4 / 4; ++i) {
uchar x0 = post_block_ptr[i + 0];
uchar x1 = post_block_ptr[i + QK_MXFP4 / 4];
pre_block_ptr[2 * i + 0] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);
pre_block_ptr[2 * i + 1] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);
}
((__global ushort8 *)(&(b->qs[0])))[0] = pre_block;
}
//------------------------------------------------------------------------------
// block_q8_0
//------------------------------------------------------------------------------

View File

@@ -0,0 +1,302 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable
#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable
#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable
#define TILESIZE_K 16
#define TILESIZE_M 64
#define TILESIZE_N 32
static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) {
ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;
fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;
fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;
fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;
bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;
bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;
bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;
bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;
fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;
fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;
fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;
fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;
sign_a.lo = (fp4x8.s0 << 12) & 0x8000;
sign_a.hi = (fp4x8.s0 << 8) & 0x8000;
sign_b.lo = (fp4x8.s0 << 4) & 0x8000;
sign_b.hi = fp4x8.s0 & 0x8000;
fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;
fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;
ushort2 fp16_packed_a_1, fp16_packed_b_1;
fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;
fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;
fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;
fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;
bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;
bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;
bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;
bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;
fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;
fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;
fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;
fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;
sign_a.lo = (fp4x8.s1 << 12) & 0x8000;
sign_a.hi = (fp4x8.s1 << 8) & 0x8000;
sign_b.lo = (fp4x8.s1 << 4) & 0x8000;
sign_b.hi = fp4x8.s1 & 0x8000;
fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;
fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;
return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));
}
#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \
acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \
acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \
acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \
acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \
acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \
acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \
acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \
acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \
acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \
acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \
acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \
acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \
acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \
acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \
acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \
acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \
acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \
acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \
acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \
acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \
acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \
acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \
acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \
acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \
acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \
acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \
acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \
acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \
acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \
acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \
acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \
acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \
c_reg.lo += convert_float8(acc.lo); \
c_reg.hi += convert_float8(acc.hi); \
acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \
acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \
acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \
acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \
acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \
acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \
acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \
acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \
acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \
acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \
acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \
acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \
acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \
acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \
acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \
acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \
acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \
acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \
acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \
acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \
acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \
acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \
acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \
acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \
acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \
acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \
acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \
acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \
acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \
acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \
acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \
acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \
c_reg.lo += convert_float8(acc.lo); \
c_reg.hi += convert_float8(acc.hi); \
static inline half e8m0_to_fp16(uchar x) {
ushort bits;
bits = (ushort)(x) - (ushort)(112);
bits = ((bits & 0x00E0) != 0) ? 0x7C00 : (bits << 10);
return as_half(bits);
}
static inline float e8m0_to_fp32(uchar x) {
int bits;
bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
return as_float(bits);
}
__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair
kernel void kernel_gemm_moe_mxfp4_f32_ns(
__read_only image1d_buffer_t src0_q,
__global uchar * src0_d,
__read_only image1d_buffer_t src1,
__global uint * src2,
__global ushort * src2_emap,
__write_only image1d_buffer_t dst,
__global int * total_tiles,
uint ne00,
uint ne01
) {
uint block_id_m = get_global_id(1); // m_tile
uint block_id_n = get_global_id(2); // n_tile
// Boundary check
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
return;
}
__private half16 reg_a;
__private float32 reg_c = (float32)(0);
__local half4 shared_b[128];
const ushort expert_id = src2_emap[block_id_n];
const uint row = block_id_m * TILESIZE_M;
const uint col = block_id_n * TILESIZE_N;
uint sub_block_id_m = get_local_id(0);
uint2 b_global_offset;
b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00;
b_global_offset.y = b_global_offset.x + (16 * ne00);
uint2 b_local_offset;
b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2);
b_local_offset.y = b_local_offset.x + 16;
// Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks
for (uint step = 0; step < ne00; step += TILESIZE_K * 2) {
// First sub-block
uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5);
uint b_sub_offset = col * ne00 + step;
// Load scale for current mxfp4 block
uint s_offset = s_sub_offset + get_global_id(0);
float s = e8m0_to_fp32(src0_d[s_offset]);
// Load 16 fp4 (64-bits) in transposed layout
uint2 mxfp4x16;
mxfp4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x;
mxfp4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x;
// Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements
float8 bx8_f32;
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
// Convert to half and store to LM to share within the subgroup
half8 bx8_f16 = convert_half8(bx8_f32);
shared_b[b_local_offset.x] = bx8_f16.lo;
shared_b[b_local_offset.y] = bx8_f16.hi;
// Dequantization
reg_a.lo = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.lo)) * s;
reg_a.hi = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.hi)) * s;
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
// 32 16x16 fp16 dot product with 8 elements reduction for better precision
half16 acc;
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
// Repeat for second sub-block
uint half_step = step + TILESIZE_K;
q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
b_sub_offset = col * ne00 + half_step;
// Load next 16 fp4 (64-bits) in transposed layout
mxfp4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x;
mxfp4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x;
// Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
// Convert to half and store to LM to share within the subgroup
bx8_f16 = convert_half8(bx8_f32);
shared_b[b_local_offset.x] = bx8_f16.lo;
shared_b[b_local_offset.y] = bx8_f16.hi;
// Dequantization
reg_a.lo = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.lo)) * s;
reg_a.hi = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.hi)) * s;
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
// 32 16x16 fp16 dot product with 3-levels reduction for better precision
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
}
// Load poster router and share in LM
__local uint out_idx[TILESIZE_N];
if (get_local_id(0) < TILESIZE_N) {
uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)];
if (idx == 0xFFFFFFFF) {
idx = src2[block_id_n * TILESIZE_N + 0];
}
out_idx[get_local_id(0)] = idx * ne01;
}
barrier(CLK_LOCAL_MEM_FENCE);
// Scatter results back to original position in output grid
uint m_offset = row + get_local_id(0);
write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1));
write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2));
write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3));
write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4));
write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5));
write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6));
write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7));
write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8));
write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9));
write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa));
write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb));
write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc));
write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd));
write_imagef(dst, out_idx[14] + m_offset, (reg_c.se));
write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf));
write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg));
write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh));
write_imagef(dst, out_idx[18] + m_offset, (reg_c.si));
write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj));
write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk));
write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl));
write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm));
write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn));
write_imagef(dst, out_idx[24] + m_offset, (reg_c.so));
write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp));
write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq));
write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr));
write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss));
write_imagef(dst, out_idx[29] + m_offset, (reg_c.st));
write_imagef(dst, out_idx[30] + m_offset, (reg_c.su));
write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv));
// Store zero padding parts to the index of first output in tile, override correct result in the end
barrier(CLK_GLOBAL_MEM_FENCE);
write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0));
}

View File

@@ -0,0 +1,161 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define QK_MXFP4 32
#define N_SIMDGROUP 4
#define SIMDGROUP_WIDTH 64
static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) {
ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;
fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;
fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;
fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;
bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;
bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;
bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;
bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;
fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;
fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;
fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;
fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;
sign_a.lo = (fp4x8.s0 << 12) & 0x8000;
sign_a.hi = (fp4x8.s0 << 8) & 0x8000;
sign_b.lo = (fp4x8.s0 << 4) & 0x8000;
sign_b.hi = fp4x8.s0 & 0x8000;
fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;
fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;
ushort2 fp16_packed_a_1, fp16_packed_b_1;
fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;
fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;
fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;
fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;
bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;
bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;
bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;
bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;
fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;
fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;
fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;
fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;
sign_a.lo = (fp4x8.s1 << 12) & 0x8000;
sign_a.hi = (fp4x8.s1 << 8) & 0x8000;
sign_b.lo = (fp4x8.s1 << 4) & 0x8000;
sign_b.hi = fp4x8.s1 & 0x8000;
fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;
fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;
return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));
}
static inline float e8m0_to_fp32(uchar x) {
int bits;
bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
return as_float(bits);
}
__attribute__((qcom_reqd_sub_group_size("half")))
__kernel void kernel_gemv_moe_mxfp4_f32_ns(
__global uint * src0_q,
__global uchar * src0_e,
__read_only image1d_buffer_t src1,
__global uint * src2,
__global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne11
) {
uint i01 = get_global_id(0);
uint i20 = get_global_id(2);
uint sgid = get_local_id(1);
uint slid = get_sub_group_local_id();
uint i11 = i20 % ne11;
uint expert_id = src2[i20];
uint expert_offset = expert_id * ne00 * ne01 / 32;
__private float sum = 0.0f; // each thread calculate partial sum of one output
// loop along ne00 in block granularity, skip 4 blocks every iter
for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) {
// load one block of q
uint4 regQ;
uint block_offset = expert_offset * 4 + ib00 * ne01 * 4 + i01;
regQ.s0 = src0_q[block_offset];
regQ.s1 = src0_q[block_offset + ne01];
regQ.s2 = src0_q[block_offset + ne01 * 2];
regQ.s3 = src0_q[block_offset + ne01 * 3];
uint offset = i11 * ne00 / 4 + ib00 * 8;
half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0));
float4 shared_y4;
shared_y4 = read_imagef(src1, (offset + 0));
float4 acc = shared_y4 * convert_float4(fp16x8.lo);
shared_y4 = read_imagef(src1, (offset + 1));
acc += shared_y4 * convert_float4(fp16x8.hi);
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1));
shared_y4 = read_imagef(src1, (offset + 2));
acc += shared_y4 * convert_float4(fp16x8.lo);
shared_y4 = read_imagef(src1, (offset + 3));
acc += shared_y4 * convert_float4(fp16x8.hi);
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2));
shared_y4 = read_imagef(src1, (offset + 4));
acc += shared_y4 * convert_float4(fp16x8.lo);
shared_y4 = read_imagef(src1, (offset + 5));
acc += shared_y4 * convert_float4(fp16x8.hi);
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3));
shared_y4 = read_imagef(src1, (offset + 6));
acc += shared_y4 * convert_float4(fp16x8.lo);
shared_y4 = read_imagef(src1, (offset + 7));
acc += shared_y4 * convert_float4(fp16x8.hi);
uchar regE = src0_e[ib00 * ne01 + i01 + expert_offset];
sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));
}
// reduction in local memory, assumes #subgroups=4
__local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];
if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;
if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;
if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;
barrier(CLK_LOCAL_MEM_FENCE);
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
// 1 outputs per thread in subgroup 0
if (sgid == 0) {
dst = dst + (offsetd >> 2);
dst[i01 + i20 * ne01] = sum;
}
}

View File

@@ -0,0 +1,30 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define QK4_0 32
kernel void kernel_moe_reorder_b(
global float4 * src,
global uint * router,
global float4 * dst,
global int * total_tiles,
uint K,
ushort map_ratio,
uint tile_size
) {
uint k_4 = get_global_id(0);
uint post_router_idx = get_global_id(1);
if ((k_4 >= (K / 4)) || (post_router_idx >= total_tiles[0] * tile_size)) {
return;
}
uint router_idx = router[post_router_idx];
float4 out = (float4)(0);
if (router_idx != 0xFFFFFFFF) {
ushort activation_idx = router_idx / map_ratio;
out = src[activation_idx * K / 4 + k_4];
}
dst[post_router_idx * K / 4 + k_4] = out;
}

View File

@@ -0,0 +1,82 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void kernel_moe_histogram(
__global const int * input,
__global int * hist,
uint N,
uint topK,
uint n_experts
) {
uint n = get_global_id(0);
uint k = get_global_id(1);
if (n >= N || k >= topK) {
return;
}
int expert_id = input[n * n_experts + k];
atomic_inc(&hist[expert_id]);
}
__kernel void kernel_moe_scan(
__global int * hist,
__global int * tile_offset,
__global int * total_tiles,
__global int * slot_counter,
int tile_size,
uint n_experts
) {
int offset = 0;
for (int v = 0; v < n_experts; v++) {
int count = hist[v];
int tiles = (count + tile_size - 1) / tile_size;
tile_offset[v] = offset;
offset += tiles;
hist[v] = 0;
slot_counter[v] = 0;
}
*total_tiles = offset;
}
__kernel void kernel_moe_scatter(
__global const int * input,
__global int * post_router,
__global ushort * emap,
__global const int * tile_offset,
__global int * slot_counter,
int N,
int topK,
uint n_experts
) {
uint n = get_global_id(0);
uint k = get_global_id(1);
if (n >= N || k >= topK) {
return;
}
int val = input[n * n_experts + k];
int local_slot = atomic_inc(&slot_counter[val]);
int tile_idx = tile_offset[val] + (local_slot / 32);
int lane = local_slot % 32;
int out_pos = tile_idx * 32 + lane;
post_router[out_pos] = n * topK + k;
emap[tile_idx] = val;
}
__kernel void kernel_moe_fill(
__global int * post_router,
__global int * total_tiles,
int tile_size
) {
int tile_id = get_global_id(0);
int vec_id_in_tile = get_global_id(1);
if (tile_id < total_tiles[0]) {
post_router[tile_id * tile_size + vec_id_in_tile] = 0xFFFFFFFF;
}
}

View File

@@ -1,6 +1,7 @@
#include "virtgpu-shm.h"
#include "virtgpu.h"
#include "ggml-remoting.h"
#include <assert.h>

View File

@@ -1,4 +1,5 @@
#include "virtgpu.h"
#include "ggml-remoting.h"
#include <stdio.h>
#include <unistd.h>

View File

@@ -18,8 +18,6 @@
#include <cstring>
#include "ggml-remoting.h"
#define VIRGL_RENDERER_UNSTABLE_APIS 1
#include "apir_hw.h"
#include <drm/virtgpu_drm.h>

View File

@@ -2253,6 +2253,28 @@ public:
llama_io_write_buffer(
uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
~llama_io_write_buffer() {
#if 1
// TODO: add backend support to batch tensor_get? or some other way to speed this up
for (const auto & info : winfos) {
ggml_backend_tensor_get(info.tensor, info.ptr, info.offset, info.size);
}
#else
// flush the writes asynchronously
// this helps on Macs, but on other devices - it does not. just an example
std::vector<std::future<void>> futures;
futures.reserve(winfos.size());
for (const auto & info : winfos) {
futures.push_back(std::async(std::launch::async, [info]() {
ggml_backend_tensor_get(info.tensor, info.ptr, info.offset, info.size);
}));
}
for (auto & f : futures) {
f.wait();
}
#endif
}
void write(const void * src, size_t size) override {
if (size > buf_size) {
throw std::runtime_error("unexpectedly reached end of buffer");
@@ -2267,7 +2289,10 @@ public:
if (size > buf_size) {
throw std::runtime_error("unexpectedly reached end of buffer");
}
ggml_backend_tensor_get(tensor, ptr, offset, size);
// save the write for later during destruction
winfos.push_back({tensor, ptr, size, offset});
ptr += size;
size_written += size;
buf_size -= size;
@@ -2281,25 +2306,48 @@ private:
uint8_t * ptr;
size_t buf_size = 0;
size_t size_written = 0;
struct write_info {
const ggml_tensor * tensor;
uint8_t * ptr;
size_t size;
size_t offset;
};
std::vector<write_info> winfos;
};
class llama_io_read_buffer : public llama_io_read_i {
public:
llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
const uint8_t * read(size_t size) override {
const uint8_t * base_ptr = ptr;
~llama_io_read_buffer() {
// flush the reads
for (const auto & info : rinfos) {
ggml_backend_tensor_set(info.tensor, info.ptr, info.offset, info.size);
}
}
void read(void * dst, size_t size) override {
if (size > buf_size) {
throw std::runtime_error("unexpectedly reached end of buffer");
}
memcpy(dst, ptr, size);
ptr += size;
size_read += size;
buf_size -= size;
return base_ptr;
}
void read_to(void * dst, size_t size) override {
memcpy(dst, read(size), size);
void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) override {
if (size > buf_size) {
throw std::runtime_error("unexpectedly reached end of buffer");
}
// save for later during destruction
rinfos.push_back({tensor, ptr, size, offset});
ptr += size;
size_read += size;
buf_size -= size;
}
size_t n_bytes() override {
@@ -2310,6 +2358,14 @@ private:
const uint8_t * ptr;
size_t buf_size = 0;
size_t size_read = 0;
struct read_info {
ggml_tensor * tensor;
const uint8_t * ptr;
size_t size;
size_t offset;
};
std::vector<read_info> rinfos;
};
class llama_io_write_file : public llama_io_write_i {
@@ -2341,15 +2397,15 @@ class llama_io_read_file : public llama_io_read_i {
public:
llama_io_read_file(llama_file * f) : file(f) {}
void read_to(void * dst, size_t size) override {
void read(void * dst, size_t size) override {
file->read_raw(dst, size);
size_read += size;
}
const uint8_t * read(size_t size) override {
void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) override {
temp_buffer.resize(size);
read_to(temp_buffer.data(), size);
return temp_buffer.data();
read(temp_buffer.data(), size);
ggml_backend_tensor_set(tensor, temp_buffer.data(), offset, size);
}
size_t n_bytes() override {

View File

@@ -1,5 +1,7 @@
#include "llama-io.h"
#include <vector>
void llama_io_write_i::write_string(const std::string & str) {
uint32_t str_size = str.size();
@@ -9,7 +11,10 @@ void llama_io_write_i::write_string(const std::string & str) {
void llama_io_read_i::read_string(std::string & str) {
uint32_t str_size;
read_to(&str_size, sizeof(str_size));
read(&str_size, sizeof(str_size));
str.assign((const char *) read(str_size), str_size);
std::vector<char> buf(str_size);
read(buf.data(), str_size);
str.assign(buf.data(), str_size);
}

View File

@@ -25,8 +25,8 @@ public:
llama_io_read_i() = default;
virtual ~llama_io_read_i() = default;
virtual const uint8_t * read(size_t size) = 0;
virtual void read_to(void * dst, size_t size) = 0;
virtual void read(void * dst, size_t size) = 0;
virtual void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) = 0;
// bytes read so far
virtual size_t n_bytes() = 0;

View File

@@ -1900,14 +1900,14 @@ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
uint32_t n_stream_cur;
io.read_to(&n_stream_cur, sizeof(n_stream_cur));
io.read(&n_stream_cur, sizeof(n_stream_cur));
if (n_stream_cur != n_stream) {
throw std::runtime_error("n_stream mismatch");
}
for (uint32_t s = 0; s < n_stream; ++s) {
uint32_t cell_count;
io.read_to(&cell_count, sizeof(cell_count));
io.read(&cell_count, sizeof(cell_count));
if (cell_count == 0) {
continue;
@@ -2082,8 +2082,8 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
llama_pos pos;
uint32_t n_seq_id;
io.read_to(&pos, sizeof(pos));
io.read_to(&n_seq_id, sizeof(n_seq_id));
io.read(&pos, sizeof(pos));
io.read(&n_seq_id, sizeof(n_seq_id));
if (n_seq_id != 1) {
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
@@ -2092,7 +2092,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
if (hparams.n_pos_per_embd() > 1) {
llama_kv_cell_ext ext;
io.read_to(&ext, sizeof(ext));
io.read(&ext, sizeof(ext));
ubatch.pos[i + ubatch.n_tokens] = ext.y;
ubatch.pos[i + ubatch.n_tokens*2] = ext.x;
@@ -2101,7 +2101,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
// read the sequence id, but directly discard it - we will use dest_seq_id instead
{
llama_seq_id seq_id;
io.read_to(&seq_id, sizeof(seq_id));
io.read(&seq_id, sizeof(seq_id));
}
ubatch.pos[i] = pos;
@@ -2143,20 +2143,20 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
llama_pos pos;
uint32_t n_seq_id;
io.read_to(&pos, sizeof(pos));
io.read_to(&n_seq_id, sizeof(n_seq_id));
io.read(&pos, sizeof(pos));
io.read(&n_seq_id, sizeof(n_seq_id));
cells.pos_set(i, pos);
if (hparams.n_pos_per_embd() > 1) {
llama_kv_cell_ext ext;
io.read_to(&ext, sizeof(ext));
io.read(&ext, sizeof(ext));
cells.ext_set(i, ext);
}
for (uint32_t j = 0; j < n_seq_id; ++j) {
llama_seq_id seq_id;
io.read_to(&seq_id, sizeof(seq_id));
io.read(&seq_id, sizeof(seq_id));
if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
@@ -2189,8 +2189,8 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
uint32_t v_trans;
uint32_t n_layer;
io.read_to(&v_trans, sizeof(v_trans));
io.read_to(&n_layer, sizeof(n_layer));
io.read(&v_trans, sizeof(v_trans));
io.read(&n_layer, sizeof(n_layer));
if (n_layer != layers.size()) {
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
@@ -2217,7 +2217,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
// Read type of key
int32_t k_type_i_ref;
io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
io.read(&k_type_i_ref, sizeof(k_type_i_ref));
const int32_t k_type_i = (int32_t) k->type;
if (k_type_i != k_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
@@ -2226,7 +2226,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
// Read row size of key
uint64_t k_size_row_ref;
io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
io.read(&k_size_row_ref, sizeof(k_size_row_ref));
const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
if (k_size_row != k_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
@@ -2236,13 +2236,12 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
if (cell_count) {
if (sinfo.is_contiguous()) {
// Fast path: contiguous cells, single memcpy
ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), sinfo.head() * k_size_row, cell_count * k_size_row);
io.read_tensor(k, sinfo.head() * k_size_row, cell_count * k_size_row);
} else {
// Slow path: scatter to non-contiguous positions
const void * src = io.read(cell_count * k_size_row);
for (uint32_t i = 0; i < cell_count; ++i) {
const size_t dst_offset = sinfo.idxs[0][i] * k_size_row;
ggml_backend_tensor_set(k, (const char*)src + i * k_size_row, dst_offset, k_size_row);
io.read_tensor(k, dst_offset, k_size_row);
}
}
}
@@ -2261,7 +2260,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
// Read type of value
int32_t v_type_i_ref;
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
io.read(&v_type_i_ref, sizeof(v_type_i_ref));
const int32_t v_type_i = (int32_t) v->type;
if (v_type_i != v_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
@@ -2270,7 +2269,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
// Read row size of value
uint64_t v_size_row_ref;
io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
io.read(&v_size_row_ref, sizeof(v_size_row_ref));
const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
if (v_size_row != v_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
@@ -2280,13 +2279,12 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
if (cell_count) {
if (sinfo.is_contiguous()) {
// Fast path: contiguous cells, single memcpy
ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), sinfo.head() * v_size_row, cell_count * v_size_row);
io.read_tensor(v, sinfo.head() * v_size_row, cell_count * v_size_row);
} else {
// Slow path: scatter to non-contiguous positions
const void * src = io.read(cell_count * v_size_row);
for (uint32_t i = 0; i < cell_count; ++i) {
const size_t dst_offset = sinfo.idxs[0][i] * v_size_row;
ggml_backend_tensor_set(v, (const char*)src + i * v_size_row, dst_offset, v_size_row);
io.read_tensor(v, dst_offset, v_size_row);
}
}
}
@@ -2305,7 +2303,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
// Read type of value
int32_t v_type_i_ref;
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
io.read(&v_type_i_ref, sizeof(v_type_i_ref));
const int32_t v_type_i = (int32_t) v->type;
if (v_type_i != v_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
@@ -2314,7 +2312,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
// Read element size of value
uint32_t v_size_el_ref;
io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
io.read(&v_size_el_ref, sizeof(v_size_el_ref));
const size_t v_size_el = ggml_type_size(v->type);
if (v_size_el != v_size_el_ref) {
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
@@ -2323,7 +2321,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
// Read GQA embedding size
uint32_t n_embd_v_gqa_ref;
io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
io.read(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
if (n_embd_v_gqa != n_embd_v_gqa_ref) {
LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
return false;
@@ -2335,15 +2333,14 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
const uint32_t h = sinfo.head();
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
const size_t dst_offset = (h + j * cells.size()) * v_size_el;
ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
io.read_tensor(v, dst_offset, cell_count * v_size_el);
}
} else {
// Slow path: scatter to non-contiguous positions
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
const void * src = io.read(cell_count * v_size_el);
for (uint32_t i = 0; i < cell_count; ++i) {
const size_t dst_offset = (sinfo.idxs[0][i] + j * cells.size()) * v_size_el;
ggml_backend_tensor_set(v, (const char*)src + i * v_size_el, dst_offset, v_size_el);
io.read_tensor(v, dst_offset, v_size_el);
}
}
}

View File

@@ -743,7 +743,7 @@ void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_i
GGML_UNUSED(flags);
uint32_t cell_count;
io.read_to(&cell_count, sizeof(cell_count));
io.read(&cell_count, sizeof(cell_count));
bool res = true;
@@ -879,8 +879,8 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
llama_pos pos;
uint32_t n_seq_id;
io.read_to(&pos, sizeof(pos));
io.read_to(&n_seq_id, sizeof(n_seq_id));
io.read(&pos, sizeof(pos));
io.read(&n_seq_id, sizeof(n_seq_id));
if (n_seq_id != 0) {
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
@@ -920,14 +920,14 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
llama_pos pos;
uint32_t n_seq_id;
io.read_to(&pos, sizeof(pos));
io.read_to(&n_seq_id, sizeof(n_seq_id));
io.read(&pos, sizeof(pos));
io.read(&n_seq_id, sizeof(n_seq_id));
cell.pos = pos;
for (uint32_t j = 0; j < n_seq_id; ++j) {
llama_seq_id seq_id;
io.read_to(&seq_id, sizeof(seq_id));
io.read(&seq_id, sizeof(seq_id));
if (seq_id < 0 || (uint32_t) seq_id >= this->n_seq_max) {
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, this->n_seq_max);
@@ -961,8 +961,8 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
uint32_t s_trans;
uint32_t n_layer;
io.read_to(&s_trans, sizeof(s_trans));
io.read_to(&n_layer, sizeof(n_layer));
io.read(&s_trans, sizeof(s_trans));
io.read(&n_layer, sizeof(n_layer));
if (n_layer != hparams.n_layer) {
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
@@ -984,7 +984,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
// Read type of key
int32_t r_type_i_ref;
io.read_to(&r_type_i_ref, sizeof(r_type_i_ref));
io.read(&r_type_i_ref, sizeof(r_type_i_ref));
const int32_t r_type_i = (int32_t) r_l[il]->type;
if (r_type_i != r_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched r type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il);
@@ -993,7 +993,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
// Read row size of key
uint64_t r_size_row_ref;
io.read_to(&r_size_row_ref, sizeof(r_size_row_ref));
io.read(&r_size_row_ref, sizeof(r_size_row_ref));
const size_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
if (r_size_row != r_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched r row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il);
@@ -1002,7 +1002,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
if (cell_count) {
// Read and set the keys for the whole cell range
ggml_backend_tensor_set(r_l[il], io.read(cell_count * r_size_row), head * r_size_row, cell_count * r_size_row);
io.read_tensor(r_l[il], head * r_size_row, cell_count * r_size_row);
}
}
@@ -1013,7 +1013,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
// Read type of value
int32_t s_type_i_ref;
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
io.read(&s_type_i_ref, sizeof(s_type_i_ref));
const int32_t s_type_i = (int32_t)s_l[il]->type;
if (s_type_i != s_type_i_ref) {
@@ -1023,7 +1023,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
// Read row size of value
uint64_t s_size_row_ref;
io.read_to(&s_size_row_ref, sizeof(s_size_row_ref));
io.read(&s_size_row_ref, sizeof(s_size_row_ref));
const size_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
if (s_size_row != s_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il);
@@ -1032,7 +1032,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
if (cell_count) {
// Read and set the values for the whole cell range
ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_row), head * s_size_row, cell_count * s_size_row);
io.read_tensor(s_l[il], head * s_size_row, cell_count * s_size_row);
}
}
} else {
@@ -1045,7 +1045,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
// Read type of value
int32_t s_type_i_ref;
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
io.read(&s_type_i_ref, sizeof(s_type_i_ref));
const int32_t s_type_i = (int32_t)s_l[il]->type;
if (s_type_i != s_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
@@ -1054,7 +1054,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
// Read element size of value
uint32_t s_size_el_ref;
io.read_to(&s_size_el_ref, sizeof(s_size_el_ref));
io.read(&s_size_el_ref, sizeof(s_size_el_ref));
const size_t s_size_el = ggml_type_size(s_l[il]->type);
if (s_size_el != s_size_el_ref) {
LLAMA_LOG_ERROR("%s: mismatched s element size (%zu != %zu, layer %d)\n", __func__, s_size_el, (size_t) s_size_el_ref, il);
@@ -1063,7 +1063,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
// Read state embedding size
uint32_t n_embd_s_ref;
io.read_to(&n_embd_s_ref, sizeof(n_embd_s_ref));
io.read(&n_embd_s_ref, sizeof(n_embd_s_ref));
if (n_embd_s != n_embd_s_ref) {
LLAMA_LOG_ERROR("%s: mismatched s embedding size (%u != %u, layer %d)\n", __func__, n_embd_s, n_embd_s_ref, il);
return false;
@@ -1073,7 +1073,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
// For each row in the transposed matrix, read the values for the whole cell range
for (uint32_t j = 0; j < n_embd_s; ++j) {
const size_t dst_offset = (head + j * size) * s_size_el;
ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_el), dst_offset, cell_count * s_size_el);
io.read_tensor(s_l[il], dst_offset, cell_count * s_size_el);
}
}
}

View File

@@ -36,7 +36,7 @@ using json = nlohmann::ordered_json;
constexpr int HTTP_POLLING_SECONDS = 1;
static server_prompt_checkpoint server_get_checkpoint(llama_context * ctx, int id, int64_t n_tokens, llama_pos pos_min = -1, llama_pos pos_max = -1) {
static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, llama_context * ctx, int id, int64_t n_tokens, llama_pos pos_min = -1, llama_pos pos_max = -1) {
if (pos_min == -1) {
pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), id);
}
@@ -46,19 +46,15 @@ static server_prompt_checkpoint server_get_checkpoint(llama_context * ctx, int i
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
auto cur = server_prompt_checkpoint {
/*.pos_min = */ pos_min,
/*.pos_max = */ pos_max,
/*.n_tokens = */ n_tokens,
/*.data = */ std::vector<uint8_t>(checkpoint_size),
};
ckpt.pos_min = pos_min;
ckpt.pos_max = pos_max;
ckpt.n_tokens = n_tokens;
ckpt.data.resize(checkpoint_size);
const size_t n = llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
const size_t n = llama_state_seq_get_data_ext(ctx, ckpt.data.data(), checkpoint_size, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != checkpoint_size) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
}
return cur;
}
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
@@ -364,7 +360,12 @@ struct server_slot {
if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
const auto n_tokens = prompt.tokens.size();
spec_ckpt = server_get_checkpoint(ctx, this->id, n_tokens);
//const int64_t t_start = ggml_time_us();
server_prompt_checkpoint_update(spec_ckpt, ctx, this->id, n_tokens);
//const int64_t t_total = ggml_time_us() - t_start;
//printf("checkpoint total: %f ms\n", t_total / 1000.0);
SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n",
spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024);
@@ -1836,7 +1837,8 @@ private:
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
}
const auto & cur = slot.prompt.checkpoints.emplace_back(server_get_checkpoint(ctx, slot.id, slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max));
auto & cur = slot.prompt.checkpoints.emplace_back();
server_prompt_checkpoint_update(cur, ctx, slot.id, slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max);
SLT_WRN(slot,
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",