Compare commits

..

1 Commits

Author SHA1 Message Date
fairydreaming 0eca4d490e cuda : prevent integer truncation and overflow errors when using KQ mask strides in flash_attn_mask_to_KV_max kernel (#24945)
Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
2026-06-30 20:47:05 +02:00
+3 -3
View File
@@ -664,7 +664,7 @@ constexpr __device__ dequantize_V_t get_dequantize_V() {
template <int ncols1>
__launch_bounds__(FATTN_KQ_STRIDE/2, 1)
static __global__ void flash_attn_mask_to_KV_max(
const half2 * __restrict__ mask, int * __restrict__ KV_max, const int ne30, const int s31, const int s33) {
const half2 * __restrict__ mask, int * __restrict__ KV_max, const int ne30, const int64_t s31, const int64_t s33) {
const int ne31 = gridDim.x;
const int tid = threadIdx.x;
const int sequence = blockIdx.y;
@@ -1089,8 +1089,8 @@ void launch_fattn(
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
// multiple sequences of possibly different lengths.
if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
const int s31 = mask->nb[1] / sizeof(half2);
const int s33 = mask->nb[3] / sizeof(half2);
const int64_t s31 = mask->nb[1] / sizeof(half2);
const int64_t s33 = mask->nb[3] / sizeof(half2);
const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1);