Compare commits

...

1 Commits

Author SHA1 Message Date
Georgi Gerganov aa27b85ecf metal : optimize pad 2026-05-19 20:17:14 +03:00
3 changed files with 15 additions and 10 deletions
+4 -2
View File
@@ -4039,14 +4039,16 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op);
const int nth = std::min(1024, ne0);
const int nth_max = MIN(1024, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
const int nth = MIN(args.ne0, nth_max);
const int nk0 = (args.ne0 + nth - 1)/nth;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne1, ne2, ne3, nth, 1, 1);
return 1;
}
+6 -4
View File
@@ -5336,10 +5336,10 @@ kernel void kernel_pad_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t k0 = tgpig.x/args.ne1;
const int64_t i1 = tgpig.x - k0*args.ne1;
const int64_t i03 = i3;
const int64_t i02 = i2;
@@ -5348,8 +5348,10 @@ kernel void kernel_pad_f32(
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
const int64_t i0 = k0*ntg.x + tpitg.x;
if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
if (i0 < args.ne0) {
if (i0 < args.ne00) {
dst_ptr[i0] = src0_ptr[i0];
} else {
@@ -5360,7 +5362,7 @@ kernel void kernel_pad_f32(
return;
}
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
if (i0 < args.ne0) {
dst_ptr[i0] = 0.0f;
}
}
+5 -4
View File
@@ -562,13 +562,14 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn(
}
const int64_t D = S_v * S_v * H_v;
const int64_t K = (int64_t) cparams.n_rs_seq + 1;
const int64_t K = cparams.n_rs_seq + 1;
// TODO: remove pad + simplify
ggml_tensor * state_in_3d = ggml_reshape_3d(ctx0, s, D, 1, n_seqs);
ggml_tensor * state_3d = ggml_pad(ctx0, state_in_3d, 0, K - 1, 0, 0);
ggml_tensor * s_4d = ggml_reshape_4d(ctx0, s, S_v, S_v*H_v, 1, n_seqs);
ggml_tensor * s_4d_pad = ggml_pad (ctx0, s_4d, 0, 0, K - 1, 0);
ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s_4d_pad, D, K, n_seqs);
ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, state_3d);
ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d);
if (n_seq_tokens > 1) {
cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il);
} else {