Compare commits

...

3 Commits

Author SHA1 Message Date
Jeff Bolz 92e854ab83 vulkan: Support GET_ROWS_BACK (#24883) 2026-06-23 15:39:37 +02:00
Jeff Bolz c5606364b2 vulkan: support CONV_3D (#24612)
* vulkan: support CONV_3D

This is a pretty direct port of conv2d_mm.comp to CONV_3D, done by codex
and cleaned up by me.

* disable slower perf tests
2026-06-23 15:39:20 +02:00
Jeff Bolz 0eb874d374 vulkan: make mul_mm ALIGNED a spec constant (#24689)
This trims down some of the shader variant explosion and reduces binary size.
2026-06-23 14:26:17 +02:00
8 changed files with 955 additions and 140 deletions
+319 -33
View File
@@ -493,6 +493,20 @@ struct vk_conv2d_pipeline_state {
}
};
struct vk_conv3d_pipeline_state {
vk_conv3d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t s2, uint32_t p0, uint32_t p1, uint32_t p2,
uint32_t d0, uint32_t d1, uint32_t d2, uint32_t KW, uint32_t KH, uint32_t KD, uint32_t aligned)
: s0(s0), s1(s1), s2(s2), p0(p0), p1(p1), p2(p2), d0(d0), d1(d1), d2(d2), KW(KW), KH(KH), KD(KD), aligned(aligned) {}
uint32_t s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD;
uint32_t aligned;
bool operator<(const vk_conv3d_pipeline_state &b) const {
return std::tie(s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD, aligned) <
std::tie(b.s0, b.s1, b.s2, b.p0, b.p1, b.p2, b.d0, b.d1, b.d2, b.KW, b.KH, b.KD, b.aligned);
}
};
struct vk_solve_tri_pipeline_state {
vk_solve_tri_pipeline_state(uint32_t N, uint32_t K)
: N(N), K(K) {}
@@ -777,6 +791,7 @@ struct vk_device_struct {
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
vk_pipeline pipeline_get_rows_back_f32;
vk_pipeline pipeline_acc_f32;
vk_pipeline pipeline_set_f32;
@@ -924,6 +939,8 @@ struct vk_device_struct {
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];
std::map<vk_conv3d_pipeline_state, vk_pipeline> pipeline_conv3d_f32[CONV_SHAPE_COUNT];
std::map<vk_conv3d_pipeline_state, vk_pipeline> pipeline_conv3d_f16_f32[CONV_SHAPE_COUNT];
vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
@@ -1669,6 +1686,41 @@ template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
}
struct vk_op_conv3d_push_constants {
uint32_t OC;
uint32_t IC;
uint32_t N;
uint32_t IW;
uint32_t IH;
uint32_t ID;
uint32_t OW;
uint32_t OH;
uint32_t OD;
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb1;
uint32_t nb2;
uint32_t nb3;
uint32_t OWmp; uint32_t OWL;
uint32_t OWOHmp; uint32_t OWOHL;
uint32_t OWOHODmp; uint32_t OWOHODL;
};
template <> void init_pushconst_fastdiv(vk_op_conv3d_push_constants &p) {
init_fastdiv_values(p.OW, p.OWmp, p.OWL);
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
init_fastdiv_values(p.OW*p.OH*p.OD, p.OWOHODmp, p.OWOHODL);
}
struct vk_op_conv2d_dw_push_constants {
uint32_t ne;
uint32_t batches;
@@ -4074,19 +4126,35 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
}
#endif
auto const &ggml_vk_mul_mm_spec = [](std::vector<uint32_t> spec, bool aligned) {
spec.push_back(aligned ? 1u : 0u);
return spec;
};
const int mul_mat_id_param_count = 5;
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (device->coopmat2) {
auto const &ggml_vk_mul_mm_cm2_spec = [](std::vector<uint32_t> spec, bool aligned, bool mul_mat_id) {
if (mul_mat_id && spec.size() > 5) {
spec.insert(spec.begin() + 5, aligned ? 1u : 0u);
} else {
spec.push_back(aligned ? 1u : 0u);
}
if (mul_mat_id && spec.size() == 6) {
spec.push_back(32);
}
return spec;
};
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(l_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(m_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(s_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(l_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), l_align, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(m_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), m_align, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(s_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), s_align, true); \
// Create 2 variants, {f16,f32} accumulator
#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
@@ -4161,17 +4229,17 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, true); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, true); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, true); \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, true); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, true); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, true); \
// Create 2 variants, {f16,f32} accumulator
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
@@ -4284,32 +4352,32 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
// Selects dot2 SPIR-V variant at runtime when device->dot2_f16 is true
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
// bf16 scalar path promotes to f32, no dot2 variant
#define CREATE_MM_NODOT2(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l_int[TYPE]) { \
@@ -4474,17 +4542,17 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l_int[TYPE]) \
@@ -4879,6 +4947,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_NVFP4], "get_rows_nvfp4_f32", get_rows_nvfp4_f32_len, get_rows_nvfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_back_f32, "get_rows_back_f32", get_rows_back_f32_len, get_rows_back_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {256, 1, 1}, {}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
@@ -5314,7 +5383,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
// conv2d, conv_transpose_2d
// conv2d, conv_transpose_2d, conv3d
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
// smaller WG for the small-tile fallback gives more concurrent WGs per SM
uint32_t conv2d_WG_SIZE = (s == CONV_SHAPE_64x32) ? 128 : 256;
@@ -5377,8 +5446,8 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
return (conv2d_BS.K * (conv2d_BS.CRS + pad) + conv2d_BS.CRS * (conv2d_BS.NPQ + pad) + csh_elems) * elem_size;
};
// coopmat1 needs to store the output through shared memory, so check up front
// whether it'll fit and disable it before applying coopmat1 parameters.
// 2D, transpose-2D, and 3D conv use the same KxCRS @ CRSxNPQ shmem
// layout. cm1 needs Csh for output, so check before applying cm1 params.
if (conv2d_use_cm1 && device->properties.limits.maxComputeSharedMemorySize < shmem_req(conv2d_cm1_shmem_pad, true, true)) {
conv2d_use_cm1 = false;
}
@@ -5470,6 +5539,53 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
}
#undef CREATE_CONV
#undef CREATE_CONVS
std::vector<uint32_t> conv3d_spec_constants = { conv2d_WG_SIZE, conv2d_BS.K, conv2d_BS.CRS, conv2d_BS.NPQ, conv2d_TS_K, conv2d_SHMEM_PAD };
#define CREATE_CONV3D(type_suffix, spv_suffix) \
for (auto &c : device->pipeline_conv3d##type_suffix[s]) { \
const vk_conv3d_pipeline_state &state = c.first; \
std::vector<uint32_t> spec_constants_cpy = conv3d_spec_constants; \
spec_constants_cpy.push_back(state.s0); \
spec_constants_cpy.push_back(state.s1); \
spec_constants_cpy.push_back(state.s2); \
spec_constants_cpy.push_back(state.p0); \
spec_constants_cpy.push_back(state.p1); \
spec_constants_cpy.push_back(state.p2); \
spec_constants_cpy.push_back(state.d0); \
spec_constants_cpy.push_back(state.d1); \
spec_constants_cpy.push_back(state.d2); \
spec_constants_cpy.push_back(state.KW); \
spec_constants_cpy.push_back(state.KH); \
spec_constants_cpy.push_back(state.KD); \
spec_constants_cpy.push_back(state.aligned); \
spec_constants_cpy.push_back(conv2d_csh_store); \
spec_constants_cpy.push_back(conv2d_WM); \
spec_constants_cpy.push_back(conv2d_WN); \
ggml_vk_create_pipeline( \
device, c.second, "conv3d" #type_suffix, \
conv3d##type_suffix##spv_suffix##_len, conv3d##type_suffix##spv_suffix##_data, "main", 3, \
sizeof(vk_op_conv3d_push_constants), wg_denoms, spec_constants_cpy, 1, true, conv2d_required_subgroup_size != 0, conv2d_required_subgroup_size); \
}
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (device->coopmat2) {
CREATE_CONV3D(_f32, _cm2)
CREATE_CONV3D(_f16_f32, _cm2)
} else
#endif
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (conv2d_use_cm1) {
CREATE_CONV3D(_f32, _cm1)
CREATE_CONV3D(_f16_f32, _cm1)
} else
#endif
if (conv2d_UNROLL) {
CREATE_CONV3D(_f32, _unroll)
CREATE_CONV3D(_f16_f32, _unroll)
} else {
CREATE_CONV3D(_f32, )
CREATE_CONV3D(_f16_f32, )
}
#undef CREATE_CONV3D
}
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
@@ -10294,6 +10410,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_get_rows_f32[src0->type];
}
return nullptr;
case GGML_OP_GET_ROWS_BACK:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_get_rows_back_f32;
}
return nullptr;
case GGML_OP_ACC:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_acc_f32;
@@ -10885,6 +11006,61 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
}
}
return nullptr;
case GGML_OP_CONV_3D:
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
const uint32_t OC = (uint32_t)ggml_get_op_params_i32(dst, 11);
const uint32_t IC = (uint32_t)ggml_get_op_params_i32(dst, 9);
const uint32_t N = (uint32_t)ggml_get_op_params_i32(dst, 10);
const uint32_t NPQ = N * dst->ne[2] * dst->ne[1] * dst->ne[0];
const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, OC, NPQ);
const uint32_t KW = (uint32_t)src0->ne[0];
const uint32_t KH = (uint32_t)src0->ne[1];
const uint32_t KD = (uint32_t)src0->ne[2];
const uint32_t s0 = (uint32_t)ggml_get_op_params_i32(dst, 0);
const uint32_t s1 = (uint32_t)ggml_get_op_params_i32(dst, 1);
const uint32_t s2 = (uint32_t)ggml_get_op_params_i32(dst, 2);
const uint32_t p0 = (uint32_t)ggml_get_op_params_i32(dst, 3);
const uint32_t p1 = (uint32_t)ggml_get_op_params_i32(dst, 4);
const uint32_t p2 = (uint32_t)ggml_get_op_params_i32(dst, 5);
const uint32_t d0 = (uint32_t)ggml_get_op_params_i32(dst, 6);
const uint32_t d1 = (uint32_t)ggml_get_op_params_i32(dst, 7);
const uint32_t d2 = (uint32_t)ggml_get_op_params_i32(dst, 8);
const uint32_t CRS = IC * KW * KH * KD;
const uint32_t BS_K = vk_conv_block_sizes[shape].K;
const uint32_t BS_CRS = vk_conv_block_sizes[shape].CRS;
const uint32_t BS_NPQ = vk_conv_block_sizes[shape].NPQ;
const uint32_t aligned = ((OC % BS_K == 0) &&
(CRS % BS_CRS == 0) &&
(NPQ % BS_NPQ == 0)) ? 1u : 0u;
vk_conv3d_pipeline_state conv3d_pipeline_state(s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD, aligned);
std::map<vk_conv3d_pipeline_state, vk_pipeline> *pipelines = nullptr;
if (src0->type == GGML_TYPE_F32) {
pipelines = &ctx->device->pipeline_conv3d_f32[shape];
} else if (src0->type == GGML_TYPE_F16) {
pipelines = &ctx->device->pipeline_conv3d_f16_f32[shape];
} else {
return nullptr;
}
vk_pipeline pipeline = nullptr;
{
std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
auto it = pipelines->find(conv3d_pipeline_state);
if (it != pipelines->end()) {
pipeline = it->second;
} else {
(*pipelines)[conv3d_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
}
}
return pipeline;
}
return nullptr;
case GGML_OP_ADD1:
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_add1_f16_f16;
@@ -11135,6 +11311,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
break;
case GGML_OP_GET_ROWS_BACK:
elements = { (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], 1 };
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
break;
case GGML_OP_ARGSORT:
GGML_ASSERT(0);
break;
@@ -11220,6 +11400,21 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
GGML_ABORT("invalid push constant type for CONV_2D");
}
break;
case GGML_OP_CONV_3D:
if constexpr (std::is_same_v<PC, vk_op_conv3d_push_constants>) {
const uint32_t NPQ = pc.N * pc.OD * pc.OH * pc.OW;
const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, pc.OC, NPQ);
const uint32_t NPQ_blocks = CEIL_DIV(NPQ, vk_conv_block_sizes[shape].NPQ);
elements = { pc.OC, NPQ_blocks, 1 };
if (elements[1] > 512) {
elements[2] = CEIL_DIV(elements[1], 512);
elements[1] = 512;
}
} else {
GGML_ABORT("invalid push constant type for CONV_3D");
}
break;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_DIV:
@@ -11380,6 +11575,21 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
});
}
static void ggml_vk_get_rows_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GET_ROWS_BACK, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2], (uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
0.0f, 0.0f, 0,
});
}
static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
@@ -13118,6 +13328,51 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx,
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, std::move(p));
}
static void ggml_vk_conv_3d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
GGML_ASSERT(nb10 == sizeof(float));
GGML_ASSERT(nb0 == sizeof(float));
vk_op_conv3d_push_constants p{};
p.IC = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 9));
p.N = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 10));
p.OC = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 11));
GGML_ASSERT(src0->ne[3] == (int64_t)p.IC * p.OC);
GGML_ASSERT(src1->ne[3] == (int64_t)p.IC * p.N);
GGML_ASSERT(dst->ne[3] == (int64_t)p.OC * p.N);
p.IW = static_cast<uint32_t>(ne10);
p.IH = static_cast<uint32_t>(ne11);
p.ID = static_cast<uint32_t>(ne12);
p.OW = static_cast<uint32_t>(ne0);
p.OH = static_cast<uint32_t>(ne1);
p.OD = static_cast<uint32_t>(ne2);
// the shader clamps src addresses to p.IC * p.N * p.IW * p.IH * p.ID - 1 in uint32, so the
// total input element count must fit in a uint32.
GGML_ASSERT((uint64_t)p.IC * p.N * p.IW * p.IH * p.ID <= 0xFFFFFFFFull);
p.nb01 = static_cast<uint32_t>(nb01 / nb00);
p.nb02 = static_cast<uint32_t>(nb02 / nb00);
p.nb03 = static_cast<uint32_t>(nb03 / nb00);
p.nb11 = static_cast<uint32_t>(nb11 / nb10);
p.nb12 = static_cast<uint32_t>(nb12 / nb10);
p.nb13 = static_cast<uint32_t>(nb13 / nb10);
p.nb1 = static_cast<uint32_t>(nb1 / nb0);
p.nb2 = static_cast<uint32_t>(nb2 / nb0);
p.nb3 = static_cast<uint32_t>(nb3 / nb0);
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_3D, std::move(p));
}
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
vk_op_conv2d_dw_push_constants p{};
p.ne = ggml_nelements(dst);
@@ -14247,6 +14502,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_GET_ROWS:
ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_GET_ROWS_BACK:
ggml_vk_get_rows_back(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_ADD:
if (ctx->num_additional_fused_ops) {
@@ -14515,6 +14774,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_CONV_TRANSPOSE_2D:
ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_CONV_3D:
ggml_vk_conv_3d(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_CONV_2D_DW:
ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node);
@@ -16964,6 +17227,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return false;
}
}
case GGML_OP_GET_ROWS_BACK:
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SET_ROWS:
{
switch (op->type) {
@@ -17285,6 +17550,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
ggml_is_contiguous(op->src[1]) &&
ggml_is_contiguous(op));
}
case GGML_OP_CONV_3D:
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
op->src[1]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32 &&
ggml_is_contiguous(op->src[0]) &&
ggml_is_contiguous(op->src[1]) &&
ggml_is_contiguous(op);
default:
return false;
}
@@ -18128,6 +18400,20 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
const int32_t d0 = tensor->op_params[4];
const int32_t d1 = tensor->op_params[5];
tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
} else if (tensor->op == GGML_OP_CONV_3D) {
const int32_t s0 = tensor->op_params[0];
const int32_t s1 = tensor->op_params[1];
const int32_t s2 = tensor->op_params[2];
const int32_t p0 = tensor->op_params[3];
const int32_t p1 = tensor->op_params[4];
const int32_t p2 = tensor->op_params[5];
const int32_t d0 = tensor->op_params[6];
const int32_t d1 = tensor->op_params[7];
const int32_t d2 = tensor->op_params[8];
const int32_t IC = tensor->op_params[9];
const int32_t N = tensor->op_params[10];
const int32_t OC = tensor->op_params[11];
tensor_clone = ggml_conv_3d_direct(ggml_ctx, src_clone[0], src_clone[1], s0, s1, s2, p0, p1, p2, d0, d1, d2, IC, N, OC);
} else if (tensor->op == GGML_OP_CONV_2D_DW) {
const int32_t s0 = tensor->op_params[0];
const int32_t s1 = tensor->op_params[1];
@@ -0,0 +1,431 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#ifdef COOPMAT2
#extension GL_NV_cooperative_matrix2 : enable
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_KHR_memory_scope_semantics : enable
#endif
#ifdef COOPMAT
#extension GL_KHR_cooperative_matrix : enable
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_KHR_memory_scope_semantics : enable
#endif
#include "types.glsl"
// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
layout(binding = 0) readonly buffer A {
A_TYPE knl_data[];
}; // src0 - kernel: [KW, KH, KD, IC*OC]
layout(binding = 1) readonly buffer B {
B_TYPE src_data[];
}; // src1 - input: [IW, IH, ID, IC*N] -- channel_first format
layout(binding = 2) writeonly buffer D {
D_TYPE dst_data[];
}; // dst - result: [OW, OH, OD, OC*N]
layout(push_constant) uniform parameter {
// I/O channels, batch size
uint32_t OC;
uint32_t IC;
uint32_t N;
// Tensor spatial sizes: input, output
uint32_t IW;
uint32_t IH;
uint32_t ID;
uint32_t OW;
uint32_t OH;
uint32_t OD;
// Strides in elements
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb1;
uint32_t nb2;
uint32_t nb3;
// fastdiv helper values
uint32_t OWmp; uint32_t OWL;
uint32_t OWOHmp; uint32_t OWOHL;
uint32_t OWOHODmp; uint32_t OWOHODL;
}
p;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
// Blocktile sizes
layout(constant_id = 1) const uint BS_K = 128;
layout(constant_id = 2) const uint BS_CRS = 16;
layout(constant_id = 3) const uint BS_NPQ = 128;
// Thread-tile sizes
layout(constant_id = 4) const uint TS_K = 8;
layout(constant_id = 5) const uint SHMEM_PAD = 4;
// Stride, padding, dilation
layout(constant_id = 6) const uint s0 = 1;
layout(constant_id = 7) const uint s1 = 1;
layout(constant_id = 8) const uint s2 = 1;
layout(constant_id = 9) const uint p0 = 0;
layout(constant_id = 10) const uint p1 = 0;
layout(constant_id = 11) const uint p2 = 0;
layout(constant_id = 12) const uint d0 = 1;
layout(constant_id = 13) const uint d1 = 1;
layout(constant_id = 14) const uint d2 = 1;
// Kernel spatial sizes
layout(constant_id = 15) const uint KW = 1;
layout(constant_id = 16) const uint KH = 1;
layout(constant_id = 17) const uint KD = 1;
// when set, skip bounds checks and address clamps (K/CRS/NPQ are tile-aligned)
layout(constant_id = 18) const uint aligned = 0;
// stage cm2 result through shmem (Csh) for coalesced stores. cm1 always does this.
layout(constant_id = 19) const uint csh_store = 0;
#ifdef COOPMAT
// cm1 subgroup tile: each subgroup computes a WM x WN region as a grid of
// TM x TN x TK fragments. Requires WM%TM == WN%TN == BS_K%WM == BS_NPQ%WN ==
// BS_CRS%TK == 0, and WG_SIZE == (BS_K/WM) * (BS_NPQ/WN) * subgroup_size.
layout(constant_id = 20) const uint WM = 32;
layout(constant_id = 21) const uint WN = 32;
const uint TM = 16;
const uint TN = 16;
const uint TK = 16;
const uint cms_per_row = WM / TM;
const uint cms_per_col = WN / TN;
const uint warps_M = BS_K / WM;
const uint warps_N = BS_NPQ / WN;
#endif
// without padding, ID_idx/IH_idx/IW_idx are in bounds by construction
const bool dhw_in_bounds = (p0 == 0) && (p1 == 0) && (p2 == 0);
uint32_t tid = gl_LocalInvocationID.x;
const uint32_t WG_SIZE = gl_WorkGroupSize.x;
uint splitWork(uint work_size, uint block_size) {
return (block_size + work_size - 1) / block_size;
}
uint32_t K = p.OC;
uint32_t CRS = p.IC * KD * KH * KW;
uint32_t NPQ = p.N * p.OD * p.OH * p.OW;
// Number of blocktiles per input
uint32_t NB_CRS = splitWork(CRS, BS_CRS);
#if defined(COOPMAT2) || defined(COOPMAT)
#define SHMEM_TYPE float16_t
#else
#define SHMEM_TYPE float
#endif
const uint32_t Ash_stride = BS_CRS + SHMEM_PAD;
const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD;
const uint32_t Ash_len = BS_K * Ash_stride;
const uint32_t Bsh_len = BS_CRS * Bsh_stride;
shared SHMEM_TYPE Ash[Ash_len]; // K x CRS
shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ
#if defined(COOPMAT2) || defined(COOPMAT)
// stage matC through shmem so global stores are row-major (NPQ-contiguous)
const uint32_t Csh_stride = BS_NPQ;
#ifdef COOPMAT
const uint32_t Csh_len = BS_K * Csh_stride;
#else
const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 1;
#endif
shared SHMEM_TYPE Csh[Csh_len]; // K x NPQ
#endif
// Threadtile sizes
const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
// Number of threadtiles per blocktile
const uint32_t NT_NPQ = BS_NPQ / TS_NPQ;
/*
Compute
KxCRS @ CRSxNPQ = K x NPQ
K=OC
C=IC
D,R,S=KD,KH,KW
Z,P,Q=OD,OH,OW
*/
uint32_t B_idx_K = gl_WorkGroupID.x;
uint32_t B_idx_NPQ = gl_WorkGroupID.y + gl_WorkGroupID.z * 512;
uint32_t T_y = tid / NT_NPQ;
uint32_t T_x = tid % NT_NPQ;
uint32_t Ar = tid / BS_CRS;
uint32_t Ac = tid % BS_CRS;
const uint32_t ArpWg = WG_SIZE / BS_CRS;
uint32_t Br = tid / BS_NPQ;
uint32_t Bc = tid % BS_NPQ;
const uint32_t BrpWg = WG_SIZE / BS_NPQ;
// see init_fastdiv_values in ggml-vulkan.cpp
uint fastdiv(uint n, uint mp, uint L) {
uint msbs, lsbs;
// msbs = mulhi(n, mp)
umulExtended(n, mp, msbs, lsbs);
return (msbs + n) >> L;
}
void split_crs(uint32_t crs_idx, out uint32_t ic, out uint32_t kd, out uint32_t kh, out uint32_t kw) {
const uint32_t KHKW = KH * KW;
const uint32_t KDKHKW = KD * KHKW;
ic = crs_idx / KDKHKW;
uint32_t rem = crs_idx - ic * KDKHKW;
kd = rem / KHKW;
rem = rem - kd * KHKW;
kh = rem / KW;
kw = rem - kh * KW;
}
void split_npq(uint32_t npq_idx, out uint32_t n, out uint32_t od, out uint32_t oh, out uint32_t ow) {
const uint32_t OWOH = p.OW * p.OH;
n = fastdiv(npq_idx, p.OWOHODmp, p.OWOHODL);
uint32_t rem = npq_idx - n * p.OD * OWOH;
od = fastdiv(rem, p.OWOHmp, p.OWOHL);
rem = rem - od * OWOH;
oh = fastdiv(rem, p.OWmp, p.OWL);
ow = rem - oh * p.OW;
}
#ifdef COOPMAT2
#define ACC_TYPE float16_t
ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem)
{
uint32_t K_idx = B_idx_K * BS_K + r;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c;
uint32_t N_idx;
uint32_t OD_idx;
uint32_t OH_idx;
uint32_t OW_idx;
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
dst_data[dst_idx] = D_TYPE(elem);
}
return elem;
}
#endif
void main() {
if (B_idx_NPQ * BS_NPQ >= NPQ) {
return;
}
#ifdef COOPMAT2
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator> matC;
matC = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator>(0.0);
#elif defined(COOPMAT)
coopmat<float16_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
sums[i] = coopmat<float16_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0);
}
const uint warp_r = gl_SubgroupID / warps_N;
const uint warp_c = gl_SubgroupID % warps_N;
#else
float regC[TS_K][TS_NPQ];
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regC[T_ly][T_lx] = 0.0;
}
}
#endif
/* Advance block in CRS dim */
[[dont_unroll]] for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
uint32_t CRS_idx_a = B_idx_CRS * BS_CRS + Ac;
uint32_t IC_idx_a;
uint32_t KD_idx_a;
uint32_t KH_idx_a;
uint32_t KW_idx_a;
split_crs(CRS_idx_a, IC_idx_a, KD_idx_a, KH_idx_a, KW_idx_a);
/* Load kernel to A_block: (BS_K x BS_CRS)*/
UNROLL for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
uint32_t B_ly = r_offset + Ar;
uint32_t B_lx = Ac;
uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
uint32_t knl_idx = KW_idx_a + KH_idx_a * p.nb01 + KD_idx_a * p.nb02 + (K_idx * p.IC + IC_idx_a) * p.nb03;
if (aligned == 0) {
knl_idx = min(knl_idx, K * CRS - 1);
}
float val = knl_data[knl_idx];
if (aligned == 0 && (K_idx >= K || CRS_idx_a >= CRS)) {
val = 0.0;
}
Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val);
}
/* Load input to B_block: (BS_CRS x BS_NPQ) */
UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
uint32_t B_ly = r_offset + Br; /* Row index of B block */
uint32_t B_lx = Bc;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
uint32_t N_idx;
uint32_t OD_idx;
uint32_t OH_idx;
uint32_t OW_idx;
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
uint32_t CRS_idx_b = B_idx_CRS * BS_CRS + B_ly;
uint32_t IC_idx_b;
uint32_t KD_idx_b;
uint32_t KH_idx_b;
uint32_t KW_idx_b;
split_crs(CRS_idx_b, IC_idx_b, KD_idx_b, KH_idx_b, KW_idx_b);
uint32_t ID_idx = OD_idx * s2 + KD_idx_b * d2 - p2;
uint32_t IH_idx = OH_idx * s1 + KH_idx_b * d1 - p1;
uint32_t IW_idx = OW_idx * s0 + KW_idx_b * d0 - p0;
uint32_t src_idx = IW_idx + IH_idx * p.nb11 + ID_idx * p.nb12 + (N_idx * p.IC + IC_idx_b) * p.nb13;
// skip clamp when address can't go OOB
if (aligned == 0 || !dhw_in_bounds) {
src_idx = min(src_idx, p.IC * p.N * p.IW * p.IH * p.ID - 1);
}
float val = src_data[src_idx];
bool oob = false;
if (aligned == 0 && (CRS_idx_b >= CRS || NPQ_idx >= NPQ)) {
oob = true;
}
// also catches lower-bound underflow (idx wraps to 0x80000000+)
if (!dhw_in_bounds && (ID_idx >= p.ID || IH_idx >= p.IH || IW_idx >= p.IW)) {
oob = true;
}
if (oob) {
val = 0.0;
}
Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val);
}
barrier();
#ifdef COOPMAT2
coopmat<float16_t, gl_ScopeWorkgroup, BS_K, BS_CRS, gl_MatrixUseA> matA;
coopmat<float16_t, gl_ScopeWorkgroup, BS_CRS, BS_NPQ, gl_MatrixUseB> matB;
coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);
coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);
matC = coopMatMulAdd(matA, matB, matC);
#elif defined(COOPMAT)
// each subgroup multiplies its grid of fragments per TK-sized CRS chunk
[[unroll]] for (uint k_step = 0; k_step < BS_CRS / TK; k_step++) {
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a[cms_per_row];
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
const uint a_off = (warp_r * WM + cm_row * TM) * Ash_stride + k_step * TK;
coopMatLoad(cache_a[cm_row], Ash, a_off, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);
}
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
const uint b_off = k_step * TK * Bsh_stride + warp_c * WN + cm_col * TN;
coopMatLoad(cache_b, Bsh, b_off, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a[cm_row], cache_b, sums[cm_col * cms_per_row + cm_row]);
}
}
}
#else
if (T_y * TS_K < K) {
UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
float regA[TS_K];
float regB[TS_NPQ];
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
}
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
}
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
}
}
}
}
#endif
barrier();
}
/* Save C* */
#if defined(COOPMAT2) || defined(COOPMAT)
// stage matC into Csh, then write to dst with coalesced NPQ-contiguous stores
#ifdef COOPMAT
const bool use_staged_store = true;
#else
const bool use_staged_store = (csh_store != 0);
#endif
if (use_staged_store) {
#ifdef COOPMAT
// cm1: each subgroup stores its fragment grid into its Csh slot
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
const uint csh_off = (warp_r * WM + cm_row * TM) * Csh_stride + warp_c * WN + cm_col * TN;
coopMatStore(sums[cm_col * cms_per_row + cm_row], Csh, csh_off, Csh_stride, gl_CooperativeMatrixLayoutRowMajor);
}
}
#else
coopMatStore(matC, Csh, 0, Csh_stride, gl_CooperativeMatrixLayoutRowMajor);
#endif
barrier();
// cooperative shmem->global: WG threads spread across BS_NPQ (the
// contiguous direction of dst), each iter covers store_rows_per_iter K-rows
const uint32_t store_rows_per_iter = WG_SIZE / BS_NPQ;
const uint32_t store_iters = BS_K / store_rows_per_iter;
const uint32_t k_thread_offset = tid / BS_NPQ;
const uint32_t npq_thread = tid % BS_NPQ;
[[unroll]] for (uint32_t i = 0; i < store_iters; i++) {
uint32_t k_local = i * store_rows_per_iter + k_thread_offset;
uint32_t K_idx = B_idx_K * BS_K + k_local;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + npq_thread;
uint32_t N_idx;
uint32_t OD_idx;
uint32_t OH_idx;
uint32_t OW_idx;
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
dst_data[dst_idx] = D_TYPE(Csh[k_local * Csh_stride + npq_thread]);
}
}
}
#ifdef COOPMAT2
else {
coopMatPerElementNV(matC, matC, perElemOpStore);
}
#endif
#else
if (T_y * TS_K < K) {
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
uint32_t N_idx;
uint32_t OD_idx;
uint32_t OH_idx;
uint32_t OW_idx;
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
dst_data[dst_idx] = D_TYPE(regC[T_ly][T_lx]);
}
}
}
}
#endif
}
@@ -0,0 +1,25 @@
#version 450
#include "types.glsl"
#include "generic_binary_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint col = gl_GlobalInvocationID.x;
if (col >= p.ne20) {
return;
}
for (uint row = gl_GlobalInvocationID.y; row < p.ne21; row += gl_WorkGroupSize.y * gl_NumWorkGroups.y) {
float sum = 0.0f;
for (uint i = 0; i < p.ne10; ++i) {
if (data_b[get_boffset() + i*p.nb10] == int(row)) {
sum += data_a[get_aoffset() + i*p.nb01 + col*p.nb00];
}
}
data_d[get_doffset() + row*p.nb21 + col*p.nb20] = sum;
}
}
+31 -23
View File
@@ -38,17 +38,7 @@
#define LOAD_VEC_B 1
#endif
// Load 2 values at once without affecting index calculations through LOAD_VEC
#if (defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)) && !defined(ALIGNED)
#define LOAD_VEC_BATCH_A 2
#else
#define LOAD_VEC_BATCH_A 1
#endif
#if !defined(ALIGNED)
#define LOAD_VEC_BATCH_B 2
#else
#define LOAD_VEC_BATCH_B 1
#endif
layout (constant_id = 11) const uint ALIGNED = 0;
#if !defined(TO_FLOAT_TYPE)
#define TO_FLOAT_TYPE FLOAT_TYPE
@@ -57,6 +47,13 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(DATA_A_F32)
layout (binding = 0) readonly buffer A_SCALAR {float data_a_scalar[];};
#elif defined(DATA_A_F16)
layout (binding = 0) readonly buffer A_SCALAR {float16_t data_a_scalar[];};
#elif defined(DATA_A_BF16)
layout (binding = 0) readonly buffer A_SCALAR {uint16_t data_a_scalar[];};
#endif
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
@@ -65,6 +62,7 @@ layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32
#endif
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 1) readonly buffer B_SCALAR {B_TYPE_SCALAR data_b_scalar[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID
@@ -194,13 +192,23 @@ void main() {
const uint warp_r = warp_i % (BM / WM);
const uint warp_c = warp_i / (BM / WM);
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
const uint LOAD_VEC_A_EFF = (ALIGNED != 0) ? LOAD_VEC_A : 1;
const uint LOAD_VEC_BATCH_A = (ALIGNED != 0) ? 1 : 2;
#else
const uint LOAD_VEC_A_EFF = LOAD_VEC_A;
const uint LOAD_VEC_BATCH_A = 1;
#endif
const uint LOAD_VEC_B_EFF = (ALIGNED != 0) ? LOAD_VEC_B : 1;
const uint LOAD_VEC_BATCH_B = (ALIGNED != 0) ? 1 : 2;
const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A * LOAD_VEC_BATCH_A / BK;
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B * LOAD_VEC_BATCH_B / BK;
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A_EFF / LOAD_VEC_BATCH_A);
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A_EFF / LOAD_VEC_BATCH_A);
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B_EFF / LOAD_VEC_BATCH_B);
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B_EFF / LOAD_VEC_BATCH_B);
const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A_EFF * LOAD_VEC_BATCH_A / BK;
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B_EFF * LOAD_VEC_BATCH_B / BK;
#ifdef MUL_MAT_ID
#ifdef MUL_MAT_ID_USE_SUBGROUPS
@@ -239,15 +247,15 @@ void main() {
uint pos_a =
#ifdef MUL_MAT_ID
expert_idx * (p.batch_stride_a / LOAD_VEC_A) +
expert_idx * (p.batch_stride_a / LOAD_VEC_A_EFF) +
#else
batch_idx_a * (p.batch_stride_a / LOAD_VEC_A) +
batch_idx_a * (p.batch_stride_a / LOAD_VEC_A_EFF) +
#endif
(ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
(ir * BM * p.stride_a + start_k) / LOAD_VEC_A_EFF;
#ifdef MUL_MAT_ID
uint pos_b = 0;
#else
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B_EFF;
#endif
#ifdef COOPMAT
@@ -287,8 +295,8 @@ void main() {
barrier();
pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B;
pos_a += BK / LOAD_VEC_A_EFF;
pos_b += BK / LOAD_VEC_B_EFF;
#ifdef COOPMAT
[[unroll]] for (uint i = 0; i < BK; i += TK) {
@@ -36,6 +36,7 @@ layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working wit
layout (constant_id = 4) const bool enable_smaller_matrices = false;
const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN;
const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN;
layout (constant_id = 5) const uint ALIGNED = 0;
layout (push_constant) uniform parameter
{
@@ -111,7 +112,7 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
};
uint _ne1;
layout (constant_id = 5) const uint subgroup_size = 32;
layout (constant_id = 6) const uint subgroup_size = 32;
shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size];
B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
@@ -297,12 +298,12 @@ void main() {
// Hint to the compiler that values are aligned (want 16B alignment).
// Quants are always block-aligned, no alignment needed.
#if ALIGNED
if (ALIGNED != 0) {
#if QUANT_K == 1
stride_a &= ~7;
#endif
stride_b &= ~7;
stride_a &= ~7;
#endif
stride_b &= ~7;
}
// Create layouts for both clamped and unclamped accesses
tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2);
@@ -1,50 +1,57 @@
void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uint idx_m, const uint block, const uint end_k) {
#if defined(DATA_A_F32) || defined(DATA_A_F16)
#if LOAD_VEC_A == 8
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]);
buf_a[buf_idx ] = aa[0].xy;
buf_a[buf_idx + 1] = aa[0].zw;
buf_a[buf_idx + 2] = aa[1].xy;
buf_a[buf_idx + 3] = aa[1].zw;
if (ALIGNED != 0) {
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]);
buf_a[buf_idx ] = aa[0].xy;
buf_a[buf_idx + 1] = aa[0].zw;
buf_a[buf_idx + 2] = aa[1].xy;
buf_a[buf_idx + 3] = aa[1].zw;
return;
}
#elif LOAD_VEC_A == 4
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]);
buf_a[buf_idx ] = aa.xy;
buf_a[buf_idx + 1] = aa.zw;
#else // LOAD_VEC_BATCH_A == 2
if (ALIGNED != 0) {
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]);
buf_a[buf_idx ] = aa.xy;
buf_a[buf_idx + 1] = aa.zw;
return;
}
#endif
const uint idx = pos_a + col * p.stride_a + row * 2;
const uint buf_idx = col * SHMEM_STRIDE + row;
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx],
data_a[idx + 1]);
buf_a[buf_idx] = FLOAT_TYPEV2(data_a_scalar[idx],
data_a_scalar[idx + 1]);
} else if (idx_m < p.M && block + row * 2 < end_k) {
buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx], 0.0f);
buf_a[buf_idx] = FLOAT_TYPEV2(data_a_scalar[idx], 0.0f);
} else {
buf_a[buf_idx] = FLOAT_TYPEV2(0.0f);
}
#endif
#elif defined(DATA_A_BF16)
#if LOAD_VEC_A == 4
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx]));
buf_a[buf_idx ] = aa.xy;
buf_a[buf_idx + 1] = aa.zw;
#else // LOAD_VEC_BATCH_A == 2
if (ALIGNED != 0) {
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx]));
buf_a[buf_idx ] = aa.xy;
buf_a[buf_idx + 1] = aa.zw;
return;
}
#endif
const uint idx = pos_a + col * p.stride_a + row * 2;
const uint buf_idx = col * SHMEM_STRIDE + row;
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]),
TO_FLOAT_TYPE(data_a[idx + 1]));
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a_scalar[idx]),
TO_FLOAT_TYPE(data_a_scalar[idx + 1]));
} else if (idx_m < p.M && block + row * 2 < end_k) {
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]), 0.0f);
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a_scalar[idx]), 0.0f);
} else {
buf_a[buf_idx] = FLOAT_TYPEV2(0.0f);
}
#endif
#elif defined(DATA_A_Q4_0)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
@@ -526,75 +533,85 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
#if !defined(MUL_MAT_ID)
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint idx_n, const uint block, const uint end_k) {
#if LOAD_VEC_B == 8
// Not supported for b_type bf16 because bf16mat2x4 does not exist
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
buf_b[buf_idx + 0] = bb[0].xy;
buf_b[buf_idx + 1] = bb[0].zw;
buf_b[buf_idx + 2] = bb[1].xy;
buf_b[buf_idx + 3] = bb[1].zw;
if (ALIGNED != 0) {
// Not supported for b_type bf16 because bf16mat2x4 does not exist
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
buf_b[buf_idx + 0] = bb[0].xy;
buf_b[buf_idx + 1] = bb[0].zw;
buf_b[buf_idx + 2] = bb[1].xy;
buf_b[buf_idx + 3] = bb[1].zw;
return;
}
#elif LOAD_VEC_B == 4
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
if (ALIGNED != 0) {
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
#if defined(DATA_B_BF16)
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
#else
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
#endif
buf_b[buf_idx + 0] = bb.xy;
buf_b[buf_idx + 1] = bb.zw;
return;
}
#endif
buf_b[buf_idx + 0] = bb.xy;
buf_b[buf_idx + 1] = bb.zw;
#else // LOAD_VEC_BATCH_B == 2
const uint idx = pos_b + col * p.stride_b + row * 2;
const uint buf_idx = col * SHMEM_STRIDE + row;
if (idx_n < p.N && block + row * 2 + 1 < end_k) {
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]),
TO_FLOAT_TYPE(data_b[idx + 1]));
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]),
TO_FLOAT_TYPE(data_b_scalar[idx + 1]));
} else if (idx_n < p.N && block + row * 2 < end_k) {
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), 0.0f);
} else {
buf_b[buf_idx] = FLOAT_TYPEV2(0.0f);
}
#endif
}
#else
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint ic, const uint _ne1, const uint block, const uint end_k) {
#if LOAD_VEC_B == 8
// Not supported for b_type bf16 because bf16mat2x4 does not exist
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
buf_b[buf_idx + 0] = bb[0].xy;
buf_b[buf_idx + 1] = bb[0].zw;
buf_b[buf_idx + 2] = bb[1].xy;
buf_b[buf_idx + 3] = bb[1].zw;
if (ALIGNED != 0) {
// Not supported for b_type bf16 because bf16mat2x4 does not exist
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
buf_b[buf_idx + 0] = bb[0].xy;
buf_b[buf_idx + 1] = bb[0].zw;
buf_b[buf_idx + 2] = bb[1].xy;
buf_b[buf_idx + 3] = bb[1].zw;
return;
}
#elif LOAD_VEC_B == 4
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
if (ALIGNED != 0) {
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
#if defined(DATA_B_BF16)
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
#else
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
#endif
buf_b[buf_idx + 0] = bb.xy;
buf_b[buf_idx + 1] = bb.zw;
return;
}
#endif
buf_b[buf_idx + 0] = bb.xy;
buf_b[buf_idx + 1] = bb.zw;
#else // LOAD_VEC_BATCH_B == 2
const uint row_i = ic * BN + col;
const uint buf_idx = col * SHMEM_STRIDE + row;
if (row_i < _ne1 && block + row * 2 + 1 < end_k) {
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]),
TO_FLOAT_TYPE(data_b[idx + 1]));
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]),
TO_FLOAT_TYPE(data_b_scalar[idx + 1]));
} else if (row_i < _ne1 && block + row * 2 < end_k) {
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), 0.0f);
} else {
buf_b[buf_idx] = FLOAT_TYPEV2(0.0f);
}
#endif
}
#endif
@@ -539,11 +539,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
};
// Shaders with f16 B_TYPE
string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
// bf16
{
@@ -565,8 +563,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
#endif
{
if (!dot2) {
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPE_SCALAR", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
}
}
}
@@ -583,8 +580,6 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
}
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
// For unaligned, load one at a time for f32/f16, or two at a time for quants
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant;
// For aligned matmul loads
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
@@ -597,13 +592,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
// don't generate f32 variants for coopmat2
if (!coopmat2) {
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPE_SCALAR", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
}
if (tname != "f16" && tname != "f32") {
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
}
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
@@ -850,6 +843,7 @@ void process_shaders() {
string_to_spv("repeat_i32", "repeat.comp", {{"A_TYPE", "int32_t"}, {"D_TYPE", "int32_t"}});
string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("get_rows_back_f32", "get_rows_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}});
string_to_spv("repeat_i16", "repeat.comp", {{"A_TYPE", "int16_t"}, {"D_TYPE", "int16_t"}});
@@ -1060,6 +1054,31 @@ void process_shaders() {
}
}
for (auto unroll : {false, true}) {
for (auto a_f16 : {false, true}) {
std::map<std::string, std::string> defines = {
{"A_TYPE", a_f16 ? "float16_t" : "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"},
{"UNROLL", unroll ? "[[unroll]]" : ""},
};
std::string name = std::string("conv3d") + (a_f16 ? "_f16" : "") + "_f32";
string_to_spv(name + (unroll ? "_unroll" : ""), "conv3d_mm.comp", defines);
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (unroll) {
auto cm2_defines = defines;
cm2_defines["COOPMAT2"] = "1";
string_to_spv(name, "conv3d_mm.comp", cm2_defines, true, false, true);
}
#endif
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (unroll) {
auto cm1_defines = defines;
cm1_defines["COOPMAT"] = "1";
string_to_spv(name, "conv3d_mm.comp", cm1_defines, true, true, false);
}
#endif
}
}
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
string_to_spv("conv2d_dw_whcn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
+28
View File
@@ -9272,6 +9272,34 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
}
}
struct conv3d_perf_case {
int N, IC, ID, IH, IW, OC, KD, KH, KW, s0, s1, s2, p0, p1, p2, d0, d1, d2;
};
const std::vector<conv3d_perf_case> conv3d_cases = {
{1, 320, 8, 38, 26, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
{1, 1280, 8, 38, 26, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
{1, 320, 8, 76, 52, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
{1, 1280, 8, 76, 52, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
{1, 320, 8, 152, 104, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
#if 0
// too slow on some devices
{1, 1280, 8, 152, 104, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
{1, 320, 4, 304, 208, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
{1, 640, 4, 304, 208, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
#endif
};
for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
for (const conv3d_perf_case & c : conv3d_cases) {
test_cases.emplace_back(new test_conv_3d(
c.N, c.IC, c.ID, c.IH, c.IW,
c.OC, c.KD, c.KH, c.KW,
c.s0, c.s1, c.s2, c.p0, c.p1, c.p2, c.d0, c.d1, c.d2,
kernel_type));
}
}
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1}));
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));