Compare commits

..

5 Commits

Author SHA1 Message Date
Michael Wand
fc2b0053ff ggml-cuda: Repost of 21896: Blackwell native NVFP4 support (#22196) 2026-04-29 06:47:42 +08:00
lnigam
7b8443ac78 ggml-cuda: add flash-attn support for DKQ=320/DV=256 with ncols2=32 (… (#22286)
* ggml-cuda: add flash-attn support for DKQ=320/DV=256 with ncols2=32 (GQA=32)

Adds MMA-f16 and tile kernel configs, dispatch logic, template instances,
and tile .cu file for Mistral Small 4 (head sizes 320/256), restricting to
ncols2=32 to support GQA ratio 32 only.

* Adding check to return BEST_FATTN_KERNEL_NONE in case GQA!=32

* Apply suggestions from code review

Address review comments

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Address review comments and making kernel config default to DQK=512, DV=512 instead of DQK=256,DV=256

* Fixed bug with sinks=1, with ncols=32, there are two warp-groups created but sinks index is same(0,...,15) for both the groups hence with sinks=1, output is not matching with CPU output. Added sink_base which will be base index for each warp_group (threadIdx.y / np)

* Apply suggestions from code review

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Update ggml/src/ggml-cuda/template-instances/generate_cu_files.py

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2026-04-28 21:37:35 +02:00
Daniel Bevenius
5d56effdee convert : add support for Nemotron Nano 3 Omni (#22481)
This commit adds support for NVIDIA Nemotron Nano 3 Omni model enabling
this model to be converted to GGUF.
2026-04-28 19:17:57 +02:00
Jillis ter Hove
52e5f0a5c1 common : re-arm reasoning budget after DONE on new <think> (#22323)
DONE state absorbs all tokens including a new start tag, causing any think blocks after the first to run unbudgeted. Observed on unsloth/Qwen3.6-27B-GGUF which interleaves multiple <think> blocks per response.

Fixed by advancing start_matcher in DONE branch and re-arming to COUNTING with a fresh budget on match. Adds regression test (test-reasoning-budget: test 6).
2026-04-28 19:15:36 +02:00
Matt Corallo
f9f33654a6 vulkan: Coalesce Q4_K/Q5_K scale loads (#21751)
Some SPIR-V compilers (notably mesa) don't handle the current
vulkan Q4_K/Q5_K scale load pattern in mul_mat particularly well.
While reading three `u8`s from the 12-byte scale array should (at
least on some hardware) result in loading the full 12 bytes in a
single LOAD followed by whatever extraction is needed, at least
the ANV Intel driver really can't practically perform this
optimization.

`mesa`'s unsigned upper bound logic doesn't handle tracking bounds
through ternary, resulting in the `(is < 4) ? ... : is - 4` having
an infinite upper bound (as it cannot prove `is - 4` doesn't
underflow). While this could still be rectified if mesa looked at
the array bounds, it currently doesn't and `glslc` currently emits
SPIR-V that doesn't allow for this optimization anyway (though
maybe it will at some point, see
https://github.com/KhronosGroup/glslang/issues/4206).

In mul_mat_vecq we took a different approach to loading the same
fields. We read the first two bytes we needed from `scale` then
took a branch before deciding whether we needed to read a third
byte. In mesa this did, indeed, lead to a top-level branch with
conditional loads. As such these loads ended up not being
coalesced either (at least in the ANV driver) resulting in
additional instructions in our hot loop.

Instead, here, we go ahead and force loading the full 12 bytes and
extract the bits we need from the packed-u32s instead. In mul_mat
there's a few less ternaries and only one extra shift, so even on
drivers that did optimize the previous loads properly the only
material change should be pulling a few extra bytes into registers
(which on most hardware won't cost anything anyway, though
ironically on Intel it theoretically could). In mul_mat_vecq this
requires a bit of extra math and may read bytes from the u32 that
weren't needed, but it seems likely avoiding the branch is a win
on most platforms.

On Intel Xe2/mesa 26.0.4 with the optimizations from
https://gitlab.freedesktop.org/mesa/mesa/-/work_items/15162,

for shader matmul_id_subgroup_q4_k_f32_f16acc_aligned_l:
 * Instruction Count: 2753 -> 2688
 * SEND Count: 269 -> 261
 * Cycle Count: 273976 -> 266138
 * Max live registers: 248 -> 246
 * Non SSA regs after NIR: 381 -> 382

for shader matmul_id_subgroup_q5_k_f32_f16acc_aligned_l:
 * Instruction Count: 2767 -> 2702
 * SEND Count: 271 -> 263
 * Cycle Count: 274140 -> 268144
 * Max live registers: 248 -> 246
 * Non SSA regs after NIR: 381 -> 382

for shader mul_mat_vec_id_q4_k_q8_1_f32:
 * Instruction Count: 1930 -> 1646
 * SEND Count: 116 -> 71
 * Cycle Count: 1348306 -> 843350
 * Max live registers: 78 -> 84
 * Non SSA regs after NIR: 300 -> 135

for shader mul_mat_vec_id_q5_k_q8_1_f32:
 * Instruction Count: 2207 -> 1922
 * SEND Count: 131 -> 86
 * Cycle Count: 1392012 -> 1037836
 * Max live registers: 90 -> 90
 * Non SSA regs after NIR: 300 -> 135

for shader mul_mat_vec_q4_k_q8_1_f32:
 * Instruction Count: 2029 -> 1749
 * SEND Count: 111 -> 66
 * Cycle Count: 1347278 -> 840118
 * Max live registers: 74 -> 80
 * Non SSA regs after NIR: 299 -> 134

for shader mul_mat_vec_q5_k_q8_1_f32:
 * Instruction Count: 2307 -> 2022
 * SEND Count: 126 -> 81
 * Cycle Count: 1379820 -> 954042
 * Max live registers: 86 -> 86
 * Non SSA regs after NIR: 299 -> 134

On one Arc Pro B60, unsloth/Qwen3.5-35B-A3B-GGUF:UD-Q4_K_XL:
 * pp512: 907.34 ± 9.28 -> 941.94 ± 10.53 (+4%)
 * pp2048: 897.95 ± 1.82 -> 931.55 ± 1.79 (+4%)
 * tg128: 49.49 ± 0.02 -> 49.86 ± 0.05 (+ <1%)

On one Arc Pro B60, unsloth/Qwen3.5-27B-GGUF:Q4_K_S:
 * pp512: 324.13 ± 10.52 -> 354.33 ± 6.81 (+9%)
 * pp2048: 329.80 ± 0.25 -> 357.10 ± 0.06 (+8%)
 * tg128: 17.11 ± 0.01 -> 18.11 ± 0.01 (+6%)

On four Arc Pro B60s, unsloth/Qwen3.5-122B-A10B-GGUF:Q5_K_S with
-sm layer (note that -sm tensor improvements will naturally be
less):
 * pp512: 264.55 ± 2.81 -> 280.45 ± 3.94 (+6%)
 * pp2048: 319.32 ± 2.72 -> 335.70 ± 3.48 (+5%)
 * tg128: 26.39 ± 0.01 -> 26.67 ± 0.01 (+1%)
2026-04-28 17:31:04 +02:00
21 changed files with 512 additions and 180 deletions

View File

@@ -122,6 +122,20 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
}
break;
case REASONING_BUDGET_DONE:
// Re-arm on a new start tag: some models emit multiple <think> blocks
// per response, and each should get a fresh budget window.
if (ctx->start_matcher.advance(token)) {
ctx->state = REASONING_BUDGET_COUNTING;
ctx->remaining = ctx->budget;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: re-activated on new start tag, budget=%d tokens\n", ctx->budget);
if (ctx->remaining <= 0) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
}
}
break;
}
}

View File

@@ -728,6 +728,9 @@ class ModelBase:
del experts, merged
def _needs_nvfp4_processing(self) -> bool:
return True
def prepare_tensors(self):
# detect NVFP4 quantization (ModelOpt format)
quant_algo = (self.hparams.get("quantization_config") or {}).get("quant_algo")
@@ -758,7 +761,7 @@ class ModelBase:
# NVFP4 weights are repacked and written directly to gguf_writer.
# This must run before dequant_model so NVFP4 tensors are removed
# from model_tensors, leaving only non-NVFP4 (e.g. FP8) for dequant.
if self._is_nvfp4:
if self._is_nvfp4 and self._needs_nvfp4_processing():
self._generate_nvfp4_tensors()
self.dequant_model()
@@ -2190,6 +2193,10 @@ class MmprojModel(ModelBase):
# merge configs
self.preprocessor_config = {**self.preprocessor_config, **cfg}
def _needs_nvfp4_processing(self) -> bool:
# nvfp4 quantization applies to the text model only.
return False
def get_vision_config(self) -> dict[str, Any] | None:
config_name = "vision_config" if not self.is_mistral_format else "vision_encoder"
return self.global_config.get(config_name)
@@ -4450,6 +4457,12 @@ class NemotronNanoV2VLModel(MmprojModel):
}
return vision_config
def dequant_model(self):
if self._is_nvfp4:
# Skip nvfp4 quantization for vision/audio model.
return
super().dequant_model()
def set_gguf_parameters(self):
if "image_mean" not in self.preprocessor_config:
self.preprocessor_config["image_mean"] = [0.485, 0.456, 0.406]
@@ -4473,6 +4486,10 @@ class NemotronNanoV2VLModel(MmprojModel):
if "input_conditioner" in name:
return
# mtmd does not support video yet so skip tensors related to video.
if "radio_model.model.patch_generator.video_embedder" in name:
return
# RADIO's pos_embed doesn't have .weight suffix, but clip.cpp expects it
if "patch_generator.pos_embed" in name:
if not name.endswith(".weight"):
@@ -10820,7 +10837,11 @@ class NemotronHModel(GraniteHybridModel):
# uses self.model_arch to build the tensor name map, and all MoE-specific
# mappings would be missed if it were called with the default non-MoE arch.
hparams = ModelBase.load_hparams(args[0], self.is_mistral_format)
if "num_experts_per_tok" in hparams:
has_moe_params = (
"num_experts_per_tok" in hparams
or (isinstance(hparams.get("llm_config"), dict) and "num_experts_per_tok" in hparams["llm_config"])
)
if has_moe_params:
self.model_arch = gguf.MODEL_ARCH.NEMOTRON_H_MOE
self.is_moe = True
@@ -10967,6 +10988,11 @@ class NemotronHModel(GraniteHybridModel):
if name.startswith(("vision_model.", "mlp1.")):
return
if name.startswith(("sound_encoder.")):
return
if name.startswith(("sound_projection.")):
return
# Strip language_model. prefix for VLM models (e.g., Nemotron Nano 12B v2 VL)
if name.startswith("language_model."):
name = name[len("language_model."):]

View File

@@ -830,6 +830,18 @@ static __device__ __forceinline__ float ggml_cuda_ue4m3_to_fp32(uint8_t x) {
#endif // defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000
}
static __device__ __forceinline__ uint8_t ggml_cuda_fp32_to_ue4m3(float x) {
#if defined(BLACKWELL_MMA_AVAILABLE) // This is used for NVFP4 subblock scale quantizations only
if (!(x > 0.0f)) {
return 0;
}
const __nv_fp8_e4m3 xf(x);
return xf.__x;
#else
NO_DEVICE_CODE; // Used only for NVFP4 Scales for Activations, only for Blackwell
#endif // defined(BLACKWELL_MMA_AVAILABLE)
}
__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {
const uint8_t sign_bit = (x < 0.0f) << 3;
float ax = fabsf(x) * e;

View File

@@ -66,6 +66,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
@@ -85,6 +88,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
@@ -118,6 +124,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 64, 160, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 128, 2, 64, 160, 128, 64, 2, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 128, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false);
@@ -1217,7 +1226,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
float KQ_max_scale[cols_per_thread];
#pragma unroll
for (int col = 0; col < cols_per_thread; ++col) {
const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col);
const int jc = (threadIdx.y/np)*cols_per_warp + (cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col));
const float sink = sinks_f[jc % ncols2];
const float KQ_max_new = fmaxf(KQ_max[col], sink);
@@ -1825,6 +1834,10 @@ extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
// Mistral Small 4 (DKQ=320, DV=256), GQA=32-only build:
extern DECL_FATTN_MMA_F16_CASE(320, 256, 1, 32);
extern DECL_FATTN_MMA_F16_CASE(320, 256, 2, 32);
// For GLM 4.7 Flash
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);

View File

@@ -38,6 +38,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst);
} break;
case 320: {
GGML_ASSERT(V->ne[0] == 256);
ggml_cuda_flash_attn_ext_tile_case<320, 256>(ctx, dst);
} break;
case 512: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case<512, 512>(ctx, dst);

View File

@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
@@ -128,6 +130,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 32, 64)
@@ -195,6 +199,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 512, 1, 128, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
@@ -264,6 +270,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 128, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 4, 64, 64)
@@ -1144,14 +1152,16 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm
}
}
if (Q->ne[1] > 8/ncols2) {
constexpr int cols_per_block = 16;
const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
launch_fattn<DV, cols_per_block/ncols2, ncols2>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
return;
if constexpr (ncols2 <= 16) {
if (Q->ne[1] > 8/ncols2) {
constexpr int cols_per_block = 16;
const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
launch_fattn<DV, cols_per_block/ncols2, ncols2>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
return;
}
}
if constexpr (ncols2 <= 8) {
@@ -1210,6 +1220,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;
const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
if constexpr (DKQ == 320) { // Mistral Small 4
if (use_gqa_opt && gqa_ratio % 32 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 32, use_logit_softcap>(ctx, dst);
return;
}
GGML_ABORT("flash-attn tile (320/256): expected GQA ratio multiple of 32");
}
if constexpr (DKQ == 576) {
if (use_gqa_opt && gqa_ratio % 16 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
@@ -1221,7 +1239,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
}
}
if constexpr (DKQ <= 512) {
if constexpr (DKQ <= 512 && DKQ != 320) {
if (use_gqa_opt && gqa_ratio % 8 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
return;
@@ -1275,5 +1293,6 @@ extern DECL_FATTN_TILE_CASE( 96, 96);
extern DECL_FATTN_TILE_CASE(112, 112);
extern DECL_FATTN_TILE_CASE(128, 128);
extern DECL_FATTN_TILE_CASE(256, 256);
extern DECL_FATTN_TILE_CASE(320, 256);
extern DECL_FATTN_TILE_CASE(512, 512);
extern DECL_FATTN_TILE_CASE(576, 512);

View File

@@ -143,6 +143,22 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
GGML_ASSERT(V->ne[0] == 256);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
break;
case 320:
// For Mistral Small 4, go straight to the ncols1 switch (ncols2=32-only build).
GGML_ASSERT(V->ne[0] == 256);
{
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
const bool use_gqa_opt = mask && max_bias == 0.0f;
GGML_ASSERT(use_gqa_opt);
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
GGML_ASSERT(gqa_ratio % 32 == 0);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<320, 256, 32>(ctx, dst);
}
break;
case 512:
GGML_ASSERT(V->ne[0] == 512);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<512, 512>(ctx, dst);
@@ -352,6 +368,14 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
return BEST_FATTN_KERNEL_NONE;
}
break;
case 320:
if (V->ne[0] != 256 || !gqa_opt_applies) {
return BEST_FATTN_KERNEL_NONE;
}
if (gqa_ratio % 32 != 0) {
return BEST_FATTN_KERNEL_NONE;
}
break;
case 512:
if (V->ne[0] != K->ne[0]) {
return BEST_FATTN_KERNEL_NONE;

View File

@@ -1015,25 +1015,35 @@ namespace ggml_cuda_mma {
#endif // AMD_MFMA_AVAILABLE
}
static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
const tile<16, 8, int> & A,
const tile<8, 8, int> & B,
uint32_t a_scale,
uint32_t b_scale) {
template <ggml_type type>
static __device__ __forceinline__ void mma_block_scaled_fp4(tile<16, 8, float> & D,
const tile<16, 8, int> & A,
const tile<8, 8, int> & B,
uint32_t a_scale,
uint32_t b_scale) {
#ifdef BLACKWELL_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
float * Dxi = (float *) D.x;
asm volatile(
"mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
"%10, {0, 0}, %11, {0, 0};"
: "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
if constexpr (type == GGML_TYPE_MXFP4) {
asm volatile(
"mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
"%10, {0, 0}, %11, {0, 0};"
: "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
} else {
asm volatile(
"mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::4X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue4m3 "
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
"%10, {0, 0}, %11, {0, 0};"
: "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
}
#else
GGML_UNUSED_VARS(D, A, B, a_scale, b_scale);
#endif // BLACKWELL_MMA_AVAILABLE
#endif // BLACKWELL_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(

View File

@@ -122,7 +122,7 @@ void ggml_cuda_mul_mat_q(
|| GGML_CUDA_CC_IS_CDNA(cc);
// TODO: tighter pool buffer size vs q8 path
const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4;
const bool use_native_fp4 = blackwell_mma_available(cc) && (src0->type == GGML_TYPE_MXFP4 || src0->type == GGML_TYPE_NVFP4);
if (!ids) {
const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
@@ -133,9 +133,9 @@ void ggml_cuda_mul_mat_q(
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[3] / ts_src1;
if (use_native_mxfp4) {
if (use_native_fp4) {
static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1));
quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
quantize_mmq_fp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
ne11, ne12, ne13, stream);
} else {
@@ -146,10 +146,8 @@ void ggml_cuda_mul_mat_q(
}
// Stride depends on quantization format
const int64_t s12 = use_native_mxfp4 ?
ne11 * ne10_padded * sizeof(block_fp4_mmq) /
(8 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 256 values (8 blocks of 32)
:
const int64_t s12 = use_native_fp4 ?
ne11 * ne10_padded * sizeof(block_fp4_mmq) / (QK_K * sizeof(int)) : // block_fp4_mmq holds 256 values
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
const int64_t s13 = ne12*s12;
@@ -198,8 +196,8 @@ void ggml_cuda_mul_mat_q(
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[3] / ts_src1;
if (use_native_mxfp4) {
quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
if (use_native_fp4) {
quantize_mmq_fp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
} else {
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
@@ -208,8 +206,9 @@ void ggml_cuda_mul_mat_q(
CUDA_CHECK(cudaGetLastError());
}
const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) :
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
static_assert(QK_K == 8 * QK_MXFP4, "QK_K needs to be 8 * QK_MXFP4");
const int64_t s12 = use_native_fp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (QK_K * sizeof(int)) :
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
const int64_t s13 = ne12*s12;
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.

View File

@@ -10,9 +10,9 @@
using namespace ggml_cuda_mma;
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
#define MMQ_ITER_K 256
#define MMQ_ITER_K_MXFP4_FP4 512
#define MMQ_NWARPS 8
#define MMQ_ITER_K 256
#define MMQ_ITER_K_FP4 512
#define MMQ_NWARPS 8
typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00);
@@ -46,9 +46,12 @@ struct block_q8_1_mmq {
int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
};
// this struct is used for fp4 data types (currently only used for Blackwell)
// mxfp4 has block size 32, each int32 of d4 contains 2 e8m0 scales in the lower 16 bits
// nvfp4 has block size 16, each int32 of d4 contains 4 ue4m3 scales
struct block_fp4_mmq {
uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc.
int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values
uint32_t d4[4];
int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte)
};
static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
@@ -143,10 +146,11 @@ static int get_mmq_y_host(const int cc) {
static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) {
#if defined(BLACKWELL_MMA_AVAILABLE)
return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K;
#else
return MMQ_ITER_K;
if (type == GGML_TYPE_NVFP4 || type == GGML_TYPE_MXFP4) {
return MMQ_ITER_K_FP4;
}
#endif // defined(BLACKWELL_MMA_AVAILABLE)
return MMQ_ITER_K;
}
static constexpr __device__ int get_mmq_y_device() {
@@ -213,8 +217,8 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
}
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4
#define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4 and NVFP4 Blackwell
#define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4 Generic
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
@@ -240,7 +244,11 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
// tile sizes are the same for Q8_1 and FP4 for blackwell
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
#if defined(BLACKWELL_MMA_AVAILABLE)
case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_FP4;
#else
case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_NVFP4;
#endif // defined(BLACKWELL_MMA_AVAILABLE)
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
@@ -934,6 +942,128 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
}
}
#ifdef BLACKWELL_MMA_AVAILABLE
template <int mmq_y, bool need_check>
static __device__ __forceinline__ void load_tiles_nvfp4_nvfp4(const char * __restrict__ x,
int * __restrict__ x_tile,
const int kbx0,
const int i_max,
const int stride) {
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int iter_k = get_iter_k(GGML_TYPE_NVFP4);
constexpr int threads_per_row = iter_k / QK_NVFP4; // each thread processes 1 block
constexpr int rows_per_warp = warp_size / threads_per_row;
uint32_t * x_u32 = (uint32_t *) x_tile;
const int txi = threadIdx.x;
const int kbx = txi % threads_per_row;
const int row_in_warp = txi / threads_per_row;
const block_nvfp4 * bxi_base = (const block_nvfp4 *) x + kbx0 + kbx;
uint32_t * x_u32_scale = x_u32 + 64 + kbx;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
if constexpr (need_check) {
i = min(i, i_max);
}
const block_nvfp4 * bxi = bxi_base + i * stride;
const int row_base = i * MMQ_MMA_TILE_X_K_FP4;
const int q_base = row_base + 8 * kbx;
const uint32_t * src_qs = reinterpret_cast<const uint32_t *>(bxi->qs);
#pragma unroll
for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) {
x_u32[q_base + 2 * sub + 0] = src_qs[2 * sub + 0];
x_u32[q_base + 2 * sub + 1] = src_qs[2 * sub + 1];
}
x_u32_scale[row_base] = get_int_b4(bxi->d, 0);
}
}
// Shared MMA kernel for MXFP4 and NVFP4 on Blackwell.
// Both quantizations encode values as e2m1 (FP4) and produce one uint32 scale per
// m16n8k64 MMA call; only the PTX kind (scale_vec::2X ue8m0 vs scale_vec::4X ue4m3)
// and the per-type stride constant differ.
template <int mmq_x, int mmq_y, ggml_type type>
static __device__ __forceinline__ void vec_dot_fp4_fp4_mma(const int * __restrict__ x,
const int * __restrict__ y,
float * __restrict__ sum,
const int k00) {
static_assert(type == GGML_TYPE_MXFP4 || type == GGML_TYPE_NVFP4,
"vec_dot_fp4_fp4_mma: type must be MXFP4 or NVFP4");
typedef tile<16, 8, int> tile_A;
typedef tile<8, 8, int> tile_B;
typedef tile<16, 8, float> tile_C;
constexpr int stride = MMQ_MMA_TILE_X_K_FP4;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp / tile_C::I;
constexpr int nfrags = MMQ_TILE_NE_K / tile_A::J;
y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
const int * y_qs = (const int *) y + 4;
const uint32_t * y_sc = (const uint32_t *) y;
// 2 threads per quad supply the packed scale register to the block_scale MMA,
// see https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
const int tidx_A = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
const int tidx_B = threadIdx.x / 4;
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
tile_A A[ntx][nfrags];
uint32_t scaleA[ntx][nfrags];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
for (int frag = 0; frag < nfrags; ++frag) {
const int k0 = k00 + frag * tile_A::J;
load_ldmatrix(A[n][frag], x_qs + (i0 + n * tile_A::I) * stride + k0, stride);
scaleA[n][frag] = x_sc[(i0 + n * tile_A::I + tidx_A) * stride + k0 / tile_A::J];
}
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
tile_B B[nfrags];
uint32_t scaleB[nfrags];
#pragma unroll
for (int frag = 0; frag < nfrags; ++frag) {
const int k0 = frag * tile_B::J;
load_generic(B[frag], y_qs + j0 * MMQ_TILE_Y_K + k0, MMQ_TILE_Y_K);
scaleB[frag] = y_sc[(j0 + tidx_B) * MMQ_TILE_Y_K + frag];
}
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
for (int frag = 0; frag < nfrags; ++frag) {
tile_C C = {};
mma_block_scaled_fp4<type>(C, A[n][frag], B[frag], scaleA[n][frag], scaleB[frag]);
#pragma unroll
for (int l = 0; l < tile_C::ne; ++l) {
sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
}
}
}
}
}
#endif // BLACKWELL_MMA_AVAILABLE
template <int mmq_y, bool need_check>
static __device__ __forceinline__ void load_tiles_nvfp4(const char * __restrict__ x,
@@ -1163,77 +1293,6 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x,
const int * __restrict__ y,
float * __restrict__ sum,
const int k00) {
typedef tile<16, 8, int> tile_A;
typedef tile<8, 8, int> tile_B;
typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp.
y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K);
// Match layout from load_tiles_mxfp4_fp4
const int * x_qs = (const int *) x;
const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
const int * y_qs = (const int *) y + 4;
const uint32_t * y_sc = (const uint32_t *) y;
// tile_A has a length of 64 logical values vs. 32 values in block_mxfp4
tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
// Block scale
// Each thread has to point to a 4 byte scale value
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
const int k0 = k00 + k01;
load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
MMQ_MMA_TILE_X_K_FP4);
// based on block-scaling document, 2 threads in each quad need to supply to the scale value
const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
scaleA[n][k01 / (2 * QI_MXFP4)] =
*(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4));
}
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
#pragma unroll
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
tile_B B;
uint32_t scaleB; // 2xN scales
load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K);
scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
tile_C C;
mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB);
#pragma unroll
for (int l = 0; l < tile_C::ne; ++l) {
sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
}
}
}
}
}
template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
@@ -3259,7 +3318,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
#ifdef BLACKWELL_MMA_AVAILABLE
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma<mmq_x, mmq_y>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_fp4_fp4_mma<mmq_x, mmq_y, GGML_TYPE_MXFP4>;
#else
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
@@ -3270,8 +3329,13 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_NVFP4> {
static constexpr int vdr = VDR_NVFP4_Q8_1_MMQ;
#ifdef BLACKWELL_MMA_AVAILABLE
static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4_nvfp4<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_fp4_fp4_mma<mmq_x, mmq_y, GGML_TYPE_NVFP4>;
#else
static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
#endif // BLACKWELL_MMA_AVAILABLE
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
};
@@ -3406,7 +3470,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
#if defined(BLACKWELL_MMA_AVAILABLE)
// FP4 tile stores 8 blocks
constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1;
constexpr int ne_block = (type == GGML_TYPE_MXFP4 || type == GGML_TYPE_NVFP4) ? QK_K : 4 * QK8_1;
#else
constexpr int ne_block = 4 * QK8_1;
#endif // defined(BLACKWELL_MMA_AVAILABLE)

View File

@@ -115,6 +115,7 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_pascal_older(gg
case GGML_TYPE_IQ4_NL: return 6;
case GGML_TYPE_IQ4_XS: return 5;
case GGML_TYPE_MXFP4: return 4;
case GGML_TYPE_NVFP4: return 4;
case GGML_TYPE_Q2_K: return 4;
case GGML_TYPE_Q3_K: return 4;
case GGML_TYPE_Q4_0: return 6;
@@ -135,6 +136,7 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_turing_plus(ggm
case GGML_TYPE_IQ3_S: return 6;
case GGML_TYPE_IQ3_XXS: return 7;
case GGML_TYPE_MXFP4: return 7;
case GGML_TYPE_NVFP4: return 8;
case GGML_TYPE_Q2_K: return 7;
case GGML_TYPE_Q3_K: return 5;
default: return MMVQ_MAX_BATCH_SIZE;
@@ -221,6 +223,7 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna4(ggml_type
case GGML_TYPE_IQ4_NL: return 7;
case GGML_TYPE_IQ4_XS: return 5;
case GGML_TYPE_MXFP4: return 5;
case GGML_TYPE_NVFP4: return 5;
case GGML_TYPE_Q3_K: return 4;
case GGML_TYPE_Q4_0: return 7;
case GGML_TYPE_Q4_1: return 7;

View File

@@ -70,6 +70,102 @@ __device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) {
return static_cast<uint8_t>(biased);
}
static __global__ void quantize_mmq_nvfp4(
const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy,
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t ne0, const int64_t ne1, const int64_t ne2) {
#if defined(BLACKWELL_MMA_AVAILABLE)
const int64_t i0_base = ((int64_t) blockDim.x * blockIdx.y + threadIdx.x) * QK_NVFP4_SUB;
if (i0_base >= ne0) {
return;
}
const int64_t i1 = blockIdx.x;
const int64_t i2 = blockIdx.z % ne2;
const int64_t i3 = blockIdx.z / ne2;
const int64_t i01 = ids ? ids[i1] : i1;
const int64_t k_block = i0_base / QK_K;
const int64_t blocks_per_col = (ne0 + QK_K - 1) / QK_K;
if (k_block >= blocks_per_col) {
return;
}
const int64_t ib = blockIdx.z * ((int64_t) blocks_per_col * ne1) + k_block * ne1 + blockIdx.x;
block_fp4_mmq * y = (block_fp4_mmq *) vy;
block_fp4_mmq * yb = y + ib;
const int sub = (i0_base % QK_K) / QK_NVFP4_SUB;
float vals_raw[QK_NVFP4_SUB];
float amax_raw = 0.0f;
const int64_t base_idx = i3 * s03 + i2 * s02 + i01 * s01;
#pragma unroll
for (int k = 0; k < QK_NVFP4_SUB; k++) {
const int64_t i00 = i0_base + k;
if (i00 < ne00) {
const float v = x[base_idx + i00];
vals_raw[k] = v;
amax_raw = fmaxf(amax_raw, fabsf(v));
} else {
vals_raw[k] = 0.0f;
}
}
static constexpr int test_offsets[5] = { 0, -1, 1, -2, 2};
const int first_fp8_code = (int) ggml_cuda_fp32_to_ue4m3(amax_raw / 6.0f);
float best_err = FLT_MAX;
uint8_t fp8_code = 0;
float subblock_scale = 0.0f;
#pragma unroll // Check +/- 2 to find best code to reduce NVFP4 activation loss. Negligible overhead on Blackwell.
for (int i = 0; i < 5; i++) {
const int test_code = first_fp8_code + test_offsets[i];
if (test_code < 0 || test_code > 0x7e) {
continue;
}
const uint8_t code = (uint8_t) test_code;
const float test_scale = ggml_cuda_ue4m3_to_fp32(code);
const float test_inv_scale = test_scale > 0.0f ? 0.5f / test_scale : 0.0f;
float cur_err = 0.0f;
#pragma unroll
for (int k = 0; k < QK_NVFP4_SUB; ++k) {
const float v = vals_raw[k];
const uint8_t q = ggml_cuda_float_to_fp4_e2m1(v, test_inv_scale);
const float err_diff = fabsf(v) - fabsf(kvalues_mxfp4[q & 0x7]) * test_scale;
cur_err = fmaf(err_diff, err_diff, cur_err);
}
if (cur_err < best_err) {
best_err = cur_err;
fp8_code = test_code;
subblock_scale = test_scale;
}
}
const float inv_scale = subblock_scale > 0.0f ? 0.5f / subblock_scale : 0.0f;
uint32_t q0 = 0;
uint32_t q1 = 0;
#pragma unroll // this is faster than the previous __nv_fp4x4_e2m1
for (int k = 0; k < QK_NVFP4_SUB / 4; ++k) {
q0 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 0], inv_scale) << (8 * k);
q0 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 8], inv_scale) << (8 * k + 4);
q1 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 4], inv_scale) << (8 * k);
q1 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 12], inv_scale) << (8 * k + 4);
}
uint32_t * yqs = reinterpret_cast<uint32_t *>(yb->qs);
yqs[2 * sub + 0] = q0;
yqs[2 * sub + 1] = q1;
reinterpret_cast<uint8_t *>(yb->d4)[sub] = fp8_code;
#else
NO_DEVICE_CODE; // This is for Blackwell NVFP4 activations only.
#endif // defined(BLACKWELL_MMA_AVAILABLE)
}
// quantize values in the format mxfp4 is stored which is interleaved nibbles
// i.e. a block a0-a31 is represented as a0a16,a1a17 ...a15a31
static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
@@ -316,28 +412,32 @@ void quantize_mmq_q8_1_cuda(
}
}
void quantize_mmq_mxfp4_cuda(const float * x,
const int32_t * ids,
void * vy,
[[maybe_unused]] const ggml_type type_src0,
const int64_t ne00,
const int64_t s01,
const int64_t s02,
const int64_t s03,
const int64_t ne0,
const int64_t ne1,
const int64_t ne2,
const int64_t ne3,
cudaStream_t stream) {
GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0);
void quantize_mmq_fp4_cuda(
const float * x, const int32_t * ids, void * vy, const ggml_type type_src0,
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
GGML_ASSERT(type_src0 == GGML_TYPE_MXFP4 || type_src0 == GGML_TYPE_NVFP4);
GGML_ASSERT(ne0 > 0);
constexpr int nwarps = 8;
constexpr int vals_per_warp = 2 * QK_MXFP4;
constexpr int vals_per_block = nwarps * vals_per_warp;
if (type_src0 == GGML_TYPE_NVFP4) {
GGML_ASSERT(ne00 % QK_NVFP4 == 0);
constexpr int nvfp4_block_size = 128;
const int64_t block_num_y = (ne0 + QK_NVFP4_SUB * nvfp4_block_size - 1) / (QK_NVFP4_SUB * nvfp4_block_size);
const dim3 block_size(nvfp4_block_size, 1, 1);
const dim3 num_blocks(ne1, block_num_y, ne2 * ne3);
quantize_mmq_nvfp4<<<num_blocks, block_size, 0, stream>>>(
x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
} else {
GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0);
const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block;
const dim3 num_blocks(ne1, block_num_y, ne2 * ne3);
const dim3 block_size(WARP_SIZE, nwarps, 1);
constexpr int nwarps = 8;
constexpr int vals_per_warp = 2 * QK_MXFP4;
constexpr int vals_per_block = nwarps * vals_per_warp;
quantize_mmq_mxfp4<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block;
const dim3 num_blocks(ne1, block_num_y, ne2 * ne3);
const dim3 block_size(WARP_SIZE, nwarps, 1);
quantize_mmq_mxfp4<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
}
}

View File

@@ -26,7 +26,7 @@ void quantize_mmq_q8_1_cuda(
ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
void quantize_mmq_mxfp4_cuda(const float * x,
void quantize_mmq_fp4_cuda(const float * x,
const int32_t * ids,
void * vy,
ggml_type type_src0,

View File

@@ -2,4 +2,5 @@
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(320, 256, 1, 32);
DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32);

View File

@@ -2,4 +2,5 @@
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(320, 256, 2, 32);
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.cuh"
DECL_FATTN_TILE_CASE(320, 256);

View File

@@ -3,7 +3,7 @@
from glob import glob
import os
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 512, 576]
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 320, 512, 576]
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"]
@@ -62,7 +62,7 @@ for filename in glob("*.cu"):
os.remove(filename)
for head_size_kq in HEAD_SIZES_KQ:
head_size_v = head_size_kq if head_size_kq != 576 else 512
head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512)
with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f:
f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v))
@@ -84,13 +84,16 @@ for ncols in [8, 16, 32, 64]:
continue
if head_size_kq == 72:
continue
if head_size_kq == 512 and ncols2 not in (4, 8):
# Skip compilation of unused ncols2 values for niche head sizes:
if head_size_kq == 320 and ncols2 != 32: # Mistral Small 4
continue
if head_size_kq != 576 and ncols2 in (16, 32):
if head_size_kq == 512 and ncols2 not in (4, 8): # Gemma 4
continue
if head_size_kq == 576 and ncols2 not in (4, 16, 32):
if head_size_kq == 576 and ncols2 not in (4, 16, 32): # Deepseek, GLM 4.7 Flash
continue
head_size_v = head_size_kq if head_size_kq != 576 else 512
if head_size_kq not in (320, 576) and ncols2 in (16, 32):
continue
head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512)
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))
for type in TYPES_MMQ:

View File

@@ -296,13 +296,22 @@ vec2 get_dm_scale(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint is = iqs_k / 8;
u8vec2 scale_dm;
if (is < 4) {
scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);
} else {
scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),
(data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
}
const uvec3 scales = uvec3(data_a_packed32[ib_k].scales[0],
data_a_packed32[ib_k].scales[1],
data_a_packed32[ib_k].scales[2]);
const uint scalesoffs = (is & 3) * 8;
const uint scidx0 = (is < 4) ? 0 : 2;
const uint scidxshift0 = scalesoffs;
const uint scidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2;
const uint mbidx0 = (is < 4) ? 1 : 2;
const uint mbidxshift0 = (is < 4) ? scalesoffs : scalesoffs + 4;
const uint mbidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2;
const uint8_t sc = uint8_t(((scales[scidx0] >> scidxshift0) & 0xF) | ((scales[0] >> scidxshift1) & 0x30));
const uint8_t mbyte = uint8_t(((scales[mbidx0] >> mbidxshift0) & 0xF) | ((scales[1] >> mbidxshift1) & 0x30));
u8vec2 scale_dm = u8vec2(sc, mbyte);
return FLOAT_TYPEV2(data_a_packed32[ib_k].dm) * FLOAT_TYPEV2(scale_dm);
}

View File

@@ -201,19 +201,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const vec2 loadd = vec2(data_a[ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint scidxshift1 = (is < 4) ? 0 : 2;
const uint mbidx0 = is + 4;
const uint mbidx1 = (is < 4) ? is + 4 : is;
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
const uint mbidxshift0 = (is < 4) ? 0 : 4;
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint mbidxshift1 = (is < 4) ? 0 : 2;
const uvec3 scales = uvec3(data_a_packed32[ib].scales[0],
data_a_packed32[ib].scales[1],
data_a_packed32[ib].scales[2]);
const uint scalesoffs = (is & 3) * 8;
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const uint scidx0 = (is < 4) ? 0 : 2;
const uint scidxshift0 = scalesoffs;
const uint scidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2;
const uint mbidx0 = (is < 4) ? 1 : 2;
const uint mbidxshift0 = (is < 4) ? scalesoffs : scalesoffs + 4;
const uint mbidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2;
const uint8_t sc = uint8_t(((scales[scidx0] >> scidxshift0) & 0xF) | ((scales[0] >> scidxshift1) & 0x30));
const uint8_t mbyte = uint8_t(((scales[mbidx0] >> mbidxshift0) & 0xF) | ((scales[1] >> mbidxshift1) & 0x30));
const float d = loadd.x * sc;
const float m = -loadd.y * mbyte;
@@ -237,19 +238,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const vec2 loadd = vec2(data_a[ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint scidxshift1 = (is < 4) ? 0 : 2;
const uint mbidx0 = is + 4;
const uint mbidx1 = (is < 4) ? is + 4 : is;
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
const uint mbidxshift0 = (is < 4) ? 0 : 4;
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint mbidxshift1 = (is < 4) ? 0 : 2;
const uvec3 scales = uvec3(data_a_packed32[ib].scales[0],
data_a_packed32[ib].scales[1],
data_a_packed32[ib].scales[2]);
const uint scalesoffs = (is & 3) * 8;
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const uint scidx0 = (is < 4) ? 0 : 2;
const uint scidxshift0 = scalesoffs;
const uint scidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2;
const uint mbidx0 = (is < 4) ? 1 : 2;
const uint mbidxshift0 = (is < 4) ? scalesoffs : scalesoffs + 4;
const uint mbidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2;
const uint8_t sc = uint8_t(((scales[scidx0] >> scidxshift0) & 0xF) | ((scales[0] >> scidxshift1) & 0x30));
const uint8_t mbyte = uint8_t(((scales[mbidx0] >> mbidxshift0) & 0xF) | ((scales[1] >> mbidxshift1) & 0x30));
const float d = loadd.x * sc;
const float m = -loadd.y * mbyte;

View File

@@ -3815,7 +3815,7 @@ struct test_mul_mat : public test_case {
double max_nmse_err(ggml_backend_t backend) override {
// for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance
if (type_a == GGML_TYPE_MXFP4 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
if ((type_a == GGML_TYPE_MXFP4 || type_a == GGML_TYPE_NVFP4) && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
return 2e-2;
}
return max_nmse_err();
@@ -3951,7 +3951,7 @@ struct test_mul_mat_id : public test_case {
double max_nmse_err(ggml_backend_t backend) override {
// for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance
if (type_a == GGML_TYPE_MXFP4 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
if ((type_a == GGML_TYPE_MXFP4 || type_a == GGML_TYPE_NVFP4) && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
return 2e-2;
}
return max_nmse_err();

View File

@@ -227,7 +227,30 @@ int main(void) {
3); // forcing continues through i=3
}
printf("OK (5 tests passed)\n");
// Test 6: Multi-block thinking. First block ends naturally at i=2, second
// start tag at i=3 re-arms the budget, which then exhausts at i=5.
// Regression: before this fix, DONE absorbed all subsequent tokens and a
// second <think> block ran unbudgeted.
// Flow: i=0 accept(100)->COUNTING rem=2; i=1 accept(50)->rem=1;
// i=2 accept(101)->end_matcher matches, DONE;
// i=3 accept(100)->re-arm, COUNTING rem=2;
// i=4 accept(60)->rem=1; i=5 accept(61)->rem=0->FORCING;
// i=6 apply()->forces token[0]=102, accept(62)->force_pos=1, stay FORCING;
// i=7 apply()->forces token[1]=101, accept(63)->force_pos=2->DONE.
{
const std::vector<llama_token> start = {100};
const std::vector<llama_token> end = {101};
const std::vector<llama_token> forced = {102, 101};
const std::vector<llama_token> sequence = {100, 50, 101, 100, 60, 61, 62, 63};
test_reasoning_budget("multi-block re-arms budget after DONE", sequence, start, end, forced,
2, // budget of 2 tokens (per block)
REASONING_BUDGET_IDLE,
6, // forcing starts at i=6 (after second block exhausts at i=5)
7); // forcing continues through i=7
}
printf("OK (6 tests passed)\n");
printf("Testing UTF-8 boundary detection... ");
test_utf8_boundary_detection();