Compare commits

..

2 Commits
b9563 ... b9565

Author SHA1 Message Date
Nikhil Jain
1705d434f6 [ggml-webgpu] Handle buffer overlap / buffer aliasing for concat operator (#24000)
* Only run webgpu CI on my fork

* Add webgpu only workflow

* handle buffer overlap case for concat operator

* restore build-webgpu.yml

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* Run clang-format

* Update ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: Reese Levine <reeselevine1@gmail.com>
2026-06-08 08:07:31 -07:00
Nikhil Jain
3b3da01dc2 [ggml-webgpu] Implement 2D workgroups for scale, binary, and unary ops (#24044)
* Only run webgpu CI on my fork

* Add webgpu only workflow

* Implement 2d workgroups for more operations

* fix

* Fix type

* Move back to global_invocation_id
2026-06-08 08:07:15 -07:00
6 changed files with 120 additions and 58 deletions

View File

@@ -448,15 +448,19 @@ struct ggml_webgpu_upscale_pipeline_key_hash {
/** Concat **/
struct ggml_webgpu_concat_pipeline_key {
int type;
int type;
bool src_overlap;
bool operator==(const ggml_webgpu_concat_pipeline_key & other) const { return type == other.type; }
bool operator==(const ggml_webgpu_concat_pipeline_key & other) const {
return type == other.type && src_overlap == other.src_overlap;
}
};
struct ggml_webgpu_concat_pipeline_key_hash {
size_t operator()(const ggml_webgpu_concat_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.src_overlap);
return seed;
}
};
@@ -2634,6 +2638,7 @@ class ggml_webgpu_shader_lib {
webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_concat_pipeline_key key = {};
key.type = context.dst->type;
key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1);
auto it = concat_pipelines.find(key);
if (it != concat_pipelines.end()) {
@@ -2656,11 +2661,17 @@ class ggml_webgpu_shader_lib {
GGML_ABORT("Unsupported type for concat shader");
}
if (key.src_overlap) {
defines.push_back("SRC_OVERLAP");
variant += "_src_overlap";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_concat, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
auto decisions = std::make_shared<ggml_webgpu_binary_shader_decisions>();
decisions->wg_size = context.max_wg_size;
decisions->src_overlap = key.src_overlap;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
concat_pipelines[key] = pipeline;

View File

@@ -621,10 +621,11 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
uint32_t value,
size_t offset,
size_t size) {
std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value };
std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_bind_group_entry(0, buf, 0, buf.GetSize()) };
size_t bytes_per_wg = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.memset_bytes_per_thread;
uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg);
std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value };
std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_bind_group_entry(0, buf, 0, buf.GetSize()) };
size_t bytes_per_wg =
ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.memset_bytes_per_thread;
uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg);
ctx->queue.WriteBuffer(ctx->memset_params_buf, 0, params.data(), params.size() * sizeof(uint32_t));
@@ -1362,7 +1363,7 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx,
shader_lib_ctx.src0 = src;
shader_lib_ctx.src1 = nullptr;
shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
@@ -2169,8 +2170,10 @@ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
}
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
uint32_t wg_x, wg_y;
uint32_t total_wg = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}
static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx,
@@ -2244,8 +2247,10 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx,
}
}
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
uint32_t wg_x, wg_y;
uint32_t total_wg = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}
static webgpu_encoded_op ggml_webgpu_add_id(webgpu_context & ctx,
@@ -2305,33 +2310,6 @@ static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx,
uint32_t ne = (uint32_t) ggml_nelements(dst);
uint32_t dim = (uint32_t) dst->op_params[0];
std::vector<uint32_t> params = {
ne,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
(uint32_t) dst->ne[0],
(uint32_t) dst->ne[1],
(uint32_t) dst->ne[2],
(uint32_t) dst->ne[3],
dim,
(uint32_t) src0->ne[dim]
};
std::vector<wgpu::BindGroupEntry> entries = {
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0),
ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1),
ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst),
};
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src0;
shader_lib_ctx.src1 = src1;
@@ -2339,8 +2317,52 @@ static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx,
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
auto * decisions = static_cast<ggml_webgpu_binary_shader_decisions *>(pipeline.context.get());
uint32_t offset_src0 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type));
uint32_t offset_src1 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type));
size_t merged_offset = 0;
size_t merged_size = 0;
if (decisions->src_overlap) {
const ggml_webgpu_merged_binding_range merged_range =
ggml_webgpu_tensor_merged_binding_range(ctx, { src0, src1 });
merged_offset = merged_range.offset;
merged_size = merged_range.size;
offset_src0 = ggml_webgpu_tensor_merged_element_offset(src0, merged_range);
offset_src1 = ggml_webgpu_tensor_merged_element_offset(src1, merged_range);
}
std::vector<uint32_t> params = { ne,
offset_src0,
offset_src1,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
(uint32_t) dst->ne[0],
(uint32_t) dst->ne[1],
(uint32_t) dst->ne[2],
(uint32_t) dst->ne[3],
dim,
(uint32_t) src0->ne[dim] };
std::vector<wgpu::BindGroupEntry> entries = {};
if (decisions->src_overlap) {
entries.push_back(
ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset, merged_size));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
} else {
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst));
}
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
}
@@ -2673,8 +2695,10 @@ static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * s
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
}
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
uint32_t wg_x, wg_y;
uint32_t total_wg = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}
static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx,
@@ -3751,7 +3775,8 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) {
static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
// we use the maximum workgroup size for the memset pipeline
size_t max_threads = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
size_t max_threads = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup *
ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
// Size the bytes_per_thread so that the largest buffer size can be handled
ctx->capabilities.memset_bytes_per_thread =
CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads);

View File

@@ -130,10 +130,13 @@ fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
}
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x < params.ne) {
let src0_i = params.offset_src0 + src0_index(gid.x);
let src1_i = params.offset_src1 + src1_index(gid.x);
update(params.offset_dst + gid.x, src0_i, src1_i);
fn main(@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) num_wg: vec3<u32>) {
let threads_per_group = u32(WG_SIZE);
let i = gid.x + (num_wg.x * threads_per_group) * gid.y;
if (i < params.ne) {
let src0_i = params.offset_src0 + src0_index(i);
let src1_i = params.offset_src1 + src1_index(i);
update(params.offset_dst + i, src0_i, src1_i);
}
}

View File

@@ -31,6 +31,16 @@ struct Params {
#define DataType i32
#endif
#ifdef SRC_OVERLAP
@group(0) @binding(0)
var<storage, read_write> merged_src: array<DataType>;
@group(0) @binding(1)
var<storage, read_write> dst: array<DataType>;
@group(0) @binding(2)
var<uniform> params: Params;
#else
@group(0) @binding(0)
var<storage, read_write> src0: array<DataType>;
@@ -42,7 +52,7 @@ var<storage, read_write> dst: array<DataType>;
@group(0) @binding(3)
var<uniform> params: Params;
#endif
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
@@ -62,14 +72,22 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
ni[1] * params.stride_src0_1 +
ni[2] * params.stride_src0_2 +
ni[3] * params.stride_src0_3;
#ifdef SRC_OVERLAP
dst[params.offset_dst + gid.x] = merged_src[params.offset_src0 + src_i];
#else
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + src_i];
#endif
} else {
ni[params.dim] -= params.src0_nedim;
let src_i = ni[0] * params.stride_src1_0 +
ni[1] * params.stride_src1_1 +
ni[2] * params.stride_src1_2 +
ni[3] * params.stride_src1_3;
#ifdef SRC_OVERLAP
dst[params.offset_dst + gid.x] = merged_src[params.offset_src1 + src_i];
#else
dst[params.offset_dst + gid.x] = src1[params.offset_src1 + src_i];
#endif
}
}
}

View File

@@ -43,12 +43,14 @@ struct Params {
var<storage, read_write> src: array<f32>;
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
fn main(
@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) num_wg: vec3<u32>) {
let threads_per_group = u32(WG_SIZE);
var i = gid.x + (num_wg.x * threads_per_group) * gid.y;
if (i >= params.ne) {
return;
}
var i = gid.x;
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
i = i % (params.ne2 * params.ne1 * params.ne0);
let i2 = i / (params.ne1 * params.ne0);

View File

@@ -66,11 +66,14 @@ fn erf_approx(x: TYPE) -> TYPE {
}
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
fn main(@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) num_wg: vec3<u32>) {
let threads_per_group = u32(WG_SIZE);
let flat_i = gid.x + (num_wg.x * threads_per_group) * gid.y;
if (flat_i >= params.ne) {
return;
}
var i = gid.x;
var i = flat_i;
let ne2 = params.ne2;
#ifdef DIAG
let ne1 = params.ne0;
@@ -205,6 +208,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
#ifdef INPLACE
src[params.offset_src + src_idx] = res;
#else
dst[params.offset_dst + gid.x] = res;
dst[params.offset_dst + flat_i] = res;
#endif
}