mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-23 04:03:15 +02:00
Compare commits
5 Commits
b9760
...
xsn/server
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
095058ca19 | ||
|
|
c62fdd5fd0 | ||
|
|
41ed530be2 | ||
|
|
fe03cce8db | ||
|
|
721354fbdf |
@@ -396,7 +396,7 @@ static bool parse_bool_value(const std::string & value) {
|
||||
// CLI argument parsing functions
|
||||
//
|
||||
|
||||
bool common_params_handle_models(common_params & params, llama_example curr_ex) {
|
||||
bool common_params_handle_models(common_params & params, llama_example curr_ex, common_download_callback * callback) {
|
||||
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
|
||||
params.speculative.types.end(),
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();
|
||||
@@ -408,6 +408,10 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex)
|
||||
opts.download_mtp = spec_type_draft_mtp;
|
||||
opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty();
|
||||
|
||||
if (callback) {
|
||||
opts.callback = callback;
|
||||
}
|
||||
|
||||
// sub-models (draft, mmproj, vocoder) are explicitly specified by the user,
|
||||
// so we should not auto-discover mtp/mmproj siblings for them
|
||||
common_download_opts sub_opts = opts;
|
||||
@@ -584,8 +588,11 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
|
||||
}
|
||||
|
||||
// export_graph_ops loads only metadata
|
||||
const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
|
||||
const bool skip_model_download =
|
||||
// server will call common_params_handle_models() later, so we skip it here
|
||||
ctx_arg.ex == LLAMA_EXAMPLE_SERVER ||
|
||||
// export_graph_ops loads only metadata
|
||||
ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
|
||||
|
||||
if (!skip_model_download) {
|
||||
// handle model and download
|
||||
@@ -594,7 +601,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
// model is required (except for server)
|
||||
// TODO @ngxson : maybe show a list of available models in CLI in this case
|
||||
if (params.model.path.empty()
|
||||
&& ctx_arg.ex != LLAMA_EXAMPLE_SERVER
|
||||
&& !params.usage
|
||||
&& !params.completion) {
|
||||
throw std::invalid_argument("error: --model is required\n");
|
||||
@@ -1162,6 +1168,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
}
|
||||
));
|
||||
add_opt(common_arg(
|
||||
{"--threads-sampling"}, "N",
|
||||
"number of threads to use during sampling (default: same as --threads)",
|
||||
[](common_params & params, int value) {
|
||||
params.sampling_n_threads = value;
|
||||
if (params.sampling_n_threads <= 0) {
|
||||
params.sampling_n_threads = std::thread::hardware_concurrency();
|
||||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"-C", "--cpu-mask"}, "M",
|
||||
"CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: \"\")",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "download.h"
|
||||
|
||||
#include <set>
|
||||
#include <map>
|
||||
@@ -133,7 +134,10 @@ void common_params_add_preset_options(std::vector<common_arg> & args);
|
||||
// return true if the model is ready to use
|
||||
// throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc)
|
||||
// if params.skip_download is true, no downloads will be attempted. return false if the model is invalid or missing (e.g. ETag check failed)
|
||||
bool common_params_handle_models(common_params & params, llama_example curr_ex);
|
||||
bool common_params_handle_models(
|
||||
common_params & params,
|
||||
llama_example curr_ex,
|
||||
common_download_callback * callback = nullptr);
|
||||
|
||||
// initialize argument parser context - used by test-arg-parser and preset
|
||||
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
|
||||
|
||||
@@ -471,6 +471,8 @@ struct common_params {
|
||||
common_cpu_params cpuparams;
|
||||
common_cpu_params cpuparams_batch;
|
||||
|
||||
int sampling_n_threads = -1; // number of threads for sampling, used by server
|
||||
|
||||
ggml_backend_sched_eval_callback cb_eval = nullptr;
|
||||
void * cb_eval_user_data = nullptr;
|
||||
|
||||
|
||||
@@ -204,9 +204,9 @@ Instead of building everything from the ground up (like what most AI agents will
|
||||
|
||||
The flow for downloading a new model:
|
||||
- POST request comes in --> `post_router_models` --> validation
|
||||
- `server_models::download()` is called
|
||||
- Sets up a new thread `inst.th` and runs the download inside
|
||||
- If a stop request comes in, set `stop_download` to `true`
|
||||
- A new `llama-server` subprocess will be spawned with special `SERVER_CHILD_MODE_DOWNLOAD`
|
||||
- Child process runs the download and report status back to router via stdin/out
|
||||
- If a stop request comes in, the router asks the child process to stop (same mechanism as running a model in child process)
|
||||
- Otherwise, upon completion, we call `load_models()` to refresh the list of models
|
||||
|
||||
### Notable Related PRs
|
||||
|
||||
@@ -1583,3 +1583,82 @@ server_tokens format_prompt_rerank(
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
//
|
||||
// threadpool
|
||||
//
|
||||
|
||||
server_threadpool::~server_threadpool() {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
stop = true;
|
||||
}
|
||||
cv.notify_all();
|
||||
for (auto & t : threads) t.join();
|
||||
}
|
||||
|
||||
void server_threadpool::init(int n) {
|
||||
// the caller (main thread) participates as a worker, so spawn n-1 threads
|
||||
const int n_workers = std::max(1, n) - 1;
|
||||
for (int i = 0; i < n_workers; i++) {
|
||||
threads.emplace_back([this]() { run_worker(); });
|
||||
}
|
||||
}
|
||||
|
||||
void server_threadpool::run_worker() {
|
||||
while (true) {
|
||||
std::function<void()> task;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mtx);
|
||||
cv.wait(lock, [this]() { return stop || !tasks.empty(); });
|
||||
if (stop && tasks.empty()) return;
|
||||
task = std::move(tasks.front());
|
||||
tasks.pop();
|
||||
}
|
||||
task();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
pending--;
|
||||
}
|
||||
cv_done.notify_all();
|
||||
}
|
||||
}
|
||||
|
||||
void server_threadpool::enqueue(std::function<void()> fn) {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
GGML_ASSERT(!stop);
|
||||
tasks.push(std::move(fn));
|
||||
pending++;
|
||||
}
|
||||
cv.notify_one();
|
||||
}
|
||||
|
||||
void server_threadpool::wait_all() {
|
||||
// the calling thread helps drain the queue until no tasks remain pending
|
||||
while (true) {
|
||||
std::function<void()> task;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
if (pending == 0) {
|
||||
return;
|
||||
}
|
||||
if (!tasks.empty()) {
|
||||
task = std::move(tasks.front());
|
||||
tasks.pop();
|
||||
}
|
||||
}
|
||||
if (task) {
|
||||
task();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
pending--;
|
||||
}
|
||||
cv_done.notify_all();
|
||||
} else {
|
||||
// no task available right now, but some are still pending (being run by workers)
|
||||
std::unique_lock<std::mutex> lock(mtx);
|
||||
cv_done.wait(lock, [this]() { return pending == 0 || !tasks.empty(); });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,11 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cinttypes>
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
#include <queue>
|
||||
#include <functional>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
@@ -370,3 +375,39 @@ server_tokens format_prompt_rerank(
|
||||
mtmd_context * mctx,
|
||||
const std::string & query,
|
||||
const std::string & doc);
|
||||
|
||||
//
|
||||
// threadpool utils
|
||||
// to be used for multi-threaded sampling
|
||||
//
|
||||
|
||||
// the main thread participates as one of the pool's workers, so init(n)
|
||||
// only spawns n-1 background threads (the caller is the nth)
|
||||
struct server_threadpool {
|
||||
std::vector<std::thread> threads;
|
||||
std::queue<std::function<void()>> tasks;
|
||||
std::mutex mtx;
|
||||
std::condition_variable cv;
|
||||
std::condition_variable cv_done;
|
||||
int pending = 0;
|
||||
bool stop = false;
|
||||
|
||||
~server_threadpool();
|
||||
void init(int n);
|
||||
|
||||
template<typename T>
|
||||
void run_all(std::vector<T> & tasks, std::function<void(T&)> handler) {
|
||||
for (auto & item : tasks) {
|
||||
enqueue([&handler, &item]() {
|
||||
handler(item);
|
||||
});
|
||||
}
|
||||
// the calling thread runs tasks too, until all are done
|
||||
wait_all();
|
||||
}
|
||||
|
||||
private:
|
||||
void enqueue(std::function<void()> fn);
|
||||
void wait_all();
|
||||
void run_worker();
|
||||
};
|
||||
|
||||
@@ -866,6 +866,9 @@ public:
|
||||
// note: chat_params must not be refreshed upon existing sleeping state
|
||||
server_chat_params chat_params;
|
||||
|
||||
// threadpool for parallel sampling
|
||||
server_threadpool threadpool;
|
||||
|
||||
server_state_callback_t callback_state = [](server_state, json) -> void {};
|
||||
|
||||
server_context_impl() {
|
||||
@@ -931,6 +934,8 @@ private:
|
||||
|
||||
bool sleeping = false;
|
||||
|
||||
int64_t t_last_load_progress_ms = 0;
|
||||
|
||||
void destroy() {
|
||||
spec.reset();
|
||||
ctx_dft.reset();
|
||||
@@ -1244,6 +1249,10 @@ private:
|
||||
}
|
||||
|
||||
if (has_mmproj) {
|
||||
if (callback_state) {
|
||||
callback_state(SERVER_STATE_LOADING, {{"stage", "mmproj_model"}});
|
||||
}
|
||||
|
||||
if (!is_resume) {
|
||||
mtmd_helper_log_set(common_log_default_callback, nullptr);
|
||||
}
|
||||
@@ -1451,6 +1460,16 @@ private:
|
||||
|
||||
metrics.init();
|
||||
|
||||
// initialize threadpool
|
||||
{
|
||||
int threadpool_size = params_base.sampling_n_threads;
|
||||
if (threadpool_size <= 0) {
|
||||
threadpool_size = params_base.cpuparams.n_threads;
|
||||
}
|
||||
SRV_DBG("%s: initializing threadpool, size = %d\n", __func__, threadpool_size);
|
||||
threadpool.init(threadpool_size);
|
||||
}
|
||||
|
||||
if (params_base.cache_idle_slots) {
|
||||
if (params_base.cache_ram_mib == 0) {
|
||||
SRV_WRN("%s", "--cache-idle-slots requires --cache-ram, disabling\n");
|
||||
@@ -3693,6 +3712,12 @@ private:
|
||||
return true;
|
||||
}
|
||||
|
||||
struct sampling_task {
|
||||
server_slot * slot = nullptr;
|
||||
int32_t tok_idx = 0;
|
||||
llama_token sampled_id = LLAMA_TOKEN_NULL; // result
|
||||
};
|
||||
|
||||
void post_decode(int32_t n_batch_tokens, int32_t off, llama_batch & batch_view) {
|
||||
// for checking if a given batch index is inside batch_view
|
||||
auto is_inside_view = [&](int32_t idx) {
|
||||
@@ -3714,7 +3739,13 @@ private:
|
||||
slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end();
|
||||
};
|
||||
|
||||
std::vector<sampling_task> smpl_tasks;
|
||||
smpl_tasks.resize(slots.size());
|
||||
bool need_sampling = false;
|
||||
|
||||
iterate(slots, [&](server_slot & slot) {
|
||||
auto & smpl_task = smpl_tasks[slot.id];
|
||||
|
||||
// optionally send prompt processing progress
|
||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
|
||||
if (slot.task->params.stream && slot.task->params.return_progress) {
|
||||
@@ -3759,15 +3790,35 @@ private:
|
||||
return; // sample using speculative decoding
|
||||
}
|
||||
|
||||
// shifted according to the current sub-batch
|
||||
const int tok_idx = slot.i_batch - off;
|
||||
// otherwise, we must sample the next token
|
||||
// also shift batch idx according to the current sub-batch
|
||||
smpl_task.slot = &slot;
|
||||
smpl_task.tok_idx = slot.i_batch - off;
|
||||
need_sampling = true;
|
||||
});
|
||||
|
||||
llama_token id;
|
||||
{
|
||||
scoped_timer timer(t_sampl, n_sampl);
|
||||
id = common_sampler_sample(slot.smpl.get(), slot.ctx_tgt, tok_idx);
|
||||
// run multiple sampling tasks in parallel
|
||||
GGML_ASSERT(smpl_tasks.size() == slots.size());
|
||||
if (need_sampling) {
|
||||
llama_synchronize(ctx_tgt);
|
||||
threadpool.run_all<sampling_task>(smpl_tasks, [](sampling_task & task) {
|
||||
if (task.slot) {
|
||||
task.sampled_id = common_sampler_sample(task.slot->smpl.get(),
|
||||
task.slot->ctx_tgt, task.tok_idx);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
iterate(slots, [&](server_slot & slot) {
|
||||
auto & smpl_task = smpl_tasks[slot.id];
|
||||
|
||||
if (!smpl_task.slot) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto tok_idx = smpl_task.tok_idx;
|
||||
auto id = smpl_task.sampled_id;
|
||||
|
||||
slot.i_batch = -1;
|
||||
|
||||
common_sampler_accept(slot.smpl.get(), id, true);
|
||||
|
||||
@@ -53,7 +53,7 @@ struct server_context_meta {
|
||||
};
|
||||
|
||||
enum server_state {
|
||||
// SERVER_STATE_DOWNLOADING,
|
||||
SERVER_STATE_DOWNLOADING,
|
||||
SERVER_STATE_LOADING,
|
||||
SERVER_STATE_READY,
|
||||
SERVER_STATE_SLEEPING,
|
||||
@@ -61,6 +61,7 @@ enum server_state {
|
||||
|
||||
static std::string server_state_to_str(server_state state) {
|
||||
switch (state) {
|
||||
case SERVER_STATE_DOWNLOADING: return "downloading";
|
||||
case SERVER_STATE_LOADING: return "loading";
|
||||
case SERVER_STATE_READY: return "ready";
|
||||
case SERVER_STATE_SLEEPING: return "sleeping";
|
||||
@@ -69,6 +70,7 @@ static std::string server_state_to_str(server_state state) {
|
||||
}
|
||||
|
||||
static server_state server_state_from_str(const std::string & str) {
|
||||
if (str == "downloading") return SERVER_STATE_DOWNLOADING;
|
||||
if (str == "loading") return SERVER_STATE_LOADING;
|
||||
if (str == "ready") return SERVER_STATE_READY;
|
||||
if (str == "sleeping") return SERVER_STATE_SLEEPING;
|
||||
|
||||
@@ -64,6 +64,17 @@ struct server_subproc {
|
||||
return sproc.has_value() && subprocess_alive(&sproc.value());
|
||||
}
|
||||
|
||||
void request_exit() {
|
||||
if (sproc.has_value()) {
|
||||
FILE * stdin_file = subprocess_stdin(&sproc.value());
|
||||
if (stdin_file) {
|
||||
fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT);
|
||||
fflush(stdin_file);
|
||||
}
|
||||
}
|
||||
stopped.store(true, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
void terminate() {
|
||||
if (!sproc.has_value()) {
|
||||
return;
|
||||
@@ -323,7 +334,7 @@ void server_models::notify_sse(const std::string & event, const std::string & mo
|
||||
}
|
||||
|
||||
void server_models::load_models() {
|
||||
// Phase 1: load presets from all sources — pure I/O, no lock needed
|
||||
// Phase 1: load presets from all sources - pure I/O, no lock needed
|
||||
// 1. cached models
|
||||
common_presets cached_models = ctx_preset.load_from_cache();
|
||||
SRV_INF("Loaded %zu cached model presets\n", cached_models.size());
|
||||
@@ -376,7 +387,7 @@ void server_models::load_models() {
|
||||
return source_map.count(name) ? source_map.at(name) : SERVER_MODEL_SOURCE_PRESET;
|
||||
};
|
||||
|
||||
// Helpers that read `mapping` — must be called while holding the lock.
|
||||
// Helpers that read `mapping` - must be called while holding the lock.
|
||||
std::unordered_set<std::string> custom_names;
|
||||
for (const auto & [name, preset] : custom_presets) custom_names.insert(name);
|
||||
auto join_set = [](const std::set<std::string> & s) {
|
||||
@@ -523,7 +534,7 @@ void server_models::load_models() {
|
||||
}
|
||||
}
|
||||
|
||||
// join outside the lock — monitoring thread calls update_status (needs lock)
|
||||
// join outside the lock - monitoring thread calls update_status (needs lock)
|
||||
lk.unlock();
|
||||
for (auto & th : threads_to_join) th.join();
|
||||
lk.lock();
|
||||
@@ -622,7 +633,7 @@ void server_models::load_models() {
|
||||
|
||||
apply_stop_timeout();
|
||||
|
||||
// clear reload flag before unlocking for autoload — load() blocks on !is_reloading,
|
||||
// clear reload flag before unlocking for autoload - load() blocks on !is_reloading,
|
||||
// so clearing it here (while still locked) prevents a deadlock in the autoload calls below
|
||||
is_reloading = false;
|
||||
cv.notify_all();
|
||||
@@ -815,17 +826,23 @@ void server_models::unload_lru() {
|
||||
}
|
||||
|
||||
void server_models::load(const std::string & name) {
|
||||
if (!has_model(name)) {
|
||||
throw std::runtime_error("model name=" + name + " is not found");
|
||||
load(name, load_options{});
|
||||
}
|
||||
|
||||
void server_models::load(const std::string & name, const load_options & opts) {
|
||||
if (!opts.custom_meta.has_value()) {
|
||||
if (!has_model(name)) {
|
||||
throw std::runtime_error("model name=" + name + " is not found");
|
||||
}
|
||||
unload_lru();
|
||||
}
|
||||
unload_lru();
|
||||
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
// edge case: block until any in-progress reload has finished so we always load
|
||||
// against the freshest preset and a consistent mapping state
|
||||
cv.wait(lk, [this]() { return !is_reloading; });
|
||||
|
||||
auto meta = mapping[name].meta;
|
||||
auto meta = opts.custom_meta.has_value() ? *opts.custom_meta : mapping[name].meta;
|
||||
if (meta.status != SERVER_MODEL_STATUS_UNLOADED) {
|
||||
SRV_INF("model %s is not ready\n", name.c_str());
|
||||
return;
|
||||
@@ -869,6 +886,12 @@ void server_models::load(const std::string & name) {
|
||||
std::vector<std::string> child_env = base_env; // copy
|
||||
child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port));
|
||||
|
||||
if (opts.mode == SERVER_CHILD_MODE_DOWNLOAD) {
|
||||
inst.meta.status = SERVER_MODEL_STATUS_DOWNLOADING;
|
||||
child_env.push_back("LLAMA_SERVER_CHILD_MODE=download");
|
||||
child_env.push_back("LLAMA_ARG_HF_REPO=" + name);
|
||||
}
|
||||
|
||||
SRV_INF("%s", "spawning server instance with args:\n");
|
||||
for (const auto & arg : child_args) {
|
||||
SRV_INF(" %s\n", arg.c_str());
|
||||
@@ -886,13 +909,17 @@ void server_models::load(const std::string & name) {
|
||||
if (result != 0) {
|
||||
throw std::runtime_error("failed to spawn server instance");
|
||||
}
|
||||
|
||||
inst.stdin_file = subprocess_stdin(&inst.subproc->get());
|
||||
}
|
||||
|
||||
// start a thread to manage the child process
|
||||
// captured variables are guaranteed to be destroyed only after the thread is joined
|
||||
inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port, stop_timeout = inst.meta.stop_timeout]() {
|
||||
inst.th = std::thread([
|
||||
this, name,
|
||||
child_proc = inst.subproc,
|
||||
port = inst.meta.port,
|
||||
stop_timeout = inst.meta.stop_timeout,
|
||||
child_mode = opts.mode
|
||||
]() {
|
||||
FILE * stdin_file = subprocess_stdin(&child_proc->get());
|
||||
FILE * stdout_file = subprocess_stdout(&child_proc->get()); // combined stdout/stderr
|
||||
|
||||
@@ -925,7 +952,7 @@ void server_models::load(const std::string & name) {
|
||||
return is_stopping() || child_proc->stopped.load(std::memory_order_acquire);
|
||||
});
|
||||
}
|
||||
// child crashed or finished on its own — skip graceful shutdown sequence
|
||||
// child crashed or finished on its own, skip graceful shutdown sequence
|
||||
if (child_proc->stopped.load(std::memory_order_acquire)) {
|
||||
return;
|
||||
}
|
||||
@@ -973,10 +1000,14 @@ void server_models::load(const std::string & name) {
|
||||
subprocess_destroy(&child_proc->get());
|
||||
|
||||
// update status and exit code
|
||||
this->update_status(name, {
|
||||
SERVER_MODEL_STATUS_UNLOADED,
|
||||
exit_code
|
||||
});
|
||||
if (child_mode == SERVER_CHILD_MODE_DOWNLOAD) {
|
||||
// instance will be cleaned up on next load_models() call
|
||||
} else {
|
||||
this->update_status(name, {
|
||||
SERVER_MODEL_STATUS_UNLOADED,
|
||||
exit_code
|
||||
});
|
||||
}
|
||||
SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code);
|
||||
});
|
||||
|
||||
@@ -984,7 +1015,7 @@ void server_models::load(const std::string & name) {
|
||||
{
|
||||
auto & old_instance = mapping[name];
|
||||
// old process should have exited already, but just in case, we clean it up here
|
||||
if (old_instance.subproc->is_alive()) {
|
||||
if (old_instance.subproc && old_instance.subproc->is_alive()) {
|
||||
SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str());
|
||||
old_instance.subproc->terminate(); // force kill
|
||||
}
|
||||
@@ -1001,92 +1032,13 @@ void server_models::load(const std::string & name) {
|
||||
cv.notify_all();
|
||||
}
|
||||
|
||||
// callback for model downloading functionality
|
||||
struct server_models_download_res : public common_download_callback {
|
||||
common_params_model model;
|
||||
common_download_opts opts;
|
||||
|
||||
std::function<bool()> should_stop;
|
||||
std::function<void(const common_download_progress & p)> on_progress;
|
||||
|
||||
bool is_ok = false;
|
||||
|
||||
bool run() {
|
||||
try {
|
||||
common_download_model(model, opts);
|
||||
is_ok = true;
|
||||
} catch (const std::exception & e) {
|
||||
auto model_name = model.get_name();
|
||||
SRV_ERR("download failed for model name=%s: %s\n", model_name.c_str(), e.what());
|
||||
is_ok = false;
|
||||
}
|
||||
return is_ok;
|
||||
}
|
||||
void on_start(const common_download_progress & p) override {
|
||||
on_progress(p);
|
||||
}
|
||||
void on_update(const common_download_progress & p) override {
|
||||
on_progress(p);
|
||||
}
|
||||
void on_done(const common_download_progress &, bool ok) override {
|
||||
is_ok = ok;
|
||||
}
|
||||
bool is_cancelled() const override {
|
||||
return should_stop();
|
||||
}
|
||||
};
|
||||
|
||||
void server_models::download(common_params_model && model, common_download_opts && opts) {
|
||||
std::string name = model.get_name();
|
||||
GGML_ASSERT(name == model.hf_repo);
|
||||
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
if (mapping.find(name) != mapping.end()) {
|
||||
throw std::runtime_error("model name=" + name + " already exists");
|
||||
}
|
||||
|
||||
instance_t inst;
|
||||
inst.meta.name = name;
|
||||
inst.meta.status = SERVER_MODEL_STATUS_DOWNLOADING;
|
||||
inst.subproc = std::make_shared<server_subproc>();
|
||||
|
||||
auto dl = std::make_unique<server_models_download_res>();
|
||||
dl->model = model; // copy
|
||||
dl->opts = opts; // copy
|
||||
|
||||
dl->should_stop = [sp = inst.subproc]() {
|
||||
return sp->stopped.load(std::memory_order_relaxed);
|
||||
};
|
||||
|
||||
dl->on_progress = [this, name](const common_download_progress & p) {
|
||||
update_download_progress(name, p, false);
|
||||
};
|
||||
|
||||
inst.th = std::thread([this, dl = std::move(dl)]() {
|
||||
dl->opts.callback = dl.get();
|
||||
bool ok = dl->run();
|
||||
auto model_name = dl->model.get_name();
|
||||
SRV_INF("download finished for model name=%s with status=%s\n",
|
||||
model_name.c_str(), ok ? "success" : "failure");
|
||||
update_download_progress(model_name, {}, true, ok);
|
||||
// need_reload is set inside update_download_progress under the mutex;
|
||||
// the next load_models() call will clean up this instance
|
||||
});
|
||||
|
||||
mapping[name] = std::move(inst);
|
||||
notify_sse("status_update", name, {
|
||||
{"status", server_model_status_to_string(SERVER_MODEL_STATUS_DOWNLOADING)},
|
||||
});
|
||||
cv.notify_all();
|
||||
}
|
||||
|
||||
void server_models::unload(const std::string & name) {
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
auto it = mapping.find(name);
|
||||
if (it != mapping.end()) {
|
||||
if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
|
||||
SRV_INF("cancelling download for model name=%s\n", name.c_str());
|
||||
it->second.subproc->stopped.store(true, std::memory_order_relaxed);
|
||||
it->second.subproc->request_exit();
|
||||
// for convenience, we wait the status change here
|
||||
wait(lk, name, [](const server_model_meta & new_meta) {
|
||||
return new_meta.status != SERVER_MODEL_STATUS_DOWNLOADING;
|
||||
@@ -1198,37 +1150,65 @@ void server_models::update_download_progress(const std::string & name, const com
|
||||
}
|
||||
|
||||
bool server_models::remove(const std::string & name) {
|
||||
auto meta = get_meta(name);
|
||||
// do everything under one lock acquisition; avoid get_meta() /
|
||||
// unload() because they can trigger load_models() which erases
|
||||
// transient DOWNLOADING / DOWNLOADED entries as a side-effect
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
|
||||
if (!meta.has_value()) {
|
||||
auto it = mapping.find(name);
|
||||
if (it == mapping.end()) {
|
||||
throw std::runtime_error("model name=" + name + " is not found");
|
||||
}
|
||||
if (meta->source != SERVER_MODEL_SOURCE_CACHE) {
|
||||
if (it->second.meta.source != SERVER_MODEL_SOURCE_CACHE) {
|
||||
throw std::runtime_error("model name=" + name + " is not removable (not from cache)");
|
||||
}
|
||||
|
||||
unload(name); // cancel download or stop running instance
|
||||
{
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
// a cancelled download lands on DOWNLOADED; a stopped instance lands on UNLOADED
|
||||
wait(lk, name, [](const server_model_meta & new_meta) {
|
||||
return new_meta.status == SERVER_MODEL_STATUS_UNLOADED
|
||||
|| new_meta.status == SERVER_MODEL_STATUS_DOWNLOADED;
|
||||
});
|
||||
// join before erasing - after status reaches UNLOADED/DOWNLOADED the thread no
|
||||
// longer acquires this mutex, so joining while holding it is safe
|
||||
if (mapping[name].th.joinable()) {
|
||||
mapping[name].th.join();
|
||||
if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
|
||||
// cancel in-flight download
|
||||
SRV_INF("cancelling download for model name=%s\n", name.c_str());
|
||||
it->second.subproc->request_exit();
|
||||
} else if (it->second.meta.is_running()) {
|
||||
// stop running instance
|
||||
SRV_INF("stopping model instance name=%s\n", name.c_str());
|
||||
stopping_models.insert(name);
|
||||
if (it->second.meta.status == SERVER_MODEL_STATUS_LOADING) {
|
||||
it->second.subproc->terminate();
|
||||
}
|
||||
// remove the model from disk (hold lock to prevent concurrent load)
|
||||
bool ok = common_download_remove(name);
|
||||
if (ok) {
|
||||
mapping.erase(name);
|
||||
}
|
||||
SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "failed");
|
||||
notify_sse("model_remove", name, {});
|
||||
return ok;
|
||||
cv_stop.notify_all();
|
||||
}
|
||||
|
||||
// wait until the monitoring thread finishes
|
||||
wait(lk, name, [](const server_model_meta & meta) {
|
||||
return meta.status == SERVER_MODEL_STATUS_UNLOADED
|
||||
|| meta.status == SERVER_MODEL_STATUS_DOWNLOADED;
|
||||
});
|
||||
|
||||
// re-find after wait - load_models() may have erased the entry during the wait
|
||||
it = mapping.find(name);
|
||||
if (it == mapping.end()) {
|
||||
// load_models() already joined the thread and erased the entry;
|
||||
// we just need to clean up the cached files on disk
|
||||
lk.unlock();
|
||||
bool ok = common_download_remove(name);
|
||||
SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "partial");
|
||||
notify_sse("model_remove", name, {});
|
||||
return true;
|
||||
}
|
||||
|
||||
// join before erasing - thread no longer acquires this mutex
|
||||
if (it->second.th.joinable()) {
|
||||
it->second.th.join();
|
||||
}
|
||||
|
||||
// remove from disk (best-effort: cancelled downloads may have no cached files)
|
||||
bool ok = common_download_remove(name);
|
||||
mapping.erase(name);
|
||||
if (!ok) {
|
||||
SRV_WRN("removing model name=%s from disk returned false (no cached files?)\n", name.c_str());
|
||||
}
|
||||
SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "partial");
|
||||
notify_sse("model_remove", name, {});
|
||||
return true;
|
||||
}
|
||||
|
||||
void server_models::wait(const std::string & name, std::function<bool(const server_model_meta &)> predicate) {
|
||||
@@ -1243,7 +1223,9 @@ void server_models::wait(std::unique_lock<std::mutex> & lk, const std::string &
|
||||
return predicate(it->second.meta);
|
||||
|
||||
}
|
||||
return false;
|
||||
// model was removed from mapping by another code path (e.g. load_models()).
|
||||
// nothing left to wait for - tell the caller to proceed.
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1328,6 +1310,31 @@ void server_models::handle_child_state(const std::string & name, const std::stri
|
||||
}
|
||||
|
||||
switch (state) {
|
||||
case SERVER_STATE_DOWNLOADING:
|
||||
{
|
||||
std::string result = json_value(payload, "result", std::string());
|
||||
std::string url = json_value(payload, "url", std::string());
|
||||
auto request_exit = [&]() {
|
||||
std::lock_guard<std::mutex> lk(mutex);
|
||||
auto it = mapping.find(name);
|
||||
if (it != mapping.end()) {
|
||||
return it->second.subproc->request_exit();
|
||||
}
|
||||
};
|
||||
if (result == "download_finished") {
|
||||
update_download_progress(name, {}, true, true);
|
||||
request_exit();
|
||||
} else if (result == "download_failed") {
|
||||
update_download_progress(name, {}, true, false);
|
||||
request_exit();
|
||||
} else if (!url.empty()) {
|
||||
common_download_progress p;
|
||||
p.url = url;
|
||||
p.downloaded = json_value(payload, "downloaded", (size_t)0);
|
||||
p.total = json_value(payload, "total", (size_t)0);
|
||||
update_download_progress(name, p, false);
|
||||
}
|
||||
} break;
|
||||
case SERVER_STATE_LOADING:
|
||||
{
|
||||
update_status(name, {
|
||||
@@ -1366,6 +1373,90 @@ bool server_child::is_child() {
|
||||
return router_port != nullptr;
|
||||
}
|
||||
|
||||
server_child_mode server_child::get_mode() {
|
||||
const char * mode = std::getenv("LLAMA_SERVER_CHILD_MODE");
|
||||
std::string mode_str(mode ? mode : "");
|
||||
if (mode_str == "download") {
|
||||
return SERVER_CHILD_MODE_DOWNLOAD;
|
||||
} else {
|
||||
return SERVER_CHILD_MODE_NORMAL;
|
||||
}
|
||||
}
|
||||
|
||||
struct server_download_state : public common_download_callback {
|
||||
server_child * self;
|
||||
std::function<bool()> should_stop;
|
||||
std::atomic<int64_t> last_progress_time{0}; // multiple files downloading in different threads
|
||||
bool is_ok = false;
|
||||
|
||||
server_download_state(server_child * s) : self(s) {}
|
||||
|
||||
bool run(common_params & params) {
|
||||
try {
|
||||
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, this);
|
||||
is_ok = true;
|
||||
} catch (const std::exception & e) {
|
||||
auto model_name = params.model.get_name();
|
||||
SRV_ERR("download failed for model name=%s: %s\n", model_name.c_str(), e.what());
|
||||
is_ok = false;
|
||||
}
|
||||
return is_ok;
|
||||
}
|
||||
void on_progress(const common_download_progress & p) {
|
||||
json data = {
|
||||
{"url", p.url},
|
||||
{"downloaded", p.downloaded},
|
||||
{"total", p.total},
|
||||
};
|
||||
self->notify_to_router(server_state_to_str(SERVER_STATE_DOWNLOADING), data);
|
||||
}
|
||||
void on_start(const common_download_progress & p) override {
|
||||
on_progress(p);
|
||||
}
|
||||
void on_update(const common_download_progress & p) override {
|
||||
int64_t now = ggml_time_ms();
|
||||
// throttle progress updates to avoid flooding logs
|
||||
if (now - last_progress_time.load(std::memory_order_relaxed) >= 100) {
|
||||
on_progress(p);
|
||||
last_progress_time.store(now, std::memory_order_relaxed);
|
||||
}
|
||||
}
|
||||
void on_done(const common_download_progress & p, bool) override {
|
||||
on_progress(p);
|
||||
}
|
||||
bool is_cancelled() const override {
|
||||
return should_stop ? should_stop() : false;
|
||||
}
|
||||
};
|
||||
|
||||
int server_child::run_download(common_params & params) {
|
||||
auto cancelled = std::make_shared<std::atomic<bool>>(false);
|
||||
|
||||
// monitor stdin for cancellation command from the router
|
||||
std::thread signal_thread = setup([cancelled](int) {
|
||||
cancelled->store(true, std::memory_order_relaxed);
|
||||
});
|
||||
|
||||
server_download_state dl(this);
|
||||
dl.should_stop = [cancelled]() {
|
||||
return cancelled->load(std::memory_order_relaxed);
|
||||
};
|
||||
|
||||
bool ok = dl.run(params);
|
||||
|
||||
notify_to_router(server_state_to_str(SERVER_STATE_DOWNLOADING), {
|
||||
{"result", ok ? "download_finished" : "download_failed"},
|
||||
});
|
||||
|
||||
// router should send CMD_ROUTER_TO_CHILD_EXIT after receiving the result
|
||||
if (signal_thread.joinable()) {
|
||||
signal_thread.join();
|
||||
}
|
||||
|
||||
SRV_INF("download completed %s\n", ok ? "successfully" : "with errors");
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::thread server_child::setup(const std::function<void(int)> & shutdown_handler) {
|
||||
// setup thread for monitoring stdin
|
||||
return std::thread([shutdown_handler]() {
|
||||
@@ -1639,7 +1730,7 @@ void server_models_routes::init_routes() {
|
||||
res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST));
|
||||
return res;
|
||||
}
|
||||
if (!model->is_running()) {
|
||||
if (!model->is_running() && model->status != SERVER_MODEL_STATUS_DOWNLOADING) {
|
||||
res_err(res, format_error_response("model is not running", ERROR_TYPE_INVALID_REQUEST));
|
||||
return res;
|
||||
}
|
||||
@@ -1680,8 +1771,9 @@ void server_models_routes::init_routes() {
|
||||
|
||||
model.hf_repo = name;
|
||||
opts.bearer_token = params.hf_token;
|
||||
opts.download_mmproj = true;
|
||||
opts.download_mtp = true;
|
||||
// note: we only check main model, no need sidecar here
|
||||
opts.download_mmproj = false;
|
||||
opts.download_mtp = false;
|
||||
|
||||
// first, only check if the model is valid and can be downloaded
|
||||
opts.skip_download = true;
|
||||
@@ -1702,10 +1794,21 @@ void server_models_routes::init_routes() {
|
||||
throw std::invalid_argument("model validation failed, unable to download");
|
||||
}
|
||||
|
||||
// reject if model already exists
|
||||
if (models.has_model(name)) {
|
||||
throw std::invalid_argument("model '" + name + "' already exists");
|
||||
}
|
||||
|
||||
// then, proceed with the actual download
|
||||
opts.skip_download = false;
|
||||
SRV_INF("starting download for model '%s'\n", name.c_str());
|
||||
models.download(std::move(model), std::move(opts));
|
||||
{
|
||||
server_models::load_options load_opts;
|
||||
load_opts.mode = SERVER_CHILD_MODE_DOWNLOAD;
|
||||
load_opts.custom_meta = server_model_meta{};
|
||||
load_opts.custom_meta->source = SERVER_MODEL_SOURCE_CACHE;
|
||||
load_opts.custom_meta->name = name;
|
||||
models.load(name, load_opts);
|
||||
}
|
||||
|
||||
res_ok(res, {{"success", true}});
|
||||
return res;
|
||||
@@ -1719,10 +1822,7 @@ void server_models_routes::init_routes() {
|
||||
throw std::invalid_argument("model must be a non-empty string");
|
||||
}
|
||||
|
||||
bool ok = models.remove(name);
|
||||
if (!ok) {
|
||||
throw std::runtime_error("failed to remove model '" + name + "'");
|
||||
}
|
||||
models.remove(name); // throws on error
|
||||
|
||||
res_ok(res, {{"success", true}});
|
||||
return res;
|
||||
|
||||
@@ -40,6 +40,11 @@ enum server_model_source {
|
||||
SERVER_MODEL_SOURCE_CACHE,
|
||||
};
|
||||
|
||||
enum server_child_mode {
|
||||
SERVER_CHILD_MODE_NORMAL, // load the model and run normally
|
||||
SERVER_CHILD_MODE_DOWNLOAD, // download the model and exit
|
||||
};
|
||||
|
||||
static std::string server_model_status_to_string(server_model_status status) {
|
||||
switch (status) {
|
||||
case SERVER_MODEL_STATUS_DOWNLOADING: return "downloading";
|
||||
@@ -105,7 +110,6 @@ private:
|
||||
std::shared_ptr<server_subproc> subproc; // shared between main thread and monitoring thread
|
||||
std::thread th;
|
||||
server_model_meta meta;
|
||||
FILE * stdin_file = nullptr;
|
||||
};
|
||||
|
||||
std::mutex mutex;
|
||||
@@ -161,16 +165,19 @@ public:
|
||||
// return a copy of all model metadata (thread-safe)
|
||||
std::vector<server_model_meta> get_all_meta();
|
||||
|
||||
struct load_options {
|
||||
server_child_mode mode = SERVER_CHILD_MODE_NORMAL;
|
||||
// used for spawning a downloading child process
|
||||
std::optional<server_model_meta> custom_meta = std::nullopt;
|
||||
};
|
||||
|
||||
// load and unload model instances
|
||||
// these functions are thread-safe
|
||||
void load(const std::string & name);
|
||||
void load(const std::string & name, const load_options & opts);
|
||||
void unload(const std::string & name);
|
||||
void unload_all();
|
||||
|
||||
// download a new model, progress is reported via SSE
|
||||
// to stop the download, call unload()
|
||||
void download(common_params_model && model, common_download_opts && opts);
|
||||
|
||||
struct update_status_args {
|
||||
server_model_status status;
|
||||
int exit_code = 0; // only valid if status == UNLOADED
|
||||
@@ -213,9 +220,12 @@ public:
|
||||
struct server_child {
|
||||
// serializes the notify_to_router writes
|
||||
std::mutex mtx_stdout;
|
||||
std::atomic<bool> is_finished_downloading = false; // set by run_download
|
||||
|
||||
// return true if the current process is a child server instance
|
||||
bool is_child();
|
||||
server_child_mode get_mode();
|
||||
int run_download(common_params & params);
|
||||
|
||||
// register the shutdown_handler to be called by the router
|
||||
// return the monitoring thread (to be joined by the caller)
|
||||
|
||||
@@ -134,6 +134,7 @@ int llama_server(int argc, char ** argv) {
|
||||
//
|
||||
|
||||
// register API routes
|
||||
server_child child; // only used in non-router mode
|
||||
server_routes routes(params, ctx_server);
|
||||
server_tools tools;
|
||||
|
||||
@@ -254,11 +255,21 @@ int llama_server(int argc, char ** argv) {
|
||||
ctx_http.post("/tools", ex_wrapper(tools.handle_post));
|
||||
}
|
||||
|
||||
//
|
||||
// Handle downloading model
|
||||
//
|
||||
|
||||
if (child.is_child() && child.get_mode() == SERVER_CHILD_MODE_DOWNLOAD) {
|
||||
return child.run_download(params);
|
||||
} else if (!is_router_server) {
|
||||
// single-model mode (NOT spawned by router)
|
||||
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER);
|
||||
}
|
||||
|
||||
//
|
||||
// Start the server
|
||||
//
|
||||
|
||||
server_child child; // only used in non-router mode
|
||||
std::function<void()> clean_up;
|
||||
|
||||
if (is_router_server) {
|
||||
|
||||
@@ -257,14 +257,25 @@ def test_router_reload_models():
|
||||
|
||||
|
||||
MODEL_DOWNLOAD_ID = "ggml-org/test-model-router-download:F16"
|
||||
MODEL_DOWNLOAD_TIMEOUT = 300
|
||||
MODEL_DOWNLOAD_TIMEOUT = 30
|
||||
|
||||
|
||||
def _listen_sse(server: ServerProcess, collected: list, stop: threading.Event):
|
||||
"""Collect /models/sse events into `collected` until `stop` is set."""
|
||||
def _listen_sse(
|
||||
server: ServerProcess, collected: list, stop: threading.Event, ready: threading.Event | None = None
|
||||
):
|
||||
"""Collect /models/sse events into `collected` until `stop` is set.
|
||||
|
||||
When `ready` is provided, it is set once the streaming response is open,
|
||||
i.e. the server has accepted the connection and registered us as a
|
||||
subscriber. Callers that trigger one-shot events (e.g. download_finished)
|
||||
must wait on `ready` before acting, otherwise the event can be broadcast
|
||||
before this client is subscribed and be lost.
|
||||
"""
|
||||
url = f"http://{server.server_host}:{server.server_port}/models/sse"
|
||||
try:
|
||||
with requests.get(url, stream=True, timeout=MODEL_DOWNLOAD_TIMEOUT) as resp:
|
||||
if ready is not None:
|
||||
ready.set()
|
||||
for line_bytes in resp.iter_lines():
|
||||
if stop.is_set():
|
||||
break
|
||||
@@ -294,11 +305,17 @@ def test_router_download_model():
|
||||
|
||||
sse_events: list = []
|
||||
stop = threading.Event()
|
||||
sse_ready = threading.Event()
|
||||
sse_thread = threading.Thread(
|
||||
target=_listen_sse, args=(server, sse_events, stop), daemon=True
|
||||
target=_listen_sse, args=(server, sse_events, stop, sse_ready), daemon=True
|
||||
)
|
||||
sse_thread.start()
|
||||
|
||||
# wait for the SSE client to be subscribed before triggering the download,
|
||||
# otherwise the one-shot download_finished event can be broadcast before
|
||||
# this client is registered and be lost
|
||||
assert sse_ready.wait(10), "SSE client failed to connect"
|
||||
|
||||
# Trigger the download
|
||||
res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID})
|
||||
assert res.status_code == 200
|
||||
@@ -328,13 +345,17 @@ def test_router_delete_model():
|
||||
|
||||
# Ensure the model exists (download it if needed)
|
||||
if MODEL_DOWNLOAD_ID not in _get_model_ids(is_reload=False):
|
||||
res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID})
|
||||
assert res.status_code == 200
|
||||
sse_events: list = []
|
||||
stop = threading.Event()
|
||||
sse_ready = threading.Event()
|
||||
threading.Thread(
|
||||
target=_listen_sse, args=(server, sse_events, stop), daemon=True
|
||||
target=_listen_sse, args=(server, sse_events, stop, sse_ready), daemon=True
|
||||
).start()
|
||||
# subscribe before triggering the download so the one-shot
|
||||
# download_finished event is not lost (see test_router_download_model)
|
||||
assert sse_ready.wait(10), "SSE client failed to connect"
|
||||
res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID})
|
||||
assert res.status_code == 200
|
||||
finished = _wait_for_sse_event(
|
||||
sse_events, "download_finished", MODEL_DOWNLOAD_ID, MODEL_DOWNLOAD_TIMEOUT
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user