Compare commits

...

10 Commits

Author SHA1 Message Date
Ruben Ortlam 938872e93f fix partial writes 2026-05-15 16:00:57 +02:00
Ruben Ortlam ff6ad60994 wider loads 2026-05-15 15:22:57 +02:00
Ruben Ortlam 13a55c8e50 deduplicate repacking code 2026-05-15 13:25:49 +02:00
Ruben Ortlam 57fb74fba3 add q4_1, q8_0, iq4_nl repacking 2026-05-15 13:10:19 +02:00
Ruben Ortlam 6906f78189 replace malloc/free with thread_local memory 2026-05-15 12:11:01 +02:00
Ruben Ortlam b64f294cbf add missing repacking functions 2026-05-15 12:04:08 +02:00
Ruben Ortlam b4e2621de8 add mxfp4 repacking 2026-05-15 11:58:13 +02:00
Ruben Ortlam b1243aa933 fix double semicolon 2026-05-15 11:23:59 +02:00
Ruben Ortlam 5c1e95c901 add coopmat2 support 2026-05-15 11:23:58 +02:00
Ruben Ortlam c285bb9838 vulkan: repack q4_0 into aligned arrays 2026-05-15 11:20:02 +02:00
12 changed files with 513 additions and 53 deletions
+207 -23
View File
@@ -76,6 +76,8 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
#define YIELD()
#endif
#define GGML_COMMON_DECL_CPP
#include "ggml-common.h"
#include "ggml-impl.h"
#include "ggml-backend-impl.h"
@@ -999,6 +1001,7 @@ struct vk_mat_mat_push_constants {
uint32_t k_split;
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
uint32_t padded_N;
uint32_t deltas_offset;
};
#define MAT_VEC_FUSION_FLAGS_BIAS0 0x1
@@ -1020,6 +1023,7 @@ struct vk_mat_vec_push_constants {
uint32_t ne12;
uint32_t broadcast2;
uint32_t broadcast3;
uint32_t deltas_offset;
};
struct vk_mat_vec_p021_push_constants {
@@ -1054,6 +1058,7 @@ struct vk_mat_mat_id_push_constants {
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11;
uint32_t padded_N;
uint32_t deltas_offset;
};
struct vk_mat_vec_id_push_constants {
uint32_t ncols;
@@ -1068,6 +1073,7 @@ struct vk_mat_vec_id_push_constants {
uint32_t ne11;
uint32_t expert_i1;
uint32_t nbi1;
uint32_t deltas_offset;
};
struct vk_flash_attn_push_constants {
@@ -1976,7 +1982,7 @@ static uint64_t vk_tensor_offset(const ggml_tensor * tensor) {
static uint32_t get_misalign_bytes(const ggml_backend_vk_context * ctx, const ggml_tensor * t)
{
return ((vk_tensor_offset(t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1));;
return ((vk_tensor_offset(t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1));
}
template <typename T> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
@@ -6674,6 +6680,8 @@ static void ggml_vk_host_get(const vk_device& device, const void * ptr, vk_buffe
}
}
static size_t ggml_vk_repack_size_tensor(const ggml_tensor * tensor);
static vk_subbuffer ggml_vk_tensor_subbuffer(
const ggml_backend_vk_context * ctx, const ggml_tensor * tensor, bool allow_misalign = false) {
@@ -6689,7 +6697,7 @@ static vk_subbuffer ggml_vk_tensor_subbuffer(
}
GGML_ASSERT(buffer != nullptr);
size_t size = ggml_nbytes(tensor);
size_t size = ggml_vk_repack_size_tensor(tensor);
size_t misalign_bytes = offset & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
// The shader must support misaligned offsets when indexing into the buffer
@@ -7284,6 +7292,134 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
ggml_vk_queue_command_pools_cleanup(dst->device);
}
constexpr uint32_t VULKAN_REPACK_ALIGNMENT = 256;
static void * ggml_vk_repack_scratch(size_t size) {
thread_local std::vector<uint8_t> buf;
if (buf.size() < size) {
buf.resize(size);
}
return buf.data();
}
static size_t ggml_vk_get_num_blocks(const ggml_tensor * tensor) {
const size_t num_blocks_per_row = tensor->ne[0] / ggml_blck_size(tensor->type);
return num_blocks_per_row * tensor->ne[1] * tensor->ne[2] * tensor->ne[3];
}
struct vk_repack_type_info {
size_t quant_bytes;
size_t delta_bytes;
size_t delta_elem_size;
};
static const vk_repack_type_info * ggml_vk_get_repack_info(ggml_type type) {
static const vk_repack_type_info q4_0_info = { 16, 2, 2 };
static const vk_repack_type_info q4_1_info = { 16, 4, 2 };
static const vk_repack_type_info q8_0_info = { 32, 2, 2 };
static const vk_repack_type_info iq4_nl_info = { 16, 2, 2 };
static const vk_repack_type_info mxfp4_info = { 16, 1, 1 };
switch (type) {
case GGML_TYPE_Q4_0: return &q4_0_info;
case GGML_TYPE_Q4_1: return &q4_1_info;
case GGML_TYPE_Q8_0: return &q8_0_info;
case GGML_TYPE_IQ4_NL: return &iq4_nl_info;
case GGML_TYPE_MXFP4: return &mxfp4_info;
default: return nullptr;
}
}
static size_t ggml_vk_repack_quants_region(const vk_repack_type_info * info, size_t n_blocks) {
return GGML_PAD(n_blocks * info->quant_bytes, VULKAN_REPACK_ALIGNMENT);
}
static size_t ggml_vk_repack_size(const vk_repack_type_info * info, size_t n_blocks) {
return ggml_vk_repack_quants_region(info, n_blocks) + n_blocks * info->delta_bytes;
}
static size_t ggml_vk_repack_size_tensor(const ggml_tensor * tensor) {
const auto * info = ggml_vk_get_repack_info(tensor->type);
if (info) {
return ggml_vk_repack_size(info, ggml_vk_get_num_blocks(tensor));
}
return ggml_nbytes(tensor);
}
static uint32_t ggml_vk_get_deltas_offset(const ggml_tensor * tensor) {
const auto * info = ggml_vk_get_repack_info(tensor->type);
if (!info) {
return 0;
}
return ggml_vk_repack_quants_region(info, ggml_vk_get_num_blocks(tensor)) / info->delta_elem_size;
}
static void ggml_vk_repack_pack(const vk_repack_type_info * info, size_t n_blocks,
const void * data, void * quants_dst, void * deltas_dst) {
const size_t block_size = info->quant_bytes + info->delta_bytes;
uint8_t * dst_q = (uint8_t *)quants_dst;
uint8_t * dst_d = (uint8_t *)deltas_dst;
const uint8_t * src = (const uint8_t *)data;
for (size_t i = 0; i < n_blocks; i++) {
memcpy(dst_q + info->quant_bytes * i, src + block_size * i + info->delta_bytes, info->quant_bytes);
memcpy(dst_d + info->delta_bytes * i, src + block_size * i, info->delta_bytes);
}
}
static void ggml_vk_repack_unpack(const vk_repack_type_info * info, size_t n_blocks,
const void * quants_src, const void * deltas_src, void * data) {
const size_t block_size = info->quant_bytes + info->delta_bytes;
const uint8_t * src_q = (const uint8_t *)quants_src;
const uint8_t * src_d = (const uint8_t *)deltas_src;
uint8_t * dst = (uint8_t *)data;
for (size_t i = 0; i < n_blocks; i++) {
memcpy(dst + block_size * i + info->delta_bytes, src_q + info->quant_bytes * i, info->quant_bytes);
memcpy(dst + block_size * i, src_d + info->delta_bytes * i, info->delta_bytes);
}
}
static void ggml_vk_repack_write(vk_buffer & buf, const ggml_tensor * tensor, size_t offset, const void * data, size_t size) {
const auto * info = ggml_vk_get_repack_info(tensor->type);
GGML_ASSERT(info);
const size_t block_size = info->quant_bytes + info->delta_bytes;
const size_t first_block = offset / block_size;
const size_t n_blocks_chunk = size / block_size;
const size_t n_blocks_total = ggml_vk_get_num_blocks(tensor);
const size_t quants_region = ggml_vk_repack_quants_region(info, n_blocks_total);
const size_t scratch_size = n_blocks_chunk * info->quant_bytes + n_blocks_chunk * info->delta_bytes;
void * scratch = ggml_vk_repack_scratch(scratch_size);
uint8_t * scratch_q = (uint8_t *)scratch;
uint8_t * scratch_d = scratch_q + n_blocks_chunk * info->quant_bytes;
ggml_vk_repack_pack(info, n_blocks_chunk, data, scratch_q, scratch_d);
const size_t buf_base = vk_tensor_offset(tensor) + tensor->view_offs;
ggml_vk_buffer_write(buf, buf_base + first_block * info->quant_bytes, scratch_q, n_blocks_chunk * info->quant_bytes);
ggml_vk_buffer_write(buf, buf_base + quants_region + first_block * info->delta_bytes, scratch_d, n_blocks_chunk * info->delta_bytes);
}
static void ggml_vk_repack_read(vk_buffer & buf, const ggml_tensor * tensor, size_t offset, void * data, size_t size) {
const auto * info = ggml_vk_get_repack_info(tensor->type);
GGML_ASSERT(info);
const size_t block_size = info->quant_bytes + info->delta_bytes;
const size_t first_block = offset / block_size;
const size_t n_blocks_chunk = size / block_size;
const size_t n_blocks_total = ggml_vk_get_num_blocks(tensor);
const size_t quants_region = ggml_vk_repack_quants_region(info, n_blocks_total);
const size_t scratch_size = n_blocks_chunk * info->quant_bytes + n_blocks_chunk * info->delta_bytes;
void * scratch = ggml_vk_repack_scratch(scratch_size);
uint8_t * scratch_q = (uint8_t *)scratch;
uint8_t * scratch_d = scratch_q + n_blocks_chunk * info->quant_bytes;
const size_t buf_base = vk_tensor_offset(tensor) + tensor->view_offs;
ggml_vk_buffer_read(buf, buf_base + first_block * info->quant_bytes, scratch_q, n_blocks_chunk * info->quant_bytes);
ggml_vk_buffer_read(buf, buf_base + quants_region + first_block * info->delta_bytes, scratch_d, n_blocks_chunk * info->delta_bytes);
ggml_vk_repack_unpack(info, n_blocks_chunk, scratch_q, scratch_d, data);
}
static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, bool disable_split_k, const vk_pipeline& pipeline) {
VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ", " << disable_split_k << ")");
@@ -7383,7 +7519,7 @@ static void ggml_vk_matmul(
uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
uint32_t padded_n) {
uint32_t padded_n, uint32_t deltas_offset) {
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
if (split_k == 1) {
ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
@@ -7392,7 +7528,7 @@ static void ggml_vk_matmul(
while (base_work_group_z < batch) {
uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k, ne02, ne12, broadcast2, broadcast3, padded_n };
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k, ne02, ne12, broadcast2, broadcast3, padded_n, deltas_offset };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, groups_z });
base_work_group_z += groups_z;
}
@@ -7415,7 +7551,7 @@ static void ggml_vk_matmul(
while (base_work_group_z < batch) {
uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k_split, ne02, ne12, broadcast2, broadcast3, padded_n, deltas_offset };
// Make sure enough workgroups get assigned for split k to work
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, groups_z });
base_work_group_z += groups_z;
@@ -7470,13 +7606,13 @@ static void ggml_vk_matmul_id(
uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11,
uint32_t padded_n) {
uint32_t padded_n, uint32_t deltas_offset) {
VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), expert_count: (" << expert_count_buf.buffer->buffer << ", " << expert_count_buf.offset << ", " << expert_count_buf.size << "), " <<
"m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " <<
"batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
"n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")");
const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
nei0, nei1, nbi1, ne11, padded_n };
nei0, nei1, nbi1, ne11, padded_n, deltas_offset };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids, expert_count_buf }, pc, { m, nei1, n_as });
}
@@ -7777,7 +7913,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, disable_split_k, pipeline);
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
const uint64_t qx_sz = ggml_vk_repack_size_tensor(src0);
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
@@ -7925,6 +8061,8 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
}
const uint32_t deltas_offset = ggml_vk_get_deltas_offset(src0);
// compute
ggml_vk_matmul(
ctx, subctx, pipeline,
@@ -7932,7 +8070,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
ggml_vk_subbuffer(ctx, d_D, d_buf_offset), { ctx->prealloc_split_k, 0, d_sz * split_k },
ne01, ne11, ne10,
ne10, ne10, stride_d, stride_batch_x, stride_batch_y, stride_batch_d,
split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n
split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n, deltas_offset
); // NOLINT
if (x_non_contig || qx_needs_dequant) {
@@ -8099,7 +8237,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
const uint64_t x_ne = ggml_nelements(src0);
const uint64_t y_ne = ggml_nelements(src1);
const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
const uint64_t qx_sz = ggml_vk_align_size(ggml_vk_repack_size_tensor(src0), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) :
(f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
@@ -8225,6 +8363,8 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
ggml_pipeline_request_descriptor_sets(ctx, dmmv, CEIL_DIV(ne12 * ne13, ctx->device->properties.limits.maxComputeWorkGroupCount[1]));
const uint32_t deltas_offset = ggml_vk_get_deltas_offset(src0);
uint32_t base_work_group_y = 0;
while (base_work_group_y < ne12 * ne13) {
@@ -8234,6 +8374,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
stride_batch_x, stride_batch_y, stride_batch_d,
fusion_flags, base_work_group_y,
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
deltas_offset,
};
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
{
@@ -8610,7 +8751,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
const uint64_t y_ne = padded_n * ne10 * ne12 * ne13;
const uint64_t d_ne = ggml_nelements(dst);
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
const uint64_t qx_sz = ggml_vk_repack_size_tensor(src0);
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
@@ -8778,6 +8919,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
}
const uint32_t deltas_offset = ggml_vk_get_deltas_offset(src0);
// compute
ggml_vk_matmul_id(
ctx, subctx, pipeline,
@@ -8785,7 +8928,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
{ d_D, d_buf_offset, d_sz }, { d_ids, ids_buf_offset, ids_sz }, expert_count_buf,
ne01, ne21, ne10, ne10, ne10, ne01,
stride_batch_x, stride_batch_y, ne20*ne21,
n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n
n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n, deltas_offset
); // NOLINT
if (x_non_contig || qx_needs_dequant) {
@@ -8877,7 +9020,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
const uint64_t x_ne = ggml_nelements(src0);
const uint64_t y_ne = ggml_nelements(src1);
const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
const uint64_t qx_sz = ggml_vk_align_size(ggml_vk_repack_size_tensor(src0), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) :
(f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
@@ -9001,13 +9144,16 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE1;
}
const uint32_t deltas_offset = ggml_vk_get_deltas_offset(src0);
// Loop over the batch dimension
for (uint32_t expert_i1 = 0; expert_i1 < nei1; ++expert_i1) {
const vk_mat_vec_id_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
(uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),
fusion_flags,
(uint32_t)nei0, (uint32_t)ne11, expert_i1, nbi1
(uint32_t)nei0, (uint32_t)ne11, expert_i1, nbi1,
deltas_offset,
};
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
{
@@ -12354,7 +12500,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
ctx, subctx, p, ggml_vk_subbuffer(ctx, d_X), ggml_vk_subbuffer(ctx, d_Y), ggml_vk_subbuffer(ctx, d_D), ggml_vk_subbuffer(ctx, ctx->prealloc_split_k),
m, n, k,
k, k, m, k*m, k*n, m*n,
split_k, batch, batch, batch, 1, 1, n
split_k, batch, batch, batch, 1, 1, n, 0
);
}
ggml_vk_ctx_end(subctx);
@@ -12831,7 +12977,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
ctx, subctx, p, { qx_buf, 0, qx_sz }, { qy_buf, 0, qy_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
m, n, k,
k, k, m, k*m, k*n, m*n,
split_k, batch, batch, batch, 1, 1, n
split_k, batch, batch, batch, 1, 1, n, 0
);
}
} else {
@@ -12840,7 +12986,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
ctx, subctx, p, { qx_buf, 0, qx_sz }, { y_buf, 0, y_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
m, n, k,
k, k, m, k*m, k*n, m*n,
split_k, batch, batch, batch, 1, 1, n
split_k, batch, batch, batch, 1, 1, n, 0
);
}
}
@@ -13796,6 +13942,11 @@ static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml
return;
}
if (ggml_vk_get_repack_info(tensor->type)) {
ggml_vk_repack_write(buf, tensor, offset, data, size);
return;
}
ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
}
@@ -13810,6 +13961,11 @@ static void ggml_backend_vk_buffer_set_tensor_2d(ggml_backend_buffer_t buffer, g
return;
}
if (ggml_vk_get_repack_info(tensor->type)) {
ggml_vk_repack_write(buf, tensor, offset, data, size);
return;
}
ggml_vk_buffer_write_2d(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, stride_data, stride_tensor, size, n_copies);
}
@@ -13823,6 +13979,11 @@ static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, cons
vk_buffer buf = buf_ctx->dev_buffer;
if (ggml_vk_get_repack_info(tensor->type)) {
ggml_vk_repack_read(buf, tensor, offset, data, size);
return;
}
ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
}
@@ -13838,6 +13999,11 @@ static void ggml_backend_vk_buffer_get_tensor_2d(ggml_backend_buffer_t buffer, c
vk_buffer buf = buf_ctx->dev_buffer;
if (ggml_vk_get_repack_info(tensor->type)) {
ggml_vk_repack_read(buf, tensor, offset, data, size);
return;
}
ggml_vk_buffer_read_2d(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, stride_tensor, stride_data, size, n_copies);
}
@@ -13916,9 +14082,8 @@ static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_
}
static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
return ggml_nbytes(tensor);
UNUSED(buft);
return ggml_vk_repack_size_tensor(tensor);
}
ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num) {
@@ -14043,6 +14208,12 @@ static void ggml_backend_vk_set_tensor_2d_async(ggml_backend_t backend, ggml_ten
}
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
vk_buffer buf = buf_ctx->dev_buffer;
if (ggml_vk_get_repack_info(tensor->type)) {
ggml_vk_repack_write(buf, tensor, offset, data, size);
return;
}
vk_context cpy_ctx;
@@ -14058,8 +14229,6 @@ static void ggml_backend_vk_set_tensor_2d_async(ggml_backend_t backend, ggml_ten
cpy_ctx = ggml_vk_get_compute_ctx(ctx);
}
vk_buffer buf = buf_ctx->dev_buffer;
auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
bool ret = ggml_vk_buffer_write_2d_async(cpy_ctx, buf, dst_offset, data, stride_data, stride_tensor, size, n_copies);
@@ -14112,11 +14281,15 @@ static void ggml_backend_vk_get_tensor_2d_async(ggml_backend_t backend, const gg
}
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
vk_buffer buf = buf_ctx->dev_buffer;
if (ggml_vk_get_repack_info(tensor->type)) {
ggml_vk_repack_read(buf, tensor, offset, data, size);
return;
}
vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
vk_buffer buf = buf_ctx->dev_buffer;
auto src_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
bool ret = ggml_vk_buffer_read_2d_async(compute_ctx, buf, src_offset, data, stride_tensor, stride_data, size, n_copies);
@@ -16639,6 +16812,17 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
for (int i = 2; i < GGML_MAX_DIMS; i++) {
srci_clone->nb[i] = srci_clone->nb[i - 1]*srci_clone->ne[i - 1];
}
} else if (ggml_vk_get_repack_info(srci->type)) {
const auto * info = ggml_vk_get_repack_info(srci->type);
const size_t n_blocks = ggml_vk_get_num_blocks(srci);
const size_t quants_region = ggml_vk_repack_quants_region(info, n_blocks);
const size_t repacked_size = ggml_vk_repack_size(info, n_blocks);
void * data_repacked = ggml_vk_repack_scratch(repacked_size);
ggml_vk_buffer_read(buffer_gpu, offset, data_repacked, repacked_size);
ggml_vk_repack_unpack(info, n_blocks,
data_repacked, (uint8_t *)data_repacked + quants_region,
srci_clone->data);
memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS);
} else {
if (offset + srci_size >= buffer_gpu->size) {
srci_size = buffer_gpu->size - offset;
@@ -23,6 +23,16 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
#endif
#if defined(DATA_A_Q4_0)
#if defined(A_TYPE_REPACKED)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a_quants[(a_offset + ib) * 16 + iqs]);
return (vec2(vui & 0xF, vui >> 4) - 8.0f);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a_quants16[(a_offset + ib) * 8 + iqs/2]);
return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12) - 8.0f);
}
#else
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return (vec2(vui & 0xF, vui >> 4) - 8.0f);
@@ -32,8 +42,19 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12) - 8.0f);
}
#endif
#endif
#if defined(DATA_A_Q4_1)
#if defined(A_TYPE_REPACKED)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a_quants[(a_offset + ib) * 16 + iqs]);
return vec2(vui & 0xF, vui >> 4);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a_quants16[(a_offset + ib) * 8 + iqs/2]);
return vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12);
}
#else
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2(vui & 0xF, vui >> 4);
@@ -43,6 +64,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
return vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12);
}
#endif
#endif
#if defined(DATA_A_Q5_0)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
@@ -77,6 +99,16 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
#endif
#if defined(DATA_A_Q8_0)
#if defined(A_TYPE_REPACKED)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const i8vec2 v = unpack8(int32_t(data_a_quants16[(a_offset + ib) * 16 + iqs/2])).xy;
return vec2(v);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const i8vec4 v = unpack8(int32_t(data_a_quants32[(a_offset + ib) * 8 + iqs/4]));
return vec4(v);
}
#else
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1]));
}
@@ -86,6 +118,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
return vec4(v0.x, v0.y, v1.x, v1.y);
}
#endif
#endif
#if defined(DATA_A_Q1_0)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
@@ -428,6 +461,16 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
#endif
#if defined(DATA_A_IQ4_NL)
#if defined(A_TYPE_REPACKED)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a_quants[(a_offset + ib) * 16 + iqs]);
return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a_quants16[(a_offset + ib) * 8 + iqs/2]);
return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[vui >> 12]);
}
#else
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]);
@@ -437,8 +480,20 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[vui >> 12]);
}
#endif
#endif
#if defined(DATA_A_MXFP4)
#if defined(A_TYPE_REPACKED)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a_quants[(a_offset + ib) * 16 + iqs]);
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]) * 0.5;
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a_quants16[(a_offset + ib) * 8 + iqs/2]);
return vec4(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[(vui >> 4) & 0xF],
kvalues_mxfp4[(vui >> 8) & 0xF], kvalues_mxfp4[vui >> 12]) * 0.5;
}
#else
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]) * 0.5;
@@ -449,6 +504,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
return vec4(v0.x, v0.y, v1.x, v1.y);
}
#endif
#endif
#if defined(DATA_A_NVFP4)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
@@ -486,7 +542,11 @@ vec2 get_dm(uint ib, uint a_offset) {
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
vec2 get_dm(uint ib, uint a_offset) {
#if (defined(DATA_A_Q4_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)) && defined(A_TYPE_REPACKED)
return vec2(float(data_a_deltas[a_offset + p.deltas_offset + ib]), 0);
#else
return vec2(float(data_a[a_offset + ib].d), 0);
#endif
}
#endif
@@ -499,7 +559,11 @@ vec2 get_dm(uint ib, uint a_offset) {
#if defined(DATA_A_MXFP4)
vec2 get_dm(uint ib, uint a_offset) {
#if defined(A_TYPE_REPACKED)
return vec2(e8m0_to_fp32(uint8_t(data_a_quants[p.deltas_offset + a_offset + ib])), 0);
#else
return vec2(e8m0_to_fp32(data_a[a_offset + ib].e), 0);
#endif
}
#endif
@@ -511,8 +575,13 @@ vec2 get_dm(uint ib, uint a_offset) {
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
vec2 get_dm(uint ib, uint a_offset) {
#if defined(DATA_A_Q4_1) && defined(A_TYPE_REPACKED)
return vec2(float(data_a_deltas[p.deltas_offset + (a_offset + ib) * 2]),
float(data_a_deltas[p.deltas_offset + (a_offset + ib) * 2 + 1]));
#else
const vec2 dm = vec2(data_a_packed32[a_offset + ib].dm);
return dm;
#endif
}
#endif
@@ -25,36 +25,65 @@ float16_t dequantFuncQ1_0(const in decodeBufQ1_0 bl, const in uint blockCoords[2
return bit != 0u ? d : -d;
}
#ifdef A_TYPE_REPACKED
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_0 {
uint32_t qs[4];
};
#else
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
block_q4_0_packed16 block;
};
#endif
float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
#ifdef A_TYPE_REPACKED
const uint ib = pos_a + blockCoords[0] * (p.stride_a / QUANT_K) + blockCoords[1];
const float16_t d = data_a_deltas[p.deltas_offset + ib];
uint32_t qs = bl.qs[(idx & 0xC) >> 2];
const uint shift = (idx & 0x10) >> 2;
qs >>= ((idx & 3) * 8 + shift);
#else
const float16_t d = bl.block.d;
uint32_t qs = uint32_t(bl.block.qs[(idx & 0xE) >> 1]);
const uint shift = (idx & 0x10) >> 2;
qs >>= shift;
qs &= 0x0F0F;
qs = unpack8(qs)[idx & 1];
#endif
qs &= 0xF;
float16_t ret = (float16_t(qs) - float16_t(8)) * d;
return ret;
}
#ifdef A_TYPE_REPACKED
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_1 {
uint32_t qs[4];
};
#else
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 {
block_q4_1 block;
};
#endif
float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const float16_t m = bl.block.m;
const uint idx = coordInBlock[1];
const uint iqs = idx & 0xF;
const uint shift = (idx & 0x10) >> 2;
#ifdef A_TYPE_REPACKED
const uint ib = pos_a + blockCoords[0] * (p.stride_a / QUANT_K) + blockCoords[1];
const float16_t d = data_a_deltas[p.deltas_offset + ib * 2];
const float16_t m = data_a_deltas[p.deltas_offset + ib * 2 + 1];
uint32_t qs = bl.qs[(idx & 0xC) >> 2];
qs >>= ((iqs & 3) * 8 + shift);
#else
const float16_t d = bl.block.d;
const float16_t m = bl.block.m;
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
#endif
qs &= 0xF;
float16_t ret = float16_t(qs) * d + m;
return ret;
@@ -105,18 +134,28 @@ float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2
return ret;
}
#ifdef A_TYPE_REPACKED
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ8_0 {
int32_t qs[8];
};
#else
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 {
block_q8_0_packed16 block;
};
#endif
float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint iqs = idx;
// Load 16b and select the byte for this element
#ifdef A_TYPE_REPACKED
const uint ib = pos_a + blockCoords[0] * (p.stride_a / QUANT_K) + blockCoords[1];
const float16_t d = data_a_deltas[p.deltas_offset + ib];
int32_t qs = unpack8(bl.qs[(iqs & 0x1C) >> 2])[iqs & 3];
#else
const float16_t d = bl.block.d;
int32_t qs = unpack8(bl.block.qs[(iqs & 0x1E) >> 1])[iqs & 1];
#endif
float16_t ret = float16_t(qs) * d;
return ret;
}
@@ -660,18 +699,32 @@ float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoor
#endif
#if defined(DATA_A_IQ4_NL)
#ifdef A_TYPE_REPACKED
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufIQ4_NL {
uint32_t qs[4];
};
#else
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL {
block_iq4_nl block;
};
#endif
float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
#ifdef A_TYPE_REPACKED
const uint ib = pos_a + blockCoords[0] * (p.stride_a / QUANT_K) + blockCoords[1];
const float16_t d = data_a_deltas[p.deltas_offset + ib];
uint32_t qs = bl.qs[(idx & 0xC) >> 2];
const uint shift = (idx & 0x10) >> 2;
qs >>= ((idx & 3) * 8 + shift);
#else
const float16_t d = bl.block.d;
const uint iqs = idx & 0xF;
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
#endif
qs &= 0xF;
float16_t ret = float16_t(kvalues_iq4nl[qs]) * d;
return ret;
@@ -679,18 +732,31 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
#endif
#if defined(DATA_A_MXFP4)
#ifdef A_TYPE_REPACKED
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufMXFP4 {
uint32_t qs[4];
};
#else
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 {
block_mxfp4 block;
};
#endif
float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float d = e8m0_to_fp32(bl.block.e);
const uint idx = coordInBlock[1];
const uint iqs = idx & 0xF;
const uint shift = (idx & 0x10) >> 2;
#ifdef A_TYPE_REPACKED
const uint ib = pos_a + blockCoords[0] * (p.stride_a / QUANT_K) + blockCoords[1];
const float d = e8m0_to_fp32(data_a_scales[p.deltas_offset + ib]);
uint32_t qs = bl.qs[(iqs & 0xC) >> 2];
qs >>= ((iqs & 3) * 8 + shift);
#else
const float d = e8m0_to_fp32(bl.block.e);
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
#endif
qs &= 0xF;
float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5);
return ret;
@@ -38,6 +38,8 @@ layout (push_constant) uniform parameter
uint broadcast2;
uint broadcast3;
#endif
uint deltas_offset;
} p;
#ifdef MUL_MAT_ID
@@ -15,6 +15,12 @@ layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
#if defined(A_TYPE_REPACKED)
layout (binding = 0) readonly buffer A_QUANTS {uint8_t data_a_quants[];};
layout (binding = 0) readonly buffer A_QUANTS16 {uint16_t data_a_quants16[];};
layout (binding = 0) readonly buffer A_QUANTS32 {uint32_t data_a_quants32[];};
layout (binding = 0) readonly buffer A_DELTAS {float16_t data_a_deltas[];};
#endif
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
#ifdef B_TYPEV2
@@ -6,19 +6,32 @@
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
FLOAT_TYPE get_dm(uint ib) {
#if (defined(DATA_A_Q4_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)) && defined(A_TYPE_REPACKED)
return FLOAT_TYPE(data_a_deltas[p.deltas_offset + ib]);
#else
return FLOAT_TYPE(data_a[ib].d);
#endif
}
#endif
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
FLOAT_TYPEV2 get_dm(uint ib) {
#if defined(DATA_A_Q4_1) && defined(A_TYPE_REPACKED)
return FLOAT_TYPEV2(data_a_deltas[p.deltas_offset + ib * 2],
data_a_deltas[p.deltas_offset + ib * 2 + 1]);
#else
return FLOAT_TYPEV2(data_a_packed32[ib].dm);
#endif
}
#endif
#if defined(DATA_A_MXFP4)
FLOAT_TYPE get_dm(uint ib) {
#if defined(A_TYPE_REPACKED)
return FLOAT_TYPE(e8m0_to_fp32(uint8_t(data_a_quants[p.deltas_offset + ib])));
#else
return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));
#endif
}
#endif
@@ -33,9 +46,13 @@ FLOAT_TYPEV2 get_dm(uint ib) {
#if defined(DATA_A_Q4_0)
// 2-byte loads for Q4_0 blocks (18 bytes)
i32vec2 repack(uint ib, uint iqs) {
#if defined(DATA_A_Q4_0) && defined(A_TYPE_REPACKED)
const uint32_t vui = data_a_quants32[ib * 4 + iqs];
#else
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
#endif
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
}
@@ -48,7 +65,11 @@ FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const i
#if defined(DATA_A_Q4_1)
// 4-byte loads for Q4_1 blocks (20 bytes)
i32vec2 repack(uint ib, uint iqs) {
#if defined(A_TYPE_REPACKED)
const uint32_t vui = data_a_quants32[ib * 4 + iqs];
#else
const uint32_t vui = data_a_packed32[ib].qs[iqs];
#endif
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
}
@@ -103,8 +124,12 @@ FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const i
#if defined(DATA_A_Q8_0)
// 2-byte loads for Q8_0 blocks (34 bytes)
int32_t repack(uint ib, uint iqs) {
#if defined(A_TYPE_REPACKED)
return int32_t(data_a_quants32[ib * 8 + iqs]);
#else
return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]));
#endif
}
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
@@ -115,10 +140,14 @@ FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const i
#if defined(DATA_A_MXFP4)
// 1-byte loads for mxfp4 blocks (17 bytes)
i32vec2 repack(uint ib, uint iqs) {
#if defined(A_TYPE_REPACKED)
const uint32_t qs = data_a_quants32[ib * 4 + iqs];
#else
const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
data_a[ib].qs[iqs * 4 + 1],
data_a[ib].qs[iqs * 4 + 2],
data_a[ib].qs[iqs * 4 + 3]));
#endif
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
@@ -62,6 +62,12 @@ layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
#if defined(A_TYPE_REPACKED)
layout (binding = 0) readonly buffer A_QUANTS {uint8_t data_a_quants[];};
layout (binding = 0) readonly buffer A_QUANTS16 {uint16_t data_a_quants16[];};
layout (binding = 0) readonly buffer A_QUANTS32 {uint32_t data_a_quants32[];};
layout (binding = 0) readonly buffer A_DELTAS {float16_t data_a_deltas[];};
#endif
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
@@ -98,6 +104,9 @@ layout (push_constant) uniform parameter
uint broadcast2;
uint broadcast3;
#endif
uint padded_N;
uint deltas_offset;
} p;
layout (constant_id = 0) const uint BLOCK_SIZE = 64;
@@ -63,13 +63,27 @@ layout (push_constant) uniform parameter
#endif
// N dimension for the B matrix can be >= p.N
uint padded_N;
uint deltas_offset;
} p;
#ifdef A_TYPE_REPACKED
#if defined(DATA_A_Q8_0)
struct block_repacked_quants { uint16_t qs[16]; };
#else
struct block_repacked_quants { uint16_t qs[8]; };
#endif
layout (binding = 0) readonly buffer A {block_repacked_quants data_a[];};
layout (binding = 0) readonly buffer A_DELTAS {float16_t data_a_deltas[];};
layout (binding = 0) readonly buffer A_SCALES {uint8_t data_a_scales[];};
#else
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#endif
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
uint pos_a;
#if QUANT_K > 1
#define DECODEFUNCA , dequantFuncA
@@ -254,10 +268,10 @@ void main() {
#endif
#ifdef MUL_MAT_ID
uint pos_a = expert_idx * (p.batch_stride_a / QUANT_K);
pos_a = expert_idx * (p.batch_stride_a / QUANT_K);
uint pos_b = 0;
#else
uint pos_a = batch_idx_a * (p.batch_stride_a / QUANT_K);
pos_a = batch_idx_a * (p.batch_stride_a / QUANT_K);
uint pos_b = batch_idx * p.batch_stride_b;
uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
#endif
@@ -52,8 +52,13 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint ib = idx / 4;
const uint iqs = idx & 0x03;
#if defined(A_TYPE_REPACKED)
const float d = float(data_a_deltas[p.deltas_offset + ib]);
const uint vui = data_a_quants32[ib * 4 + iqs];
#else
const float d = float(data_a_packed16[ib].d);
const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
#endif
const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d;
const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d;
@@ -68,8 +73,14 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint ib = idx / 4;
const uint iqs = idx & 0x03;
#if defined(A_TYPE_REPACKED)
const vec2 dm = vec2(data_a_deltas[p.deltas_offset + ib * 2],
data_a_deltas[p.deltas_offset + ib * 2 + 1]);
const uint vui = data_a_quants32[ib * 4 + iqs];
#else
const vec2 dm = vec2(data_a_packed32[ib].dm);
const uint vui = data_a_packed32[ib].qs[iqs];
#endif
const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * dm.x + dm.y;
const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * dm.x + dm.y;
@@ -123,10 +134,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint ib = idx / 8;
const uint iqs = idx & 0x07;
#if defined(A_TYPE_REPACKED)
const float d = float(data_a_deltas[p.deltas_offset + ib]);
const vec4 v = vec4(unpack8(int32_t(data_a_quants32[ib * 8 + iqs]))) * d;
#else
const float d = float(data_a_packed16[ib].d);
const i8vec2 v0 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy;
const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;
#endif
buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy);
buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw);
@@ -481,8 +497,13 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint ib = idx / 8;
const uint iqs = idx & 0x07;
#if defined(A_TYPE_REPACKED)
const FLOAT_TYPE d = FLOAT_TYPE(data_a_deltas[p.deltas_offset + ib]);
const uint vui = uint(data_a_quants16[ib * 8 + iqs]);
#else
const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d);
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
#endif
buf_a[buf_idx ] = d * FLOAT_TYPEV2(kvalues_iq4nl[vui & 0xF],
kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]);
@@ -495,9 +516,16 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint ib = idx / 8;
const uint iqs = (idx & 0x07) * 2;
#if defined(A_TYPE_REPACKED)
const float d = e8m0_to_fp32(uint8_t(data_a_quants[p.deltas_offset + ib])) * 0.5;
const uint vui16 = uint(data_a_quants16[ib * 8 + iqs/2]);
const uint vui = vui16 & 0xFF;
const uint vui2 = vui16 >> 8;
#else
const float d = e8m0_to_fp32(data_a[ib].e) * 0.5;
const uint vui = uint(data_a[ib].qs[iqs]);
const uint vui2 = uint(data_a[ib].qs[iqs+1]);
#endif
buf_a[buf_idx ] = FLOAT_TYPEV2(kvalues_mxfp4[vui & 0xF] * d,
kvalues_mxfp4[vui2 & 0xF] * d);
@@ -30,6 +30,13 @@ layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
#if defined(A_TYPE_REPACKED)
layout (binding = 0) readonly buffer A_QUANTS {uint8_t data_a_quants[];};
layout (binding = 0) readonly buffer A_QUANTS16 {uint16_t data_a_quants16[];};
layout (binding = 0) readonly buffer A_QUANTS32 {uint32_t data_a_quants32[];};
layout (binding = 0) readonly buffer A_DELTAS {float16_t data_a_deltas[];};
#endif
layout (binding = 1) readonly buffer B {block_q8_1_x4_packed128 data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
@@ -65,6 +72,9 @@ layout (push_constant) uniform parameter
uint broadcast2;
uint broadcast3;
#endif
uint padded_N;
uint deltas_offset;
} p;
layout (constant_id = 0) const uint BLOCK_SIZE = 64;
@@ -11,19 +11,36 @@
// 4-byte loads for Q4_1 blocks (20 bytes)
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
#ifdef DATA_A_Q4_0
#if defined(A_TYPE_REPACKED)
buf_a[buf_ib].qs[iqs] = data_a_quants32[ib * 4 + iqs];
#else
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
data_a_packed16[ib].qs[iqs * 2 + 1]));
#endif
if (iqs == 0) {
#if defined(A_TYPE_REPACKED)
buf_a[buf_ib].dm = FLOAT_TYPE(data_a_deltas[p.deltas_offset + ib]);
#else
buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
#endif
}
#else // DATA_A_Q4_1
#if defined(A_TYPE_REPACKED)
buf_a[buf_ib].qs[iqs] = data_a_quants32[ib * 4 + iqs];
if (iqs == 0) {
buf_a[buf_ib].dm = FLOAT_TYPEV2(data_a_deltas[p.deltas_offset + ib * 2],
data_a_deltas[p.deltas_offset + ib * 2 + 1]);
}
#else
buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
if (iqs == 0) {
buf_a[buf_ib].dm = FLOAT_TYPEV2(data_a_packed32[ib].dm);
}
#endif
#endif
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
@@ -115,12 +132,20 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
#if defined(DATA_A_Q8_0)
// 2-byte loads for Q8_0 blocks (34 bytes)
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
#if defined(A_TYPE_REPACKED)
buf_a[buf_ib].qs[iqs] = int32_t(data_a_quants32[ib * 8 + iqs]);
if (iqs == 0) {
buf_a[buf_ib].dm = FLOAT_TYPE(data_a_deltas[p.deltas_offset + ib]);
}
#else
buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2],
data_a_packed16[ib].qs[iqs * 2 + 1]));
if (iqs == 0) {
buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
}
#endif
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
@@ -147,10 +172,14 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
#if defined(DATA_A_MXFP4)
// 1-byte loads for mxfp4 blocks (17 bytes)
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
#if defined(A_TYPE_REPACKED)
const uint32_t qs = data_a_quants32[ib * 4 + iqs];
#else
const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
data_a[ib].qs[iqs * 4 + 1],
data_a[ib].qs[iqs * 4 + 2],
data_a[ib].qs[iqs * 4 + 3]));
#endif
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
@@ -159,7 +188,11 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
buf_a[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w]));
if (iqs == 0) {
#if defined(A_TYPE_REPACKED)
buf_a[buf_ib].d = FLOAT_TYPE(e8m0_to_fp32(uint8_t(data_a_quants[p.deltas_offset + ib])) * 0.5);
#else
buf_a[buf_ib].d = FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e) * 0.5);
#endif
}
}
@@ -564,6 +564,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
continue;
}
std::map<std::string, std::string> mm_base_dict = base_dict;
if (tname == "q4_0" || tname == "q4_1" || tname == "q8_0" || tname == "iq4_nl" || tname == "mxfp4") {
mm_base_dict["A_TYPE_REPACKED"] = "1";
}
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
// For unaligned, load one at a time for f32/f16, or two at a time for quants
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant;
@@ -579,19 +584,19 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
// don't generate f32 variants for coopmat2
if (!coopmat2) {
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(merge_maps(mm_base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(merge_maps(mm_base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
}
if (tname != "f16" && tname != "f32") {
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(merge_maps(mm_base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(merge_maps(mm_base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
}
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
// Integer dot mmq performs better with f32 accumulators
if (!f16acc && !coopmat && !coopmat2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) {
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(mm_base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
}
#endif
}
@@ -665,33 +670,38 @@ void process_shaders() {
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}};
for (const auto& tname : type_names) {
std::map<std::string, std::string> mmv_base_dict = base_dict;
if (tname == "q4_0" || tname == "q4_1" || tname == "q8_0" || tname == "iq4_nl" || tname == "mxfp4") {
mmv_base_dict["A_TYPE_REPACKED"] = "1";
}
// mul mat vec
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(mmv_base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(mmv_base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(mmv_base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(mmv_base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(mmv_base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(mmv_base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(mmv_base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(mmv_base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(mmv_base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
// mul mat vec with integer dot product
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname) || tname == "iq1_s" || tname == "iq1_m") {
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(mmv_base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(mmv_base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(mmv_base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(mmv_base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(mmv_base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(mmv_base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
}
#endif