Compare commits

..

2 Commits

Author SHA1 Message Date
Rithik Sharma
665abc6097 add fast mat-vec kernels for i-quants (#22344) 2026-04-27 08:25:45 -07:00
Igor Rudenko
4414c04b9a Additional test for common/gemma4 : handle parsing edge cases (#22420)
* Additional test for common/gemma4 : handle parsing edge cases

* Move tests to Gemma 4 test group
2026-04-27 16:36:59 +02:00
4 changed files with 583 additions and 0 deletions

View File

@@ -1615,6 +1615,24 @@ class ggml_webgpu_shader_lib {
defines.push_back("MUL_ACC_" + type_upper);
defines.push_back("U32_DEQUANT_HELPERS");
defines.push_back("SRC0_INNER_TYPE=u32");
switch (context.src0->type) {
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
defines.push_back(type_upper + "_GRID");
break;
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
defines.push_back(type_upper + "_GRID");
defines.push_back(type_upper + "_TABLES");
break;
default:
break;
}
break;
}
}

View File

@@ -1391,6 +1391,17 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
case GGML_TYPE_Q2_K:
use_fast = true;
break;
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
use_fast = is_vec;
break;
default:
break;
}

View File

@@ -812,6 +812,520 @@ fn main(
}
#endif
#ifdef MUL_ACC_IQ1_S
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 50
#define THREADS_PER_BLOCK 16
let tid = thread_id % THREADS_PER_BLOCK;
let block_group = thread_id / THREADS_PER_BLOCK;
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
let sub_blk = tid / 2u;
let half = tid % 2u;
let slot0 = half * 2u;
let y_offset = sub_blk * 32u + slot0 * 8u;
let num_blocks = params.k / BLOCK_SIZE;
for (var block = block_group; block < num_blocks; block += num_block_groups) {
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
var x_block: array<f32, 16>;
for (var i = 0u; i < 16u; i++) {
x_block[i] = f32(src1[x_base + i]);
}
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
let d = f32(load_f16_at_src0(block_byte_base));
let qh = load_u32_at_src0(block_byte_base + 34u + sub_blk * 2u) & 0xFFFFu;
let dl = d * f32(2u * ((qh >> 12u) & 7u) + 1u);
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u);
let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u);
var row_sum = 0.0;
for (var ll = 0u; ll < 2u; ll++) {
let l = slot0 + ll;
let qs_byte = get_byte(qs_w, l);
let ig = (qs_byte | (((qh >> (3u * l)) & 7u) << 8u)) * 8u;
let gw = iq1_grid[ig / 16u];
let bit_base = (ig % 16u) * 2u;
for (var j = 0u; j < 8u; j++) {
let g = (gw >> (bit_base + j * 2u)) & 3u;
let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u);
row_sum += dl * (gs + delta) * x_block[ll * 8u + j];
}
}
acc[row] += row_sum;
}
}
}
#endif
#ifdef MUL_ACC_IQ1_M
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 56
#define THREADS_PER_BLOCK 16
let tid = thread_id % THREADS_PER_BLOCK;
let block_group = thread_id / THREADS_PER_BLOCK;
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
let sub_blk = tid / 2u;
let half = tid % 2u;
let slot0 = half * 2u;
let y_offset = sub_blk * 32u + slot0 * 8u;
let num_blocks = params.k / BLOCK_SIZE;
for (var block = block_group; block < num_blocks; block += num_block_groups) {
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
var x_block: array<f32, 16>;
for (var i = 0u; i < 16u; i++) {
x_block[i] = f32(src1[x_base + i]);
}
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
let sc_lo = load_u32_at_src0(block_byte_base + 48u);
let sc_hi = load_u32_at_src0(block_byte_base + 52u);
let sc0 = sc_lo & 0xFFFFu;
let sc1 = (sc_lo >> 16u) & 0xFFFFu;
let sc2 = sc_hi & 0xFFFFu;
let sc3 = (sc_hi >> 16u) & 0xFFFFu;
let d_bits = (sc0 >> 12u) | ((sc1 >> 8u) & 0xF0u) | ((sc2 >> 4u) & 0xF00u) | (sc3 & 0xF000u);
let d = f32(bitcast<vec2<f16>>(d_bits)[0]);
let sc_u16 = select(select(sc2, sc3, sub_blk >= 6u),
select(sc0, sc1, sub_blk >= 2u),
sub_blk < 4u);
let qs_w = load_u32_at_src0(block_byte_base + sub_blk * 4u);
let qh = load_u32_at_src0(block_byte_base + 32u + sub_blk * 2u) & 0xFFFFu;
let qh_lo = qh & 0xFFu;
let qh_hi = (qh >> 8u) & 0xFFu;
var row_sum = 0.0;
for (var ll = 0u; ll < 2u; ll++) {
let l = slot0 + ll;
let bit_off = 6u * (sub_blk % 2u) + 3u * (l / 2u);
let sub_scale = (sc_u16 >> bit_off) & 0x7u;
let dl = d * f32(2u * sub_scale + 1u);
let qh_byte = select(qh_lo, qh_hi, l >= 2u);
let ll2 = l % 2u;
let grid_idx = get_byte(qs_w, l) | (((qh_byte >> (4u * ll2)) & 7u) << 8u);
let delta = select(IQ1_DELTA, -IQ1_DELTA, ((qh_byte >> (3u + 4u * ll2)) & 1u) != 0u);
let ig = grid_idx * 8u;
let gw = iq1_grid[ig / 16u];
let bit_base = (ig % 16u) * 2u;
for (var j = 0u; j < 8u; j++) {
let g = (gw >> (bit_base + j * 2u)) & 3u;
let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u);
row_sum += dl * (gs + delta) * x_block[ll * 8u + j];
}
}
acc[row] += row_sum;
}
}
}
#endif
#ifdef MUL_ACC_IQ2_XXS
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 66
#define THREADS_PER_BLOCK 16
let tid = thread_id % THREADS_PER_BLOCK;
let block_group = thread_id / THREADS_PER_BLOCK;
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
let sub_blk = tid / 2u;
let half = tid % 2u;
let slot0 = half * 2u;
let y_offset = sub_blk * 32u + slot0 * 8u;
let num_blocks = params.k / BLOCK_SIZE;
for (var block = block_group; block < num_blocks; block += num_block_groups) {
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
var x_block: array<f32, 16>;
for (var i = 0u; i < 16u; i++) {
x_block[i] = f32(src1[x_base + i]);
}
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
let d = f32(load_f16_at_src0(block_byte_base));
let aux_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u);
let aux_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u);
let ls = aux_hi >> 28u;
let db = d * (0.5 + f32(ls)) * 0.25;
var row_sum = 0.0;
for (var ll = 0u; ll < 2u; ll++) {
let l = slot0 + ll;
let grid_idx = (aux_lo >> (8u * l)) & 0xFFu;
let signs_idx = (aux_hi >> (7u * l)) & 0x7Fu;
let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu;
let gw_lo = iq2xxs_grid[grid_idx * 2u];
let gw_hi = iq2xxs_grid[grid_idx * 2u + 1u];
for (var j = 0u; j < 8u; j++) {
let gw = select(gw_hi, gw_lo, j < 4u);
let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu);
let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u);
row_sum += db * b * s * x_block[ll * 8u + j];
}
}
acc[row] += row_sum;
}
}
}
#endif
#ifdef MUL_ACC_IQ2_XS
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 74
#define THREADS_PER_BLOCK 16
let tid = thread_id % THREADS_PER_BLOCK;
let block_group = thread_id / THREADS_PER_BLOCK;
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
let sub_blk = tid / 2u;
let half = tid % 2u;
let slot0 = half * 2u;
let y_offset = sub_blk * 32u + slot0 * 8u;
let num_blocks = params.k / BLOCK_SIZE;
for (var block = block_group; block < num_blocks; block += num_block_groups) {
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
var x_block: array<f32, 16>;
for (var i = 0u; i < 16u; i++) {
x_block[i] = f32(src1[x_base + i]);
}
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
let d = f32(load_f16_at_src0(block_byte_base));
let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u);
let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u);
let scales_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u);
let scales_byte = get_byte(scales_word, sub_blk % 4u);
var row_sum = 0.0;
for (var ll = 0u; ll < 2u; ll++) {
let l = slot0 + ll;
let qs_word = select(qs_hi, qs_lo, l < 2u);
let half2 = (l % 2u) * 16u;
let qs_val = (qs_word >> half2) & 0xFFFFu;
let grid_idx = qs_val & 0x1FFu;
let signs_idx = (qs_val >> 9u) & 0x7Fu;
let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu;
let db = d * (0.5 + f32(sub_scale)) * 0.25;
let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu;
let gw_lo = iq2xs_grid[grid_idx * 2u];
let gw_hi = iq2xs_grid[grid_idx * 2u + 1u];
for (var j = 0u; j < 8u; j++) {
let gw = select(gw_hi, gw_lo, j < 4u);
let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu);
let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u);
row_sum += db * b * s * x_block[ll * 8u + j];
}
}
acc[row] += row_sum;
}
}
}
#endif
#ifdef MUL_ACC_IQ2_S
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 82
#define THREADS_PER_BLOCK 16
let tid = thread_id % THREADS_PER_BLOCK;
let block_group = thread_id / THREADS_PER_BLOCK;
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
let sub_blk = tid / 2u;
let half = tid % 2u;
let slot0 = half * 2u;
let y_offset = sub_blk * 32u + slot0 * 8u;
let num_blocks = params.k / BLOCK_SIZE;
for (var block = block_group; block < num_blocks; block += num_block_groups) {
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
var x_block: array<f32, 16>;
for (var i = 0u; i < 16u; i++) {
x_block[i] = f32(src1[x_base + i]);
}
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
let d = f32(load_f16_at_src0(block_byte_base));
let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u);
let sg_w = load_u32_at_src0(block_byte_base + 34u + sub_blk * 4u);
let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u);
let qh_byte = get_byte(qh_word, sub_blk % 4u);
let sc_word = load_u32_at_src0(block_byte_base + 74u + (sub_blk / 4u) * 4u);
let scales_byte = get_byte(sc_word, sub_blk % 4u);
var row_sum = 0.0;
for (var ll = 0u; ll < 2u; ll++) {
let l = slot0 + ll;
let qs_byte = get_byte(qs_w, l);
let sign_byte = get_byte(sg_w, l);
let grid_idx = qs_byte | (((qh_byte >> (2u * l)) & 3u) << 8u);
let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu;
let db = d * (0.5 + f32(sub_scale)) * 0.25;
let gw_lo = iq2s_grid[grid_idx * 2u];
let gw_hi = iq2s_grid[grid_idx * 2u + 1u];
for (var j = 0u; j < 8u; j++) {
let gw = select(gw_hi, gw_lo, j < 4u);
let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu);
let s = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u);
row_sum += db * b * s * x_block[ll * 8u + j];
}
}
acc[row] += row_sum;
}
}
}
#endif
#ifdef MUL_ACC_IQ3_XXS
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 98
#define THREADS_PER_BLOCK 16
let tid = thread_id % THREADS_PER_BLOCK;
let block_group = thread_id / THREADS_PER_BLOCK;
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
let sub_blk = tid / 2u;
let half = tid % 2u;
let slot0 = half * 2u;
let y_offset = sub_blk * 32u + slot0 * 8u;
let num_blocks = params.k / BLOCK_SIZE;
for (var block = block_group; block < num_blocks; block += num_block_groups) {
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
var x_block: array<f32, 16>;
for (var i = 0u; i < 16u; i++) {
x_block[i] = f32(src1[x_base + i]);
}
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
let d = f32(load_f16_at_src0(block_byte_base));
let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u);
let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u);
let aux = load_u32_at_src0(block_byte_base + 66u + sub_blk * 4u);
let ls = aux >> 28u;
let db = d * (0.5 + f32(ls)) * 0.5;
var row_sum = 0.0;
for (var ll = 0u; ll < 2u; ll++) {
let l = slot0 + ll;
let qs_word = select(qs_hi, qs_lo, l < 2u);
let byte_pos = (l % 2u) * 2u;
let grid_idx_0 = (qs_word >> (byte_pos * 8u)) & 0xFFu;
let grid_idx_1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu;
let signs_idx = (aux >> (7u * l)) & 0x7Fu;
let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu;
let grid1 = iq3xxs_grid[grid_idx_0];
let grid2 = iq3xxs_grid[grid_idx_1];
for (var j = 0u; j < 4u; j++) {
let b1 = f32((grid1 >> (j * 8u)) & 0xFFu);
let b2 = f32((grid2 >> (j * 8u)) & 0xFFu);
let s1 = select(1.0, -1.0, ((signs >> j) & 1u) != 0u);
let s2 = select(1.0, -1.0, ((signs >> (j + 4u)) & 1u) != 0u);
row_sum += db * b1 * s1 * x_block[ll * 8u + j];
row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u];
}
}
acc[row] += row_sum;
}
}
}
#endif
#ifdef MUL_ACC_IQ3_S
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 110
#define THREADS_PER_BLOCK 16
let tid = thread_id % THREADS_PER_BLOCK;
let block_group = thread_id / THREADS_PER_BLOCK;
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
let sub_blk = tid / 2u;
let half = tid % 2u;
let slot0 = half * 2u;
let y_offset = sub_blk * 32u + slot0 * 8u;
let num_blocks = params.k / BLOCK_SIZE;
for (var block = block_group; block < num_blocks; block += num_block_groups) {
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
var x_block: array<f32, 16>;
for (var i = 0u; i < 16u; i++) {
x_block[i] = f32(src1[x_base + i]);
}
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
let d = f32(load_f16_at_src0(block_byte_base));
let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u);
let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u);
let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u);
let qh_byte = get_byte(qh_word, sub_blk % 4u);
let sg_w = load_u32_at_src0(block_byte_base + 74u + sub_blk * 4u);
let sc_word = load_u32_at_src0(block_byte_base + 106u);
let scales_byte = get_byte(sc_word, sub_blk / 2u);
let sub_scale = (scales_byte >> (4u * (sub_blk % 2u))) & 0xFu;
let db = d * (1.0 + 2.0 * f32(sub_scale));
var row_sum = 0.0;
for (var ll = 0u; ll < 2u; ll++) {
let l = slot0 + ll;
let qs_word = select(qs_hi, qs_lo, l < 2u);
let byte_pos = (l % 2u) * 2u;
let qs0 = (qs_word >> (byte_pos * 8u)) & 0xFFu;
let qs1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu;
let grid_idx_1 = qs0 | (((qh_byte >> (2u * l)) & 1u) << 8u);
let grid_idx_2 = qs1 | (((qh_byte >> (2u * l + 1u)) & 1u) << 8u);
let sign_byte = get_byte(sg_w, l);
let grid1 = iq3s_grid[grid_idx_1];
let grid2 = iq3s_grid[grid_idx_2];
for (var j = 0u; j < 4u; j++) {
let b1 = f32((grid1 >> (j * 8u)) & 0xFFu);
let b2 = f32((grid2 >> (j * 8u)) & 0xFFu);
let s1 = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u);
let s2 = select(1.0, -1.0, ((sign_byte >> (j + 4u)) & 1u) != 0u);
row_sum += db * b1 * s1 * x_block[ll * 8u + j];
row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u];
}
}
acc[row] += row_sum;
}
}
}
#endif
#ifdef MUL_ACC_IQ4_NL
#define BLOCK_SIZE 32
#define BLOCK_SIZE_BYTES 18
#define THREADS_PER_BLOCK 4
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
let num_blocks = params.k / BLOCK_SIZE;
let thread_within_block = thread_id % THREADS_PER_BLOCK;
for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4u;
var x_block: array<f32, ELEMS_PER_THREAD>;
for (var i = 0u; i < ELEMS_PER_THREAD / 2u; i++) {
x_block[i] = f32(src1[x_base + i]);
x_block[i + 4u] = f32(src1[x_base + i + 16u]);
}
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
let d = f32(load_f16_at_src0(block_byte_base));
var row_sum = 0.0;
let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block);
for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
let q_byte = get_byte(q_packed, byte_idx);
let q_lo = f32(kvalues_iq4nl[q_byte & 0xFu]) * d;
let q_hi = f32(kvalues_iq4nl[(q_byte >> 4u) & 0xFu]) * d;
row_sum += q_lo * x_block[byte_idx];
row_sum += q_hi * x_block[byte_idx + 4u];
}
acc[row] += row_sum;
}
}
}
#endif
#ifdef MUL_ACC_IQ4_XS
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 136
#define THREADS_PER_BLOCK 16
let tid = thread_id % THREADS_PER_BLOCK;
let block_group = thread_id / THREADS_PER_BLOCK;
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
let sub_blk = tid / 2u;
let half = tid % 2u;
let y_offset = sub_blk * 32u + half * 16u;
let num_blocks = params.k / BLOCK_SIZE;
for (var block = block_group; block < num_blocks; block += num_block_groups) {
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
var x_block: array<f32, 16>;
for (var i = 0u; i < 16u; i++) {
x_block[i] = f32(src1[x_base + i]);
}
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
let d = f32(load_f16_at_src0(block_byte_base));
let scales_h = load_u16_at_src0(block_byte_base + 2u);
let scales_l_word = load_u32_at_src0(block_byte_base + 4u);
let sl_byte = get_byte(scales_l_word, sub_blk / 2u);
let sl = (sl_byte >> (4u * (sub_blk % 2u))) & 0xFu;
let sh_bits = (scales_h >> (2u * sub_blk)) & 3u;
let ls = i32(sl | (sh_bits << 4u));
let dl = d * f32(ls - 32);
let qs_byte_off = 8u + sub_blk * 16u;
let q_w0 = load_u32_at_src0(block_byte_base + qs_byte_off);
let q_w1 = load_u32_at_src0(block_byte_base + qs_byte_off + 4u);
let q_w2 = load_u32_at_src0(block_byte_base + qs_byte_off + 8u);
let q_w3 = load_u32_at_src0(block_byte_base + qs_byte_off + 12u);
var row_sum = 0.0;
for (var i = 0u; i < 16u; i++) {
let q_word = select(
select(q_w0, q_w1, i >= 4u),
select(q_w2, q_w3, i >= 12u),
i >= 8u);
let q_byte = get_byte(q_word, i % 4u);
let nib = select(q_byte & 0xFu, (q_byte >> 4u) & 0xFu, half == 1u);
row_sum += f32(kvalues_iq4nl[nib]) * dl * x_block[i];
}
acc[row] += row_sum;
}
}
}
#endif
#ifdef USE_SUBGROUP_REDUCTION
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let subgroup_total = subgroupAdd(acc[row]);

View File

@@ -2249,6 +2249,46 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.expect(message_assist)
.run();
{
// additional tests for https://github.com/ggml-org/llama.cpp/pull/21760
auto tmpls = read_templates("models/templates/google-gemma-4-31B-it.jinja");
common_chat_msg tool_call_msg = simple_assist_msg(
"Let me check.", "", "special_function", "{\"arg1\": 1}","c0");
common_chat_msg tool_msg;
tool_msg.role = "tool";
tool_msg.tool_name = "special_function";
tool_msg.tool_call_id = "c0";
tool_msg.content = "{\"r\":\"ok\"}";
{
common_chat_templates_inputs inputs;
inputs.messages = { message_user, tool_call_msg, tool_msg };
inputs.tools = { special_function_tool };
inputs.add_generation_prompt = true;
auto params = common_chat_templates_apply(tmpls.get(), inputs);
if (!string_ends_with(params.prompt, "<turn|>\n<|turn>model\n")) {
throw std::runtime_error("Missing generation prompt for Gemma 4");
}
}
{
common_chat_templates_inputs inputs;
inputs.messages = { message_user, tool_call_msg, tool_msg };
inputs.tools = { special_function_tool };
inputs.add_generation_prompt = false;
auto params = common_chat_templates_apply(tmpls.get(), inputs);
if (string_ends_with(params.prompt, "<|turn>model\n")) {
throw std::runtime_error("Gemma 4: generation prompt was modified despite add_generation_prompt=false");
}
}
}
}
{