Compare commits

...

9 Commits
b6029 ... b6038

Author SHA1 Message Date
Ed Addario
e9192bec56 quantize : fix using combined imatrix GGUFs (multiple datasets) (#14973) 2025-07-30 21:11:56 +02:00
Daniel Bevenius
41e78c567e server : add support for embd_normalize parameter (#14964)
This commit adds support for the `embd_normalize` parameter in the
server code.

The motivation for this is that currently if the server is started with
a pooling type that is not `none`, then Euclidean/L2 normalization will
be the normalization method used for embeddings. However, this is not
always the desired behavior, and users may want to use other
normalization (or none) and this commit allows that.

Example usage:
```console
curl --request POST \
    --url http://localhost:8080/embedding \
    --header "Content-Type: application/json" \
    --data '{"input": "Hello world today", "embd_normalize": -1}
```
2025-07-30 18:07:11 +02:00
uvos
ad4a700117 HIP: enable mfma mmq on gfx908 and gfx90a for select datatypes and shapes (#14949) 2025-07-30 17:38:06 +02:00
Georgi Gerganov
e32a4ec60e sync : ggml
ggml-ci
2025-07-30 17:33:11 +03:00
Kai Pastor
e228de9449 cmake : Fix BLAS link interface (ggml/1316) 2025-07-30 17:33:11 +03:00
Kai Pastor
73a8e5ca03 vulkan : fix 32-bit builds (ggml/1313)
The pipeline member can be cast to VkPipeline.
This is a VkPipeline_T* on 64 bit but a uint64_t on 32 bit.
Cf. VK_DEFINE_NON_DISPATCHABLE_HANDLE documentation.
2025-07-30 17:33:11 +03:00
Johannes Gäßler
92b8810ec7 CUDA: skip masked KV slices for all FA kernels (#14924) 2025-07-30 15:46:13 +02:00
Georgi Gerganov
00131d6eaf tests : update for LLAMA_SET_ROWS=1 (#14961)
* test-thread-safety : each context uses a single sequence

* embedding : handle --parallel argument

ggml-ci

* save-load : handle -np 1

ggml-ci

* thread-safety : avoid overriding threads, reduce test case arg

ggml-ci
2025-07-30 15:12:02 +03:00
Georgi Gerganov
1e15bfd42c graph : fix stack-use-after-return (#14960)
ggml-ci
2025-07-30 13:52:11 +03:00
23 changed files with 206 additions and 86 deletions

View File

@@ -81,6 +81,14 @@ int main(int argc, char ** argv) {
params.embedding = true;
// if the number of prompts that would be encoded is known in advance, it's more efficient to specify the
// --parallel argument accordingly. for convenience, if not specified, we fallback to unified KV cache
// in order to support any number of prompts
if (params.n_parallel == 1) {
LOG_INF("%s: n_parallel == 1 -> unified KV cache is enabled\n", __func__);
params.kv_unified = true;
}
// utilize the full context
if (params.n_batch < params.n_ctx) {
LOG_WRN("%s: setting batch size to %d\n", __func__, params.n_ctx);

View File

@@ -15,6 +15,12 @@ int main(int argc, char ** argv) {
return 1;
}
if (params.n_parallel == 1) {
// the example uses 2 sequences, so when n_parallel == 1, we need to enable unified kv cache
printf("%s: n_parallel == 1, enabling unified kv cache\n", __func__);
params.kv_unified = true;
}
common_init();
if (params.n_predict < 0) {

View File

@@ -34,8 +34,8 @@ if (NOT GGML_SHARED_LIB)
if (GGML_BLAS)
find_dependency(BLAS)
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${BLAS_LIBRARIES})
list(APPEND GGML_CPU_INTERFACE_LINK_OPTIONS ${BLAS_LINKER_FLAGS})
list(APPEND GGML_BLAS_INTERFACE_LINK_LIBRARIES ${BLAS_LIBRARIES})
list(APPEND GGML_BLAS_INTERFACE_LINK_OPTIONS ${BLAS_LINKER_FLAGS})
endif()
if (GGML_CUDA)

View File

@@ -227,9 +227,9 @@ typedef float2 dfloat2;
#define FP16_MMA_AVAILABLE
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
#if defined(GGML_USE_HIP) && defined(CDNA3) && !defined(GGML_HIP_NO_MMQ_MFMA)
#if defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
#define AMD_MFMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(CDNA3) && !defined(GGML_HIP_NO_MMQ_MFMA)
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#define NEW_MMA_AVAILABLE
@@ -293,10 +293,9 @@ static bool fp32_mma_hardware_available(const int cc) {
return GGML_CUDA_CC_IS_CDNA(cc);
}
// AMD CDNA3 matrix cores.. Will add support for other CDNA generations later.
static bool amd_mfma_available(const int cc) {
#if !defined(GGML_HIP_NO_MMQ_MFMA)
return GGML_CUDA_CC_IS_CDNA3(cc);
return GGML_CUDA_CC_IS_CDNA(cc);
#else
return false;
#endif //!defined(GGML_HIP_NO_MMQ_MFMA)
@@ -432,6 +431,20 @@ static __global__ void reduce_rows_f32(const float * x, float * dst, const int n
dst[row] = norm ? sum / ncols : sum;
}
template<int width = WARP_SIZE>
static __device__ __forceinline__ int warp_reduce_all(int x) {
#ifdef GGML_USE_HIP
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
x = x && __shfl_xor_sync(0xffffffff, x, offset, width);
}
return x;
#else
static_assert(width == WARP_SIZE, "width != WARP_SIZE not implemented");
return __all_sync(0xffffffff, x);
#endif // GGML_USE_HIP
}
template<int width = WARP_SIZE>
static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll

View File

@@ -15,6 +15,7 @@ typedef void (* fattn_kernel_t)(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
@@ -500,6 +501,55 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
nullptr;
}
template <int ncols1>
__launch_bounds__(FATTN_KQ_STRIDE/2, 1)
static __global__ void flash_attn_mask_to_KV_max(
const half2 * __restrict__ mask, int * __restrict__ KV_max, const int ne30, const int s31, const int s33) {
const int ne31 = gridDim.x;
const int tid = threadIdx.x;
const int sequence = blockIdx.y;
const int jt = blockIdx.x;
mask += sequence*s33 + jt*ncols1*s31;
__shared__ int buf_iw[WARP_SIZE];
if (tid < WARP_SIZE) {
buf_iw[tid] = 1;
}
__syncthreads();
int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE;
for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) {
int all_inf = 1;
#pragma unroll
for (int j = 0; j < ncols1; ++j) {
const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]);
all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y));
}
all_inf = warp_reduce_all(all_inf);
if (tid % WARP_SIZE == 0) {
buf_iw[tid / WARP_SIZE] = all_inf;
}
__syncthreads();
all_inf = buf_iw[tid % WARP_SIZE];
__syncthreads();
all_inf = warp_reduce_all(all_inf);
if (!all_inf) {
KV_max_sj += FATTN_KQ_STRIDE;
break;
}
}
if (threadIdx.x != 0) {
return;
}
KV_max[sequence*ne31 + jt] = KV_max_sj;
}
template<int D, int ncols1, int ncols2> // D == head size
__launch_bounds__(D, 1)
static __global__ void flash_attn_stream_k_fixup(
@@ -711,6 +761,7 @@ void launch_fattn(
ggml_cuda_pool_alloc<half> K_f16(pool);
ggml_cuda_pool_alloc<half> V_f16(pool);
ggml_cuda_pool_alloc<int> KV_max(pool);
ggml_cuda_pool_alloc<float> dst_tmp(pool);
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
@@ -779,11 +830,30 @@ void launch_fattn(
V_data = (char *) V_f16.ptr;
}
int parallel_blocks = 1;
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
// multiple sequences of possibly different lengths.
if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
const int s31 = mask->nb[1] / sizeof(half2);
const int s33 = mask->nb[3] / sizeof(half2);
const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1);
const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y;
const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
KV_max.alloc(ne_KV_max);
flash_attn_mask_to_KV_max<ncols1><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
((const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33);
CUDA_CHECK(cudaGetLastError());
}
int parallel_blocks = 1;
const dim3 block_dim(warp_size, nwarps, 1);
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
@@ -870,6 +940,7 @@ void launch_fattn(
K_data,
V_data,
mask ? ((const char *) mask->data) : nullptr,
KV_max.ptr,
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],

View File

@@ -392,7 +392,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
}
}
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles,
bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
const float2 * const __restrict__ Q_f2,
const half2 * const __restrict__ K_h2,
@@ -922,7 +923,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
}
// Iterate over ne11 == previous tokens:
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
int kb0 = kb0_start;
for (; kb0 < kb0_stop-1; ++kb0) {
constexpr bool last_iter = false;
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
@@ -932,7 +934,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
constexpr bool last_iter = true;
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
}
// With multi-stage loading there is no __syncthreads at the end of the iter,
@@ -1204,6 +1206,7 @@ static __global__ void flash_attn_ext_f16(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
@@ -1280,7 +1283,11 @@ static __global__ void flash_attn_ext_f16(
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
const int kb0_start_kernel = kb0_start * kb_niter;
const int kb0_stop_kernel = kb0_stop * kb_niter;
int kb0_stop_kernel = kb0_stop * kb_niter;
if (KV_max) {
kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa);
}
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
if (kb0_start == 0) {
@@ -1321,7 +1328,11 @@ static __global__ void flash_attn_ext_f16(
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
const int kb0_start_kernel = kb0_start * kb_niter;
const int kb0_stop_kernel = kb0_stop * kb_niter;
int kb0_stop_kernel = kb0_stop * kb_niter;
if (KV_max) {
kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa);
}
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
constexpr bool needs_fixup = false;

View File

@@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f16(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
@@ -90,7 +91,8 @@ static __global__ void flash_attn_tile_ext_f16(
__syncthreads();
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) {
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) {
// Calculate KQ tile and keep track of new maximum KQ values:
half kqmax_new[ncols/nwarps];

View File

@@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f32(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
@@ -99,7 +100,8 @@ static __global__ void flash_attn_tile_ext_f32(
__syncthreads();
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) {
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) {
// Calculate KQ tile and keep track of new maximum KQ values:
float kqmax_new[ncols/nwarps];

View File

@@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f16(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
@@ -177,10 +178,11 @@ static __global__ void flash_attn_vec_ext_f16(
half2 VKQ[ncols] = {{0.0f, 0.0f}};
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
K += blockIdx.y*D * nb11;
V += blockIdx.y*D * nb21;
maskh += blockIdx.y*D;
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*D,
// Increment pointers after each loop:
K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
@@ -191,29 +193,7 @@ static __global__ void flash_attn_vec_ext_f16(
for (int j = 0; j < ncols; ++j) {
maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + tid];
}
__syncthreads();
// When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
// In such cases, skip the KV slice.
// On AMD __all_sync would not work correctly because it assumes a warp size of 64.
#ifndef GGML_USE_HIP
bool skip = true;
#pragma unroll
for (int j = 0; j < ncols; ++j) {
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
const float2 tmp = __half22float2(((const half2 *) maskh_shared)[j*(D/2) + i]);
skip = skip && isinf(tmp.x) && isinf(tmp.y);
}
}
if (__all_sync(0xFFFFFFFF, skip)) {
__syncthreads();
continue;
}
#endif // GGML_USE_HIP
}
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,

View File

@@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f32(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
@@ -183,10 +184,11 @@ static __global__ void flash_attn_vec_ext_f32(
float VKQ[ncols] = {0.0f};
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
K += blockIdx.y*D * nb11;
V += blockIdx.y*D * nb21;
maskh += blockIdx.y*D;
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*D,
// Increment pointers after each loop:
K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
@@ -197,28 +199,7 @@ static __global__ void flash_attn_vec_ext_f32(
for (int j = 0; j < ncols; ++j) {
maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + tid]);
}
__syncthreads();
// When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
// In such cases, skip the KV slice.
// On AMD __all_sync would not work correctly because it assumes a warp size of 64.
#ifndef GGML_USE_HIP
bool skip = true;
#pragma unroll
for (int j = 0; j < ncols; ++j) {
#pragma unroll
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
skip = skip && isinf(maskf_shared[j*D + i]);
}
}
if (__all_sync(0xFFFFFFFF, skip)) {
__syncthreads();
continue;
}
#endif // GGML_USE_HIP
}
float kqmax_new_arr[ncols];

View File

@@ -29,6 +29,7 @@ static __global__ void flash_attn_ext_f16(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
@@ -165,7 +166,8 @@ static __global__ void flash_attn_ext_f16(
__syncthreads();
// Iterate over ne11 == previous tokens:
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) {
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) {
// Calculate tile of KQ:
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {

View File

@@ -315,7 +315,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies &&
(Q->ne[3] > 1 || cc < GGML_CUDA_CC_ADA_LOVELACE) && !mma_needs_data_conversion;
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
if (prec == GGML_PREC_DEFAULT) {

View File

@@ -109,8 +109,8 @@ void ggml_cuda_mul_mat_q(
const int64_t s03 = src0->nb[3] / ts_src0;
const int64_t s3 = dst->nb[3] / ts_dst;
const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|| (GGML_CUDA_CC_IS_AMD(cc) && GGML_CUDA_CC_IS_CDNA3(cc)));
const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|| GGML_CUDA_CC_IS_CDNA(cc);
if (!ids) {
const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
@@ -252,7 +252,7 @@ void ggml_cuda_op_mul_mat_q(
// Also its fixup needs to allocate a temporary buffer in the memory pool.
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|| (GGML_CUDA_CC_IS_AMD(cc) && GGML_CUDA_CC_IS_CDNA3(cc)))
|| GGML_CUDA_CC_IS_CDNA(cc))
&& src1_ncols == ne11;
const mmq_args args = {
src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i,
@@ -306,7 +306,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return false;
}
if (new_mma_available(cc) || amd_mfma_available(cc)) {
if (new_mma_available(cc)) {
return true;
}
@@ -322,5 +322,21 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}
if (amd_mfma_available(cc)) {
// As of ROCM 7.0 rocblas/tensile performs very poorly on CDNA3 and hipblaslt (via ROCBLAS_USE_HIPBLASLT)
// performs better but is currently suffering from a crash on this architecture.
// TODO: Revisit when hipblaslt is fixed on CDNA3
if (GGML_CUDA_CC_IS_CDNA3(cc)) {
return true;
}
if (ne11 <= 128 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) {
return true;
}
if (ne11 <= 256 && (type == GGML_TYPE_Q4_K || type == GGML_TYPE_Q5_K)) {
return true;
}
return false;
}
return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}

View File

@@ -3096,8 +3096,8 @@ static __global__ void mul_mat_q(
}
__syncthreads();
// On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
#if (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
// On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
#if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
{
const int wt = blockIdx.z / nchannels_y;
const int zt = blockIdx.z - wt*nchannels_y;

View File

@@ -1341,7 +1341,7 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
vk::DebugUtilsObjectNameInfoEXT duoni;
duoni.objectType = vk::ObjectType::ePipeline;
duoni.pObjectName = pipeline->name.c_str();
duoni.objectHandle = reinterpret_cast<uint64_t>(static_cast<VkPipeline_T*>(pipeline->pipeline));
duoni.objectHandle = /*reinterpret_cast*/(uint64_t)(static_cast<VkPipeline>(pipeline->pipeline));
vk_instance.pfn_vkSetDebugUtilsObjectNameEXT(device->device, &static_cast<VkDebugUtilsObjectNameInfoEXT &>(duoni));
}

View File

@@ -1 +1 @@
b7bfde9c88aa4b063ce68dab6cc4f5c6caae37fd
daf7906728036a82f20c69fcbd74b6f536c74d3f

View File

@@ -59,7 +59,7 @@ bool llama_batch_allocr::init(
for (int32_t i = 0; i < batch.n_tokens; ++i) {
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) {
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d >= %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
return false;
}
}

View File

@@ -144,7 +144,7 @@ public:
ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
const llama_hparams & hparams;
const llama_hparams hparams;
};
class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
@@ -158,7 +158,7 @@ public:
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
const llama_hparams & hparams;
const llama_hparams hparams;
const llama_kv_cache_unified_context * mctx;
};
@@ -177,8 +177,8 @@ public:
ggml_tensor * out_ids; // I32 [n_outputs]
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_hparams hparams;
const llama_cparams cparams;
const uint32_t n_outputs;
};
@@ -192,7 +192,7 @@ public:
ggml_tensor * mean; // F32 [n_batch, n_batch]
const llama_cparams & cparams;
const llama_cparams cparams;
};
class llm_graph_input_cls : public llm_graph_input_i {
@@ -204,7 +204,7 @@ public:
ggml_tensor * cls; // I32 [n_batch]
const llama_cparams & cparams;
const llama_cparams cparams;
};
class llm_graph_input_rs : public llm_graph_input_i {
@@ -247,8 +247,8 @@ public:
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1]
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1]
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_hparams hparams;
const llama_cparams cparams;
};
class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
@@ -278,8 +278,11 @@ public:
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
const llama_hparams & hparams;
const llama_cparams & cparams;
// note: these have to be copies because in order to be able to reuse a graph, its inputs
// need to carry these parameters with them. otherwise, they can point to freed
// llm_graph_params from a previous batch, causing stack-use-after-return
const llama_hparams hparams;
const llama_cparams cparams;
const llama_kv_cache_unified_context * mctx;
};
@@ -318,8 +321,8 @@ public:
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_hparams hparams;
const llama_cparams cparams;
const llama_kv_cache_unified_iswa_context * mctx;
};

View File

@@ -185,7 +185,7 @@ llama_build_and_test(test-json-partial.cpp)
llama_build_and_test(test-log.cpp)
llama_build_and_test(test-regex-partial.cpp)
llama_build_and_test(test-thread-safety.cpp ARGS -hf ggml-org/models -hff tinyllamas/stories15M-q4_0.gguf -ngl 99 -p "The meaning of life is" -n 128 -c 256 -ub 32 -np 4)
llama_build_and_test(test-thread-safety.cpp ARGS -hf ggml-org/models -hff tinyllamas/stories15M-q4_0.gguf -ngl 99 -p "The meaning of life is" -n 128 -c 256 -ub 32 -np 4 -t 2)
# this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135)
if (NOT WIN32)

View File

@@ -34,6 +34,9 @@ int main(int argc, char ** argv) {
auto cparams = common_context_params_to_llama(params);
// each context has a single sequence
cparams.n_seq_max = 1;
int dev_count = ggml_backend_dev_count();
int gpu_dev_count = 0;
for (int i = 0; i < dev_count; ++i) {

View File

@@ -311,7 +311,7 @@ static int load_imatrix(const std::string & imatrix_file, std::vector<std::strin
int64_t n_datasets = gguf_get_arr_n(ctx_gguf, dataset_idx);
imatrix_datasets.reserve(n_datasets);
for (int64_t i = 0; i < n_datasets; ++i) {
imatrix_datasets.push_back(gguf_get_val_str(ctx_gguf, dataset_idx));
imatrix_datasets.push_back(gguf_get_arr_str(ctx_gguf, dataset_idx, i));
}
printf("%s: imatrix datasets=['%s'", __func__, imatrix_datasets[0].c_str());
for (size_t i = 1; i < imatrix_datasets.size(); ++i) {

View File

@@ -644,6 +644,15 @@ The same as [the embedding example](../embedding) does.
`image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `content`. You can determine the place of the image in the content as in the following: `Image: [img-21].\nCaption: This is a picture of a house`. In this case, `[img-21]` will be replaced by the embeddings of the image with id `21` in the following `image_data` array: `{..., "image_data": [{"data": "<BASE64_STRING>", "id": 21}]}`. Use `image_data` only with multimodal models, e.g., LLaVA.
`embd_normalize`: Normalization for pooled embeddings. Can be one of the following values:
```
-1: No normalization
0: Max absolute
1: Taxicab
2: Euclidean/L2
>2: P-Norm
```
### POST `/reranking`: Rerank documents according to a given query
Similar to https://jina.ai/reranker/ but might change in the future.

View File

@@ -138,6 +138,9 @@ struct slot_params {
std::string oaicompat_cmpl_id;
common_chat_syntax oaicompat_chat_syntax;
// Embeddings
int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
json to_json() const {
std::vector<std::string> samplers;
samplers.reserve(sampling.samplers.size());
@@ -2601,7 +2604,7 @@ struct server_context {
// normalize only when there is pooling
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
common_embd_normalize(embd, embd_res.data(), n_embd, slot.params.embd_normalize);
res->embedding.push_back(embd_res);
break;
} else {
@@ -4614,6 +4617,14 @@ int main(int argc, char ** argv) {
}
}
int embd_normalize = 2; // default to Euclidean/L2 norm
if (body.count("embd_normalize") != 0) {
embd_normalize = body.at("embd_normalize");
if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx));
}
}
// create and queue the task
json responses = json::array();
bool error = false;
@@ -4629,6 +4640,7 @@ int main(int argc, char ** argv) {
// OAI-compat
task.params.oaicompat = oaicompat;
task.params.embd_normalize = embd_normalize;
tasks.push_back(std::move(task));
}