Compare commits

..

4 Commits

Author SHA1 Message Date
Xuan Son Nguyen 095058ca19 add arg --threads-sampling 2026-06-22 20:03:49 +02:00
Xuan Son Nguyen c62fdd5fd0 working 2026-06-22 19:38:25 +02:00
Xuan Son Nguyen 41ed530be2 wip 2026-06-22 19:30:11 +02:00
Xuan Son Nguyen fe03cce8db server: run sampling in a threadpool 2026-06-22 19:05:39 +02:00
47 changed files with 1712 additions and 2464 deletions
+1 -1
View File
@@ -10,7 +10,7 @@
# ggml-org/ggml-rpc : rgerganov
# ggml-org/ggml-sycl : arthw
# ggml-org/ggml-vulkan : 0cc4m, jeffbolznv
# ggml-org/ggml-webgpu : reeselevine, yomaytk
# ggml-org/ggml-webgpu : reeselevine
# ggml-org/ggml-zdnn : taronaeo
# ggml-org/llama-common : ggerganov, aldehir, angt, danbev, ngxson, pwilkin
# ggml-org/llama-mtmd : ngxson
+17 -16
View File
@@ -301,8 +301,6 @@ static handle_model_result common_params_handle_model(struct common_params_model
const common_download_opts & opts) {
handle_model_result result;
// TODO @ngxson : refactor this into a new common_model_download_context
if (!model.docker_repo.empty()) {
model.path = common_docker_resolve_model(model.docker_repo);
} else if (!model.hf_repo.empty()) {
@@ -398,7 +396,7 @@ static bool parse_bool_value(const std::string & value) {
// CLI argument parsing functions
//
bool common_params_handle_models(common_params & params, llama_example curr_ex, const common_params_handle_models_params & handle_params) {
bool common_params_handle_models(common_params & params, llama_example curr_ex, common_download_callback * callback) {
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
params.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();
@@ -409,10 +407,9 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex,
opts.skip_download = params.skip_download;
opts.download_mtp = spec_type_draft_mtp;
opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty();
opts.preset_only = handle_params.preset_only;
if (handle_params.callback) {
opts.callback = handle_params.callback;
if (callback) {
opts.callback = callback;
}
// sub-models (draft, mmproj, vocoder) are explicitly specified by the user,
@@ -599,12 +596,13 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
if (!skip_model_download) {
// handle model and download
common_params_handle_models(params, ctx_arg.ex, {});
common_params_handle_models(params, ctx_arg.ex);
// model is required (except for server)
// TODO @ngxson : maybe show a list of available models in CLI in this case
bool can_skip_model = params.usage || params.completion || !params.server_base.empty();
if (!can_skip_model && params.model.path.empty()) {
if (params.model.path.empty()
&& !params.usage
&& !params.completion) {
throw std::invalid_argument("error: --model is required\n");
}
}
@@ -1118,13 +1116,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.completion = true;
}
));
add_opt(common_arg(
{"--server-base"}, "URL",
string_format("connect to this server instead of starting a new one, example: 'http://localhost:8080' (default: none)"),
[](common_params & params, const std::string & value) {
params.server_base = value;
}
).set_examples({LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
{"--verbose-prompt"},
string_format("print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false"),
@@ -1177,6 +1168,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
));
add_opt(common_arg(
{"--threads-sampling"}, "N",
"number of threads to use during sampling (default: same as --threads)",
[](common_params & params, int value) {
params.sampling_n_threads = value;
if (params.sampling_n_threads <= 0) {
params.sampling_n_threads = std::thread::hardware_concurrency();
}
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-C", "--cpu-mask"}, "M",
"CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: \"\")",
+1 -6
View File
@@ -130,11 +130,6 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
// see: https://github.com/ggml-org/llama.cpp/issues/18163
void common_params_add_preset_options(std::vector<common_arg> & args);
struct common_params_handle_models_params {
common_download_callback * callback = nullptr;
bool preset_only = false; // if true, only check & download remote preset (for router mode)
};
// populate model paths (main model, mmproj, etc) from -hf if necessary
// return true if the model is ready to use
// throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc)
@@ -142,7 +137,7 @@ struct common_params_handle_models_params {
bool common_params_handle_models(
common_params & params,
llama_example curr_ex,
const common_params_handle_models_params & handle_params);
common_download_callback * callback = nullptr);
// initialize argument parser context - used by test-arg-parser and preset
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
+53 -103
View File
@@ -90,93 +90,41 @@ std::string common_chat_msg::render_content(const std::string & delimiter) const
return text;
}
common_chat_role common_chat_role_from_string(const std::string & role) {
if (role == "system") { return COMMON_CHAT_ROLE_SYSTEM; }
if (role == "assistant") { return COMMON_CHAT_ROLE_ASSISTANT; }
if (role == "user") { return COMMON_CHAT_ROLE_USER; }
if (role == "tool") { return COMMON_CHAT_ROLE_TOOL; }
return COMMON_CHAT_ROLE_UNKNOWN;
}
const char * common_chat_role_to_string(common_chat_role role) {
switch (role) {
case COMMON_CHAT_ROLE_SYSTEM: return "system";
case COMMON_CHAT_ROLE_ASSISTANT: return "assistant";
case COMMON_CHAT_ROLE_USER: return "user";
case COMMON_CHAT_ROLE_TOOL: return "tool";
case COMMON_CHAT_ROLE_UNKNOWN: return "";
}
return "";
}
json common_chat_msg_delimiters::to_json() const {
json result = json::array();
for (const auto & d : delimiters) {
result.push_back({
{ "role", common_chat_role_to_string(d.role) },
{ "delimiter", d.delimiter },
});
}
return result;
}
common_chat_msg_delimiters common_chat_msg_delimiters_parse(const json & delimiters) {
common_chat_msg_delimiters result;
if (!delimiters.is_array()) {
return result;
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims) {
if (delims.empty() || prompt.empty()) {
return {};
}
result.delimiters.reserve(delimiters.size());
for (const auto & d : delimiters) {
if (!d.is_object()) {
continue;
auto parser = build_peg_parser([&](common_peg_parser_builder & p) {
std::vector<std::string> all_delims;
std::vector<common_peg_parser> tagged_messages;
all_delims.reserve(delims.size());
tagged_messages.reserve(delims.size());
for (const auto & d : delims) {
all_delims.push_back(d.delimiter);
}
result.delimiters.push_back({
common_chat_role_from_string(d.value("role", std::string())),
d.value("delimiter", std::string()),
});
}
return result;
}
void common_chat_msg_delimiters::tokenize(const llama_vocab * vocab) {
for (auto & d : delimiters) {
d.tokens = common_tokenize(vocab, d.delimiter, false, true);
}
}
common_chat_msg_spans common_chat_msg_delimiters::split(const llama_tokens & tokens, const std::map<size_t, size_t> & skips) const {
std::vector<std::pair<common_chat_role, size_t>> matches;
auto skip = skips.begin();
for (size_t i = 0; i < tokens.size();) {
if (skip != skips.end() && i == skip->first) {
i += skip->second;
++skip;
continue;
auto any_delim = p.until_one_of(all_delims);
for (const auto & d : delims) {
tagged_messages.push_back(p.tag(d.role, p.literal(d.delimiter) + any_delim));
}
for (const auto & d : delimiters) {
if (i + d.tokens.size() > tokens.size()) {
continue;
}
if (std::equal(d.tokens.begin(), d.tokens.end(), tokens.begin() + i)) {
matches.emplace_back(d.role, i);
break;
}
return any_delim + p.zero_or_more(p.choice(tagged_messages)) + p.end();
});
common_peg_parse_context ctx(prompt);
const auto result = parser.parse(ctx);
if (!result.success()) {
return {};
}
std::vector<common_chat_msg_span> spans;
ctx.ast.visit(result, [&](const common_peg_ast_node & node) {
if (!node.tag.empty()) {
spans.push_back({ node.tag, node.start, node.end - node.start });
}
i++;
}
matches.emplace_back(COMMON_CHAT_ROLE_UNKNOWN, tokens.size());
common_chat_msg_spans spans;
for (size_t i = 0; i + 1 < matches.size(); i++) {
const auto & curr = matches[i];
const auto & next = matches[i + 1];
spans.add(curr.first, curr.second, next.second - curr.second);
}
});
return spans;
}
@@ -1133,13 +1081,13 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
data.prompt = prompt;
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override= */ adjusted_messages);
data.message_delimiters = {
{ COMMON_CHAT_ROLE_ASSISTANT, "<|start|>assistant" },
{ COMMON_CHAT_ROLE_USER, "<|start|>user" },
{ COMMON_CHAT_ROLE_SYSTEM, "<|start|>developer" },
{ COMMON_CHAT_ROLE_SYSTEM, "<|start|>system" },
{ COMMON_CHAT_ROLE_TOOL, "<|start|>functions" },
};
data.message_spans = common_chat_split_by_role(prompt, {
{ "assistant", "<|start|>assistant" },
{ "user", "<|start|>user" },
{ "system", "<|start|>developer" },
{ "system", "<|start|>system" },
{ "tool", "<|start|>functions" },
});
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;
@@ -1280,10 +1228,10 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
data.prompt += data.generation_prompt;
}
data.message_delimiters = {
{ COMMON_CHAT_ROLE_USER, "<|turn>user" },
{ COMMON_CHAT_ROLE_ASSISTANT, "<|turn>model" },
};
data.message_spans = common_chat_split_by_role(data.prompt, {
{ "user", "<|turn>user\n" },
{ "assistant", "<|turn>model\n" },
});
data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4;
data.supports_thinking = true;
@@ -2082,15 +2030,15 @@ static common_chat_params common_chat_params_init_cohere2moe(const common_chat_t
RESULT_START, RESULT_END,
};
// Declare per-role message delimiters. Tool results are rendered with the
// Split the rendered prompt into per-role message spans. Tool results are rendered with the
// system token followed by <|START_TOOL_RESULT|>, so the "tool" delimiter must be listed before
// the plain "system" one (it is a strict superset, and the role split tries delimiters in order).
data.message_delimiters = {
{ COMMON_CHAT_ROLE_ASSISTANT, GEN_PREFIX },
{ COMMON_CHAT_ROLE_USER, TURN_START + USER },
{ COMMON_CHAT_ROLE_TOOL, TURN_START + SYSTEM + RESULT_START },
{ COMMON_CHAT_ROLE_SYSTEM, TURN_START + SYSTEM },
};
data.message_spans = common_chat_split_by_role(data.prompt, {
{ "assistant", GEN_PREFIX },
{ "user", TURN_START + USER },
{ "tool", TURN_START + SYSTEM + RESULT_START },
{ "system", TURN_START + SYSTEM },
});
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
@@ -2578,15 +2526,17 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
autoparser.analyze_template(tmpl);
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
common_chat_msg_delimiters delimiters;
std::vector<common_chat_msg_delimiter> delimiters;
if (!autoparser.assistant_start.empty()) {
delimiters.add(COMMON_CHAT_ROLE_ASSISTANT, autoparser.assistant_start);
delimiters.push_back({ "assistant", autoparser.assistant_start });
}
if (!autoparser.user_start.empty()) {
delimiters.add(COMMON_CHAT_ROLE_USER, autoparser.user_start);
delimiters.push_back({ "user", autoparser.user_start });
}
auto_params.message_delimiters = std::move(delimiters);
if (!delimiters.empty()) {
auto_params.message_spans = common_chat_split_by_role(auto_params.prompt, delimiters);
}
auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
if (auto_params.supports_thinking) {
+6 -65
View File
@@ -143,75 +143,15 @@ struct common_chat_msg_diff {
}
};
enum common_chat_role {
COMMON_CHAT_ROLE_UNKNOWN,
COMMON_CHAT_ROLE_SYSTEM,
COMMON_CHAT_ROLE_ASSISTANT,
COMMON_CHAT_ROLE_USER,
COMMON_CHAT_ROLE_TOOL
};
common_chat_role common_chat_role_from_string(const std::string & role);
const char * common_chat_role_to_string(common_chat_role role);
struct common_chat_msg_span {
common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN;
std::string role;
std::size_t pos = 0;
std::size_t len = 0;
bool valid() const {
return role != COMMON_CHAT_ROLE_UNKNOWN;
}
};
struct common_chat_msg_spans {
std::vector<common_chat_msg_span> spans;
void add(common_chat_role role, size_t pos, size_t len) {
spans.push_back({ role, pos, len });
}
bool is_user_start(int32_t pos) const {
for (auto it = spans.begin(); it != spans.end(); ++it) {
if (it->role == COMMON_CHAT_ROLE_USER && pos == (int32_t) it->pos) {
return true;
}
}
return false;
}
int32_t last_user_message_pos() const {
for (auto it = spans.rbegin(); it != spans.rend(); ++it) {
if (it->role == COMMON_CHAT_ROLE_USER) {
return (int32_t) it->pos;
}
}
return -1;
}
};
struct common_chat_msg_delimiter {
common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN;
std::string delimiter;
llama_tokens tokens = {};
};
struct common_chat_msg_delimiters {
std::vector<common_chat_msg_delimiter> delimiters;
common_chat_msg_delimiters() = default;
common_chat_msg_delimiters(std::initializer_list<common_chat_msg_delimiter> delims) : delimiters(delims) {}
void add(common_chat_role role, const std::string & delimiter) {
delimiters.push_back({ role, delimiter });
}
void tokenize(const llama_vocab * vocab);
// split tokens into message spans. skips maps a start index to a length of a region to jump over without matching
common_chat_msg_spans split(const llama_tokens & tokens, const std::map<size_t, size_t> & skips = {}) const;
nlohmann::ordered_json to_json() const;
std::string role;
std::string delimiter;
};
struct common_chat_tool {
@@ -279,7 +219,7 @@ struct common_chat_params {
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;
std::string parser;
common_chat_msg_delimiters message_delimiters;
std::vector<common_chat_msg_span> message_spans;
};
// per-message parsing syntax
@@ -385,4 +325,5 @@ struct common_chat_prompt_preset {
common_chat_prompt_preset common_chat_get_asr_prompt(const common_chat_templates * chat_templates);
common_chat_msg_delimiters common_chat_msg_delimiters_parse(const nlohmann::ordered_json & delimiters);
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims);
+3 -4
View File
@@ -471,6 +471,8 @@ struct common_params {
common_cpu_params cpuparams;
common_cpu_params cpuparams_batch;
int sampling_n_threads = -1; // number of threads for sampling, used by server
ggml_backend_sched_eval_callback cb_eval = nullptr;
void * cb_eval_user_data = nullptr;
@@ -609,7 +611,7 @@ struct common_params {
bool cache_prompt = true; // whether to enable prompt caching
bool cache_idle_slots = true; // save and clear idle slots upon starting a new task
int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot
int32_t checkpoint_min_step = 8192; // minimum spacing between context checkpoints
int32_t checkpoint_min_step = 256; // minimum spacing between context checkpoints
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
std::string hostname = "127.0.0.1";
@@ -631,9 +633,6 @@ struct common_params {
std::map<std::string, std::string> default_template_kwargs;
// CLI params
std::string server_base; // if set, connect to this server instead of starting a new one
// UI configs
bool ui = true;
bool ui_mcp_proxy = false;
+1 -3
View File
@@ -799,7 +799,6 @@ common_download_model_result common_download_model(const common_params_model &
bool download_mmproj = opts.download_mmproj;
bool download_mtp = opts.download_mtp;
bool preset_only = opts.preset_only;
bool is_hf = !model.hf_repo.empty();
if (is_hf) {
@@ -807,8 +806,7 @@ common_download_model_result common_download_model(const common_params_model &
if (!hf.preset.path.empty()) {
// if preset.ini exists, only download that file alone
tasks.push_back({hf.preset.url, hf.preset.local_path});
} else if (!preset_only) {
// only add other files if we're NOT in preset-only mode (normal run, non-router)
} else {
for (const auto & f : hf.model_files) {
tasks.push_back({f.url, f.local_path});
}
-1
View File
@@ -55,7 +55,6 @@ struct common_download_opts {
bool skip_download = false; // if true, only validation is performed, common_skip_download_exception may be thrown if the file is missing or invalid
bool download_mmproj = false;
bool download_mtp = false;
bool preset_only = false; // if true, only check & download remote preset (for router mode)
common_download_callback * callback = nullptr;
};
-70
View File
@@ -2,16 +2,6 @@
#include <cpp-httplib/httplib.h>
#ifdef _WIN32
#include <winsock2.h>
#include <windows.h>
#else
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <unistd.h>
#endif
struct common_http_url {
std::string scheme;
std::string user;
@@ -107,63 +97,3 @@ static std::pair<httplib::Client, common_http_url> common_http_client(const std:
static std::string common_http_show_masked_url(const common_http_url & parts) {
return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + parts.host + parts.path;
}
static int common_http_get_free_port() {
#ifdef _WIN32
WSADATA wsaData;
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
return -1;
}
typedef SOCKET native_socket_t;
#define INVALID_SOCKET_VAL INVALID_SOCKET
#define CLOSE_SOCKET(s) closesocket(s)
#else
typedef int native_socket_t;
#define INVALID_SOCKET_VAL -1
#define CLOSE_SOCKET(s) close(s)
#endif
native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock == INVALID_SOCKET_VAL) {
#ifdef _WIN32
WSACleanup();
#endif
return -1;
}
struct sockaddr_in serv_addr;
std::memset(&serv_addr, 0, sizeof(serv_addr));
serv_addr.sin_family = AF_INET;
serv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
serv_addr.sin_port = htons(0);
if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) {
CLOSE_SOCKET(sock);
#ifdef _WIN32
WSACleanup();
#endif
return -1;
}
#ifdef _WIN32
int namelen = sizeof(serv_addr);
#else
socklen_t namelen = sizeof(serv_addr);
#endif
if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) {
CLOSE_SOCKET(sock);
#ifdef _WIN32
WSACleanup();
#endif
return -1;
}
int port = ntohs(serv_addr.sin_port);
CLOSE_SOCKET(sock);
#ifdef _WIN32
WSACleanup();
#endif
return port;
}
-2
View File
@@ -96,7 +96,6 @@ TEXT_MODEL_MAP: dict[str, str] = {
"GraniteMoeHybridForCausalLM": "granite",
"GraniteMoeSharedForCausalLM": "granite",
"GraniteSpeechForConditionalGeneration": "granite",
"GraniteSpeechPlusForConditionalGeneration": "granite",
"Grok1ForCausalLM": "grok",
"GrokForCausalLM": "grok",
"GroveMoeForCausalLM": "grovemoe",
@@ -262,7 +261,6 @@ MMPROJ_MODEL_MAP: dict[str, str] = {
"GlmasrModel": "ultravox",
"Granite4VisionForConditionalGeneration": "granite",
"GraniteSpeechForConditionalGeneration": "granite",
"GraniteSpeechPlusForConditionalGeneration": "granite",
"HunYuanVLForConditionalGeneration": "hunyuan",
"Idefics3ForConditionalGeneration": "smolvlm",
"InternVisionModel": "internvl",
-28
View File
@@ -348,34 +348,6 @@ class GraniteSpeechMmprojModel(MmprojModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("GraniteSpeechPlusForConditionalGeneration")
class GraniteSpeechPlusMmprojModel(GraniteSpeechMmprojModel):
"""Conversion for GraniteSpeechPlus - extends GraniteSpeech with feature layer concatenation"""
has_vision_encoder = False
has_audio_encoder = True
def set_gguf_parameters(self):
assert self.hparams_audio is not None
super().set_gguf_parameters()
# Add feature_layer if present in encoder config
if feature_layers := self.hparams_audio.get("cat_hidden_layers"):
self.gguf_writer.add_audio_feature_layers(feature_layers)
logger.info(f"gguf: audio feature_layers = {feature_layers}")
# Validate projector dimension matches concatenated encoder output
hidden_dim = self.hparams_audio["hidden_dim"]
expected_dim = hidden_dim * (len(feature_layers) + 1)
projector_dim = self.global_config["projector_config"]["encoder_hidden_size"]
if projector_dim != expected_dim:
raise ValueError(
f"Projector encoder_hidden_size ({projector_dim}) does not match "
f"expected concatenated dimension ({expected_dim}). "
f"Expected: hidden_dim ({hidden_dim}) * (len(feature_layers) + 1) = {expected_dim}"
)
@ModelBase.register("Granite4VisionForConditionalGeneration")
class Granite4VisionMmprojModel(MmprojModel):
has_vision_encoder = True
@@ -174,7 +174,7 @@ __kernel void kernel_gemv_noshuffle_q8_0_f32(
regA.s6 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
regA.s7 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, convert_float(regS), regB);
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, regS, regB);
}
// reduction in local memory, assumes #wave=4
-5
View File
@@ -108,9 +108,6 @@ if (Vulkan_FOUND)
if (GGML_VULKAN_CHECK_RESULTS)
add_compile_definitions(GGML_VULKAN_CHECK_RESULTS)
# the result-checking path computes a CPU reference graph via
# ggml_graph_compute_with_ctx(), which is defined in ggml-cpu
target_link_libraries(ggml-vulkan PRIVATE ggml-cpu)
endif()
if (GGML_VULKAN_DEBUG)
@@ -132,8 +129,6 @@ if (Vulkan_FOUND)
if (GGML_VULKAN_RUN_TESTS)
add_compile_definitions(GGML_VULKAN_RUN_TESTS)
# the test path also calls ggml_graph_compute_with_ctx() (ggml-cpu)
target_link_libraries(ggml-vulkan PRIVATE ggml-cpu)
endif()
# Set up toolchain for host compilation whether cross-compiling or not
@@ -905,12 +905,11 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key {
ggml_type src0_type;
ggml_type src1_type;
int vectorized;
uint32_t num_cols;
bool use_mmvq;
bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const {
return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&
num_cols == other.num_cols && use_mmvq == other.use_mmvq;
use_mmvq == other.use_mmvq;
}
};
@@ -920,7 +919,6 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {
ggml_webgpu_hash_combine(seed, key.src0_type);
ggml_webgpu_hash_combine(seed, key.src1_type);
ggml_webgpu_hash_combine(seed, key.vectorized);
ggml_webgpu_hash_combine(seed, key.num_cols);
ggml_webgpu_hash_combine(seed, key.use_mmvq);
return seed;
}
@@ -995,12 +993,11 @@ struct ggml_webgpu_mul_mat_id_pipeline_key {
ggml_type src0_type;
ggml_type src1_type;
uint32_t n_experts;
uint32_t num_cols;
int vectorized;
bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const {
return src0_type == other.src0_type && src1_type == other.src1_type && n_experts == other.n_experts &&
num_cols == other.num_cols && vectorized == other.vectorized;
vectorized == other.vectorized;
}
};
@@ -1010,7 +1007,6 @@ struct ggml_webgpu_mul_mat_id_pipeline_key_hash {
ggml_webgpu_hash_combine(seed, key.src0_type);
ggml_webgpu_hash_combine(seed, key.src1_type);
ggml_webgpu_hash_combine(seed, key.n_experts);
ggml_webgpu_hash_combine(seed, key.num_cols);
ggml_webgpu_hash_combine(seed, key.vectorized);
return seed;
}
@@ -1111,7 +1107,7 @@ inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0,
const ggml_tensor * src1,
bool supports_dot_product,
const std::string & vendor) {
if (src1->ne[1] <= 4) {
if (src1->ne[1] == 1) {
bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia";
if (supports_dp4a && supports_dot_product) {
switch (src1->type) {
@@ -1893,7 +1889,6 @@ class ggml_webgpu_shader_lib {
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0;
key.num_cols = context.dst->ne[1];
key.use_mmvq =
ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor);
@@ -2009,7 +2004,6 @@ class ggml_webgpu_shader_lib {
if (key.vectorized) {
variant += "_vectorized";
}
defines.push_back(std::string("NUM_COLS=") + std::to_string(key.num_cols));
auto processed = preprocessor.preprocess(shader_src, defines);
auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
@@ -2427,7 +2421,6 @@ class ggml_webgpu_shader_lib {
if (key.vectorized) {
variant += "_vectorized";
}
defines.push_back(std::string("NUM_COLS=1"));
defines.push_back(std::string("N_EXPERTS=") + std::to_string(key.n_experts));
+9 -11
View File
@@ -1418,17 +1418,15 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context &
const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
const size_t q8_src1_align_offset = ROUNDUP_POW2(
dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
const size_t q8_src1_binding_size = ROUNDUP_POW2(
src1->ne[3] * src1->ne[2] * src1->ne[1] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)),
WEBGPU_STORAGE_BUF_BINDING_MULT);
const size_t q8_src1_binding_size =
ROUNDUP_POW2(src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)),
WEBGPU_STORAGE_BUF_BINDING_MULT);
std::vector<uint32_t> q8_params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
(uint32_t) src1->ne[0],
(uint32_t) src1->ne[1],
(uint32_t) src1->ne[2],
(uint32_t) src1->ne[3],
};
@@ -1444,7 +1442,7 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context &
uint32_t q8_wg_x = 1;
uint32_t q8_wg_y = 1;
const uint32_t wg_per_vec = (src0->ne[0] / 4 + (q8_wg_size - 1)) / q8_wg_size;
const uint32_t q8_total_wg = src1->ne[1] * src1->ne[2] * src1->ne[3] * wg_per_vec;
const uint32_t q8_total_wg = src1->ne[2] * src1->ne[3] * wg_per_vec;
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
compute_2d_workgroups(q8_total_wg, max_wg_per_dim, q8_wg_x, q8_wg_y);
@@ -1458,7 +1456,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
ggml_tensor * src1,
ggml_tensor * dst) {
// Determine if this is a mat-vec operation
bool use_mat_vec = (dst->ne[1] <= 4);
bool is_vec = (dst->ne[1] == 1);
// use MMVQ path for mat-vec
bool use_mmvq = ggml_webgpu_can_use_mmvq(src0, src1, ctx->global_ctx->capabilities.supports_dot_product,
@@ -1484,7 +1482,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
webgpu_pipeline pipeline;
std::vector<webgpu_dispatch_desc> dispatches;
if (use_mat_vec) {
if (is_vec) {
if (use_mmvq) {
ggml_webgpu_quantize_q8_dispatch(ctx, src0, src1, dst, dispatches);
}
@@ -1531,7 +1529,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
uint32_t wg_y = 1;
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
if (use_mat_vec) {
if (is_vec) {
auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
uint32_t batches = dst->ne[2] * dst->ne[3];
@@ -3693,8 +3691,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
ggml_webgpu_can_use_mmvq(src0, src1, ctx->webgpu_global_ctx->capabilities.supports_dot_product,
ctx->webgpu_global_ctx->vendor);
if (use_mmvq) {
const size_t q8_src1_size = src1->ne[3] * src1->ne[2] * src1->ne[1] *
(36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32));
const size_t q8_src1_size =
src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32));
res = ROUNDUP_POW2(res + q8_src1_size +
ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
WEBGPU_STORAGE_BUF_BINDING_MULT);
@@ -103,7 +103,7 @@ fn main(
#ifdef USE_SUBGROUP_REDUCTION
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let subgroup_total = subgroupAdd(acc[0][row]);
let subgroup_total = subgroupAdd(acc[row]);
if (subgroup_invocation_id == 0u) {
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
}
@@ -126,7 +126,7 @@ fn main(
#ifdef USE_WORKGROUP_REDUCTION
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
partial_sums[partial_index(row, thread_id)] = acc[0][row];
partial_sums[partial_index(row, thread_id)] = acc[row];
}
workgroupBarrier();
@@ -91,67 +91,61 @@ fn main(
let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base;
#ifdef MMVQ
let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * params.n * (params.k / 32u);
let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * (params.k / 32u);
let acc = accumulate_vec_q_dot(thread_id, row_base, src0_batch_offset, src1q_idx_base);
#else
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base);
#endif
for (var col = 0u;col < NUM_COLS;col += 1) {
#ifdef USE_SUBGROUP_REDUCTION
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let subgroup_total = subgroupAdd(acc[col][row]);
if (subgroup_invocation_id == 0u) {
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
}
}
workgroupBarrier();
for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) {
let output_row = row_base + row;
var row_acc = 0.0f;
for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) {
row_acc += partial_sums[partial_index(row, k)];
}
let row_total = subgroupAdd(row_acc);
if (subgroup_invocation_id == 0) {
dst[dst_idx_base + col * params.m + row] = row_total;
}
}
#endif
#ifdef USE_WORKGROUP_REDUCTION
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
partial_sums[partial_index(row, thread_id)] = acc[col][row];
}
workgroupBarrier();
var stride = WG_SIZE / 2u;
while (stride > 0) {
if (thread_id < stride) {
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)];
}
}
workgroupBarrier();
stride = stride / 2;
}
if (thread_id < OUTPUTS_PER_WG) {
let output_row = row_base + thread_id;
if (output_row < params.m) {
dst[dst_idx_base + col * params.m + thread_id] = partial_sums[partial_index(thread_id, 0)];
}
}
#endif
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let subgroup_total = subgroupAdd(acc[row]);
if (subgroup_invocation_id == 0u) {
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
}
}
workgroupBarrier();
for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) {
let output_row = row_base + row;
var row_acc = 0.0f;
for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) {
row_acc += partial_sums[partial_index(row, k)];
}
let row_total = subgroupAdd(row_acc);
if (subgroup_invocation_id == 0) {
dst[dst_idx_base + row] = row_total;
}
}
#endif
#ifdef USE_WORKGROUP_REDUCTION
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
partial_sums[partial_index(row, thread_id)] = acc[row];
}
workgroupBarrier();
var stride = WG_SIZE / 2u;
while (stride > 0) {
if (thread_id < stride) {
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)];
}
}
workgroupBarrier();
stride = stride / 2;
}
if (thread_id < OUTPUTS_PER_WG) {
let output_row = row_base + thread_id;
if (output_row < params.m) {
dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)];
}
}
#endif
}
File diff suppressed because it is too large Load Diff
@@ -51,7 +51,10 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE {
fn get_dm(block_byte_base: u32) -> f32 {
return f32(load_f16_at_src0(block_byte_base));
}
#endif // MUL_ACC_Q4_0
fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 {
return f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK;
}
#endif
#ifdef MUL_ACC_Q4_1
#define BLOCK_SIZE_BYTES 20
@@ -82,7 +85,10 @@ fn get_dm(block_byte_base: u32) -> vec2<f32> {
f32(load_f16_at_src0(block_byte_base + 2u))
);
}
#endif // MUL_ACC_Q4_1
fn mul_q8_1(row_sum: i32, dma: vec2<f32>, b_ds: B_DS_TYPE) -> f32 {
return f32(row_sum) * (dma.x * b_ds.x) + dma.y * b_ds.y / THREADS_PER_BLOCK;
}
#endif
#ifdef MUL_ACC_Q8_0
#define BLOCK_SIZE_BYTES 34
@@ -105,48 +111,46 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE {
fn get_dm(block_byte_base: u32) -> f32 {
return f32(load_f16_at_src0(block_byte_base));
}
#endif // MUL_ACC_Q8_0
fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 {
return f32(row_sum) * (da * b_ds);
}
#endif
#if defined(LEGACY_QUANTS)
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<array<f32, OUTPUTS_PER_WG>, NUM_COLS> {
var acc: array<array<f32, OUTPUTS_PER_WG>, NUM_COLS>;
#ifdef LEGACY_QUANTS
fn mmvq_dot_product(a_byte_base: u32, b_inner_id: u32, b_repacked: vec2<u32>, b_ds: B_DS_TYPE) -> f32 {
var row_sum = 0;
let a_repacked = repack_a(a_byte_base, b_inner_id);
row_sum += dot4I8Packed(a_repacked[0], b_repacked[0]);
row_sum += dot4I8Packed(a_repacked[1], b_repacked[1]);
return mul_q8_1(row_sum, get_dm(a_byte_base), b_ds);
}
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
var acc: array<f32, OUTPUTS_PER_WG>;
let num_blocks = params.k / BLOCK_SIZE;
for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
let inner_id = thread_id % THREADS_PER_BLOCK;
let b_inner_id = thread_id % THREADS_PER_BLOCK;
let b_block_idx = src1q_idx_base + block;
let b_repacked = repack_b_qs(b_block_idx, b_inner_id);
let b_ds = repack_b_dm(b_block_idx);
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
let a_repacked = repack_a(block_byte_base, inner_id);
let da = get_dm(block_byte_base);
for (var col = 0u;col < NUM_COLS;col += 1) {
let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + block;
let b_repacked = repack_b_qs(src1q_idx, inner_id);
let b_ds = repack_b_dm(src1q_idx);
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]);
#if defined(MUL_ACC_Q4_0)
acc[col][row] += f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK;
#endif // MUL_ACC_Q4_0
#if defined(MUL_ACC_Q4_1)
acc[col][row] += f32(row_sum) * (da.x * b_ds.x) + da.y * b_ds.y / THREADS_PER_BLOCK;
#endif // MUL_ACC_Q4_1
#if defined(MUL_ACC_Q8_0)
acc[col][row] += f32(row_sum) * (da * b_ds);
#endif // MUL_ACC_Q8_0
}
acc[row] += mmvq_dot_product(block_byte_base, b_inner_id, b_repacked, b_ds);
}
}
}
return acc;
}
#endif // LEGACY_QUANTS
#endif
#ifdef MUL_ACC_Q2_K
#define BLOCK_SIZE_BYTES 84
@@ -187,7 +191,22 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> {
let scale = byte_of(load_u32_at_src0_aligned(scale_byte), scale_byte & 3u);
return vec2<f32>(f32(scale & 0xFu), f32(scale >> 4u));
}
#endif // MUL_ACC_Q2_K
fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 {
let a_repacked = repack_a(a_byte_base, tid);
let dm = get_dm(a_byte_base);
let scale_min = get_scale_min(a_byte_base, tid);
let scale_q = i32(scale_min.x);
let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u;
let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1])
+ dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q;
let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4)
+ dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4);
return b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m));
}
#endif
#ifdef MUL_ACC_Q4_K
#define BLOCK_SIZE_BYTES 144
@@ -246,52 +265,39 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> {
return vec2<f32>(scale, min_val);
}
#endif // MUL_ACC_Q4_K
fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 {
let a_repacked = repack_a(a_byte_base, tid);
let dm = get_dm(a_byte_base);
let scale_min = get_scale_min(a_byte_base, tid);
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1])
+ dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]);
// Each thread covers half of the Q8_1 block, so add only b_ds.y/2.
return b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD));
}
#endif
#ifdef K_QUANTS
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<array<f32, OUTPUTS_PER_WG>, NUM_COLS> {
var acc: array<array<f32, OUTPUTS_PER_WG>, NUM_COLS>;
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
var acc: array<f32, OUTPUTS_PER_WG>;
let tid = thread_id % THREADS_PER_BLOCK;
for (var block = thread_id / THREADS_PER_BLOCK; block < params.k / BLOCK_SIZE; block += WG_SIZE / THREADS_PER_BLOCK) {
let src1q_idx = src1q_idx_base + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE;
let b_repacked = repack_b_qs(src1q_idx, tid);
let b_ds = repack_b_dm(src1q_idx);
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
let a_repacked = repack_a(block_byte_base, tid);
let dm = get_dm(block_byte_base);
let scale_min = get_scale_min(block_byte_base, tid);
for (var col = 0u;col < NUM_COLS;col += 1) {
let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE;
let b_repacked = repack_b_qs(src1q_idx, tid);
let b_ds = repack_b_dm(src1q_idx);
#if defined(MUL_ACC_Q2_K)
let scale_q = i32(scale_min.x);
let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u;
let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1])
+ dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q;
let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4)
+ dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4);
acc[col][row] += b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m));
#endif // MUL_ACC_Q2_K
#if defined(MUL_ACC_Q4_K)
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1])
+ dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]);
// Each thread covers half of the Q8_1 block, so add only b_ds.y/2.
acc[col][row] += b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD));
#endif // MUL_ACC_Q4_K
}
acc[row] += mmvq_dot_product(block_byte_base, tid, b_repacked, b_ds);
}
}
}
return acc;
}
#endif // K_QUANTS
#endif
@@ -9,11 +9,9 @@ requires packed_4x8_integer_dot_product;
struct Params {
offset_src1: u32,
stride_11: u32,
stride_12: u32,
stride_13: u32,
ne0: u32,
ne1: u32,
ne2: u32,
ne3: u32,
};
@@ -59,28 +57,25 @@ fn main(
@builtin(num_workgroups) num_wg: vec3<u32>
) {
let thread_id = local_id.x;
let ne0_vec4 = params.ne0 / 4u;
let num_vec4 = params.ne0 / 4u;
let wg_per_vec = (ne0_vec4 + (WG_SIZE - 1u)) / WG_SIZE;
let total_batches = wg_per_vec * params.ne1 * params.ne2 * params.ne3;
let wg_per_vec = (num_vec4 + (WG_SIZE - 1u)) / WG_SIZE;
let total_batches = wg_per_vec * params.ne2 * params.ne3;
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
if (wg_linear >= total_batches) {
return;
}
let vec_idx = wg_linear / wg_per_vec;
let src13_idx = vec_idx / (params.ne2 * params.ne1);
let vec_ne12_num = vec_idx % (params.ne2 * params.ne1);
let src12_idx = vec_ne12_num / params.ne1;
let src11_idx = vec_ne12_num % params.ne1;
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + src11_idx * params.stride_11;
let src13_idx = wg_linear / (params.ne2 * wg_per_vec);
let src12_idx = (wg_linear - src13_idx * (params.ne2 * wg_per_vec)) / wg_per_vec;
let src11_wg_idx = wg_linear % wg_per_vec;
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
let src1_idx_vec4_base = src1_idx_base / 4u;
let blocks_per_row = params.ne0 / 32u;
let blocks_per_wg = (WG_SIZE * 4u) / 32u;
let src1q_idx_base = ((src13_idx * params.ne2 + src12_idx) * params.ne1 + src11_idx) * blocks_per_row;
let src11_wg_idx = wg_linear % wg_per_vec;
let src1q_idx_base = (src13_idx * params.ne2 + src12_idx) * blocks_per_row;
let src1q_idx = src1q_idx_base + src11_wg_idx * blocks_per_wg + thread_id / 8u;
let qs_idx = thread_id % 8u;
@@ -90,7 +85,7 @@ fn main(
var thread_amax = 0.0;
let src11_vec4_idx = src11_wg_idx * WG_SIZE + thread_id;
let is_valid = src11_vec4_idx < ne0_vec4;
let is_valid = src11_vec4_idx < num_vec4;
#ifdef USE_SUBGROUP_REDUCTION
-1
View File
@@ -359,7 +359,6 @@ class Keys:
CHUNK_SIZE = "clip.audio.chunk_size"
CONV_KERNEL_SIZE = "clip.audio.conv_kernel_size"
MAX_POS_EMB = "clip.audio.max_pos_emb"
FEATURE_LAYERS = "clip.audio.feature_layer" # Granite Speech Plus
class Attention:
HEAD_COUNT = "clip.audio.attention.head_count"
-3
View File
@@ -1310,9 +1310,6 @@ class GGUFWriter:
def add_audio_max_pos_emb(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.MAX_POS_EMB, value)
def add_audio_feature_layers(self, layers: Sequence[int]) -> None:
self.add_array(Keys.ClipAudio.FEATURE_LAYERS, layers)
def add_audio_projector_window_size(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.Projector.WINDOW_SIZE, value)
-2
View File
@@ -8433,7 +8433,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 4, k, {3, 2}, {2, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {2, 1}));
@@ -8450,7 +8449,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 4, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
+24 -99
View File
@@ -1562,112 +1562,37 @@ static void test_msgs_oaicompat_json_conversion() {
}
}
static void test_msg_token_delimiters_split() {
static void test_split_by_role() {
LOG_DBG("%s\n", __func__);
// Delimiters that share a leading token, distinguished by the second token,
// to exercise the per-position token matching.
const common_chat_msg_delimiters delims = {
{ { COMMON_CHAT_ROLE_USER, "", { 10, 11 } },
{ COMMON_CHAT_ROLE_ASSISTANT, "", { 10, 12 } } }
};
// Empty inputs
assert_equals<size_t>(0, common_chat_msg_delimiters{}.split({}).spans.size());
assert_equals<size_t>(0, common_chat_msg_delimiters{}.split({ 10, 11 }).spans.size());
assert_equals<size_t>(0, delims.split({}).spans.size());
assert_equals<size_t>(0, common_chat_split_by_role("", {}).size());
assert_equals<size_t>(0, common_chat_split_by_role("hello", {}).size());
assert_equals<size_t>(0, common_chat_split_by_role("", { { "user", "<|user|>" } }).size());
// No delimiters match -> no spans
assert_equals<size_t>(0, delims.split({ 100, 101, 102 }).spans.size());
// Multi-role conversation: <user>Hi<assistant>Hello<user>Bye
// Multi-role conversation, no leading/trailing content
{
const llama_tokens tokens = {
10, 11, // <user>
100, 101, // Hi
10, 12, // <assistant>
200, 201, 202, // Hello
10, 11, // <user>
300, 301, // Bye
};
const std::string prompt = "<|user|>Hi<|assistant|>Hello<|user|>Bye";
const auto splits = common_chat_split_by_role(prompt, {
{ "user", "<|user|>" },
{ "assistant", "<|assistant|>" },
});
assert_equals<size_t>(3, splits.size());
const auto result = delims.split(tokens);
const auto & spans = result.spans;
assert_equals<size_t>(3, spans.size());
assert_equals<std::string>("user", splits[0].role);
assert_equals<size_t>(0, splits[0].pos);
assert_equals<size_t>(10, splits[0].len);
assert_equals<std::string>("<|user|>Hi", prompt.substr(splits[0].pos, splits[0].len));
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
assert_equals<size_t>(0, spans[0].pos);
assert_equals<size_t>(4, spans[0].len);
assert_equals<std::string>("assistant", splits[1].role);
assert_equals<size_t>(10, splits[1].pos);
assert_equals<size_t>(18, splits[1].len);
assert_equals<std::string>("<|assistant|>Hello", prompt.substr(splits[1].pos, splits[1].len));
assert_equals(COMMON_CHAT_ROLE_ASSISTANT, spans[1].role);
assert_equals<size_t>(4, spans[1].pos);
assert_equals<size_t>(5, spans[1].len);
assert_equals(COMMON_CHAT_ROLE_USER, spans[2].role);
assert_equals<size_t>(9, spans[2].pos);
assert_equals<size_t>(4, spans[2].len);
// is_user_start() is true at the token position where a user span begins
assert_equals(true, result.is_user_start(0));
assert_equals(false, result.is_user_start(4)); // assistant span
assert_equals(true, result.is_user_start(9));
}
// Content before the first delimiter is not captured as a span
{
const llama_tokens tokens = {
500, 501, // leading content (dropped)
10, 11, // <user>
100, // Hi
};
const auto spans = delims.split(tokens).spans;
assert_equals<size_t>(1, spans.size());
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
assert_equals<size_t>(2, spans[0].pos);
assert_equals<size_t>(3, spans[0].len);
}
// Skipped regions (media chunks) are jumped over but still count as span content
{
const llama_tokens tokens = {
10, 11, // <user>
LLAMA_TOKEN_NULL, // media chunk (3 tokens)
LLAMA_TOKEN_NULL,
LLAMA_TOKEN_NULL,
100, // Hi
10, 12, // <assistant>
};
const std::map<size_t, size_t> skips = { { 2, 3 } };
const auto spans = delims.split(tokens, skips).spans;
assert_equals<size_t>(2, spans.size());
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
assert_equals<size_t>(0, spans[0].pos);
assert_equals<size_t>(6, spans[0].len);
assert_equals(COMMON_CHAT_ROLE_ASSISTANT, spans[1].role);
assert_equals<size_t>(6, spans[1].pos);
assert_equals<size_t>(2, spans[1].len);
}
// A delimiter sequence inside a skipped region is not matched
{
const llama_tokens tokens = {
10, 11, // <user>
10, 12, // skipped region that happens to contain delimiter tokens
100, // Hi
};
const std::map<size_t, size_t> skips = { { 2, 2 } };
const auto spans = delims.split(tokens, skips).spans;
assert_equals<size_t>(1, spans.size());
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
assert_equals<size_t>(0, spans[0].pos);
assert_equals<size_t>(5, spans[0].len);
assert_equals<std::string>("user", splits[2].role);
assert_equals<size_t>(28, splits[2].pos);
assert_equals<size_t>(11, splits[2].len);
assert_equals<std::string>("<|user|>Bye", prompt.substr(splits[2].pos, splits[2].len));
}
}
@@ -5932,7 +5857,7 @@ int main(int argc, char ** argv) {
{
test_msg_diffs_compute();
test_msgs_oaicompat_json_conversion();
test_msg_token_delimiters_split();
test_split_by_role();
test_tools_oaicompat_json_conversion();
test_convert_responses_to_chatcmpl();
test_developer_role_to_system_workaround();
+2 -4
View File
@@ -2,13 +2,11 @@
set(TARGET llama-cli-impl)
add_library(${TARGET} cli.cpp
cli-client.cpp
cli-context.cpp)
add_library(${TARGET} cli.cpp)
set_target_properties(${TARGET} PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS ON)
target_include_directories(${TARGET} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} ../server)
target_link_libraries(${TARGET} PUBLIC llama-server-impl llama-common ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(${TARGET} PUBLIC server-context llama-common ${CMAKE_THREAD_LIBS_INIT})
if(LLAMA_TOOLS_INSTALL)
install(TARGETS ${TARGET} LIBRARY)
-164
View File
@@ -1,164 +0,0 @@
#include "cli-client.h"
#include "http.h"
#include <algorithm>
#include <chrono>
#include <thread>
// generation can stall for a long time during prompt processing, so the
// read timeout must be generous
static constexpr time_t CLI_HTTP_READ_TIMEOUT_SEC = 3600;
// upper bound for the accumulated response body kept for error reporting
static constexpr size_t CLI_HTTP_MAX_ERROR_BODY = 1024 * 1024;
// returns the path with the base url's path prefix prepended (if any)
static std::string join_path(const common_http_url & parts, const std::string & path) {
if (parts.path.empty() || parts.path == "/") {
return path;
}
std::string prefix = parts.path;
if (prefix.back() == '/') {
prefix.pop_back();
}
return prefix + path;
}
json cli_client::get(const std::string & path) {
auto [cli, parts] = common_http_client(server_base);
cli.set_read_timeout(CLI_HTTP_READ_TIMEOUT_SEC, 0);
auto path_with_model = path + (model.empty() ? "" : ("?model=" + model));
auto res = cli.Get(join_path(parts, path_with_model));
if (!res) {
throw std::runtime_error("failed to connect to " + server_base + ": " + httplib::to_string(res.error()));
}
if (res->status < 200 || res->status >= 300) {
throw std::runtime_error("GET " + path + " failed with status " + std::to_string(res->status) + ": " + res->body);
}
json result = json::parse(res->body, nullptr, false);
if (result.is_discarded()) {
throw std::runtime_error("GET " + path + " returned invalid JSON");
}
return result;
}
json cli_client::post(const std::string & path, const json & body) {
auto [cli, parts] = common_http_client(server_base);
cli.set_read_timeout(CLI_HTTP_READ_TIMEOUT_SEC, 0);
auto body_with_model = body;
if (!model.empty()) {
body_with_model["model"] = model;
}
auto res = cli.Post(join_path(parts, path), body_with_model.dump(), "application/json");
if (!res) {
throw std::runtime_error("failed to connect to " + server_base + ": " + httplib::to_string(res.error()));
}
if (res->status < 200 || res->status >= 300) {
throw std::runtime_error("POST " + path + " failed with status " + std::to_string(res->status) + ": " + res->body);
}
json result = json::parse(res->body, nullptr, false);
if (result.is_discarded()) {
throw std::runtime_error("POST " + path + " returned invalid JSON");
}
return result;
}
json cli_client::post_sse(const std::string & path,
const json & body,
const std::function<bool()> & should_stop,
const std::function<void(const json &)> & on_data) {
auto [cli, parts] = common_http_client(server_base);
cli.set_read_timeout(CLI_HTTP_READ_TIMEOUT_SEC, 0);
std::string pending; // buffer for incomplete SSE lines
std::string raw_body; // accumulated body, used only for error reporting
auto receiver = [&](const char * data, size_t len) -> bool {
if (should_stop()) {
return false; // aborts the request
}
if (raw_body.size() < CLI_HTTP_MAX_ERROR_BODY) {
raw_body.append(data, std::min(len, CLI_HTTP_MAX_ERROR_BODY - raw_body.size()));
}
pending.append(data, len);
size_t pos;
while ((pos = pending.find('\n')) != std::string::npos) {
std::string line = pending.substr(0, pos);
pending.erase(0, pos + 1);
if (!line.empty() && line.back() == '\r') {
line.pop_back();
}
if (line.rfind("data: ", 0) != 0) {
continue;
}
std::string payload = line.substr(6);
if (payload == "[DONE]") {
continue;
}
json event = json::parse(payload, nullptr, false);
if (!event.is_discarded()) {
on_data(event);
}
}
return true;
};
httplib::Headers headers = {{"Accept", "text/event-stream"}};
auto body_with_model = body;
if (!model.empty()) {
body_with_model["model"] = model;
}
auto res = cli.Post(join_path(parts, path), headers, body_with_model.dump(), "application/json", receiver);
if (!res) {
if (res.error() == httplib::Error::Canceled && should_stop()) {
return json(); // cancelled by the user
}
return json {{"error", {{"message", "failed to connect to " + server_base + ": " + httplib::to_string(res.error())}}}};
}
if (res->status < 200 || res->status >= 300) {
json error_body = json::parse(raw_body, nullptr, false);
if (!error_body.is_discarded() && error_body.contains("error")) {
return error_body;
}
return json {{"error", {{"message", "request failed with status " + std::to_string(res->status)}}}};
}
return json();
}
bool cli_client::wait_health(const std::function<bool()> & is_aborted) {
int connect_attempts = 0;
while (!is_aborted()) {
auto [cli, parts] = common_http_client(server_base);
cli.set_connection_timeout(1, 0);
auto res = cli.Get(join_path(parts, "/health"));
if (res) {
if (res->status == 200) {
return true;
}
// any other status means the server is up but not ready yet
// (e.g. 503 while the model is still loading)
} else if (++connect_attempts >= 10) {
last_error = "failed to connect to " + server_base + ": " + httplib::to_string(res.error());
return false;
}
std::this_thread::sleep_for(std::chrono::milliseconds(300));
}
last_error = "aborted while waiting for the server to become ready";
return false;
}
std::vector<std::string> cli_client::list_models() {
json resp = get("/v1/models");
if (!resp.contains("data") || !resp.at("data").is_array()) {
throw std::runtime_error("invalid response from /v1/models");
}
std::vector<std::string> models;
for (const auto & m : resp.at("data")) {
if (m.contains("id") && m.at("id").is_string()) {
models.push_back(m.at("id").get<std::string>());
}
}
return models;
}
-56
View File
@@ -1,56 +0,0 @@
#pragma once
#include "ggml.h"
#define JSON_ASSERT GGML_ASSERT
#include <nlohmann/json.hpp>
#include <functional>
#include <string>
using json = nlohmann::ordered_json;
// openai-like client for CLI
struct cli_client {
std::string server_base; // base url, for example "http://127.0.0.1:8080"
std::string last_error; // set when wait_health() fails
std::string model; // optional, set when the server has multiple models (router mode)
// simple GET request, returns the response json
// throws std::runtime_error on transport error or non-2xx status
json get(const std::string & path);
// simple POST request, returns the response json
// throws std::runtime_error on transport error or non-2xx status
json post(const std::string & path, const json & body);
// POST request with an SSE streaming response; on_data is invoked once
// per "data:" event; the function returns after the stream is finished:
// a null json on graceful exit (incl. cancellation via should_stop),
// the error response json otherwise
json post_sse(const std::string & path,
const json & body,
const std::function<bool()> & should_stop,
const std::function<void(const json &)> & on_data);
// poll /health until the server is ready to accept requests
// returns false if is_aborted returned true or the server is unreachable
bool wait_health(const std::function<bool()> & is_aborted);
//
// higher-level wrappers
//
json create_chat_completion(const json & request,
const std::function<bool()> & should_stop,
const std::function<void(const json &)> & on_data) {
return post_sse("/v1/chat/completions", request, should_stop, on_data);
}
json get_props() {
return get("/props");
}
std::vector<std::string> list_models();
};
-559
View File
@@ -1,559 +0,0 @@
#include "cli-context.h"
#include "cli-view.h"
#include "arg.h"
#include "base64.hpp"
#include "log.h"
#include "console.h"
#include <algorithm>
#include <filesystem>
#include <fstream>
#include <map>
#include <set>
std::atomic<bool> g_cli_interrupted = false;
static bool should_stop() {
return g_cli_interrupted.load();
}
static constexpr size_t FILE_GLOB_MAX_RESULTS = 100;
const char * LLAMA_ASCII_LOGO = R"(
)";
// number of values an arg consumes on the command line
static int arg_num_values(const common_arg & opt) {
if (opt.value_hint_2 != nullptr) {
return 2;
}
if (opt.value_hint != nullptr) {
return 1;
}
return 0;
}
static std::string format_error_message(const json & err) {
if (err.contains("error") && err.at("error").is_object()) {
const auto & e = err.at("error");
if (e.contains("message") && e.at("message").is_string()) {
return e.at("message").get<std::string>();
}
}
return err.dump();
}
static std::string media_type_from_ext(const std::string & fname) {
std::string ext = std::filesystem::path(fname).extension().string();
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
if (ext == ".wav" || ext == ".mp3") {
return "audio";
}
if (ext == ".mp4" || ext == ".avi" || ext == ".mkv" || ext == ".mov" || ext == ".webm") {
return "video";
}
return "image";
}
bool cli_context::init() {
view::init(params);
std::optional<view::spinner> spinner;
bool use_external_server = !params.server_base.empty();
if (use_external_server) {
std::string base = params.server_base;
while (!base.empty() && base.back() == '/') {
base.pop_back();
}
client.server_base = base;
spinner.emplace("Connecting to server at " + base);
} else {
if (params.model.path.empty() && params.model.url.empty() &&
params.model.hf_repo.empty() && params.model.docker_repo.empty()) {
view::show_error(
"no model specified",
"use -m <file.gguf> or -hf <user/repo> to run a local model,\n"
"or --server-base <url> to connect to a running llama-server"
);
return false;
}
spinner.emplace("\n\nLoading model...");
server.emplace();
if (!server->start(params)) {
view::show_error("server start failed");
return false;
}
if (!server->wait_ready(should_stop)) {
if (!should_stop()) {
view::show_error("the server exited before becoming ready");
}
return false;
}
client.server_base = server->address();
}
// for --server-base this is the main availability check; for a spawned
// server it is a cheap sanity check on top of the ready signal
auto is_aborted = [this]() {
return should_stop() || (server && !server->alive());
};
bool healthy = false;
try {
healthy = client.wait_health(is_aborted);
} catch (const std::exception & e) {
client.last_error = e.what();
}
if (!healthy) {
if (!should_stop()) {
view::show_error(client.last_error);
}
return false;
}
if (use_external_server) {
spinner.reset();
if (!list_and_ask_models()) {
return false;
}
// restore the spinner for the next step
spinner.emplace("Waiting for server...");
}
fetch_server_props();
return true;
}
void cli_context::fetch_server_props() {
try {
json props = client.get_props();
model_name = props.value("model_alias", "");
if (model_name.empty()) {
const std::string path = props.value("model_path", "");
if (!path.empty()) {
model_name = std::filesystem::path(path).filename().string();
}
}
build_info = props.value("build_info", "");
if (props.contains("modalities") && props.at("modalities").is_object()) {
const auto & modalities = props.at("modalities");
has_vision = modalities.value("vision", false);
has_audio = modalities.value("audio", false);
has_video = modalities.value("video", false);
}
} catch (const std::exception & e) {
// /props can be disabled on remote servers; not fatal
LOG_DBG("failed to fetch /props: %s\n", e.what());
}
}
bool cli_context::list_and_ask_models() {
auto models = client.list_models();
// only one model: use it without asking
if (models.size() == 1) {
model_name = models[0];
client.model = model_name;
return true;
}
std::string message = "\nAvailable models:";
if (!models.empty()) {
for (size_t i = 0; i < models.size(); ++i) {
message += "\n " + std::to_string(i + 1) + ". " + models[i];
}
}
message += "\n";
view::show_message(message);
std::string selection;
while (selection.empty()) {
if (should_stop()) {
return false;
}
view::user_turn user_turn;
selection = user_turn.read_input(false, "Select model by number: ");
if (selection.empty()) {
continue;
}
try {
size_t idx = std::stoul(selection);
if (idx > 0 && idx <= models.size()) {
model_name = models[idx - 1];
client.model = model_name;
view::show_message("Selected model: " + model_name);
break;
}
} catch (...) {
// ignore
}
view::show_error("Invalid selection. Please enter a valid number.");
selection.clear();
continue;
}
return true;
}
void cli_context::add_system_prompt() {
if (!params.system_prompt.empty()) {
messages.push_back({
{"role", "system"},
{"content", params.system_prompt}
});
}
}
void cli_context::push_user_message(const std::string & text) {
json content;
if (pending_media.empty()) {
content = text;
} else {
// multimodal message: media parts first, then the text
content = pending_media;
content.push_back({
{"type", "text"},
{"text", text}
});
pending_media = json::array();
}
messages.push_back({
{"role", "user"},
{"content", content}
});
}
bool cli_context::stage_media_file(const std::string & fname, const std::string & type) {
std::ifstream file(fname, std::ios::binary);
if (!file) {
return false;
}
std::string data((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
std::string encoded = base64::encode(data);
if (type == "audio") {
std::string ext = std::filesystem::path(fname).extension().string();
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
pending_media.push_back({
{"type", "input_audio"},
{"input_audio", {
{"data", encoded},
{"format", ext == ".mp3" ? "mp3" : "wav"}
}}
});
} else if (type == "video") {
pending_media.push_back({
{"type", "input_video"},
{"input_video", {
{"data", encoded}
}}
});
} else {
// the server detects the actual image type from the data
pending_media.push_back({
{"type", "image_url"},
{"image_url", {
{"url", "data:image/unknown;base64," + encoded}
}}
});
}
return true;
}
bool cli_context::generate_completion(std::string & assistant_content, cli_timings & timings) {
json body = {
{"messages", messages},
{"stream", true},
// in order to get timings even when we cancel mid-way
{"timings_per_token", true},
};
bool stream_error = false;
view::assistant_turn a;
json err = client.create_chat_completion(body, should_stop, [&](const json & chunk) {
if (chunk.contains("error")) {
stream_error = true;
view::show_error(format_error_message(chunk));
return;
}
if (chunk.contains("timings")) {
const auto & t = chunk.at("timings");
timings.prompt_per_second = t.value("prompt_per_second", 0.0);
timings.predicted_per_second = t.value("predicted_per_second", 0.0);
}
if (!chunk.contains("choices") || !chunk.at("choices").is_array() || chunk.at("choices").empty()) {
return;
}
const auto & choice = chunk.at("choices").at(0);
if (!choice.contains("delta")) {
return;
}
const auto & delta = choice.at("delta");
if (delta.contains("reasoning_content") && delta.at("reasoning_content").is_string()) {
const std::string text = delta.at("reasoning_content").get<std::string>();
if (!text.empty()) {
a.push(view::ASSISTANT_DISPLAY_MODE_REASONING, text);
}
}
if (delta.contains("content") && delta.at("content").is_string()) {
const std::string text = delta.at("content").get<std::string>();
if (!text.empty()) {
assistant_content += text;
a.push(view::ASSISTANT_DISPLAY_MODE_CONTENT, text);
}
}
});
g_cli_interrupted.store(false);
if (!err.is_null()) {
view::show_error(format_error_message(err));
return false;
}
return !stream_error;
}
int cli_context::run() {
add_system_prompt();
std::string modalities = "text";
if (has_vision) {
modalities += ", vision";
}
if (has_audio) {
modalities += ", audio";
}
if (has_video) {
modalities += ", video";
}
std::string banner;
banner += "\n";
banner += LLAMA_ASCII_LOGO;
banner += "\n";
banner += "build : " + build_info + "\n";
banner += "model : " + model_name + "\n";
banner += "modalities : " + modalities + "\n";
if (!params.system_prompt.empty()) {
banner += "using custom system prompt\n";
}
banner += "\n";
banner += "available commands:\n";
banner += " /exit or Ctrl+C stop or exit\n";
banner += " /regen regenerate the last response\n";
banner += " /clear clear the chat history\n";
banner += " /read <file> add a text file\n";
banner += " /glob <pattern> add text files using globbing pattern\n";
if (has_vision) {
banner += " /image <file> add an image file\n";
}
if (has_audio) {
banner += " /audio <file> add an audio file\n";
}
if (has_video) {
banner += " /video <file> add a video file\n";
}
banner += "\n";
view::show_message(banner);
// interactive loop
std::string cur_msg;
auto add_text_file = [&](const std::string & fname) -> bool {
std::ifstream file(fname, std::ios::binary);
if (!file) {
view::show_error(string_format("file does not exist or cannot be opened: '%s'", fname.c_str()));
return false;
}
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
cur_msg += "--- File: ";
cur_msg += fname;
cur_msg += " ---\n";
cur_msg += content;
view::show_message(string_format("Loaded text from '%s'", fname.c_str()));
return true;
};
while (true) {
std::string buffer;
{
view::user_turn user_turn;
if (params.prompt.empty()) {
buffer = user_turn.read_input(params.multiline_input);
} else {
// process input prompt from args
for (auto & fname : params.image) {
if (!stage_media_file(fname, media_type_from_ext(fname))) {
view::show_error(string_format("file does not exist or cannot be opened: '%s'", fname.c_str()));
break;
}
view::show_message(string_format("Loaded media from '%s'", fname.c_str()));
}
buffer = params.prompt;
user_turn.echo(buffer);
params.prompt.clear(); // only use it once
}
}
if (should_stop()) {
g_cli_interrupted.store(false);
break;
}
// remove trailing newline
if (!buffer.empty() && buffer.back() == '\n') {
buffer.pop_back();
}
// skip empty messages
if (buffer.empty()) {
continue;
}
bool add_user_msg = true;
// process commands
if (string_starts_with(buffer, "/exit")) {
break;
} else if (string_starts_with(buffer, "/regen")) {
if (messages.size() >= 2) {
size_t last_idx = messages.size() - 1;
messages.erase(last_idx);
add_user_msg = false;
} else {
view::show_error("No message to regenerate.");
continue;
}
} else if (string_starts_with(buffer, "/clear")) {
messages.clear();
add_system_prompt();
pending_media = json::array();
view::show_message("Chat history cleared.");
continue;
} else if (
(string_starts_with(buffer, "/image ") && has_vision) ||
(string_starts_with(buffer, "/audio ") && has_audio) ||
(string_starts_with(buffer, "/video ") && has_video)) {
std::string type = buffer.substr(1, 5);
// just in case (bad copy-paste for example), we strip all trailing/leading spaces
std::string fname = string_strip(buffer.substr(7));
if (!stage_media_file(fname, type)) {
view::show_error(string_format("file does not exist or cannot be opened: '%s'", fname.c_str()));
continue;
}
view::show_message(string_format("Loaded media from '%s'", fname.c_str()));
continue;
} else if (string_starts_with(buffer, "/read ")) {
std::string fname = string_strip(buffer.substr(6));
add_text_file(fname);
continue;
} else if (string_starts_with(buffer, "/glob ")) {
std::error_code ec;
size_t count = 0;
auto curdir = std::filesystem::current_path();
std::string pattern = string_strip(buffer.substr(6));
std::filesystem::path rel_path;
auto startglob = pattern.find_first_of("![*?");
if (startglob != std::string::npos && startglob != 0) {
auto endpath = pattern.substr(0, startglob).find_last_of('/');
if (endpath != std::string::npos) {
std::string rel_pattern = pattern.substr(0, endpath);
#if !defined(_WIN32)
if (string_starts_with(rel_pattern, '~')) {
const char * home = std::getenv("HOME");
if (home && home[0]) {
rel_pattern = home + rel_pattern.substr(1);
}
}
#endif
rel_path = rel_pattern;
pattern.erase(0, endpath + 1);
curdir /= rel_path;
}
}
for (const auto & entry : std::filesystem::recursive_directory_iterator(curdir,
std::filesystem::directory_options::skip_permission_denied, ec)) {
if (!entry.is_regular_file()) {
continue;
}
std::string rel = std::filesystem::relative(entry.path(), curdir, ec).string();
if (ec) {
ec.clear();
continue;
}
std::replace(rel.begin(), rel.end(), '\\', '/');
if (!glob_match(pattern, rel)) {
continue;
}
if (!add_text_file((rel_path / rel).string())) {
continue;
}
if (++count >= FILE_GLOB_MAX_RESULTS) {
view::show_error(string_format("Maximum number of globbed files allowed (%zu) reached.", FILE_GLOB_MAX_RESULTS));
break;
}
}
continue;
} else {
// not a command
cur_msg += buffer;
}
// generate response
if (add_user_msg) {
push_user_message(cur_msg);
cur_msg.clear();
}
cli_timings timings;
std::string assistant_content;
generate_completion(assistant_content, timings);
messages.push_back({
{"role", "assistant"},
{"content", assistant_content}
});
if (params.show_timings) {
view::show_info(string_format(
"\n[ Prompt: %.1f t/s | Generation: %.1f t/s ]",
timings.prompt_per_second,
timings.predicted_per_second
));
}
if (params.single_turn) {
break;
}
}
view::show_message("\n\nExiting...");
return 0;
}
void cli_context::shutdown() {
if (server) {
server->stop();
server.reset();
}
}
-65
View File
@@ -1,65 +0,0 @@
#pragma once
#include "common.h"
#include "cli-client.h"
#include "cli-server.h"
#include <atomic>
#include <optional>
#include <string>
struct cli_timings {
double prompt_per_second = 0.0;
double predicted_per_second = 0.0;
};
// set by the SIGINT handler; cleared once the interrupt has been handled
extern std::atomic<bool> g_cli_interrupted;
struct cli_context {
common_params params;
cli_client client; // always initialized
std::optional<cli_server> server; // only set when no --server-base is given
json messages = json::array();
json pending_media = json::array(); // staged multimodal content parts
// properties of the connected server
// will be populated by fetch_server_props()
std::string model_name;
std::string build_info;
bool has_vision = false;
bool has_audio = false;
bool has_video = false;
cli_context(const common_params & params) : params(params) {}
~cli_context() {
shutdown();
}
// connect to --server-base or spawn a local llama-server child;
// argc/argv are needed to forward the server-relevant args to the child
bool init();
// run the interactive chat loop, returns the process exit code
int run();
// stop the local server child (if any)
void shutdown();
private:
bool generate_completion(std::string & assistant_content, cli_timings & timings);
void fetch_server_props();
void add_system_prompt();
void push_user_message(const std::string & text);
// check if server have multiple models (router mode)
// if yes, list them then ask; do nothing otherwise
bool list_and_ask_models();
// read a file and stage it as a multimodal content part; type is one of
// "image", "audio", "video"; returns false if the file cannot be read
bool stage_media_file(const std::string & fname, const std::string & type);
};
-83
View File
@@ -1,83 +0,0 @@
#pragma once
#include <thread>
#include "http.h"
// llama_server will be available as a dynamic library symbol
int llama_server(common_params & params, int argc, char ** argv);
void llama_server_terminate();
struct cli_server {
std::thread th;
int port = -1;
std::atomic<bool> is_alive = false;
std::atomic<bool> is_stopping = false;
~cli_server() {
stop();
}
void stop() {
if (alive() && !is_stopping.exchange(true)) {
llama_server_terminate();
th.join();
}
}
// spawn llama-server in a thread and interact with it via a random port
bool start(common_params & params) {
port = common_http_get_free_port();
if (port <= 0) {
fprintf(stderr, "failed to get a free port\n");
exit(1);
}
is_alive.store(true, std::memory_order_release);
th = std::thread([&]() {
common_params server_params = params; // copy
server_params.port = port;
// argc / argv are only used in router mode, we can skip them for now
int res = llama_server(server_params, 0, nullptr);
if (res != 0) {
fprintf(stderr, "llama_server exited with code %d\n", res);
}
is_alive.store(false, std::memory_order_release);
});
return true;
}
std::string address() const {
return "http://127.0.0.1:" + std::to_string(port);
}
bool wait_ready(std::function<bool()> should_stop) {
if (!alive()) {
return false;
}
while (!should_stop()) {
auto [cli, parts] = common_http_client(address());
cli.set_connection_timeout(1, 0);
auto res = cli.Get("/health");
if (res) {
if (res->status == 200) {
return true;
}
// any other status means the server is up but not ready yet
// (e.g. 503 while the model is still loading)
}
if (!alive()) {
// in case server die permanently
return false;
}
std::this_thread::sleep_for(std::chrono::milliseconds(200));
}
return true;
}
bool alive() const {
return is_alive.load(std::memory_order_acquire);
}
};
-250
View File
@@ -1,250 +0,0 @@
#pragma once
#include "common.h"
#include "console.h"
#include <array>
#include <algorithm>
#include <filesystem>
#include <string_view>
// TODO?: Make this reusable, enums, docs
static const std::array<std::string_view, 8> cmds = {
"/audio ",
"/clear",
"/exit",
"/glob ",
"/image ",
"/read ",
"/regen",
"/video ",
};
static std::vector<std::pair<std::string, size_t>> auto_completion_callback(std::string_view line, size_t cursor_byte_pos) {
std::vector<std::pair<std::string, size_t>> matches;
std::string cmd;
if (line.length() > 1 && line.front() == '/' && !std::any_of(cmds.begin(), cmds.end(), [line](std::string_view prefix) {
return string_starts_with(line, prefix);
})) {
auto it = cmds.begin();
while ((it = std::find_if(it, cmds.end(), [line](std::string_view cmd_line) {
return string_starts_with(cmd_line, line);
})) != cmds.end()) {
matches.emplace_back(*it, it->length());
++it;
}
} else {
auto it = std::find_if(cmds.begin(), cmds.end(), [line](std::string_view prefix) {
return prefix.back() == ' ' && string_starts_with(line, prefix);
});
if (it != cmds.end()) {
cmd = *it;
}
}
if (!cmd.empty() && cmd != "/glob " && line.length() >= cmd.length() && cursor_byte_pos >= cmd.length()) {
const std::string path_prefix = std::string(line.substr(cmd.length(), cursor_byte_pos - cmd.length()));
const std::string path_postfix = std::string(line.substr(cursor_byte_pos));
auto cur_dir = std::filesystem::current_path();
std::string cur_dir_str = cur_dir.string();
std::string expanded_prefix = path_prefix;
#if !defined(_WIN32)
if (string_starts_with(path_prefix, '~')) {
const char * home = std::getenv("HOME");
if (home && home[0]) {
expanded_prefix = home + path_prefix.substr(1);
}
}
if (string_starts_with(expanded_prefix, '/')) {
#else
if (std::isalpha(expanded_prefix[0]) && expanded_prefix.find(':') == 1) {
#endif
cur_dir = std::filesystem::path(expanded_prefix).parent_path();
cur_dir_str.clear();
} else if (!path_prefix.empty()) {
cur_dir /= std::filesystem::path(path_prefix).parent_path();
}
std::error_code ec;
for (const auto & entry : std::filesystem::directory_iterator(cur_dir, ec)) {
if (ec) {
break;
}
if (!entry.exists(ec)) {
ec.clear();
continue;
}
const std::string path_full = entry.path().string();
std::string path_entry = !cur_dir_str.empty() && string_starts_with(path_full, cur_dir_str) ? path_full.substr(cur_dir_str.length() + 1) : path_full;
if (entry.is_directory(ec)) {
path_entry.push_back(std::filesystem::path::preferred_separator);
}
if (expanded_prefix.empty() || string_starts_with(path_entry, expanded_prefix)) {
const std::string updated_line = cmd + path_entry;
matches.emplace_back(updated_line + path_postfix, updated_line.length());
}
if (ec) {
ec.clear();
}
}
if (matches.empty()) {
const std::string updated_line = cmd + path_prefix;
matches.emplace_back(updated_line + path_postfix, updated_line.length());
}
// Add the longest common prefix
if (!expanded_prefix.empty() && matches.size() > 1) {
const std::string_view match0(matches[0].first);
const std::string_view match1(matches[1].first);
auto it = std::mismatch(match0.begin(), match0.end(), match1.begin(), match1.end());
size_t len = it.first - match0.begin();
for (size_t i = 2; i < matches.size(); ++i) {
const std::string_view matchi(matches[i].first);
auto cmp = std::mismatch(match0.begin(), match0.end(), matchi.begin(), matchi.end());
len = std::min(len, static_cast<size_t>(cmp.first - match0.begin()));
}
const std::string updated_line = std::string(match0.substr(0, len));
matches.emplace_back(updated_line + path_postfix, updated_line.length());
}
std::sort(matches.begin(), matches.end(), [](const auto & a, const auto & b) {
return a.first.compare(0, a.second, b.first, 0, b.second) < 0;
});
}
return matches;
}
// note: make this view implementation generic, so that we can move to TUI in the future if we want to
namespace view {
static void init(const common_params & params) {
// TODO: avoid using atexit() here by making `console` a singleton
console::init(params.simple_io, params.use_color);
atexit([]() { console::cleanup(); });
console::set_completion_callback(auto_completion_callback);
}
struct spinner {
spinner(const std::string & message) {
if (!message.empty()) {
console::log("%s ", message.c_str());
}
console::spinner::start();
}
~spinner() {
console::spinner::stop();
}
};
struct user_turn {
user_turn() {
console::set_display(DISPLAY_TYPE_USER_INPUT);
}
~user_turn() {
console::set_display(DISPLAY_TYPE_RESET);
}
void echo(const std::string & buffer) {
if (buffer.size() > 500) {
console::log("\n> %s ... (truncated)\n", buffer.substr(0, 500).c_str());
} else {
console::log("\n> %s\n", buffer.c_str());
}
}
std::string read_input(bool multiline_input, const char * prompt = nullptr) {
if (prompt) {
console::log("%s", prompt);
} else {
console::log("\n> ");
}
std::string buffer;
std::string line;
bool another_line = true;
do {
another_line = console::readline(line, multiline_input);
buffer += line;
} while (another_line);
return buffer;
}
};
enum assistant_display_mode {
ASSISTANT_DISPLAY_MODE_REASONING,
ASSISTANT_DISPLAY_MODE_CONTENT,
};
struct assistant_turn {
assistant_display_mode mode = ASSISTANT_DISPLAY_MODE_CONTENT;
bool trailing_newline = true;
bool is_inside_reasoning = false;
assistant_turn() {
console::set_display(DISPLAY_TYPE_RESET);
}
~assistant_turn() {
console::set_display(DISPLAY_TYPE_RESET);
add_newline_if_needed();
}
void push(assistant_display_mode m, const std::string & buffer) {
if (m != mode) {
add_newline_if_needed();
switch (m) {
case ASSISTANT_DISPLAY_MODE_CONTENT:
{
if (is_inside_reasoning) {
console::log("[End thinking]\n\n");
is_inside_reasoning = false;
}
console::set_display(DISPLAY_TYPE_RESET);
} break;
case ASSISTANT_DISPLAY_MODE_REASONING:
{
console::set_display(DISPLAY_TYPE_REASONING);
is_inside_reasoning = true;
console::log("\n[Start thinking]\n\n");
} break;
}
}
mode = m;
if (buffer.empty()) {
return;
}
trailing_newline = buffer.back() == '\n';
console::log("%s", buffer.c_str());
console::flush();
}
void add_newline_if_needed() {
if (!trailing_newline) {
console::log("\n");
console::flush();
}
}
};
static void show_error(const std::string & title, const std::string & message = "") {
console::spinner::stop();
console::error("Error: %s\n", title.c_str());
if (!message.empty()) {
console::log("%s\n", message.c_str());
}
}
static void show_message(const std::string & message) {
console::log("%s\n", message.c_str());
}
static void show_info(const std::string & message) {
console::set_display(DISPLAY_TYPE_INFO);
console::log("%s\n", message.c_str());
console::set_display(DISPLAY_TYPE_RESET);
}
}
+624 -10
View File
@@ -1,10 +1,20 @@
#include "arg.h"
#include "chat.h"
#include "common.h"
#include "log.h"
#include "arg.h"
#include "console.h"
#include "fit.h"
// #include "log.h"
#include "cli-context.h"
#include "cli-view.h"
#include "server-common.h"
#include "server-context.h"
#include "server-task.h"
#include <array>
#include <atomic>
#include <algorithm>
#include <filesystem>
#include <fstream>
#include <thread>
#include <signal.h>
#if defined(_WIN32)
@@ -15,19 +25,342 @@
#include <windows.h>
#endif
const char * LLAMA_ASCII_LOGO = R"(
)";
static std::atomic<bool> g_is_interrupted = false;
static bool should_stop() {
return g_is_interrupted.load();
}
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
static void signal_handler(int) {
if (g_cli_interrupted.load()) {
if (g_is_interrupted.load()) {
// second Ctrl+C - exit immediately
// make sure to clear colors before exiting (not using LOG or console.cpp here to avoid deadlock)
fprintf(stdout, "\033[0m\n");
fflush(stdout);
std::exit(130);
}
g_cli_interrupted.store(true);
g_is_interrupted.store(true);
}
#endif
struct cli_context {
server_context ctx_server;
json messages = json::array();
std::vector<raw_buffer> input_files;
task_params defaults;
bool verbose_prompt;
// thread for showing "loading" animation
std::atomic<bool> loading_show;
cli_context(const common_params & params) {
defaults.sampling = params.sampling;
defaults.speculative = params.speculative;
defaults.n_keep = params.n_keep;
defaults.n_predict = params.n_predict;
defaults.antiprompt = params.antiprompt;
defaults.stream = true; // make sure we always use streaming mode
defaults.timings_per_token = true; // in order to get timings even when we cancel mid-way
// defaults.return_progress = true; // TODO: show progress
verbose_prompt = params.verbose_prompt;
}
std::string generate_completion(result_timings & out_timings) {
server_response_reader rd = ctx_server.get_response_reader();
auto chat_params = format_chat();
{
// TODO: reduce some copies here in the future
server_task task = server_task(SERVER_TASK_TYPE_COMPLETION);
task.id = rd.get_new_id();
task.index = 0;
task.params = defaults; // copy
task.cli_prompt = chat_params.prompt; // copy
task.cli_files = input_files; // copy
task.cli = true;
// chat template settings
task.params.chat_parser_params = common_chat_parser_params(chat_params);
task.params.chat_parser_params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
if (!chat_params.parser.empty()) {
task.params.chat_parser_params.parser.load(chat_params.parser);
}
// Copy the preserved tokens into the sampling params
const llama_vocab * vocab = llama_model_get_vocab(
llama_get_model(ctx_server.get_llama_context()));
for (const auto & token : chat_params.preserved_tokens) {
auto ids = common_tokenize(vocab, token, false, true);
if (ids.size() == 1) {
task.params.sampling.preserved_tokens.insert(ids[0]);
}
}
// reasoning budget sampler
if (!chat_params.thinking_end_tag.empty()) {
task.params.sampling.reasoning_budget_tokens = defaults.sampling.reasoning_budget_tokens;
task.params.sampling.generation_prompt = chat_params.generation_prompt;
if (!chat_params.thinking_start_tag.empty()) {
task.params.sampling.reasoning_budget_start =
common_tokenize(vocab, chat_params.thinking_start_tag, false, true);
}
task.params.sampling.reasoning_budget_end =
common_tokenize(vocab, chat_params.thinking_end_tag, false, true);
task.params.sampling.reasoning_budget_forced =
common_tokenize(vocab, defaults.sampling.reasoning_budget_message + chat_params.thinking_end_tag, false, true);
}
rd.post_task({std::move(task)});
}
if (verbose_prompt) {
console::set_display(DISPLAY_TYPE_PROMPT);
console::log("%s\n\n", chat_params.prompt.c_str());
console::set_display(DISPLAY_TYPE_RESET);
}
// wait for first result
console::spinner::start();
server_task_result_ptr result = rd.next(should_stop);
while (true) {
auto res_partial = dynamic_cast<server_task_result_cmpl_partial *>(result.get());
if (res_partial && res_partial->is_begin) {
// this is the "send 200 status to client" signal in streaming mode
// skip, do not stop the spinner
result = rd.next(should_stop);
} else {
console::spinner::stop();
break;
}
}
std::string curr_content;
bool is_thinking = false;
while (result) {
if (should_stop()) {
break;
}
if (result->is_error()) {
json err_data = result->to_json();
if (err_data.contains("message")) {
console::error("Error: %s\n", err_data["message"].get<std::string>().c_str());
} else {
console::error("Error: %s\n", err_data.dump().c_str());
}
return curr_content;
}
auto res_partial = dynamic_cast<server_task_result_cmpl_partial *>(result.get());
if (res_partial) {
out_timings = std::move(res_partial->timings);
for (const auto & diff : res_partial->oaicompat_msg_diffs) {
if (!diff.content_delta.empty()) {
if (is_thinking) {
console::log("\n[End thinking]\n\n");
console::set_display(DISPLAY_TYPE_RESET);
is_thinking = false;
}
curr_content += diff.content_delta;
console::log("%s", diff.content_delta.c_str());
console::flush();
}
if (!diff.reasoning_content_delta.empty()) {
console::set_display(DISPLAY_TYPE_REASONING);
if (!is_thinking) {
console::log("[Start thinking]\n");
}
is_thinking = true;
console::log("%s", diff.reasoning_content_delta.c_str());
console::flush();
}
}
}
auto res_final = dynamic_cast<server_task_result_cmpl_final *>(result.get());
if (res_final) {
out_timings = std::move(res_final->timings);
break;
}
result = rd.next(should_stop);
}
g_is_interrupted.store(false);
// server_response_reader automatically cancels pending tasks upon destruction
return curr_content;
}
// TODO: support remote files in the future (http, https, etc)
std::string load_input_file(const std::string & fname, bool is_media) {
std::ifstream file = fs_open_ifstream(fname, std::ios::binary);
if (!file) {
return "";
}
if (is_media) {
raw_buffer buf;
buf.assign((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
input_files.push_back(std::move(buf));
return get_media_marker();
} else {
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
return content;
}
}
common_chat_params format_chat() {
auto meta = ctx_server.get_meta();
auto & chat_params = meta.chat_params;
auto caps = common_chat_templates_get_caps(chat_params.tmpls.get());
common_chat_templates_inputs inputs;
inputs.messages = common_chat_msgs_parse_oaicompat(messages);
inputs.tools = {}; // TODO
inputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE;
inputs.json_schema = ""; // TODO
inputs.grammar = ""; // TODO
inputs.use_jinja = chat_params.use_jinja;
inputs.parallel_tool_calls = caps["supports_parallel_tool_calls"];
inputs.add_generation_prompt = true;
inputs.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
inputs.force_pure_content = chat_params.force_pure_content;
inputs.enable_thinking = chat_params.enable_thinking ? common_chat_templates_support_enable_thinking(chat_params.tmpls.get()) : false;
// Apply chat template to the list of messages
return common_chat_templates_apply(chat_params.tmpls.get(), inputs);
}
};
// TODO?: Make this reusable, enums, docs
static const std::array<std::string_view, 8> cmds = {
"/audio ",
"/clear",
"/exit",
"/glob ",
"/image ",
"/read ",
"/regen",
"/video ",
};
static std::vector<std::pair<std::string, size_t>> auto_completion_callback(std::string_view line, size_t cursor_byte_pos) {
std::vector<std::pair<std::string, size_t>> matches;
std::string cmd;
if (line.length() > 1 && line.front() == '/' && !std::any_of(cmds.begin(), cmds.end(), [line](std::string_view prefix) {
return string_starts_with(line, prefix);
})) {
auto it = cmds.begin();
while ((it = std::find_if(it, cmds.end(), [line](std::string_view cmd_line) {
return string_starts_with(cmd_line, line);
})) != cmds.end()) {
matches.emplace_back(*it, it->length());
++it;
}
} else {
auto it = std::find_if(cmds.begin(), cmds.end(), [line](std::string_view prefix) {
return prefix.back() == ' ' && string_starts_with(line, prefix);
});
if (it != cmds.end()) {
cmd = *it;
}
}
if (!cmd.empty() && cmd != "/glob " && line.length() >= cmd.length() && cursor_byte_pos >= cmd.length()) {
const std::string path_prefix = std::string(line.substr(cmd.length(), cursor_byte_pos - cmd.length()));
const std::string path_postfix = std::string(line.substr(cursor_byte_pos));
auto cur_dir = std::filesystem::current_path();
std::string cur_dir_str = cur_dir.string();
std::string expanded_prefix = path_prefix;
#if !defined(_WIN32)
if (string_starts_with(path_prefix, '~')) {
const char * home = std::getenv("HOME");
if (home && home[0]) {
expanded_prefix = home + path_prefix.substr(1);
}
}
if (string_starts_with(expanded_prefix, '/')) {
#else
if (std::isalpha(expanded_prefix[0]) && expanded_prefix.find(':') == 1) {
#endif
cur_dir = std::filesystem::path(expanded_prefix).parent_path();
cur_dir_str.clear();
} else if (!path_prefix.empty()) {
cur_dir /= std::filesystem::path(path_prefix).parent_path();
}
std::error_code ec;
for (const auto & entry : std::filesystem::directory_iterator(cur_dir, ec)) {
if (ec) {
break;
}
if (!entry.exists(ec)) {
ec.clear();
continue;
}
const std::string path_full = entry.path().string();
std::string path_entry = !cur_dir_str.empty() && string_starts_with(path_full, cur_dir_str) ? path_full.substr(cur_dir_str.length() + 1) : path_full;
if (entry.is_directory(ec)) {
path_entry.push_back(std::filesystem::path::preferred_separator);
}
if (expanded_prefix.empty() || string_starts_with(path_entry, expanded_prefix)) {
const std::string updated_line = cmd + path_entry;
matches.emplace_back(updated_line + path_postfix, updated_line.length());
}
if (ec) {
ec.clear();
}
}
if (matches.empty()) {
const std::string updated_line = cmd + path_prefix;
matches.emplace_back(updated_line + path_postfix, updated_line.length());
}
// Add the longest common prefix
if (!expanded_prefix.empty() && matches.size() > 1) {
const std::string_view match0(matches[0].first);
const std::string_view match1(matches[1].first);
auto it = std::mismatch(match0.begin(), match0.end(), match1.begin(), match1.end());
size_t len = it.first - match0.begin();
for (size_t i = 2; i < matches.size(); ++i) {
const std::string_view matchi(matches[i].first);
auto cmp = std::mismatch(match0.begin(), match0.end(), matchi.begin(), matchi.end());
len = std::min(len, static_cast<size_t>(cmp.first - match0.begin()));
}
const std::string updated_line = std::string(match0.substr(0, len));
matches.emplace_back(updated_line + path_postfix, updated_line.length());
}
std::sort(matches.begin(), matches.end(), [](const auto & a, const auto & b) {
return a.first.compare(0, a.second, b.first, 0, b.second) < 0;
});
}
return matches;
}
static constexpr size_t FILE_GLOB_MAX_RESULTS = 100;
// satisfies -Wmissing-declarations
int llama_cli(int argc, char ** argv);
@@ -42,6 +375,25 @@ int llama_cli(int argc, char ** argv) {
return 1;
}
// TODO: maybe support it later?
if (params.conversation_mode == COMMON_CONVERSATION_MODE_DISABLED) {
console::error("--no-conversation is not supported by llama-cli\n");
console::error("please use llama-completion instead\n");
}
// struct that contains llama context and inference
cli_context ctx_cli(params);
llama_backend_init();
llama_numa_init(params.numa);
// TODO: avoid using atexit() here by making `console` a singleton
console::init(params.simple_io, params.use_color);
atexit([]() { console::cleanup(); });
console::set_display(DISPLAY_TYPE_RESET);
console::set_completion_callback(auto_completion_callback);
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action;
sigint_action.sa_handler = signal_handler;
@@ -56,11 +408,273 @@ int llama_cli(int argc, char ** argv) {
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif
cli_context ctx_cli(params);
if (!ctx_cli.init()) {
console::log("\nLoading model... "); // followed by loading animation
console::spinner::start();
if (!ctx_cli.ctx_server.load_model(params)) {
console::spinner::stop();
console::error("\nFailed to load the model\n");
return 1;
}
return ctx_cli.run();
ctx_cli.defaults.sampling = params.sampling;
console::spinner::stop();
console::log("\n");
std::thread inference_thread([&ctx_cli]() {
ctx_cli.ctx_server.start_loop();
});
auto inf = ctx_cli.ctx_server.get_meta();
std::string modalities = "text";
if (inf.has_inp_image) {
modalities += ", vision";
}
if (inf.has_inp_audio) {
modalities += ", audio";
}
auto add_system_prompt = [&]() {
if (!params.system_prompt.empty()) {
ctx_cli.messages.push_back({
{"role", "system"},
{"content", params.system_prompt}
});
}
};
add_system_prompt();
console::log("\n");
console::log("%s\n", LLAMA_ASCII_LOGO);
console::log("build : %s\n", inf.build_info.c_str());
console::log("model : %s\n", inf.model_name.c_str());
console::log("modalities : %s\n", modalities.c_str());
if (!params.system_prompt.empty()) {
console::log("using custom system prompt\n");
}
console::log("\n");
console::log("available commands:\n");
console::log(" /exit or Ctrl+C stop or exit\n");
console::log(" /regen regenerate the last response\n");
console::log(" /clear clear the chat history\n");
console::log(" /read <file> add a text file\n");
console::log(" /glob <pattern> add text files using globbing pattern\n");
if (inf.has_inp_image) {
console::log(" /image <file> add an image file\n");
}
if (inf.has_inp_audio) {
console::log(" /audio <file> add an audio file\n");
}
if (inf.has_inp_video) {
console::log(" /video <file> add a video file\n");
}
console::log("\n");
// interactive loop
std::string cur_msg;
auto add_text_file = [&](const std::string & fname) -> bool {
std::string marker = ctx_cli.load_input_file(fname, false);
if (marker.empty()) {
console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
return false;
}
if (inf.fim_sep_token != LLAMA_TOKEN_NULL) {
cur_msg += common_token_to_piece(ctx_cli.ctx_server.get_llama_context(), inf.fim_sep_token, true);
cur_msg += fname;
cur_msg.push_back('\n');
} else {
cur_msg += "--- File: ";
cur_msg += fname;
cur_msg += " ---\n";
}
cur_msg += marker;
console::log("Loaded text from '%s'\n", fname.c_str());
return true;
};
while (true) {
std::string buffer;
console::set_display(DISPLAY_TYPE_USER_INPUT);
if (params.prompt.empty()) {
console::log("\n> ");
std::string line;
bool another_line = true;
do {
another_line = console::readline(line, params.multiline_input);
buffer += line;
} while (another_line);
} else {
// process input prompt from args
for (auto & fname : params.image) {
std::string marker = ctx_cli.load_input_file(fname, true);
if (marker.empty()) {
console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
break;
}
console::log("Loaded media from '%s'\n", fname.c_str());
cur_msg += marker;
}
buffer = params.prompt;
if (buffer.size() > 500) {
console::log("\n> %s ... (truncated)\n", buffer.substr(0, 500).c_str());
} else {
console::log("\n> %s\n", buffer.c_str());
}
params.prompt.clear(); // only use it once
}
console::set_display(DISPLAY_TYPE_RESET);
console::log("\n");
if (should_stop()) {
g_is_interrupted.store(false);
break;
}
// remove trailing newline
if (!buffer.empty() &&buffer.back() == '\n') {
buffer.pop_back();
}
// skip empty messages
if (buffer.empty()) {
continue;
}
bool add_user_msg = true;
// process commands
if (string_starts_with(buffer, "/exit")) {
break;
} else if (string_starts_with(buffer, "/regen")) {
if (ctx_cli.messages.size() >= 2) {
size_t last_idx = ctx_cli.messages.size() - 1;
ctx_cli.messages.erase(last_idx);
add_user_msg = false;
} else {
console::error("No message to regenerate.\n");
continue;
}
} else if (string_starts_with(buffer, "/clear")) {
ctx_cli.messages.clear();
add_system_prompt();
ctx_cli.input_files.clear();
console::log("Chat history cleared.\n");
continue;
} else if (
(string_starts_with(buffer, "/image ") && inf.has_inp_image) ||
(string_starts_with(buffer, "/audio ") && inf.has_inp_audio) ||
(string_starts_with(buffer, "/video ") && inf.has_inp_video)) {
// just in case (bad copy-paste for example), we strip all trailing/leading spaces
std::string fname = string_strip(buffer.substr(7));
std::string marker = ctx_cli.load_input_file(fname, true);
if (marker.empty()) {
console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
continue;
}
cur_msg += marker;
console::log("Loaded media from '%s'\n", fname.c_str());
continue;
} else if (string_starts_with(buffer, "/read ")) {
std::string fname = string_strip(buffer.substr(6));
add_text_file(fname);
continue;
} else if (string_starts_with(buffer, "/glob ")) {
std::error_code ec;
size_t count = 0;
auto curdir = std::filesystem::current_path();
std::string pattern = string_strip(buffer.substr(6));
std::filesystem::path rel_path;
auto startglob = pattern.find_first_of("![*?");
if (startglob != std::string::npos && startglob != 0) {
auto endpath = pattern.substr(0, startglob).find_last_of('/');
if (endpath != std::string::npos) {
std::string rel_pattern = pattern.substr(0, endpath);
#if !defined(_WIN32)
if (string_starts_with(rel_pattern, '~')) {
const char * home = std::getenv("HOME");
if (home && home[0]) {
rel_pattern = home + rel_pattern.substr(1);
}
}
#endif
rel_path = rel_pattern;
pattern.erase(0, endpath + 1);
curdir /= rel_path;
}
}
for (const auto & entry : std::filesystem::recursive_directory_iterator(curdir,
std::filesystem::directory_options::skip_permission_denied, ec)) {
if (!entry.is_regular_file()) {
continue;
}
std::string rel = std::filesystem::relative(entry.path(), curdir, ec).string();
if (ec) {
ec.clear();
continue;
}
std::replace(rel.begin(), rel.end(), '\\', '/');
if (!glob_match(pattern, rel)) {
continue;
}
if (!add_text_file((rel_path / rel).string())) {
continue;
}
if (++count >= FILE_GLOB_MAX_RESULTS) {
console::error("Maximum number of globbed files allowed (%zu) reached.\n", FILE_GLOB_MAX_RESULTS);
break;
}
}
continue;
} else {
// not a command
cur_msg += buffer;
}
// generate response
if (add_user_msg) {
ctx_cli.messages.push_back({
{"role", "user"},
{"content", cur_msg}
});
cur_msg.clear();
}
result_timings timings;
std::string assistant_content = ctx_cli.generate_completion(timings);
ctx_cli.messages.push_back({
{"role", "assistant"},
{"content", assistant_content}
});
console::log("\n");
if (params.show_timings) {
console::set_display(DISPLAY_TYPE_INFO);
console::log("\n");
console::log("[ Prompt: %.1f t/s | Generation: %.1f t/s ]\n", timings.prompt_per_second, timings.predicted_per_second);
console::set_display(DISPLAY_TYPE_RESET);
}
if (params.single_turn) {
break;
}
}
console::set_display(DISPLAY_TYPE_RESET);
console::log("\nExiting...\n");
ctx_cli.ctx_server.terminate();
inference_thread.join();
// bump the log level to display timings
common_log_set_verbosity_thold(LOG_LEVEL_INFO);
common_memory_breakdown_print(ctx_cli.ctx_server.get_llama_context());
return 0;
}
+1 -1
View File
@@ -42,7 +42,6 @@
#define KEY_N_HEAD "clip.%s.attention.head_count"
#define KEY_N_HEAD_KV "clip.%s.attention.head_count_kv"
#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon"
#define KEY_FEATURE_LAYERS "clip.%s.feature_layer"
// vision-specific
#define KEY_VISION_PROJ_TYPE "clip.vision.projector_type" // for models with mixed modalities
@@ -55,6 +54,7 @@
#define KEY_PATCH_SIZE "clip.vision.patch_size"
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
#define KEY_IMAGE_STD "clip.vision.image_std"
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
#define KEY_PROJ_SAMPLE_QUERY_SIDE "clip.vision.projector.query_side"
#define KEY_PROJ_SAMPLE_WINDOW_SIDE "clip.vision.projector.window_side"
+3 -3
View File
@@ -91,7 +91,7 @@ struct clip_hparams {
float eps = 1e-6;
float rope_theta = 0.0;
std::vector<int32_t> feature_layers;
std::vector<int32_t> vision_feature_layer;
int32_t attn_window_size = 0;
int32_t n_wa_pattern = 0;
std::unordered_set<int32_t> wa_layer_indexes; // explicit layer indexes that use full attention (for irregular patterns like YoutuVL)
@@ -165,8 +165,8 @@ struct clip_hparams {
return false;
}
bool is_feature_layer(int32_t layer) const {
return std::find(feature_layers.begin(), feature_layers.end(), layer) != feature_layers.end();
bool is_vision_feature_layer(int32_t layer) const {
return std::find(vision_feature_layer.begin(), vision_feature_layer.end(), layer) != vision_feature_layer.end();
}
};
+10 -9
View File
@@ -1264,10 +1264,12 @@ struct clip_model_loader {
}
}
// Load the vision/audio feature layer indices if they are explicitly provided
// Load the vision feature layer indices if they are explicitly provided;
// if multiple vision feature layers are present, the values will be concatenated
// to form the final visual features.
// NOTE: gguf conversions should standardize the values of the vision feature layer to
// be non-negative, since we use -1 to mark values as unset here.
get_arr_int(string_format(KEY_FEATURE_LAYERS, prefix), hparams.feature_layers, false);
get_arr_int(KEY_FEATURE_LAYER, hparams.vision_feature_layer, false);
// model-specific params
switch (model.proj_type) {
@@ -1649,7 +1651,6 @@ struct clip_model_loader {
get_u32(KEY_A_PROJ_WINDOW_SIZE, hparams.audio_proj_window_size);
get_u32(KEY_A_PROJ_DOWNSAMPLE_RATE, hparams.audio_proj_downsample_rate);
get_u32(KEY_A_PROJ_HEAD_COUNT, hparams.audio_proj_head_count);
// NOTE: feature layers loaded above in common path
} break;
case PROJECTOR_TYPE_JANUS_PRO:
{
@@ -1662,11 +1663,11 @@ struct clip_model_loader {
hparams.image_resize_algo = RESIZE_ALGO_BICUBIC_PILLOW;
hparams.image_resize_pad = PAD_CEIL;
// NOTE: feature_layers loaded in common path as optional
get_arr_int(KEY_FEATURE_LAYER, hparams.vision_feature_layer);
get_arr_int(KEY_PROJ_SPATIAL_OFFSETS, hparams.proj_spatial_offsets);
if (hparams.feature_layers.size() != hparams.proj_spatial_offsets.size()) {
throw std::runtime_error(string_format("%s: feature_layers.size() %d != proj_spatial_offsets.size() %d",
hparams.feature_layers.size(), hparams.proj_spatial_offsets.size()));
if (hparams.vision_feature_layer.size() != hparams.proj_spatial_offsets.size()) {
throw std::runtime_error(string_format("%s: vision_feature_layer.size() %d != proj_spatial_offsets.size() %d",
hparams.vision_feature_layer.size(), hparams.proj_spatial_offsets.size()));
}
get_u32(KEY_PROJ_SAMPLE_QUERY_SIDE, hparams.downsample_query_side);
@@ -2739,7 +2740,7 @@ struct clip_model_loader {
model.image_newline = get_tensor(TN_IMAGE_NEWLINE);
// Load separate layerwise and spatial projector tensors
const auto projector_count = hparams.feature_layers.size();
const auto projector_count = hparams.vision_feature_layer.size();
model.qf_proj_blocks.resize(projector_count);
for (size_t bid = 0; bid < projector_count; ++bid) {
auto & b = model.qf_proj_blocks[bid];
@@ -4387,7 +4388,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, int n_threads, const clip_image_f32
// Stage 1b only uses block 0's permutations; future stages
// will upload all blocks.
for (size_t bid = 0; bid < hparams.feature_layers.size(); ++bid) {
for (size_t bid = 0; bid < hparams.vision_feature_layer.size(); ++bid) {
const std::string prefix = "g4v_blk" + std::to_string(bid) + "_";
upload(prefix + "win_idx", make_win_idx(image_side, window_side));
upload(prefix + "qwin_idx", make_win_idx(new_side, query_side));
+1 -35
View File
@@ -1,7 +1,5 @@
#include "models.h"
#include <algorithm>
ggml_cgraph * clip_graph_granite_speech::build() {
const int n_frames = img.nx();
const int context_size = hparams.audio_chunk_size;
@@ -13,10 +11,6 @@ ggml_cgraph * clip_graph_granite_speech::build() {
const int padded_len = num_blocks * context_size;
const int remainder = n_frames % context_size;
// Calculate projector input dimension based on feature layers
const int proj_input_dim = n_embd * (hparams.feature_layers.size() + 1);
const bool use_feature_concat = !hparams.feature_layers.empty();
ggml_tensor * attn_dists = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, context_size * context_size);
ggml_set_name(attn_dists, "attn_dists");
ggml_set_input(attn_dists);
@@ -37,15 +31,6 @@ ggml_cgraph * clip_graph_granite_speech::build() {
cur = ggml_add(ctx0, cur, model.inp_proj_b);
cb(cur, "inp_linear", -1);
// Capture layer 0 if requested (after input_linear)
ggml_tensor * concat_result = nullptr;
if (use_feature_concat) {
if (std::find(hparams.feature_layers.begin(), hparams.feature_layers.end(), 0) != hparams.feature_layers.end()) {
concat_result = cur;
cb(concat_result, "feature_layer_0", -1);
}
}
for (int il = 0; il < n_layer; il++) {
const auto & layer = model.layers[il];
auto * residual = cur;
@@ -183,18 +168,6 @@ ggml_cgraph * clip_graph_granite_speech::build() {
NORM_TYPE_NORMAL, eps, il);
cb(cur, "layer_out", il);
// Capture intermediate layer (il + 1) if requested
if (use_feature_concat) {
if (hparams.is_feature_layer(il + 1)) {
if (concat_result == nullptr) {
concat_result = cur;
} else {
concat_result = ggml_concat(ctx0, concat_result, cur, 0);
}
cb(concat_result, string_format("feature_layer_%d", il + 1).c_str(), il);
}
}
// CTC branch
if (il + 1 == ctc_layer) {
auto * mid = build_mm(model.ctc_out_w, cur);
@@ -207,13 +180,6 @@ ggml_cgraph * clip_graph_granite_speech::build() {
}
}
// Append final output to concatenated features if using feature concatenation
if (use_feature_concat && concat_result != nullptr) {
concat_result = ggml_concat(ctx0, concat_result, cur, 0);
cb(concat_result, "concat_final", -1);
cur = concat_result;
}
cb(cur, "encoder_out", -1);
// QFormer projector
@@ -231,7 +197,7 @@ ggml_cgraph * clip_graph_granite_speech::build() {
cur = ggml_pad(ctx0, cur, 0, padded_proj - n_frames, 0, 0);
}
ggml_tensor * enc_windows = ggml_reshape_3d(ctx0, cur, proj_input_dim, window_size, nblocks_proj);
ggml_tensor * enc_windows = ggml_reshape_3d(ctx0, cur, n_embd, window_size, nblocks_proj);
ggml_tensor * queries = build_norm(model.qf_proj_blocks[0].qf_proj_query,
model.qf_proj_blocks[0].qf_proj_norm_w, model.qf_proj_blocks[0].qf_proj_norm_b,
+2 -2
View File
@@ -304,14 +304,14 @@ ggml_cgraph * clip_graph_granite4_vision::build() {
}
// --- Stage 1b/1c: WindowQFormer blocks ---
const int projector_count = hparams.feature_layers.size();
const int projector_count = hparams.vision_feature_layer.size();
const float qformer_eps = 1e-12f;
ggml_tensor * mmproj = nullptr;
for (int bid = 0; bid < projector_count; ++bid) {
const auto & blk = model.qf_proj_blocks[bid];
int vlayer = hparams.feature_layers[bid];
int vlayer = hparams.vision_feature_layer[bid];
GGML_ASSERT(vlayer >= 0 && vlayer < n_layer);
ggml_tensor * h = layer_outs[vlayer];
+3 -3
View File
@@ -21,7 +21,7 @@ ggml_cgraph * clip_graph_llava::build() {
// If we set explicit vision feature layers, only go up to the deepest one
// NOTE: only used by granite-vision models for now
for (const auto & feature_layer : hparams.feature_layers) {
for (const auto & feature_layer : hparams.vision_feature_layer) {
if (feature_layer > deepest_feature_layer) {
deepest_feature_layer = feature_layer;
}
@@ -59,7 +59,7 @@ ggml_cgraph * clip_graph_llava::build() {
// If this is an embedding feature layer, save the output.
// NOTE: 0 index here refers to the input to the encoder.
if (hparams.is_feature_layer(il)) {
if (hparams.is_vision_feature_layer(il)) {
embedding_stack.push_back(cur);
}
@@ -134,7 +134,7 @@ ggml_cgraph * clip_graph_llava::build() {
// process vision feature layers (used by granite)
{
// final layer is a vision feature layer
if (hparams.is_feature_layer(max_feature_layer)) {
if (hparams.is_vision_feature_layer(max_feature_layer)) {
embedding_stack.push_back(inpL);
}
+88 -9
View File
@@ -518,14 +518,6 @@ size_t server_tokens::get_common_prefix(const server_tokens & b) const {
return max_idx; // all tokens are equal
}
common_chat_msg_spans server_tokens::find_message_spans(const common_chat_msg_delimiters & delims) const {
std::map<size_t, size_t> skips;
for (const auto & it : map_idx_to_media) {
skips[it.first] = mtmd_input_chunk_get_n_tokens(it.second.get());
}
return delims.split(tokens, skips);
}
bool server_tokens::validate(const struct llama_context * ctx) const {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
@@ -1112,7 +1104,15 @@ json oaicompat_chat_params_parse(
llama_params["chat_parser"] = chat_params.parser;
}
llama_params["message_delimiters"] = chat_params.message_delimiters.to_json();
llama_params["message_spans"] = json::array();
for (const auto & span : chat_params.message_spans) {
llama_params["message_spans"].push_back({
{ "role", span.role },
{ "pos", span.pos },
{ "len", span.len },
});
}
// Reasoning budget: pass parameters through to sampling layer
{
@@ -1583,3 +1583,82 @@ server_tokens format_prompt_rerank(
return result;
}
//
// threadpool
//
server_threadpool::~server_threadpool() {
{
std::lock_guard<std::mutex> lock(mtx);
stop = true;
}
cv.notify_all();
for (auto & t : threads) t.join();
}
void server_threadpool::init(int n) {
// the caller (main thread) participates as a worker, so spawn n-1 threads
const int n_workers = std::max(1, n) - 1;
for (int i = 0; i < n_workers; i++) {
threads.emplace_back([this]() { run_worker(); });
}
}
void server_threadpool::run_worker() {
while (true) {
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(mtx);
cv.wait(lock, [this]() { return stop || !tasks.empty(); });
if (stop && tasks.empty()) return;
task = std::move(tasks.front());
tasks.pop();
}
task();
{
std::lock_guard<std::mutex> lock(mtx);
pending--;
}
cv_done.notify_all();
}
}
void server_threadpool::enqueue(std::function<void()> fn) {
{
std::lock_guard<std::mutex> lock(mtx);
GGML_ASSERT(!stop);
tasks.push(std::move(fn));
pending++;
}
cv.notify_one();
}
void server_threadpool::wait_all() {
// the calling thread helps drain the queue until no tasks remain pending
while (true) {
std::function<void()> task;
{
std::lock_guard<std::mutex> lock(mtx);
if (pending == 0) {
return;
}
if (!tasks.empty()) {
task = std::move(tasks.front());
tasks.pop();
}
}
if (task) {
task();
{
std::lock_guard<std::mutex> lock(mtx);
pending--;
}
cv_done.notify_all();
} else {
// no task available right now, but some are still pending (being run by workers)
std::unique_lock<std::mutex> lock(mtx);
cv_done.wait(lock, [this]() { return pending == 0 || !tasks.empty(); });
}
}
}
+41 -3
View File
@@ -12,6 +12,11 @@
#include <string>
#include <vector>
#include <cinttypes>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <queue>
#include <functional>
using json = nlohmann::ordered_json;
@@ -218,9 +223,6 @@ public:
size_t get_common_prefix(const server_tokens & b) const;
// split the tokens into message spans, skipping over media chunks
common_chat_msg_spans find_message_spans(const common_chat_msg_delimiters & delims) const;
// make sure all text tokens are within the vocab range
bool validate(const struct llama_context * ctx) const;
@@ -373,3 +375,39 @@ server_tokens format_prompt_rerank(
mtmd_context * mctx,
const std::string & query,
const std::string & doc);
//
// threadpool utils
// to be used for multi-threaded sampling
//
// the main thread participates as one of the pool's workers, so init(n)
// only spawns n-1 background threads (the caller is the nth)
struct server_threadpool {
std::vector<std::thread> threads;
std::queue<std::function<void()>> tasks;
std::mutex mtx;
std::condition_variable cv;
std::condition_variable cv_done;
int pending = 0;
bool stop = false;
~server_threadpool();
void init(int n);
template<typename T>
void run_all(std::vector<T> & tasks, std::function<void(T&)> handler) {
for (auto & item : tasks) {
enqueue([&handler, &item]() {
handler(item);
});
}
// the calling thread runs tasks too, until all are done
wait_all();
}
private:
void enqueue(std::function<void()> fn);
void wait_all();
void run_worker();
};
+142 -25
View File
@@ -866,6 +866,9 @@ public:
// note: chat_params must not be refreshed upon existing sleeping state
server_chat_params chat_params;
// threadpool for parallel sampling
server_threadpool threadpool;
server_state_callback_t callback_state = [](server_state, json) -> void {};
server_context_impl() {
@@ -1457,6 +1460,16 @@ private:
metrics.init();
// initialize threadpool
{
int threadpool_size = params_base.sampling_n_threads;
if (threadpool_size <= 0) {
threadpool_size = params_base.cpuparams.n_threads;
}
SRV_DBG("%s: initializing threadpool, size = %d\n", __func__, threadpool_size);
threadpool.init(threadpool_size);
}
if (params_base.cache_idle_slots) {
if (params_base.cache_ram_mib == 0) {
SRV_WRN("%s", "--cache-idle-slots requires --cache-ram, disabling\n");
@@ -3436,8 +3449,8 @@ private:
has_mtmd = true;
}
const auto & spans = slot.task->params.message_spans;
const auto last_user_pos = spans.last_user_message_pos();
const int32_t n_before_user = slot.task->params.n_before_user;
const bool n_before_user_known = n_before_user > 0;
// add prompt tokens for processing in the current batch
while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.size() < n_batch) {
@@ -3466,8 +3479,10 @@ private:
slot.n_prompt_tokens_processed++;
// stop the prompt batch exactly before a user message
if (spans.is_user_start(slot.prompt.n_tokens())) {
// stop the prompt batch exactly before the latest user input, so a checkpoint
// can be created after the previous messages
if (n_before_user_known &&
slot.prompt.n_tokens() == n_before_user) {
break;
}
@@ -3496,13 +3511,8 @@ private:
// the number of tokens added to the batch for the current slot
const auto n_tokens_cur = batch.size() - n_tokens_prev;
const auto n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur;
const bool near_prompt_end = slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch;
const bool is_user_start = spans.is_user_start(n_tokens_start);
const bool is_last_user_message = n_tokens_start == last_user_pos;
// entire prompt has been processed
if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
slot.state = SLOT_STATE_DONE_PROMPT;
@@ -3517,9 +3527,8 @@ private:
slot.init_sampler();
} else {
// skip ordinary mid-prompt checkpoints, unless the batch starts a user
// message or we are near the end of the prompt
if (!is_user_start && !near_prompt_end) {
// skip ordinary mid-prompt checkpoints
if (!n_before_user_known && !near_prompt_end) {
do_checkpoint = false;
}
}
@@ -3527,6 +3536,29 @@ private:
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id);
// checkpoints are created before the current batch is decoded, so
// their token position is the batch start rather than the prompt end
const int32_t n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur;
{
const bool is_on_user =
n_before_user_known &&
n_tokens_start == n_before_user;
const bool is_after_user =
n_before_user_known &&
n_tokens_start > n_before_user;
const bool is_allowed =
!n_before_user_known ||
is_on_user ||
(is_after_user && near_prompt_end);
if (do_checkpoint && !is_allowed) {
do_checkpoint = false;
}
}
// nothing to checkpoint yet
// TODO: is this check needed?
if (do_checkpoint && pos_min < 0) {
@@ -3536,8 +3568,8 @@ private:
// do not checkpoint after mtmd chunks
do_checkpoint = do_checkpoint && !has_mtmd;
// no need to create checkpoints that are too close together, unless it's the last user message
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || is_last_user_message || n_tokens_start > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step);
// no need to create checkpoints that are too close together
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || n_tokens_start > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step);
SLT_DBG(slot, "main/do_checkpoint = %s, pos_min = %d, pos_max = %d\n", do_checkpoint ? "yes" : "no", pos_min, pos_max);
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
@@ -3680,6 +3712,12 @@ private:
return true;
}
struct sampling_task {
server_slot * slot = nullptr;
int32_t tok_idx = 0;
llama_token sampled_id = LLAMA_TOKEN_NULL; // result
};
void post_decode(int32_t n_batch_tokens, int32_t off, llama_batch & batch_view) {
// for checking if a given batch index is inside batch_view
auto is_inside_view = [&](int32_t idx) {
@@ -3701,7 +3739,13 @@ private:
slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end();
};
std::vector<sampling_task> smpl_tasks;
smpl_tasks.resize(slots.size());
bool need_sampling = false;
iterate(slots, [&](server_slot & slot) {
auto & smpl_task = smpl_tasks[slot.id];
// optionally send prompt processing progress
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
if (slot.task->params.stream && slot.task->params.return_progress) {
@@ -3746,15 +3790,35 @@ private:
return; // sample using speculative decoding
}
// shifted according to the current sub-batch
const int tok_idx = slot.i_batch - off;
// otherwise, we must sample the next token
// also shift batch idx according to the current sub-batch
smpl_task.slot = &slot;
smpl_task.tok_idx = slot.i_batch - off;
need_sampling = true;
});
llama_token id;
{
scoped_timer timer(t_sampl, n_sampl);
id = common_sampler_sample(slot.smpl.get(), slot.ctx_tgt, tok_idx);
// run multiple sampling tasks in parallel
GGML_ASSERT(smpl_tasks.size() == slots.size());
if (need_sampling) {
llama_synchronize(ctx_tgt);
threadpool.run_all<sampling_task>(smpl_tasks, [](sampling_task & task) {
if (task.slot) {
task.sampled_id = common_sampler_sample(task.slot->smpl.get(),
task.slot->ctx_tgt, task.tok_idx);
}
});
}
iterate(slots, [&](server_slot & slot) {
auto & smpl_task = smpl_tasks[slot.id];
if (!smpl_task.slot) {
return;
}
auto tok_idx = smpl_task.tok_idx;
auto id = smpl_task.sampled_id;
slot.i_batch = -1;
common_sampler_accept(slot.smpl.get(), id, true);
@@ -4036,6 +4100,54 @@ void server_context::set_state_callback(server_state_callback_t callback) {
});
}
// compute the number of tokens before the last user message in the prompt
static int32_t prompt_get_n_before_user(
const json & message_spans,
const std::string & prompt,
const std::vector<raw_buffer> & files,
const llama_vocab * vocab,
mtmd_context * mctx) {
int32_t result = -1;
int32_t byte_pos = -1;
for (const auto & span : message_spans) {
const std::string role = json_value(span, "role", std::string());
if (role == "user") {
byte_pos = json_value(span, "pos", -1);
}
}
if (byte_pos >= 0) {
GGML_ASSERT((size_t) byte_pos <= prompt.size());
const std::string prefix = prompt.substr(0, (size_t) byte_pos);
const std::string marker = get_media_marker();
size_t n_prefix_media = 0;
for (size_t pos = 0; (pos = prefix.find(marker, pos)) != std::string::npos; pos += marker.size()) {
n_prefix_media++;
}
GGML_ASSERT(n_prefix_media <= files.size());
if (mctx != nullptr && n_prefix_media > 0) {
// TODO: this makes a copy - avoid it
std::vector<raw_buffer> prefix_files(files.begin(), files.begin() + n_prefix_media);
result = (int32_t) process_mtmd_prompt(mctx, prefix, prefix_files).size();
} else {
result = (int32_t) tokenize_input_prompts(vocab, nullptr, prefix, true, true)[0].size();
}
SRV_TRC("message_spans: last user message: byte_pos=%d, media=%zu, n_before_user=%d\n",
byte_pos, n_prefix_media, result);
}
return result;
}
//
// server_routes
//
@@ -4083,10 +4195,6 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
// tasks.reserve(inputs.size()); // TODO: this is inaccurate due to child tasks
// message delimiters for checkpointing
auto delimiters = common_chat_msg_delimiters_parse(json_value(data, "message_delimiters", json::array()));
delimiters.tokenize(ctx_server.vocab);
for (size_t i = 0; i < inputs.size(); i++) {
server_task task = server_task(type);
@@ -4100,7 +4208,16 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
meta->logit_bias_eog,
data);
task.params.message_spans = task.tokens.find_message_spans(delimiters);
const auto message_spans = json_value(data, "message_spans", json::array());
if (prompt.is_string() && message_spans.is_array()) {
task.params.n_before_user =
prompt_get_n_before_user(
message_spans,
prompt.get<std::string>(),
files,
ctx_server.vocab,
ctx_server.mctx);
}
task.id_slot = json_value(data, "id_slot", -1);
+71 -7
View File
@@ -5,7 +5,6 @@
#include "build-info.h"
#include "preset.h"
#include "download.h"
#include "http.h"
#include <cpp-httplib/httplib.h> // TODO: remove this once we use HTTP client from download.h
#include <sheredom/subprocess.h>
@@ -26,7 +25,14 @@
#include <sstream>
#include <cstring>
#ifndef _WIN32
#ifdef _WIN32
#include <winsock2.h>
#include <windows.h>
#else
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <unistd.h>
extern char **environ;
#endif
@@ -218,7 +224,7 @@ void server_model_meta::update_caps() {
});
params.offline = true;
// params.skip_download = true; // TODO: ideally, we should validate the model here, but it takes too much time
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, {});
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER);
if (params.mmproj.path.empty()) {
multimodal = { false, false };
} else {
@@ -698,6 +704,66 @@ std::optional<server_model_meta> server_models::get_meta(const std::string & nam
return std::nullopt;
}
static int get_free_port() {
#ifdef _WIN32
WSADATA wsaData;
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
return -1;
}
typedef SOCKET native_socket_t;
#define INVALID_SOCKET_VAL INVALID_SOCKET
#define CLOSE_SOCKET(s) closesocket(s)
#else
typedef int native_socket_t;
#define INVALID_SOCKET_VAL -1
#define CLOSE_SOCKET(s) close(s)
#endif
native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock == INVALID_SOCKET_VAL) {
#ifdef _WIN32
WSACleanup();
#endif
return -1;
}
struct sockaddr_in serv_addr;
std::memset(&serv_addr, 0, sizeof(serv_addr));
serv_addr.sin_family = AF_INET;
serv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
serv_addr.sin_port = htons(0);
if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) {
CLOSE_SOCKET(sock);
#ifdef _WIN32
WSACleanup();
#endif
return -1;
}
#ifdef _WIN32
int namelen = sizeof(serv_addr);
#else
socklen_t namelen = sizeof(serv_addr);
#endif
if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) {
CLOSE_SOCKET(sock);
#ifdef _WIN32
WSACleanup();
#endif
return -1;
}
int port = ntohs(serv_addr.sin_port);
CLOSE_SOCKET(sock);
#ifdef _WIN32
WSACleanup();
#endif
return port;
}
// helper to convert vector<string> to char **
// pointers are only valid as long as the original vector is valid
static std::vector<char *> to_char_ptr_array(const std::vector<std::string> & vec) {
@@ -801,7 +867,7 @@ void server_models::load(const std::string & name, const load_options & opts) {
// prepare new instance info
instance_t inst;
inst.meta = meta;
inst.meta.port = common_http_get_free_port();
inst.meta.port = get_free_port();
inst.meta.status = SERVER_MODEL_STATUS_LOADING;
inst.meta.loaded_info = json{};
inst.meta.last_used = ggml_time_ms();
@@ -1327,9 +1393,7 @@ struct server_download_state : public common_download_callback {
bool run(common_params & params) {
try {
common_params_handle_models_params p;
p.callback = this;
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, p);
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, this);
is_ok = true;
} catch (const std::exception & e) {
auto model_name = params.model.get_name();
+3 -6
View File
@@ -591,11 +591,10 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp() {
for (const common_chat_tool_call & tool_call : oaicompat_msg.tool_calls) {
output.push_back(json {
{"id", "fc_" + tool_call.id},
{"type", "function_call"},
{"status", "completed"},
{"arguments", tool_call.arguments},
{"call_id", "call_" + tool_call.id},
{"call_id", "fc_" + tool_call.id},
{"name", tool_call.name},
});
}
@@ -691,11 +690,10 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp_stream() {
for (const common_chat_tool_call & tool_call : oaicompat_msg.tool_calls) {
const json output_item = {
{"id", "fc_" + tool_call.id},
{"type", "function_call"},
{"status", "completed"},
{"arguments", tool_call.arguments},
{"call_id", "call_" + tool_call.id},
{"call_id", "fc_" + tool_call.id},
{"name", tool_call.name}
};
server_sent_events.push_back(json {
@@ -1279,9 +1277,8 @@ json server_task_result_cmpl_partial::to_json_oaicompat_resp() {
{"data", json {
{"type", "response.output_item.added"},
{"item", json {
{"id", "fc_" + diff.tool_call_delta.id},
{"arguments", ""},
{"call_id", "call_" + diff.tool_call_delta.id},
{"call_id", "fc_" + diff.tool_call_delta.id},
{"name", diff.tool_call_delta.name},
{"type", "function_call"},
{"status", "in_progress"},
+3 -3
View File
@@ -62,6 +62,9 @@ struct task_params {
int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled)
// number of prompt tokens before the latest user message
int32_t n_before_user = -1;
int64_t t_max_prompt_ms = -1; // TODO: implement
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
@@ -89,9 +92,6 @@ struct task_params {
// per-request parameters for chat parsing
common_chat_parser_params chat_parser_params;
// message spans for checkpointing
common_chat_msg_spans message_spans;
// Embeddings
int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
+16 -46
View File
@@ -35,19 +35,6 @@ static inline void signal_handler(int signal) {
shutdown_handler(signal);
}
// satisfies -Wmissing-declarations (used by llama command)
int llama_server(int argc, char ** argv);
// to be used via CLI (argc / argv are used by router mode only)
int llama_server(common_params & params, int argc, char ** argv);
void llama_server_terminate();
void llama_server_terminate() {
if (shutdown_handler) {
shutdown_handler(0);
}
}
// wrapper function that handles exceptions and logs errors
// this is to make sure handler_t never throws exceptions; instead, it returns an error response
static server_http_context::handler_t ex_wrapper(server_http_context::handler_t func) {
@@ -84,6 +71,9 @@ static server_http_context::handler_t ex_wrapper(server_http_context::handler_t
};
}
// satisfies -Wmissing-declarations
int llama_server(int argc, char ** argv);
int llama_server(int argc, char ** argv) {
std::setlocale(LC_NUMERIC, "C");
@@ -99,23 +89,6 @@ int llama_server(int argc, char ** argv) {
llama_backend_init();
llama_numa_init(params.numa);
return llama_server(params, argc, argv);
}
int llama_server(common_params & params, int argc, char ** argv) {
bool is_run_by_cli = (argv == nullptr);
// note: router mode also accepts -hf remote-preset, so we need to check that first
if (!is_run_by_cli && !params.model.hf_repo.empty()) {
try {
common_params_handle_models_params handle_params;
handle_params.preset_only = true;
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, handle_params);
} catch (const std::exception & e) {
// ignored for now
}
}
// router server never loads a model and must not touch the GPU
const bool is_router_server = params.model.path.empty()
&& params.model.hf_repo.empty();
@@ -288,10 +261,9 @@ int llama_server(common_params & params, int argc, char ** argv) {
if (child.is_child() && child.get_mode() == SERVER_CHILD_MODE_DOWNLOAD) {
return child.run_download(params);
} else if (!is_router_server && !is_run_by_cli) {
} else if (!is_router_server) {
// single-model mode (NOT spawned by router)
// if this is invoked by CLI, model downloading should be already handled
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, {});
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER);
}
//
@@ -373,22 +345,20 @@ int llama_server(common_params & params, int argc, char ** argv) {
};
}
// register signal handler if not running by CLI
if (!is_run_by_cli) {
// TODO: refactor in common/console
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action;
sigint_action.sa_handler = signal_handler;
sigemptyset (&sigint_action.sa_mask);
sigint_action.sa_flags = 0;
sigaction(SIGINT, &sigint_action, NULL);
sigaction(SIGTERM, &sigint_action, NULL);
struct sigaction sigint_action;
sigint_action.sa_handler = signal_handler;
sigemptyset (&sigint_action.sa_mask);
sigint_action.sa_flags = 0;
sigaction(SIGINT, &sigint_action, NULL);
sigaction(SIGTERM, &sigint_action, NULL);
#elif defined (_WIN32)
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
};
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
};
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif
}
if (is_router_server) {
SRV_INF("router server is listening on %s\n", ctx_http.listening_address.c_str());
-19
View File
@@ -256,25 +256,6 @@ def test_router_reload_models():
os.remove(preset_path)
def test_router_remote_preset():
global server
server.model_hf_repo = "ggml-org/test-preset-ci"
server.model_hf_file = None
server.offline = False
server.start()
# Should see preset models in GET /models
res = server.make_request("GET", "/models")
assert res.status_code == 200
ids = {item["id"] for item in res.body.get("data", [])}
assert "tinygemma3-preset" in ids
assert "stories260K-test" in ids
# Should be able to load a preset model
model_id = "tinygemma3-preset"
_load_model_and_wait(model_id)
MODEL_DOWNLOAD_ID = "ggml-org/test-model-router-download:F16"
MODEL_DOWNLOAD_TIMEOUT = 30
+1 -9
View File
@@ -545,8 +545,7 @@ class ModelsStore {
* 1. Model from active conversation's last assistant response (if loaded)
* 2. Model from active conversation's last assistant response (if not loaded)
* 3. First loaded model (not from active conversation)
* 4. A favorite model
* 5. First available model
* 4. First available model
*/
async ensureFirstModelSelected(): Promise<void> {
if (this.selectedModelName) return;
@@ -575,13 +574,6 @@ class ModelsStore {
return;
}
// Try loading a favorite model
const favorite = this.favoriteModelIds.values().next()?.value
if (favorite) {
await this.selectModelById(favorite);
return;
}
// Fall back to the first available model
await this.selectModelById(availableModels[0].id);
}