mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-04-29 06:14:19 +02:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fc2b0053ff | ||
|
|
7b8443ac78 | ||
|
|
5d56effdee |
@@ -728,6 +728,9 @@ class ModelBase:
|
||||
|
||||
del experts, merged
|
||||
|
||||
def _needs_nvfp4_processing(self) -> bool:
|
||||
return True
|
||||
|
||||
def prepare_tensors(self):
|
||||
# detect NVFP4 quantization (ModelOpt format)
|
||||
quant_algo = (self.hparams.get("quantization_config") or {}).get("quant_algo")
|
||||
@@ -758,7 +761,7 @@ class ModelBase:
|
||||
# NVFP4 weights are repacked and written directly to gguf_writer.
|
||||
# This must run before dequant_model so NVFP4 tensors are removed
|
||||
# from model_tensors, leaving only non-NVFP4 (e.g. FP8) for dequant.
|
||||
if self._is_nvfp4:
|
||||
if self._is_nvfp4 and self._needs_nvfp4_processing():
|
||||
self._generate_nvfp4_tensors()
|
||||
|
||||
self.dequant_model()
|
||||
@@ -2190,6 +2193,10 @@ class MmprojModel(ModelBase):
|
||||
# merge configs
|
||||
self.preprocessor_config = {**self.preprocessor_config, **cfg}
|
||||
|
||||
def _needs_nvfp4_processing(self) -> bool:
|
||||
# nvfp4 quantization applies to the text model only.
|
||||
return False
|
||||
|
||||
def get_vision_config(self) -> dict[str, Any] | None:
|
||||
config_name = "vision_config" if not self.is_mistral_format else "vision_encoder"
|
||||
return self.global_config.get(config_name)
|
||||
@@ -4450,6 +4457,12 @@ class NemotronNanoV2VLModel(MmprojModel):
|
||||
}
|
||||
return vision_config
|
||||
|
||||
def dequant_model(self):
|
||||
if self._is_nvfp4:
|
||||
# Skip nvfp4 quantization for vision/audio model.
|
||||
return
|
||||
super().dequant_model()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
if "image_mean" not in self.preprocessor_config:
|
||||
self.preprocessor_config["image_mean"] = [0.485, 0.456, 0.406]
|
||||
@@ -4473,6 +4486,10 @@ class NemotronNanoV2VLModel(MmprojModel):
|
||||
if "input_conditioner" in name:
|
||||
return
|
||||
|
||||
# mtmd does not support video yet so skip tensors related to video.
|
||||
if "radio_model.model.patch_generator.video_embedder" in name:
|
||||
return
|
||||
|
||||
# RADIO's pos_embed doesn't have .weight suffix, but clip.cpp expects it
|
||||
if "patch_generator.pos_embed" in name:
|
||||
if not name.endswith(".weight"):
|
||||
@@ -10820,7 +10837,11 @@ class NemotronHModel(GraniteHybridModel):
|
||||
# uses self.model_arch to build the tensor name map, and all MoE-specific
|
||||
# mappings would be missed if it were called with the default non-MoE arch.
|
||||
hparams = ModelBase.load_hparams(args[0], self.is_mistral_format)
|
||||
if "num_experts_per_tok" in hparams:
|
||||
has_moe_params = (
|
||||
"num_experts_per_tok" in hparams
|
||||
or (isinstance(hparams.get("llm_config"), dict) and "num_experts_per_tok" in hparams["llm_config"])
|
||||
)
|
||||
if has_moe_params:
|
||||
self.model_arch = gguf.MODEL_ARCH.NEMOTRON_H_MOE
|
||||
self.is_moe = True
|
||||
|
||||
@@ -10967,6 +10988,11 @@ class NemotronHModel(GraniteHybridModel):
|
||||
if name.startswith(("vision_model.", "mlp1.")):
|
||||
return
|
||||
|
||||
if name.startswith(("sound_encoder.")):
|
||||
return
|
||||
if name.startswith(("sound_projection.")):
|
||||
return
|
||||
|
||||
# Strip language_model. prefix for VLM models (e.g., Nemotron Nano 12B v2 VL)
|
||||
if name.startswith("language_model."):
|
||||
name = name[len("language_model."):]
|
||||
|
||||
@@ -830,6 +830,18 @@ static __device__ __forceinline__ float ggml_cuda_ue4m3_to_fp32(uint8_t x) {
|
||||
#endif // defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ uint8_t ggml_cuda_fp32_to_ue4m3(float x) {
|
||||
#if defined(BLACKWELL_MMA_AVAILABLE) // This is used for NVFP4 subblock scale quantizations only
|
||||
if (!(x > 0.0f)) {
|
||||
return 0;
|
||||
}
|
||||
const __nv_fp8_e4m3 xf(x);
|
||||
return xf.__x;
|
||||
#else
|
||||
NO_DEVICE_CODE; // Used only for NVFP4 Scales for Activations, only for Blackwell
|
||||
#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {
|
||||
const uint8_t sign_bit = (x < 0.0f) << 3;
|
||||
float ax = fabsf(x) * e;
|
||||
|
||||
@@ -66,6 +66,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
|
||||
@@ -85,6 +88,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
|
||||
@@ -118,6 +124,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 64, 160, 128, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 128, 2, 64, 160, 128, 64, 2, false);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 128, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false);
|
||||
@@ -1217,7 +1226,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
float KQ_max_scale[cols_per_thread];
|
||||
#pragma unroll
|
||||
for (int col = 0; col < cols_per_thread; ++col) {
|
||||
const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col);
|
||||
const int jc = (threadIdx.y/np)*cols_per_warp + (cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col));
|
||||
const float sink = sinks_f[jc % ncols2];
|
||||
|
||||
const float KQ_max_new = fmaxf(KQ_max[col], sink);
|
||||
@@ -1825,6 +1834,10 @@ extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
||||
|
||||
// Mistral Small 4 (DKQ=320, DV=256), GQA=32-only build:
|
||||
extern DECL_FATTN_MMA_F16_CASE(320, 256, 1, 32);
|
||||
extern DECL_FATTN_MMA_F16_CASE(320, 256, 2, 32);
|
||||
|
||||
// For GLM 4.7 Flash
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||
|
||||
@@ -38,6 +38,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst);
|
||||
} break;
|
||||
case 320: {
|
||||
GGML_ASSERT(V->ne[0] == 256);
|
||||
ggml_cuda_flash_attn_ext_tile_case<320, 256>(ctx, dst);
|
||||
} break;
|
||||
case 512: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case<512, 512>(ctx, dst);
|
||||
|
||||
@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
|
||||
@@ -128,6 +130,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 32, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 32, 64)
|
||||
@@ -195,6 +199,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 512, 1, 128, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
|
||||
@@ -264,6 +270,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 128, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 4, 64, 64)
|
||||
@@ -1144,14 +1152,16 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm
|
||||
}
|
||||
}
|
||||
|
||||
if (Q->ne[1] > 8/ncols2) {
|
||||
constexpr int cols_per_block = 16;
|
||||
const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
|
||||
launch_fattn<DV, cols_per_block/ncols2, ncols2>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
|
||||
return;
|
||||
if constexpr (ncols2 <= 16) {
|
||||
if (Q->ne[1] > 8/ncols2) {
|
||||
constexpr int cols_per_block = 16;
|
||||
const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
|
||||
launch_fattn<DV, cols_per_block/ncols2, ncols2>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (ncols2 <= 8) {
|
||||
@@ -1210,6 +1220,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
||||
const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;
|
||||
const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
if constexpr (DKQ == 320) { // Mistral Small 4
|
||||
if (use_gqa_opt && gqa_ratio % 32 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 32, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
GGML_ABORT("flash-attn tile (320/256): expected GQA ratio multiple of 32");
|
||||
}
|
||||
|
||||
if constexpr (DKQ == 576) {
|
||||
if (use_gqa_opt && gqa_ratio % 16 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
|
||||
@@ -1221,7 +1239,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (DKQ <= 512) {
|
||||
if constexpr (DKQ <= 512 && DKQ != 320) {
|
||||
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
@@ -1275,5 +1293,6 @@ extern DECL_FATTN_TILE_CASE( 96, 96);
|
||||
extern DECL_FATTN_TILE_CASE(112, 112);
|
||||
extern DECL_FATTN_TILE_CASE(128, 128);
|
||||
extern DECL_FATTN_TILE_CASE(256, 256);
|
||||
extern DECL_FATTN_TILE_CASE(320, 256);
|
||||
extern DECL_FATTN_TILE_CASE(512, 512);
|
||||
extern DECL_FATTN_TILE_CASE(576, 512);
|
||||
|
||||
@@ -143,6 +143,22 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
GGML_ASSERT(V->ne[0] == 256);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
|
||||
break;
|
||||
case 320:
|
||||
// For Mistral Small 4, go straight to the ncols1 switch (ncols2=32-only build).
|
||||
GGML_ASSERT(V->ne[0] == 256);
|
||||
{
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
|
||||
const bool use_gqa_opt = mask && max_bias == 0.0f;
|
||||
GGML_ASSERT(use_gqa_opt);
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
GGML_ASSERT(gqa_ratio % 32 == 0);
|
||||
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<320, 256, 32>(ctx, dst);
|
||||
}
|
||||
break;
|
||||
case 512:
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<512, 512>(ctx, dst);
|
||||
@@ -352,6 +368,14 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
case 320:
|
||||
if (V->ne[0] != 256 || !gqa_opt_applies) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
if (gqa_ratio % 32 != 0) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
case 512:
|
||||
if (V->ne[0] != K->ne[0]) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
|
||||
@@ -1015,25 +1015,35 @@ namespace ggml_cuda_mma {
|
||||
#endif // AMD_MFMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
|
||||
const tile<16, 8, int> & A,
|
||||
const tile<8, 8, int> & B,
|
||||
uint32_t a_scale,
|
||||
uint32_t b_scale) {
|
||||
template <ggml_type type>
|
||||
static __device__ __forceinline__ void mma_block_scaled_fp4(tile<16, 8, float> & D,
|
||||
const tile<16, 8, int> & A,
|
||||
const tile<8, 8, int> & B,
|
||||
uint32_t a_scale,
|
||||
uint32_t b_scale) {
|
||||
#ifdef BLACKWELL_MMA_AVAILABLE
|
||||
const int * Axi = (const int *) A.x;
|
||||
const int * Bxi = (const int *) B.x;
|
||||
float * Dxi = (float *) D.x;
|
||||
|
||||
asm volatile(
|
||||
"mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
|
||||
"%10, {0, 0}, %11, {0, 0};"
|
||||
: "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
|
||||
if constexpr (type == GGML_TYPE_MXFP4) {
|
||||
asm volatile(
|
||||
"mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
|
||||
"%10, {0, 0}, %11, {0, 0};"
|
||||
: "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
|
||||
} else {
|
||||
asm volatile(
|
||||
"mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::4X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue4m3 "
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
|
||||
"%10, {0, 0}, %11, {0, 0};"
|
||||
: "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B, a_scale, b_scale);
|
||||
#endif // BLACKWELL_MMA_AVAILABLE
|
||||
#endif // BLACKWELL_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
|
||||
@@ -122,7 +122,7 @@ void ggml_cuda_mul_mat_q(
|
||||
|| GGML_CUDA_CC_IS_CDNA(cc);
|
||||
|
||||
// TODO: tighter pool buffer size vs q8 path
|
||||
const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4;
|
||||
const bool use_native_fp4 = blackwell_mma_available(cc) && (src0->type == GGML_TYPE_MXFP4 || src0->type == GGML_TYPE_NVFP4);
|
||||
|
||||
if (!ids) {
|
||||
const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
|
||||
@@ -133,9 +133,9 @@ void ggml_cuda_mul_mat_q(
|
||||
const int64_t s11 = src1->nb[1] / ts_src1;
|
||||
const int64_t s12 = src1->nb[2] / ts_src1;
|
||||
const int64_t s13 = src1->nb[3] / ts_src1;
|
||||
if (use_native_mxfp4) {
|
||||
if (use_native_fp4) {
|
||||
static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1));
|
||||
quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
|
||||
quantize_mmq_fp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
|
||||
ne11, ne12, ne13, stream);
|
||||
|
||||
} else {
|
||||
@@ -146,10 +146,8 @@ void ggml_cuda_mul_mat_q(
|
||||
}
|
||||
|
||||
// Stride depends on quantization format
|
||||
const int64_t s12 = use_native_mxfp4 ?
|
||||
ne11 * ne10_padded * sizeof(block_fp4_mmq) /
|
||||
(8 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 256 values (8 blocks of 32)
|
||||
:
|
||||
const int64_t s12 = use_native_fp4 ?
|
||||
ne11 * ne10_padded * sizeof(block_fp4_mmq) / (QK_K * sizeof(int)) : // block_fp4_mmq holds 256 values
|
||||
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
|
||||
const int64_t s13 = ne12*s12;
|
||||
|
||||
@@ -198,8 +196,8 @@ void ggml_cuda_mul_mat_q(
|
||||
const int64_t s12 = src1->nb[2] / ts_src1;
|
||||
const int64_t s13 = src1->nb[3] / ts_src1;
|
||||
|
||||
if (use_native_mxfp4) {
|
||||
quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
|
||||
if (use_native_fp4) {
|
||||
quantize_mmq_fp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
|
||||
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
|
||||
} else {
|
||||
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
|
||||
@@ -208,8 +206,9 @@ void ggml_cuda_mul_mat_q(
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) :
|
||||
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
|
||||
static_assert(QK_K == 8 * QK_MXFP4, "QK_K needs to be 8 * QK_MXFP4");
|
||||
const int64_t s12 = use_native_fp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (QK_K * sizeof(int)) :
|
||||
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
|
||||
const int64_t s13 = ne12*s12;
|
||||
|
||||
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
|
||||
|
||||
@@ -10,9 +10,9 @@
|
||||
using namespace ggml_cuda_mma;
|
||||
|
||||
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
|
||||
#define MMQ_ITER_K 256
|
||||
#define MMQ_ITER_K_MXFP4_FP4 512
|
||||
#define MMQ_NWARPS 8
|
||||
#define MMQ_ITER_K 256
|
||||
#define MMQ_ITER_K_FP4 512
|
||||
#define MMQ_NWARPS 8
|
||||
|
||||
typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
|
||||
typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00);
|
||||
@@ -46,9 +46,12 @@ struct block_q8_1_mmq {
|
||||
int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
|
||||
};
|
||||
|
||||
// this struct is used for fp4 data types (currently only used for Blackwell)
|
||||
// mxfp4 has block size 32, each int32 of d4 contains 2 e8m0 scales in the lower 16 bits
|
||||
// nvfp4 has block size 16, each int32 of d4 contains 4 ue4m3 scales
|
||||
struct block_fp4_mmq {
|
||||
uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc.
|
||||
int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values
|
||||
uint32_t d4[4];
|
||||
int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte)
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
|
||||
@@ -143,10 +146,11 @@ static int get_mmq_y_host(const int cc) {
|
||||
|
||||
static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) {
|
||||
#if defined(BLACKWELL_MMA_AVAILABLE)
|
||||
return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K;
|
||||
#else
|
||||
return MMQ_ITER_K;
|
||||
if (type == GGML_TYPE_NVFP4 || type == GGML_TYPE_MXFP4) {
|
||||
return MMQ_ITER_K_FP4;
|
||||
}
|
||||
#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
||||
return MMQ_ITER_K;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_mmq_y_device() {
|
||||
@@ -213,8 +217,8 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
||||
}
|
||||
|
||||
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
||||
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4
|
||||
#define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4
|
||||
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4 and NVFP4 Blackwell
|
||||
#define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4 Generic
|
||||
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
||||
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
|
||||
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
|
||||
@@ -240,7 +244,11 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
||||
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
// tile sizes are the same for Q8_1 and FP4 for blackwell
|
||||
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
|
||||
#if defined(BLACKWELL_MMA_AVAILABLE)
|
||||
case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_FP4;
|
||||
#else
|
||||
case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_NVFP4;
|
||||
#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
||||
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
|
||||
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
|
||||
case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
|
||||
@@ -934,6 +942,128 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef BLACKWELL_MMA_AVAILABLE
|
||||
template <int mmq_y, bool need_check>
|
||||
static __device__ __forceinline__ void load_tiles_nvfp4_nvfp4(const char * __restrict__ x,
|
||||
int * __restrict__ x_tile,
|
||||
const int kbx0,
|
||||
const int i_max,
|
||||
const int stride) {
|
||||
constexpr int nwarps = mmq_get_nwarps_device();
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int iter_k = get_iter_k(GGML_TYPE_NVFP4);
|
||||
constexpr int threads_per_row = iter_k / QK_NVFP4; // each thread processes 1 block
|
||||
constexpr int rows_per_warp = warp_size / threads_per_row;
|
||||
|
||||
uint32_t * x_u32 = (uint32_t *) x_tile;
|
||||
|
||||
const int txi = threadIdx.x;
|
||||
const int kbx = txi % threads_per_row;
|
||||
const int row_in_warp = txi / threads_per_row;
|
||||
|
||||
const block_nvfp4 * bxi_base = (const block_nvfp4 *) x + kbx0 + kbx;
|
||||
uint32_t * x_u32_scale = x_u32 + 64 + kbx;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
|
||||
int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
|
||||
|
||||
if constexpr (need_check) {
|
||||
i = min(i, i_max);
|
||||
}
|
||||
|
||||
const block_nvfp4 * bxi = bxi_base + i * stride;
|
||||
const int row_base = i * MMQ_MMA_TILE_X_K_FP4;
|
||||
const int q_base = row_base + 8 * kbx;
|
||||
|
||||
const uint32_t * src_qs = reinterpret_cast<const uint32_t *>(bxi->qs);
|
||||
|
||||
#pragma unroll
|
||||
for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) {
|
||||
x_u32[q_base + 2 * sub + 0] = src_qs[2 * sub + 0];
|
||||
x_u32[q_base + 2 * sub + 1] = src_qs[2 * sub + 1];
|
||||
}
|
||||
|
||||
x_u32_scale[row_base] = get_int_b4(bxi->d, 0);
|
||||
}
|
||||
}
|
||||
|
||||
// Shared MMA kernel for MXFP4 and NVFP4 on Blackwell.
|
||||
// Both quantizations encode values as e2m1 (FP4) and produce one uint32 scale per
|
||||
// m16n8k64 MMA call; only the PTX kind (scale_vec::2X ue8m0 vs scale_vec::4X ue4m3)
|
||||
// and the per-type stride constant differ.
|
||||
template <int mmq_x, int mmq_y, ggml_type type>
|
||||
static __device__ __forceinline__ void vec_dot_fp4_fp4_mma(const int * __restrict__ x,
|
||||
const int * __restrict__ y,
|
||||
float * __restrict__ sum,
|
||||
const int k00) {
|
||||
static_assert(type == GGML_TYPE_MXFP4 || type == GGML_TYPE_NVFP4,
|
||||
"vec_dot_fp4_fp4_mma: type must be MXFP4 or NVFP4");
|
||||
|
||||
typedef tile<16, 8, int> tile_A;
|
||||
typedef tile<8, 8, int> tile_B;
|
||||
typedef tile<16, 8, float> tile_C;
|
||||
|
||||
constexpr int stride = MMQ_MMA_TILE_X_K_FP4;
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
constexpr int rows_per_warp = 2 * granularity;
|
||||
constexpr int ntx = rows_per_warp / tile_C::I;
|
||||
constexpr int nfrags = MMQ_TILE_NE_K / tile_A::J;
|
||||
|
||||
y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_K);
|
||||
|
||||
const int * x_qs = (const int *) x;
|
||||
const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const uint32_t * y_sc = (const uint32_t *) y;
|
||||
|
||||
// 2 threads per quad supply the packed scale register to the block_scale MMA,
|
||||
// see https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
|
||||
const int tidx_A = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
|
||||
const int tidx_B = threadIdx.x / 4;
|
||||
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
||||
|
||||
tile_A A[ntx][nfrags];
|
||||
uint32_t scaleA[ntx][nfrags];
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
#pragma unroll
|
||||
for (int frag = 0; frag < nfrags; ++frag) {
|
||||
const int k0 = k00 + frag * tile_A::J;
|
||||
load_ldmatrix(A[n][frag], x_qs + (i0 + n * tile_A::I) * stride + k0, stride);
|
||||
scaleA[n][frag] = x_sc[(i0 + n * tile_A::I + tidx_A) * stride + k0 / tile_A::J];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
|
||||
tile_B B[nfrags];
|
||||
uint32_t scaleB[nfrags];
|
||||
|
||||
#pragma unroll
|
||||
for (int frag = 0; frag < nfrags; ++frag) {
|
||||
const int k0 = frag * tile_B::J;
|
||||
load_generic(B[frag], y_qs + j0 * MMQ_TILE_Y_K + k0, MMQ_TILE_Y_K);
|
||||
scaleB[frag] = y_sc[(j0 + tidx_B) * MMQ_TILE_Y_K + frag];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
#pragma unroll
|
||||
for (int frag = 0; frag < nfrags; ++frag) {
|
||||
tile_C C = {};
|
||||
mma_block_scaled_fp4<type>(C, A[n][frag], B[frag], scaleA[n][frag], scaleB[frag]);
|
||||
#pragma unroll
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // BLACKWELL_MMA_AVAILABLE
|
||||
|
||||
|
||||
template <int mmq_y, bool need_check>
|
||||
static __device__ __forceinline__ void load_tiles_nvfp4(const char * __restrict__ x,
|
||||
@@ -1163,77 +1293,6 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
||||
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y>
|
||||
static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x,
|
||||
const int * __restrict__ y,
|
||||
float * __restrict__ sum,
|
||||
const int k00) {
|
||||
typedef tile<16, 8, int> tile_A;
|
||||
typedef tile<8, 8, int> tile_B;
|
||||
typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA
|
||||
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
constexpr int rows_per_warp = 2 * granularity;
|
||||
constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp.
|
||||
|
||||
y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K);
|
||||
|
||||
// Match layout from load_tiles_mxfp4_fp4
|
||||
const int * x_qs = (const int *) x;
|
||||
const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const uint32_t * y_sc = (const uint32_t *) y;
|
||||
|
||||
// tile_A has a length of 64 logical values vs. 32 values in block_mxfp4
|
||||
tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
|
||||
uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
|
||||
|
||||
// Block scale
|
||||
// Each thread has to point to a 4 byte scale value
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
|
||||
|
||||
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
|
||||
const int k0 = k00 + k01;
|
||||
|
||||
load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
|
||||
MMQ_MMA_TILE_X_K_FP4);
|
||||
|
||||
// based on block-scaling document, 2 threads in each quad need to supply to the scale value
|
||||
const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
|
||||
scaleA[n][k01 / (2 * QI_MXFP4)] =
|
||||
*(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4));
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
|
||||
tile_B B;
|
||||
uint32_t scaleB; // 2xN scales
|
||||
|
||||
load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K);
|
||||
|
||||
scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)];
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
tile_C C;
|
||||
|
||||
mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB);
|
||||
#pragma unroll
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y>
|
||||
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
|
||||
@@ -3259,7 +3318,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
|
||||
static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
|
||||
#ifdef BLACKWELL_MMA_AVAILABLE
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4<mmq_y, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma<mmq_x, mmq_y>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_fp4_fp4_mma<mmq_x, mmq_y, GGML_TYPE_MXFP4>;
|
||||
#else
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
||||
@@ -3270,8 +3329,13 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
|
||||
template <int mmq_x, int mmq_y, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_NVFP4> {
|
||||
static constexpr int vdr = VDR_NVFP4_Q8_1_MMQ;
|
||||
#ifdef BLACKWELL_MMA_AVAILABLE
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4_nvfp4<mmq_y, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_fp4_fp4_mma<mmq_x, mmq_y, GGML_TYPE_NVFP4>;
|
||||
#else
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4<mmq_y, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
|
||||
#endif // BLACKWELL_MMA_AVAILABLE
|
||||
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
|
||||
};
|
||||
|
||||
@@ -3406,7 +3470,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
|
||||
|
||||
#if defined(BLACKWELL_MMA_AVAILABLE)
|
||||
// FP4 tile stores 8 blocks
|
||||
constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1;
|
||||
constexpr int ne_block = (type == GGML_TYPE_MXFP4 || type == GGML_TYPE_NVFP4) ? QK_K : 4 * QK8_1;
|
||||
#else
|
||||
constexpr int ne_block = 4 * QK8_1;
|
||||
#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
||||
|
||||
@@ -115,6 +115,7 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_pascal_older(gg
|
||||
case GGML_TYPE_IQ4_NL: return 6;
|
||||
case GGML_TYPE_IQ4_XS: return 5;
|
||||
case GGML_TYPE_MXFP4: return 4;
|
||||
case GGML_TYPE_NVFP4: return 4;
|
||||
case GGML_TYPE_Q2_K: return 4;
|
||||
case GGML_TYPE_Q3_K: return 4;
|
||||
case GGML_TYPE_Q4_0: return 6;
|
||||
@@ -135,6 +136,7 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_turing_plus(ggm
|
||||
case GGML_TYPE_IQ3_S: return 6;
|
||||
case GGML_TYPE_IQ3_XXS: return 7;
|
||||
case GGML_TYPE_MXFP4: return 7;
|
||||
case GGML_TYPE_NVFP4: return 8;
|
||||
case GGML_TYPE_Q2_K: return 7;
|
||||
case GGML_TYPE_Q3_K: return 5;
|
||||
default: return MMVQ_MAX_BATCH_SIZE;
|
||||
@@ -221,6 +223,7 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna4(ggml_type
|
||||
case GGML_TYPE_IQ4_NL: return 7;
|
||||
case GGML_TYPE_IQ4_XS: return 5;
|
||||
case GGML_TYPE_MXFP4: return 5;
|
||||
case GGML_TYPE_NVFP4: return 5;
|
||||
case GGML_TYPE_Q3_K: return 4;
|
||||
case GGML_TYPE_Q4_0: return 7;
|
||||
case GGML_TYPE_Q4_1: return 7;
|
||||
|
||||
@@ -70,6 +70,102 @@ __device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) {
|
||||
return static_cast<uint8_t>(biased);
|
||||
}
|
||||
|
||||
|
||||
static __global__ void quantize_mmq_nvfp4(
|
||||
const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy,
|
||||
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2) {
|
||||
#if defined(BLACKWELL_MMA_AVAILABLE)
|
||||
|
||||
const int64_t i0_base = ((int64_t) blockDim.x * blockIdx.y + threadIdx.x) * QK_NVFP4_SUB;
|
||||
if (i0_base >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i1 = blockIdx.x;
|
||||
const int64_t i2 = blockIdx.z % ne2;
|
||||
const int64_t i3 = blockIdx.z / ne2;
|
||||
const int64_t i01 = ids ? ids[i1] : i1;
|
||||
const int64_t k_block = i0_base / QK_K;
|
||||
const int64_t blocks_per_col = (ne0 + QK_K - 1) / QK_K;
|
||||
if (k_block >= blocks_per_col) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t ib = blockIdx.z * ((int64_t) blocks_per_col * ne1) + k_block * ne1 + blockIdx.x;
|
||||
block_fp4_mmq * y = (block_fp4_mmq *) vy;
|
||||
block_fp4_mmq * yb = y + ib;
|
||||
|
||||
const int sub = (i0_base % QK_K) / QK_NVFP4_SUB;
|
||||
|
||||
float vals_raw[QK_NVFP4_SUB];
|
||||
float amax_raw = 0.0f;
|
||||
const int64_t base_idx = i3 * s03 + i2 * s02 + i01 * s01;
|
||||
#pragma unroll
|
||||
for (int k = 0; k < QK_NVFP4_SUB; k++) {
|
||||
const int64_t i00 = i0_base + k;
|
||||
if (i00 < ne00) {
|
||||
const float v = x[base_idx + i00];
|
||||
vals_raw[k] = v;
|
||||
amax_raw = fmaxf(amax_raw, fabsf(v));
|
||||
} else {
|
||||
vals_raw[k] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr int test_offsets[5] = { 0, -1, 1, -2, 2};
|
||||
const int first_fp8_code = (int) ggml_cuda_fp32_to_ue4m3(amax_raw / 6.0f);
|
||||
|
||||
float best_err = FLT_MAX;
|
||||
uint8_t fp8_code = 0;
|
||||
float subblock_scale = 0.0f;
|
||||
|
||||
#pragma unroll // Check +/- 2 to find best code to reduce NVFP4 activation loss. Negligible overhead on Blackwell.
|
||||
for (int i = 0; i < 5; i++) {
|
||||
const int test_code = first_fp8_code + test_offsets[i];
|
||||
if (test_code < 0 || test_code > 0x7e) {
|
||||
continue;
|
||||
}
|
||||
const uint8_t code = (uint8_t) test_code;
|
||||
const float test_scale = ggml_cuda_ue4m3_to_fp32(code);
|
||||
const float test_inv_scale = test_scale > 0.0f ? 0.5f / test_scale : 0.0f;
|
||||
float cur_err = 0.0f;
|
||||
#pragma unroll
|
||||
for (int k = 0; k < QK_NVFP4_SUB; ++k) {
|
||||
const float v = vals_raw[k];
|
||||
const uint8_t q = ggml_cuda_float_to_fp4_e2m1(v, test_inv_scale);
|
||||
const float err_diff = fabsf(v) - fabsf(kvalues_mxfp4[q & 0x7]) * test_scale;
|
||||
cur_err = fmaf(err_diff, err_diff, cur_err);
|
||||
}
|
||||
|
||||
if (cur_err < best_err) {
|
||||
best_err = cur_err;
|
||||
fp8_code = test_code;
|
||||
subblock_scale = test_scale;
|
||||
}
|
||||
}
|
||||
|
||||
const float inv_scale = subblock_scale > 0.0f ? 0.5f / subblock_scale : 0.0f;
|
||||
uint32_t q0 = 0;
|
||||
uint32_t q1 = 0;
|
||||
#pragma unroll // this is faster than the previous __nv_fp4x4_e2m1
|
||||
for (int k = 0; k < QK_NVFP4_SUB / 4; ++k) {
|
||||
q0 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 0], inv_scale) << (8 * k);
|
||||
q0 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 8], inv_scale) << (8 * k + 4);
|
||||
q1 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 4], inv_scale) << (8 * k);
|
||||
q1 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 12], inv_scale) << (8 * k + 4);
|
||||
}
|
||||
|
||||
uint32_t * yqs = reinterpret_cast<uint32_t *>(yb->qs);
|
||||
yqs[2 * sub + 0] = q0;
|
||||
yqs[2 * sub + 1] = q1;
|
||||
reinterpret_cast<uint8_t *>(yb->d4)[sub] = fp8_code;
|
||||
#else
|
||||
NO_DEVICE_CODE; // This is for Blackwell NVFP4 activations only.
|
||||
#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
||||
|
||||
}
|
||||
|
||||
// quantize values in the format mxfp4 is stored which is interleaved nibbles
|
||||
// i.e. a block a0-a31 is represented as a0a16,a1a17 ...a15a31
|
||||
static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
|
||||
@@ -316,28 +412,32 @@ void quantize_mmq_q8_1_cuda(
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_mmq_mxfp4_cuda(const float * x,
|
||||
const int32_t * ids,
|
||||
void * vy,
|
||||
[[maybe_unused]] const ggml_type type_src0,
|
||||
const int64_t ne00,
|
||||
const int64_t s01,
|
||||
const int64_t s02,
|
||||
const int64_t s03,
|
||||
const int64_t ne0,
|
||||
const int64_t ne1,
|
||||
const int64_t ne2,
|
||||
const int64_t ne3,
|
||||
cudaStream_t stream) {
|
||||
GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0);
|
||||
void quantize_mmq_fp4_cuda(
|
||||
const float * x, const int32_t * ids, void * vy, const ggml_type type_src0,
|
||||
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
|
||||
GGML_ASSERT(type_src0 == GGML_TYPE_MXFP4 || type_src0 == GGML_TYPE_NVFP4);
|
||||
GGML_ASSERT(ne0 > 0);
|
||||
|
||||
constexpr int nwarps = 8;
|
||||
constexpr int vals_per_warp = 2 * QK_MXFP4;
|
||||
constexpr int vals_per_block = nwarps * vals_per_warp;
|
||||
if (type_src0 == GGML_TYPE_NVFP4) {
|
||||
GGML_ASSERT(ne00 % QK_NVFP4 == 0);
|
||||
constexpr int nvfp4_block_size = 128;
|
||||
const int64_t block_num_y = (ne0 + QK_NVFP4_SUB * nvfp4_block_size - 1) / (QK_NVFP4_SUB * nvfp4_block_size);
|
||||
const dim3 block_size(nvfp4_block_size, 1, 1);
|
||||
const dim3 num_blocks(ne1, block_num_y, ne2 * ne3);
|
||||
quantize_mmq_nvfp4<<<num_blocks, block_size, 0, stream>>>(
|
||||
x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
|
||||
} else {
|
||||
GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0);
|
||||
|
||||
const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block;
|
||||
const dim3 num_blocks(ne1, block_num_y, ne2 * ne3);
|
||||
const dim3 block_size(WARP_SIZE, nwarps, 1);
|
||||
constexpr int nwarps = 8;
|
||||
constexpr int vals_per_warp = 2 * QK_MXFP4;
|
||||
constexpr int vals_per_block = nwarps * vals_per_warp;
|
||||
|
||||
quantize_mmq_mxfp4<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
|
||||
const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block;
|
||||
const dim3 num_blocks(ne1, block_num_y, ne2 * ne3);
|
||||
const dim3 block_size(WARP_SIZE, nwarps, 1);
|
||||
|
||||
quantize_mmq_mxfp4<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ void quantize_mmq_q8_1_cuda(
|
||||
ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
|
||||
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
|
||||
|
||||
void quantize_mmq_mxfp4_cuda(const float * x,
|
||||
void quantize_mmq_fp4_cuda(const float * x,
|
||||
const int32_t * ids,
|
||||
void * vy,
|
||||
ggml_type type_src0,
|
||||
|
||||
@@ -2,4 +2,5 @@
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(320, 256, 1, 32);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32);
|
||||
|
||||
@@ -2,4 +2,5 @@
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(320, 256, 2, 32);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32);
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(320, 256);
|
||||
@@ -3,7 +3,7 @@
|
||||
from glob import glob
|
||||
import os
|
||||
|
||||
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 512, 576]
|
||||
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 320, 512, 576]
|
||||
|
||||
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"]
|
||||
|
||||
@@ -62,7 +62,7 @@ for filename in glob("*.cu"):
|
||||
os.remove(filename)
|
||||
|
||||
for head_size_kq in HEAD_SIZES_KQ:
|
||||
head_size_v = head_size_kq if head_size_kq != 576 else 512
|
||||
head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512)
|
||||
with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f:
|
||||
f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v))
|
||||
|
||||
@@ -84,13 +84,16 @@ for ncols in [8, 16, 32, 64]:
|
||||
continue
|
||||
if head_size_kq == 72:
|
||||
continue
|
||||
if head_size_kq == 512 and ncols2 not in (4, 8):
|
||||
# Skip compilation of unused ncols2 values for niche head sizes:
|
||||
if head_size_kq == 320 and ncols2 != 32: # Mistral Small 4
|
||||
continue
|
||||
if head_size_kq != 576 and ncols2 in (16, 32):
|
||||
if head_size_kq == 512 and ncols2 not in (4, 8): # Gemma 4
|
||||
continue
|
||||
if head_size_kq == 576 and ncols2 not in (4, 16, 32):
|
||||
if head_size_kq == 576 and ncols2 not in (4, 16, 32): # Deepseek, GLM 4.7 Flash
|
||||
continue
|
||||
head_size_v = head_size_kq if head_size_kq != 576 else 512
|
||||
if head_size_kq not in (320, 576) and ncols2 in (16, 32):
|
||||
continue
|
||||
head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512)
|
||||
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))
|
||||
|
||||
for type in TYPES_MMQ:
|
||||
|
||||
@@ -3815,7 +3815,7 @@ struct test_mul_mat : public test_case {
|
||||
|
||||
double max_nmse_err(ggml_backend_t backend) override {
|
||||
// for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance
|
||||
if (type_a == GGML_TYPE_MXFP4 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
|
||||
if ((type_a == GGML_TYPE_MXFP4 || type_a == GGML_TYPE_NVFP4) && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
|
||||
return 2e-2;
|
||||
}
|
||||
return max_nmse_err();
|
||||
@@ -3951,7 +3951,7 @@ struct test_mul_mat_id : public test_case {
|
||||
|
||||
double max_nmse_err(ggml_backend_t backend) override {
|
||||
// for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance
|
||||
if (type_a == GGML_TYPE_MXFP4 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
|
||||
if ((type_a == GGML_TYPE_MXFP4 || type_a == GGML_TYPE_NVFP4) && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
|
||||
return 2e-2;
|
||||
}
|
||||
return max_nmse_err();
|
||||
|
||||
Reference in New Issue
Block a user