Compare commits

...

4 Commits

Author SHA1 Message Date
Xuan Son Nguyen 095058ca19 add arg --threads-sampling 2026-06-22 20:03:49 +02:00
Xuan Son Nguyen c62fdd5fd0 working 2026-06-22 19:38:25 +02:00
Xuan Son Nguyen 41ed530be2 wip 2026-06-22 19:30:11 +02:00
Xuan Son Nguyen fe03cce8db server: run sampling in a threadpool 2026-06-22 19:05:39 +02:00
5 changed files with 183 additions and 6 deletions
+10
View File
@@ -1168,6 +1168,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
));
add_opt(common_arg(
{"--threads-sampling"}, "N",
"number of threads to use during sampling (default: same as --threads)",
[](common_params & params, int value) {
params.sampling_n_threads = value;
if (params.sampling_n_threads <= 0) {
params.sampling_n_threads = std::thread::hardware_concurrency();
}
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-C", "--cpu-mask"}, "M",
"CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: \"\")",
+2
View File
@@ -471,6 +471,8 @@ struct common_params {
common_cpu_params cpuparams;
common_cpu_params cpuparams_batch;
int sampling_n_threads = -1; // number of threads for sampling, used by server
ggml_backend_sched_eval_callback cb_eval = nullptr;
void * cb_eval_user_data = nullptr;
+79
View File
@@ -1583,3 +1583,82 @@ server_tokens format_prompt_rerank(
return result;
}
//
// threadpool
//
server_threadpool::~server_threadpool() {
{
std::lock_guard<std::mutex> lock(mtx);
stop = true;
}
cv.notify_all();
for (auto & t : threads) t.join();
}
void server_threadpool::init(int n) {
// the caller (main thread) participates as a worker, so spawn n-1 threads
const int n_workers = std::max(1, n) - 1;
for (int i = 0; i < n_workers; i++) {
threads.emplace_back([this]() { run_worker(); });
}
}
void server_threadpool::run_worker() {
while (true) {
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(mtx);
cv.wait(lock, [this]() { return stop || !tasks.empty(); });
if (stop && tasks.empty()) return;
task = std::move(tasks.front());
tasks.pop();
}
task();
{
std::lock_guard<std::mutex> lock(mtx);
pending--;
}
cv_done.notify_all();
}
}
void server_threadpool::enqueue(std::function<void()> fn) {
{
std::lock_guard<std::mutex> lock(mtx);
GGML_ASSERT(!stop);
tasks.push(std::move(fn));
pending++;
}
cv.notify_one();
}
void server_threadpool::wait_all() {
// the calling thread helps drain the queue until no tasks remain pending
while (true) {
std::function<void()> task;
{
std::lock_guard<std::mutex> lock(mtx);
if (pending == 0) {
return;
}
if (!tasks.empty()) {
task = std::move(tasks.front());
tasks.pop();
}
}
if (task) {
task();
{
std::lock_guard<std::mutex> lock(mtx);
pending--;
}
cv_done.notify_all();
} else {
// no task available right now, but some are still pending (being run by workers)
std::unique_lock<std::mutex> lock(mtx);
cv_done.wait(lock, [this]() { return pending == 0 || !tasks.empty(); });
}
}
}
+41
View File
@@ -12,6 +12,11 @@
#include <string>
#include <vector>
#include <cinttypes>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <queue>
#include <functional>
using json = nlohmann::ordered_json;
@@ -370,3 +375,39 @@ server_tokens format_prompt_rerank(
mtmd_context * mctx,
const std::string & query,
const std::string & doc);
//
// threadpool utils
// to be used for multi-threaded sampling
//
// the main thread participates as one of the pool's workers, so init(n)
// only spawns n-1 background threads (the caller is the nth)
struct server_threadpool {
std::vector<std::thread> threads;
std::queue<std::function<void()>> tasks;
std::mutex mtx;
std::condition_variable cv;
std::condition_variable cv_done;
int pending = 0;
bool stop = false;
~server_threadpool();
void init(int n);
template<typename T>
void run_all(std::vector<T> & tasks, std::function<void(T&)> handler) {
for (auto & item : tasks) {
enqueue([&handler, &item]() {
handler(item);
});
}
// the calling thread runs tasks too, until all are done
wait_all();
}
private:
void enqueue(std::function<void()> fn);
void wait_all();
void run_worker();
};
+51 -6
View File
@@ -866,6 +866,9 @@ public:
// note: chat_params must not be refreshed upon existing sleeping state
server_chat_params chat_params;
// threadpool for parallel sampling
server_threadpool threadpool;
server_state_callback_t callback_state = [](server_state, json) -> void {};
server_context_impl() {
@@ -1457,6 +1460,16 @@ private:
metrics.init();
// initialize threadpool
{
int threadpool_size = params_base.sampling_n_threads;
if (threadpool_size <= 0) {
threadpool_size = params_base.cpuparams.n_threads;
}
SRV_DBG("%s: initializing threadpool, size = %d\n", __func__, threadpool_size);
threadpool.init(threadpool_size);
}
if (params_base.cache_idle_slots) {
if (params_base.cache_ram_mib == 0) {
SRV_WRN("%s", "--cache-idle-slots requires --cache-ram, disabling\n");
@@ -3699,6 +3712,12 @@ private:
return true;
}
struct sampling_task {
server_slot * slot = nullptr;
int32_t tok_idx = 0;
llama_token sampled_id = LLAMA_TOKEN_NULL; // result
};
void post_decode(int32_t n_batch_tokens, int32_t off, llama_batch & batch_view) {
// for checking if a given batch index is inside batch_view
auto is_inside_view = [&](int32_t idx) {
@@ -3720,7 +3739,13 @@ private:
slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end();
};
std::vector<sampling_task> smpl_tasks;
smpl_tasks.resize(slots.size());
bool need_sampling = false;
iterate(slots, [&](server_slot & slot) {
auto & smpl_task = smpl_tasks[slot.id];
// optionally send prompt processing progress
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
if (slot.task->params.stream && slot.task->params.return_progress) {
@@ -3765,15 +3790,35 @@ private:
return; // sample using speculative decoding
}
// shifted according to the current sub-batch
const int tok_idx = slot.i_batch - off;
// otherwise, we must sample the next token
// also shift batch idx according to the current sub-batch
smpl_task.slot = &slot;
smpl_task.tok_idx = slot.i_batch - off;
need_sampling = true;
});
llama_token id;
{
scoped_timer timer(t_sampl, n_sampl);
id = common_sampler_sample(slot.smpl.get(), slot.ctx_tgt, tok_idx);
// run multiple sampling tasks in parallel
GGML_ASSERT(smpl_tasks.size() == slots.size());
if (need_sampling) {
llama_synchronize(ctx_tgt);
threadpool.run_all<sampling_task>(smpl_tasks, [](sampling_task & task) {
if (task.slot) {
task.sampled_id = common_sampler_sample(task.slot->smpl.get(),
task.slot->ctx_tgt, task.tok_idx);
}
});
}
iterate(slots, [&](server_slot & slot) {
auto & smpl_task = smpl_tasks[slot.id];
if (!smpl_task.slot) {
return;
}
auto tok_idx = smpl_task.tok_idx;
auto id = smpl_task.sampled_id;
slot.i_batch = -1;
common_sampler_accept(slot.smpl.get(), id, true);