Compare commits

...

6 Commits
b5979 ... b5985

Author SHA1 Message Date
R0CKSTAR
3f4fc97f1d musa: upgrade musa sdk to rc4.2.0 (#14498)
* musa: apply mublas API changes

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

* musa: update musa version to 4.2.0

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

* musa: restore MUSA graph settings in CMakeLists.txt

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

* musa: disable mudnnMemcpyAsync by default

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

* musa: switch back to non-mudnn images

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

* minor changes

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

* musa: restore rc in docker image tag

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

---------

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
2025-07-24 20:05:37 +01:00
Georgi Gerganov
2df255da3c sync : ggml
ggml-ci
2025-07-24 20:27:23 +03:00
Kai Pastor
60f816a79d cmake : fix usage issues (ggml/1257)
* CMake config: Create target only once

Fix error on repeated find_package(ggml).
For simplicity, check only for the top-level ggml::ggml.

* CMake config: Add CUDA link libs

* CMake config: Add OpenCL link libs

* CMake config: Use canonical find_dependency

Use set and append to control link lib variables.
Apply more $<LINK_ONLY...>.

* CMake config: Wire OpenMP dependency
2025-07-24 20:27:23 +03:00
Daniel Bevenius
5592f278b6 ggml-cpu : remove stdlib include from repack.cpp (ggml/1276)
This commit removes the inclusion of `<cstdlib>`.

The motivation for this change is that this source file does not seem to
use any functions from this header and the comment about `qsort` is a
little misleading/confusing.
2025-07-24 20:27:23 +03:00
Georgi Gerganov
e4868d16d2 context : perform output reorder lazily upon access after sync (#14853)
* context : perform output reorder after lazily upon access after sync

ggml-ci

* cont : add TODO
2025-07-24 16:31:48 +03:00
Xuan-Son Nguyen
820de57d4f chat : fix kimi-k2 chat template (#14852) 2025-07-24 13:59:56 +02:00
18 changed files with 194 additions and 106 deletions

View File

@@ -1,10 +1,10 @@
ARG UBUNTU_VERSION=22.04
# This needs to generally match the container host's environment.
ARG MUSA_VERSION=rc4.0.1
ARG MUSA_VERSION=rc4.2.0
# Target the MUSA build image
ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-mudnn-devel-ubuntu${UBUNTU_VERSION}
ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-devel-ubuntu${UBUNTU_VERSION}-amd64
ARG BASE_MUSA_RUN_CONTAINER=mthreads/musa:${MUSA_VERSION}-mudnn-runtime-ubuntu${UBUNTU_VERSION}
ARG BASE_MUSA_RUN_CONTAINER=mthreads/musa:${MUSA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}-amd64
FROM ${BASE_MUSA_DEV_CONTAINER} AS build

View File

@@ -515,7 +515,7 @@ jobs:
ubuntu-22-cmake-musa:
runs-on: ubuntu-22.04
container: mthreads/musa:rc4.0.1-mudnn-devel-ubuntu22.04
container: mthreads/musa:rc4.2.0-devel-ubuntu22.04-amd64
steps:
- name: Clone

View File

@@ -54,7 +54,7 @@ docker run --privileged -it \
-v $HOME/llama.cpp/ci-cache:/ci-cache \
-v $HOME/llama.cpp/ci-results:/ci-results \
-v $PWD:/ws -w /ws \
mthreads/musa:rc4.0.1-mudnn-devel-ubuntu22.04
mthreads/musa:rc4.2.0-devel-ubuntu22.04-amd64
```
Inside the container, execute the following commands:

View File

@@ -110,7 +110,7 @@ You may want to pass in some different `ARGS`, depending on the MUSA environment
The defaults are:
- `MUSA_VERSION` set to `rc4.0.1`
- `MUSA_VERSION` set to `rc4.2.0`
The resulting images, are essentially the same as the non-MUSA images:

View File

@@ -174,6 +174,8 @@ option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental,
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
option(GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 "ggml: enable rocWMMA FlashAttention on GFX12" OFF)
option(GGML_MUSA_GRAPHS "ggml: use MUSA graph, experimental, unstable" OFF)
option(GGML_MUSA_MUDNN_COPY "ggml: enable muDNN for accelerated copy" OFF)
option(GGML_VULKAN "ggml: use Vulkan" OFF)
option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)
option(GGML_VULKAN_DEBUG "ggml: enable Vulkan debug output" OFF)

View File

@@ -1,12 +1,108 @@
@PACKAGE_INIT@
@GGML_VARIABLES_EXPANDED@
@PACKAGE_INIT@
# Find all dependencies before creating any target.
include(CMakeFindDependencyMacro)
find_dependency(Threads)
if (NOT GGML_SHARED_LIB)
set(GGML_CPU_INTERFACE_LINK_LIBRARIES "")
set(GGML_CPU_INTERFACE_LINK_OPTIONS "")
if (APPLE AND GGML_ACCELERATE)
find_library(ACCELERATE_FRAMEWORK Accelerate)
if(NOT ACCELERATE_FRAMEWORK)
set(${CMAKE_FIND_PACKAGE_NAME}_FOUND 0)
return()
endif()
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${ACCELERATE_FRAMEWORK})
endif()
if (GGML_OPENMP_ENABLED)
find_dependency(OpenMP)
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
endif()
if (GGML_CPU_HBM)
find_library(memkind memkind)
if(NOT memkind)
set(${CMAKE_FIND_PACKAGE_NAME}_FOUND 0)
return()
endif()
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES memkind)
endif()
if (GGML_BLAS)
find_dependency(BLAS)
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${BLAS_LIBRARIES})
list(APPEND GGML_CPU_INTERFACE_LINK_OPTIONS ${BLAS_LINKER_FLAGS})
endif()
if (GGML_CUDA)
set(GGML_CUDA_INTERFACE_LINK_LIBRARIES "")
find_dependency(CUDAToolkit)
if (GGML_STATIC)
list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:CUDA::cudart_static>)
if (WIN32)
list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:CUDA::cublas> $<LINK_ONLY:CUDA::cublasLt>)
else()
list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:CUDA::cublas_static> $<LINK_ONLY:CUDA::cublasLt_static>)
endif()
endif()
if (NOT GGML_CUDA_NO_VMM)
list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:CUDA::cuda_driver>)
endif()
endif()
if (GGML_METAL)
find_library(FOUNDATION_LIBRARY Foundation)
find_library(METAL_FRAMEWORK Metal)
find_library(METALKIT_FRAMEWORK MetalKit)
if(NOT FOUNDATION_LIBRARY OR NOT METAL_FRAMEWORK OR NOT METALKIT_FRAMEWORK)
set(${CMAKE_FIND_PACKAGE_NAME}_FOUND 0)
return()
endif()
set(GGML_METAL_INTERFACE_LINK_LIBRARIES
${FOUNDATION_LIBRARY} ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK})
endif()
if (GGML_OPENCL)
find_dependency(OpenCL)
set(GGML_OPENCL_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:OpenCL::OpenCL>)
endif()
if (GGML_VULKAN)
find_dependency(Vulkan)
set(GGML_VULKAN_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:Vulkan::Vulkan>)
endif()
if (GGML_HIP)
find_dependency(hip)
find_dependency(hipblas)
find_dependency(rocblas)
set(GGML_HIP_INTERFACE_LINK_LIBRARIES hip::host roc::rocblas roc::hipblas)
endif()
if (GGML_SYCL)
set(GGML_SYCL_INTERFACE_LINK_LIBRARIES "")
find_package(DNNL)
if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES DNNL::dnnl)
endif()
if (WIN32)
find_dependency(IntelSYCL)
find_dependency(MKL)
list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
endif()
endif()
endif()
set_and_check(GGML_INCLUDE_DIR "@PACKAGE_GGML_INCLUDE_INSTALL_DIR@")
set_and_check(GGML_LIB_DIR "@PACKAGE_GGML_LIB_INSTALL_DIR@")
#set_and_check(GGML_BIN_DIR "@PACKAGE_GGML_BIN_INSTALL_DIR@")
if(NOT TARGET ggml::ggml)
find_package(Threads REQUIRED)
find_library(GGML_LIBRARY ggml
@@ -29,66 +125,6 @@ set_target_properties(ggml::ggml-base
PROPERTIES
IMPORTED_LOCATION "${GGML_BASE_LIBRARY}")
if (NOT GGML_SHARED_LIB)
if (APPLE AND GGML_ACCELERATE)
find_library(ACCELERATE_FRAMEWORK Accelerate REQUIRED)
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${ACCELERATE_FRAMEWORK})
endif()
if (GGML_OPENMP)
find_package(OpenMP REQUIRED)
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
endif()
if (GGML_CPU_HBM)
find_library(memkind memkind REQUIRED)
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES memkind)
endif()
if (GGML_BLAS)
find_package(BLAS REQUIRED)
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${BLAS_LIBRARIES})
list(APPEND GGML_CPU_INTERFACE_LINK_OPTIONS ${BLAS_LINKER_FLAGS})
endif()
if (GGML_CUDA)
find_package(CUDAToolkit REQUIRED)
endif()
if (GGML_METAL)
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
find_library(METAL_FRAMEWORK Metal REQUIRED)
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
list(APPEND GGML_METAL_INTERFACE_LINK_LIBRARIES
${FOUNDATION_LIBRARY} ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK})
endif()
if (GGML_VULKAN)
find_package(Vulkan REQUIRED)
list(APPEND GGML_VULKAN_INTERFACE_LINK_LIBRARIES Vulkan::Vulkan)
endif()
if (GGML_HIP)
find_package(hip REQUIRED)
find_package(hipblas REQUIRED)
find_package(rocblas REQUIRED)
list(APPEND GGML_HIP_INTERFACE_LINK_LIBRARIES hip::host roc::rocblas roc::hipblas)
endif()
if (GGML_SYCL)
find_package(DNNL)
if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES DNNL::dnnl)
endif()
if (WIN32)
find_package(IntelSYCL REQUIRED)
find_package(MKL REQUIRED)
list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
endif()
endif()
endif()
set(_ggml_all_targets "")
foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS})
string(REPLACE "-" "_" _ggml_backend_pfx "${_ggml_backend}")
@@ -149,4 +185,6 @@ set_target_properties(ggml::all
PROPERTIES
INTERFACE_LINK_LIBRARIES "${_ggml_all_targets}")
endif() # TARGET ggml::ggml
check_required_components(ggml)

View File

@@ -70,10 +70,12 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
if (GGML_OPENMP)
find_package(OpenMP)
if (OpenMP_FOUND)
set(GGML_OPENMP_ENABLED "ON" CACHE INTERNAL "")
target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_OPENMP)
target_link_libraries(${GGML_CPU_NAME} PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
else()
set(GGML_OPENMP_ENABLED "OFF" CACHE INTERNAL "")
message(WARNING "OpenMP not found")
endif()
endif()

View File

@@ -14,7 +14,6 @@
#include <cmath>
#include <cstring>
#include <cassert>
#include <cstdlib> // for qsort
#include <cstdio> // for GGML_ASSERT
#include "repack.h"

View File

@@ -765,7 +765,7 @@ struct ggml_tensor_extra_gpu {
};
#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS))
#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)) || defined(GGML_MUSA_GRAPHS)
#define USE_CUDA_GRAPH
#endif

View File

@@ -1,9 +1,9 @@
#include "cpy.cuh"
#include "dequantize.cuh"
#include "cpy-utils.cuh"
#ifdef GGML_USE_MUSA
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
#include "ggml-musa/mudnn.cuh"
#endif // GGML_USE_MUSA
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
@@ -121,7 +121,7 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int
// Copy destination pointers to GPU to be available when pointer indirection is in use
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) {
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers
CUDA_CHECK(cudaStreamSynchronize(stream));
if (cuda_graph->dest_ptrs_d != nullptr) {
@@ -314,7 +314,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
char ** dest_ptrs_d = nullptr;
int graph_cpynode_index = -1;
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d;
graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
@@ -324,11 +324,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
#endif
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
#ifdef GGML_USE_MUSA
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
CUDA_CHECK(mudnnMemcpyAsync(ctx, src1, src0));
} else
#endif // GGML_USE_MUSA
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
{
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
}
@@ -379,7 +379,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
}
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
}

View File

@@ -13,7 +13,7 @@
#define CUBLAS_OP_N MUBLAS_OP_N
#define CUBLAS_OP_T MUBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_TENSOR_OP_MATH
#define CUDA_R_16F MUSA_R_16F
#define CUDA_R_16BF MUSA_R_16BF
#define CUDA_R_32F MUSA_R_32F
@@ -29,7 +29,7 @@
#define cublasSgemm mublasSgemm
#define cublasStatus_t mublasStatus_t
#define cublasOperation_t mublasOperation_t
#define cublasGetStatusString mublasStatus_to_string
#define cublasGetStatusString mublasGetStatusString
#define cudaDataType_t musaDataType_t
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess

View File

@@ -34,8 +34,12 @@ if (MUSAToolkit_FOUND)
list(APPEND GGML_SOURCES_MUSA ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
list(APPEND GGML_SOURCES_MUSA ${SRCS})
file(GLOB SRCS "../ggml-musa/*.cu")
list(APPEND GGML_SOURCES_MUSA ${SRCS})
if (GGML_MUSA_MUDNN_COPY)
file(GLOB SRCS "../ggml-musa/*.cu")
list(APPEND GGML_SOURCES_MUSA ${SRCS})
add_compile_definitions(GGML_MUSA_MUDNN_COPY)
endif()
if (GGML_CUDA_FA_ALL_QUANTS)
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu")
@@ -72,6 +76,10 @@ if (MUSAToolkit_FOUND)
add_compile_definitions(GGML_USE_MUSA)
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
if (GGML_MUSA_GRAPHS)
add_compile_definitions(GGML_MUSA_GRAPHS)
endif()
if (GGML_CUDA_FORCE_MMQ)
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
endif()
@@ -97,10 +105,16 @@ if (MUSAToolkit_FOUND)
endif()
if (GGML_STATIC)
# TODO: mudnn has not provided static libraries yet
target_link_libraries(ggml-musa PRIVATE MUSA::musart_static MUSA::mublas_static)
# TODO: mudnn has not provided static libraries yet
# if (GGML_MUSA_MUDNN_COPY)
# target_link_libraries(ggml-musa PRIVATE mudnn_static)
# endif()
else()
target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas mudnn)
target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas)
if (GGML_MUSA_MUDNN_COPY)
target_link_libraries(ggml-musa PRIVATE mudnn)
endif()
endif()
if (GGML_CUDA_NO_VMM)

View File

@@ -956,6 +956,7 @@ extern "C" {
// in the order they have appeared in the batch.
// Rows: number of tokens for which llama_batch.logits[i] != 0
// Cols: n_vocab
// TODO: deprecate in favor of llama_get_logits_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522)
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
// Logits for the ith token. For positive indices, Equivalent to:
@@ -970,6 +971,7 @@ extern "C" {
// in the order they have appeared in the batch.
// shape: [n_outputs*n_embd]
// Otherwise, returns NULL.
// TODO: deprecate in favor of llama_get_embeddings_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522)
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
// Get the embeddings for the ith token. For positive indices, Equivalent to:

View File

@@ -1 +1 @@
3323219cd3cc050e5c7133cd4fc1e50d1f590faf
56938c4a3b2d923f42040f9ad32d229c76c466cd

View File

@@ -1933,12 +1933,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
}
},
{
LLM_ARCH_UNKNOWN,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
},
},
{
LLM_ARCH_DREAM,
{
@@ -1956,6 +1950,12 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_UNKNOWN,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
},
},
};
static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {

View File

@@ -718,10 +718,9 @@ int32_t llm_chat_apply_template(
}
ss << message->content << "<|im_end|>";
if (add_ass) {
ss << "<|im_assistant|>assistant<|im_middle|>";
}
}
if (add_ass) {
ss << "<|im_assistant|>assistant<|im_middle|>";
}
} else {
// template not supported

View File

@@ -508,12 +508,16 @@ enum llama_pooling_type llama_context::pooling_type() const {
}
float * llama_context::get_logits() {
output_reorder();
return logits;
}
float * llama_context::get_logits_ith(int32_t i) {
int64_t j = -1;
output_reorder();
try {
if (logits == nullptr) {
throw std::runtime_error("no logits");
@@ -550,12 +554,16 @@ float * llama_context::get_logits_ith(int32_t i) {
}
float * llama_context::get_embeddings() {
output_reorder();
return embd;
}
float * llama_context::get_embeddings_ith(int32_t i) {
int64_t j = -1;
output_reorder();
try {
if (embd == nullptr) {
throw std::runtime_error("no embeddings");
@@ -970,6 +978,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
// TODO: this clear of the buffer can easily be forgotten - need something better
embd_seq.clear();
output_swaps.clear();
bool did_optimize = false;
@@ -1189,9 +1198,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
// make the outputs have the same order they had in the user-provided batch
// note: this is mostly relevant for recurrent models atm
if (!sorted_output) {
const uint32_t n_vocab = model.vocab.n_tokens();
const uint64_t n_embd = model.hparams.n_embd;
GGML_ASSERT((size_t) n_outputs == out_ids.size());
// TODO: is there something more efficient which also minimizes swaps?
@@ -1207,16 +1213,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
continue;
}
std::swap(out_ids[i], out_ids[j_min]);
if (logits_size > 0) {
for (uint32_t k = 0; k < n_vocab; k++) {
std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
}
}
if (embd_size > 0) {
for (uint32_t k = 0; k < n_embd; k++) {
std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
}
}
// remember the swaps and apply them lazily upon logits/embeddings access
output_swaps.push_back({ i, j_min });
}
std::fill(output_ids.begin(), output_ids.end(), -1);
@@ -1307,6 +1306,30 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
return n_outputs_max;
}
void llama_context::output_reorder() {
const uint32_t n_vocab = model.vocab.n_tokens();
const uint64_t n_embd = model.hparams.n_embd;
for (uint32_t s = 0; s < output_swaps.size(); ++s) {
const uint32_t i0 = output_swaps[s].i0;
const uint32_t i1 = output_swaps[s].i1;
if (logits_size > 0) {
for (uint32_t k = 0; k < n_vocab; k++) {
std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
}
}
if (embd_size > 0) {
for (uint32_t k = 0; k < n_embd; k++) {
std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
}
}
}
output_swaps.clear();
}
//
// graph
//

View File

@@ -181,6 +181,8 @@ private:
// Returns max number of outputs for which space was reserved.
uint32_t output_reserve(int32_t n_outputs);
void output_reorder();
//
// graph
//
@@ -250,6 +252,13 @@ private:
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
struct swap_info {
uint32_t i0;
uint32_t i1;
};
std::vector<swap_info> output_swaps;
ggml_backend_sched_ptr sched;
ggml_backend_t backend_cpu = nullptr;