|
|
|
|
@@ -676,9 +676,96 @@ static __global__ void flash_attn_mask_to_KV_max(
|
|
|
|
|
|
|
|
|
|
template<int D, int ncols1, int ncols2> // D == head size
|
|
|
|
|
__launch_bounds__(D, 1)
|
|
|
|
|
static __global__ void flash_attn_stream_k_fixup(
|
|
|
|
|
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03,
|
|
|
|
|
const int ne11, const int ne12, const int nbatch_fa) {
|
|
|
|
|
static __global__ void flash_attn_stream_k_fixup_uniform(
|
|
|
|
|
float * __restrict__ dst,
|
|
|
|
|
const float2 * __restrict__ dst_fixup,
|
|
|
|
|
const int ne01, const int ne02,
|
|
|
|
|
const int ne12, const int nblocks_stream_k,
|
|
|
|
|
const int gqa_ratio,
|
|
|
|
|
const int blocks_per_tile,
|
|
|
|
|
const uint3 fd_iter_j_z_ne12,
|
|
|
|
|
const uint3 fd_iter_j_z,
|
|
|
|
|
const uint3 fd_iter_j) {
|
|
|
|
|
constexpr int ncols = ncols1*ncols2;
|
|
|
|
|
|
|
|
|
|
const int tile_idx = blockIdx.x; // One block per output tile.
|
|
|
|
|
const int j = blockIdx.y;
|
|
|
|
|
const int c = blockIdx.z;
|
|
|
|
|
const int jc = j*ncols2 + c;
|
|
|
|
|
const int tid = threadIdx.x;
|
|
|
|
|
|
|
|
|
|
// nblocks_stream_k is a multiple of ntiles_dst (== gridDim.x), so each tile gets the same number of blocks.
|
|
|
|
|
const int b_first = tile_idx * blocks_per_tile;
|
|
|
|
|
const int b_last = b_first + blocks_per_tile - 1;
|
|
|
|
|
|
|
|
|
|
const float * dst_fixup_data = ((const float *) dst_fixup) + nblocks_stream_k*(2*2*ncols);
|
|
|
|
|
|
|
|
|
|
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
|
|
|
|
|
const uint2 dm0 = fast_div_modulo(tile_idx, fd_iter_j_z_ne12);
|
|
|
|
|
const uint2 dm1 = fast_div_modulo(dm0.y, fd_iter_j_z);
|
|
|
|
|
const uint2 dm2 = fast_div_modulo(dm1.y, fd_iter_j);
|
|
|
|
|
|
|
|
|
|
const int sequence = dm0.x;
|
|
|
|
|
const int z_KV = dm1.x;
|
|
|
|
|
const int zt_gqa = dm2.x;
|
|
|
|
|
const int jt = dm2.y;
|
|
|
|
|
|
|
|
|
|
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
|
|
|
|
|
|
|
|
|
|
if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;
|
|
|
|
|
|
|
|
|
|
// Load the partial result that needs a fixup
|
|
|
|
|
float dst_val = *dst;
|
|
|
|
|
float max_val;
|
|
|
|
|
float rowsum;
|
|
|
|
|
{
|
|
|
|
|
const float2 tmp = dst_fixup[b_last*ncols + jc];
|
|
|
|
|
max_val = tmp.x;
|
|
|
|
|
rowsum = tmp.y;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Combine with all previous blocks in this tile.
|
|
|
|
|
for (int bidx = b_last - 1; bidx >= b_first; --bidx) {
|
|
|
|
|
const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
|
|
|
|
|
|
|
|
|
|
const float2 tmp = dst_fixup[(nblocks_stream_k + bidx)*ncols + jc];
|
|
|
|
|
|
|
|
|
|
const float max_val_new = fmaxf(max_val, tmp.x);
|
|
|
|
|
|
|
|
|
|
const float diff_val = max_val - max_val_new;
|
|
|
|
|
const float diff_add = tmp.x - max_val_new;
|
|
|
|
|
|
|
|
|
|
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
|
|
|
|
|
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
|
|
|
|
|
|
|
|
|
|
dst_val = scale_val*dst_val + scale_add*dst_add;
|
|
|
|
|
rowsum = scale_val*rowsum + scale_add*tmp.y;
|
|
|
|
|
|
|
|
|
|
max_val = max_val_new;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Write back final result:
|
|
|
|
|
*dst = dst_val / rowsum;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// General fixup kernel for the case where the number of blocks per tile is not uniform across tiles
|
|
|
|
|
// (blocks_num.x not a multiple of ntiles_dst)
|
|
|
|
|
template <int D, int ncols1, int ncols2> // D == head size
|
|
|
|
|
__launch_bounds__(D, 1)
|
|
|
|
|
static __global__ void flash_attn_stream_k_fixup_general(
|
|
|
|
|
float * __restrict__ dst,
|
|
|
|
|
const float2 * __restrict__ dst_fixup,
|
|
|
|
|
const int ne01, const int ne02,
|
|
|
|
|
const int gqa_ratio,
|
|
|
|
|
const int total_work,
|
|
|
|
|
const uint3 fd_iter_k_j_z_ne12,
|
|
|
|
|
const uint3 fd_iter_k_j_z,
|
|
|
|
|
const uint3 fd_iter_k_j,
|
|
|
|
|
const uint3 fd_iter_k) {
|
|
|
|
|
constexpr int ncols = ncols1*ncols2;
|
|
|
|
|
|
|
|
|
|
const int bidx0 = blockIdx.x;
|
|
|
|
|
@@ -689,27 +776,26 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
|
|
|
|
|
|
|
|
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
|
|
|
|
|
|
|
|
|
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
|
|
|
|
|
|
|
|
|
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
|
|
|
|
|
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
|
|
|
|
const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
|
|
|
|
|
|
|
|
|
|
const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
|
|
|
|
|
const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
|
|
|
|
|
const int kbc0 = int64_t(bidx0 + 0)*total_work / gridDim.x;
|
|
|
|
|
const int kbc0_stop = int64_t(bidx0 + 1)*total_work / gridDim.x;
|
|
|
|
|
|
|
|
|
|
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
|
|
|
|
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
|
|
|
|
const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;
|
|
|
|
|
const bool wrote_beginning_of_tile = fastmodulo(kbc0, fd_iter_k) == 0;
|
|
|
|
|
const bool did_not_write_last = fastdiv(kbc0, fd_iter_k) == fastdiv(kbc0_stop, fd_iter_k) && fastmodulo(kbc0_stop, fd_iter_k) != 0;
|
|
|
|
|
if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
|
|
|
|
|
const int sequence = kbc0 /(iter_k*iter_j*iter_z_gqa*ne12);
|
|
|
|
|
const int z_KV = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
|
|
|
|
|
const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
|
|
|
|
|
const int jt = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
|
|
|
|
|
const uint2 dm0 = fast_div_modulo(kbc0, fd_iter_k_j_z_ne12);
|
|
|
|
|
const uint2 dm1 = fast_div_modulo(dm0.y, fd_iter_k_j_z);
|
|
|
|
|
const uint2 dm2 = fast_div_modulo(dm1.y, fd_iter_k_j);
|
|
|
|
|
const uint2 dm3 = fast_div_modulo(dm2.y, fd_iter_k);
|
|
|
|
|
|
|
|
|
|
const int sequence = dm0.x;
|
|
|
|
|
const int z_KV = dm1.x;
|
|
|
|
|
const int zt_gqa = dm2.x;
|
|
|
|
|
const int jt = dm3.x;
|
|
|
|
|
|
|
|
|
|
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
|
|
|
|
|
|
|
|
|
|
@@ -733,10 +819,11 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
|
|
|
|
|
|
|
|
// Iterate over previous blocks and compute the combined results.
|
|
|
|
|
// All CUDA blocks that get here must have a previous block that needs a fixup.
|
|
|
|
|
const int tile_kbc0 = fastdiv(kbc0, fd_iter_k);
|
|
|
|
|
int bidx = bidx0 - 1;
|
|
|
|
|
int kbc_stop = kbc0;
|
|
|
|
|
while(true) {
|
|
|
|
|
const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
|
|
|
|
|
const int kbc = int64_t(bidx)*total_work / gridDim.x;
|
|
|
|
|
if (kbc == kbc_stop) { // Did not have any data.
|
|
|
|
|
bidx--;
|
|
|
|
|
kbc_stop = kbc;
|
|
|
|
|
@@ -762,7 +849,7 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
|
|
|
max_val = max_val_new;
|
|
|
|
|
|
|
|
|
|
// If this block started in a previous tile we are done and don't need to combine additional partial results.
|
|
|
|
|
if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
|
|
|
|
|
if (fastmodulo(kbc, fd_iter_k) == 0 || fastdiv(kbc, fd_iter_k) < tile_kbc0) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
bidx--;
|
|
|
|
|
@@ -976,14 +1063,28 @@ void launch_fattn(
|
|
|
|
|
const int tiles_nwaves = (ntiles_dst + max_blocks - 1) / max_blocks;
|
|
|
|
|
const int tiles_efficiency_percent = 100 * ntiles_dst / (max_blocks*tiles_nwaves);
|
|
|
|
|
|
|
|
|
|
const int nblocks_stream_k = std::min(max_blocks, ntiles_KV*ntiles_dst);
|
|
|
|
|
|
|
|
|
|
const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75;
|
|
|
|
|
|
|
|
|
|
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_dst;
|
|
|
|
|
blocks_num.x = ntiles_dst;
|
|
|
|
|
blocks_num.y = 1;
|
|
|
|
|
blocks_num.z = 1;
|
|
|
|
|
|
|
|
|
|
if(use_stream_k) {
|
|
|
|
|
const int nblocks_stream_k_raw = std::min(max_blocks, ntiles_KV*ntiles_dst);
|
|
|
|
|
// Round down to a multiple of ntiles_dst so that each output tile gets the same number of blocks (avoids fixup).
|
|
|
|
|
// Only do this if the occupancy loss from rounding is acceptable.
|
|
|
|
|
const int nblocks_stream_k_rounded = (nblocks_stream_k_raw / ntiles_dst) * ntiles_dst;
|
|
|
|
|
const int max_efficiency_loss_percent = 5;
|
|
|
|
|
const int efficiency_loss_percent = nblocks_stream_k_rounded > 0
|
|
|
|
|
? 100 * (nblocks_stream_k_raw - nblocks_stream_k_rounded) / nblocks_stream_k_raw
|
|
|
|
|
: 100;
|
|
|
|
|
const int nblocks_stream_k = efficiency_loss_percent <= max_efficiency_loss_percent
|
|
|
|
|
? nblocks_stream_k_rounded
|
|
|
|
|
: nblocks_stream_k_raw;
|
|
|
|
|
|
|
|
|
|
blocks_num.x = nblocks_stream_k;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
|
|
|
|
dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));
|
|
|
|
|
}
|
|
|
|
|
@@ -1063,13 +1164,40 @@ void launch_fattn(
|
|
|
|
|
CUDA_CHECK(cudaGetLastError());
|
|
|
|
|
|
|
|
|
|
if (stream_k) {
|
|
|
|
|
if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
|
|
|
|
if ((int)blocks_num.x % ntiles_dst == 0 && (int)blocks_num.x > ntiles_dst) {
|
|
|
|
|
// Optimized fixup: nblocks_stream_k is a multiple of ntiles_dst, launch one block per tile.
|
|
|
|
|
const int nblocks_sk = (int)blocks_num.x;
|
|
|
|
|
const int bpt = nblocks_sk / ntiles_dst;
|
|
|
|
|
|
|
|
|
|
const uint3 fd0 = init_fastdiv_values(ntiles_x * ntiles_z_gqa * K->ne[2]);
|
|
|
|
|
const uint3 fd1 = init_fastdiv_values(ntiles_x * ntiles_z_gqa);
|
|
|
|
|
const uint3 fd2 = init_fastdiv_values(ntiles_x);
|
|
|
|
|
|
|
|
|
|
const dim3 block_dim_combine(DV, 1, 1);
|
|
|
|
|
const dim3 blocks_num_combine = {(unsigned)ntiles_dst, ncols1, ncols2};
|
|
|
|
|
|
|
|
|
|
flash_attn_stream_k_fixup_uniform<DV, ncols1, ncols2>
|
|
|
|
|
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
|
|
|
|
((float *) KQV->data, dst_tmp_meta.ptr,
|
|
|
|
|
Q->ne[1], Q->ne[2], K->ne[2], nblocks_sk,
|
|
|
|
|
gqa_ratio, bpt, fd0, fd1, fd2);
|
|
|
|
|
} else if (ntiles_dst % blocks_num.x != 0) {
|
|
|
|
|
// General fixup for the cases where nblocks_stream_k < ntiles_dst.
|
|
|
|
|
const int total_work = ntiles_KV * ntiles_dst;
|
|
|
|
|
|
|
|
|
|
const uint3 fd_k_j_z_ne12 = init_fastdiv_values(ntiles_KV * ntiles_x * ntiles_z_gqa * K->ne[2]);
|
|
|
|
|
const uint3 fd_k_j_z = init_fastdiv_values(ntiles_KV * ntiles_x * ntiles_z_gqa);
|
|
|
|
|
const uint3 fd_k_j = init_fastdiv_values(ntiles_KV * ntiles_x);
|
|
|
|
|
const uint3 fd_k = init_fastdiv_values(ntiles_KV);
|
|
|
|
|
|
|
|
|
|
const dim3 block_dim_combine(DV, 1, 1);
|
|
|
|
|
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
|
|
|
|
|
|
|
|
|
|
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
|
|
|
|
|
flash_attn_stream_k_fixup_general<DV, ncols1, ncols2>
|
|
|
|
|
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
|
|
|
|
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa);
|
|
|
|
|
((float *) KQV->data, dst_tmp_meta.ptr,
|
|
|
|
|
Q->ne[1], Q->ne[2], gqa_ratio, total_work,
|
|
|
|
|
fd_k_j_z_ne12, fd_k_j_z, fd_k_j, fd_k);
|
|
|
|
|
}
|
|
|
|
|
} else if (parallel_blocks > 1) {
|
|
|
|
|
const dim3 block_dim_combine(DV, 1, 1);
|
|
|
|
|
|