Compare commits

...

2 Commits
b6035 ... b6037

Author SHA1 Message Date
Daniel Bevenius
41e78c567e server : add support for embd_normalize parameter (#14964)
This commit adds support for the `embd_normalize` parameter in the
server code.

The motivation for this is that currently if the server is started with
a pooling type that is not `none`, then Euclidean/L2 normalization will
be the normalization method used for embeddings. However, this is not
always the desired behavior, and users may want to use other
normalization (or none) and this commit allows that.

Example usage:
```console
curl --request POST \
    --url http://localhost:8080/embedding \
    --header "Content-Type: application/json" \
    --data '{"input": "Hello world today", "embd_normalize": -1}
```
2025-07-30 18:07:11 +02:00
uvos
ad4a700117 HIP: enable mfma mmq on gfx908 and gfx90a for select datatypes and shapes (#14949) 2025-07-30 17:38:06 +02:00
5 changed files with 47 additions and 11 deletions

View File

@@ -227,9 +227,9 @@ typedef float2 dfloat2;
#define FP16_MMA_AVAILABLE
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
#if defined(GGML_USE_HIP) && defined(CDNA3) && !defined(GGML_HIP_NO_MMQ_MFMA)
#if defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
#define AMD_MFMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(CDNA3) && !defined(GGML_HIP_NO_MMQ_MFMA)
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#define NEW_MMA_AVAILABLE
@@ -293,10 +293,9 @@ static bool fp32_mma_hardware_available(const int cc) {
return GGML_CUDA_CC_IS_CDNA(cc);
}
// AMD CDNA3 matrix cores.. Will add support for other CDNA generations later.
static bool amd_mfma_available(const int cc) {
#if !defined(GGML_HIP_NO_MMQ_MFMA)
return GGML_CUDA_CC_IS_CDNA3(cc);
return GGML_CUDA_CC_IS_CDNA(cc);
#else
return false;
#endif //!defined(GGML_HIP_NO_MMQ_MFMA)

View File

@@ -109,8 +109,8 @@ void ggml_cuda_mul_mat_q(
const int64_t s03 = src0->nb[3] / ts_src0;
const int64_t s3 = dst->nb[3] / ts_dst;
const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|| (GGML_CUDA_CC_IS_AMD(cc) && GGML_CUDA_CC_IS_CDNA3(cc)));
const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|| GGML_CUDA_CC_IS_CDNA(cc);
if (!ids) {
const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
@@ -252,7 +252,7 @@ void ggml_cuda_op_mul_mat_q(
// Also its fixup needs to allocate a temporary buffer in the memory pool.
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|| (GGML_CUDA_CC_IS_AMD(cc) && GGML_CUDA_CC_IS_CDNA3(cc)))
|| GGML_CUDA_CC_IS_CDNA(cc))
&& src1_ncols == ne11;
const mmq_args args = {
src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i,
@@ -306,7 +306,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return false;
}
if (new_mma_available(cc) || amd_mfma_available(cc)) {
if (new_mma_available(cc)) {
return true;
}
@@ -322,5 +322,21 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}
if (amd_mfma_available(cc)) {
// As of ROCM 7.0 rocblas/tensile performs very poorly on CDNA3 and hipblaslt (via ROCBLAS_USE_HIPBLASLT)
// performs better but is currently suffering from a crash on this architecture.
// TODO: Revisit when hipblaslt is fixed on CDNA3
if (GGML_CUDA_CC_IS_CDNA3(cc)) {
return true;
}
if (ne11 <= 128 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) {
return true;
}
if (ne11 <= 256 && (type == GGML_TYPE_Q4_K || type == GGML_TYPE_Q5_K)) {
return true;
}
return false;
}
return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}

View File

@@ -3096,8 +3096,8 @@ static __global__ void mul_mat_q(
}
__syncthreads();
// On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
#if (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
// On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
#if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
{
const int wt = blockIdx.z / nchannels_y;
const int zt = blockIdx.z - wt*nchannels_y;

View File

@@ -644,6 +644,15 @@ The same as [the embedding example](../embedding) does.
`image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `content`. You can determine the place of the image in the content as in the following: `Image: [img-21].\nCaption: This is a picture of a house`. In this case, `[img-21]` will be replaced by the embeddings of the image with id `21` in the following `image_data` array: `{..., "image_data": [{"data": "<BASE64_STRING>", "id": 21}]}`. Use `image_data` only with multimodal models, e.g., LLaVA.
`embd_normalize`: Normalization for pooled embeddings. Can be one of the following values:
```
-1: No normalization
0: Max absolute
1: Taxicab
2: Euclidean/L2
>2: P-Norm
```
### POST `/reranking`: Rerank documents according to a given query
Similar to https://jina.ai/reranker/ but might change in the future.

View File

@@ -138,6 +138,9 @@ struct slot_params {
std::string oaicompat_cmpl_id;
common_chat_syntax oaicompat_chat_syntax;
// Embeddings
int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
json to_json() const {
std::vector<std::string> samplers;
samplers.reserve(sampling.samplers.size());
@@ -2601,7 +2604,7 @@ struct server_context {
// normalize only when there is pooling
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
common_embd_normalize(embd, embd_res.data(), n_embd, slot.params.embd_normalize);
res->embedding.push_back(embd_res);
break;
} else {
@@ -4614,6 +4617,14 @@ int main(int argc, char ** argv) {
}
}
int embd_normalize = 2; // default to Euclidean/L2 norm
if (body.count("embd_normalize") != 0) {
embd_normalize = body.at("embd_normalize");
if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx));
}
}
// create and queue the task
json responses = json::array();
bool error = false;
@@ -4629,6 +4640,7 @@ int main(int argc, char ** argv) {
// OAI-compat
task.params.oaicompat = oaicompat;
task.params.embd_normalize = embd_normalize;
tasks.push_back(std::move(task));
}