mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-24 20:49:45 +02:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 095058ca19 | |||
| c62fdd5fd0 | |||
| 41ed530be2 | |||
| fe03cce8db |
+1
-1
@@ -10,7 +10,7 @@
|
||||
# ggml-org/ggml-rpc : rgerganov
|
||||
# ggml-org/ggml-sycl : arthw
|
||||
# ggml-org/ggml-vulkan : 0cc4m, jeffbolznv
|
||||
# ggml-org/ggml-webgpu : reeselevine, yomaytk
|
||||
# ggml-org/ggml-webgpu : reeselevine
|
||||
# ggml-org/ggml-zdnn : taronaeo
|
||||
# ggml-org/llama-common : ggerganov, aldehir, angt, danbev, ngxson, pwilkin
|
||||
# ggml-org/llama-mtmd : ngxson
|
||||
|
||||
+17
-16
@@ -301,8 +301,6 @@ static handle_model_result common_params_handle_model(struct common_params_model
|
||||
const common_download_opts & opts) {
|
||||
handle_model_result result;
|
||||
|
||||
// TODO @ngxson : refactor this into a new common_model_download_context
|
||||
|
||||
if (!model.docker_repo.empty()) {
|
||||
model.path = common_docker_resolve_model(model.docker_repo);
|
||||
} else if (!model.hf_repo.empty()) {
|
||||
@@ -398,7 +396,7 @@ static bool parse_bool_value(const std::string & value) {
|
||||
// CLI argument parsing functions
|
||||
//
|
||||
|
||||
bool common_params_handle_models(common_params & params, llama_example curr_ex, const common_params_handle_models_params & handle_params) {
|
||||
bool common_params_handle_models(common_params & params, llama_example curr_ex, common_download_callback * callback) {
|
||||
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
|
||||
params.speculative.types.end(),
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();
|
||||
@@ -409,10 +407,9 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex,
|
||||
opts.skip_download = params.skip_download;
|
||||
opts.download_mtp = spec_type_draft_mtp;
|
||||
opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty();
|
||||
opts.preset_only = handle_params.preset_only;
|
||||
|
||||
if (handle_params.callback) {
|
||||
opts.callback = handle_params.callback;
|
||||
if (callback) {
|
||||
opts.callback = callback;
|
||||
}
|
||||
|
||||
// sub-models (draft, mmproj, vocoder) are explicitly specified by the user,
|
||||
@@ -599,12 +596,13 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
|
||||
if (!skip_model_download) {
|
||||
// handle model and download
|
||||
common_params_handle_models(params, ctx_arg.ex, {});
|
||||
common_params_handle_models(params, ctx_arg.ex);
|
||||
|
||||
// model is required (except for server)
|
||||
// TODO @ngxson : maybe show a list of available models in CLI in this case
|
||||
bool can_skip_model = params.usage || params.completion || !params.server_base.empty();
|
||||
if (!can_skip_model && params.model.path.empty()) {
|
||||
if (params.model.path.empty()
|
||||
&& !params.usage
|
||||
&& !params.completion) {
|
||||
throw std::invalid_argument("error: --model is required\n");
|
||||
}
|
||||
}
|
||||
@@ -1118,13 +1116,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.completion = true;
|
||||
}
|
||||
));
|
||||
add_opt(common_arg(
|
||||
{"--server-base"}, "URL",
|
||||
string_format("connect to this server instead of starting a new one, example: 'http://localhost:8080' (default: none)"),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.server_base = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_CLI}));
|
||||
add_opt(common_arg(
|
||||
{"--verbose-prompt"},
|
||||
string_format("print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false"),
|
||||
@@ -1177,6 +1168,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
}
|
||||
));
|
||||
add_opt(common_arg(
|
||||
{"--threads-sampling"}, "N",
|
||||
"number of threads to use during sampling (default: same as --threads)",
|
||||
[](common_params & params, int value) {
|
||||
params.sampling_n_threads = value;
|
||||
if (params.sampling_n_threads <= 0) {
|
||||
params.sampling_n_threads = std::thread::hardware_concurrency();
|
||||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"-C", "--cpu-mask"}, "M",
|
||||
"CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: \"\")",
|
||||
|
||||
+1
-6
@@ -130,11 +130,6 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
|
||||
// see: https://github.com/ggml-org/llama.cpp/issues/18163
|
||||
void common_params_add_preset_options(std::vector<common_arg> & args);
|
||||
|
||||
struct common_params_handle_models_params {
|
||||
common_download_callback * callback = nullptr;
|
||||
bool preset_only = false; // if true, only check & download remote preset (for router mode)
|
||||
};
|
||||
|
||||
// populate model paths (main model, mmproj, etc) from -hf if necessary
|
||||
// return true if the model is ready to use
|
||||
// throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc)
|
||||
@@ -142,7 +137,7 @@ struct common_params_handle_models_params {
|
||||
bool common_params_handle_models(
|
||||
common_params & params,
|
||||
llama_example curr_ex,
|
||||
const common_params_handle_models_params & handle_params);
|
||||
common_download_callback * callback = nullptr);
|
||||
|
||||
// initialize argument parser context - used by test-arg-parser and preset
|
||||
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
|
||||
|
||||
+53
-103
@@ -90,93 +90,41 @@ std::string common_chat_msg::render_content(const std::string & delimiter) const
|
||||
return text;
|
||||
}
|
||||
|
||||
common_chat_role common_chat_role_from_string(const std::string & role) {
|
||||
if (role == "system") { return COMMON_CHAT_ROLE_SYSTEM; }
|
||||
if (role == "assistant") { return COMMON_CHAT_ROLE_ASSISTANT; }
|
||||
if (role == "user") { return COMMON_CHAT_ROLE_USER; }
|
||||
if (role == "tool") { return COMMON_CHAT_ROLE_TOOL; }
|
||||
return COMMON_CHAT_ROLE_UNKNOWN;
|
||||
}
|
||||
|
||||
const char * common_chat_role_to_string(common_chat_role role) {
|
||||
switch (role) {
|
||||
case COMMON_CHAT_ROLE_SYSTEM: return "system";
|
||||
case COMMON_CHAT_ROLE_ASSISTANT: return "assistant";
|
||||
case COMMON_CHAT_ROLE_USER: return "user";
|
||||
case COMMON_CHAT_ROLE_TOOL: return "tool";
|
||||
case COMMON_CHAT_ROLE_UNKNOWN: return "";
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
json common_chat_msg_delimiters::to_json() const {
|
||||
json result = json::array();
|
||||
for (const auto & d : delimiters) {
|
||||
result.push_back({
|
||||
{ "role", common_chat_role_to_string(d.role) },
|
||||
{ "delimiter", d.delimiter },
|
||||
});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
common_chat_msg_delimiters common_chat_msg_delimiters_parse(const json & delimiters) {
|
||||
common_chat_msg_delimiters result;
|
||||
|
||||
if (!delimiters.is_array()) {
|
||||
return result;
|
||||
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims) {
|
||||
if (delims.empty() || prompt.empty()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
result.delimiters.reserve(delimiters.size());
|
||||
for (const auto & d : delimiters) {
|
||||
if (!d.is_object()) {
|
||||
continue;
|
||||
auto parser = build_peg_parser([&](common_peg_parser_builder & p) {
|
||||
std::vector<std::string> all_delims;
|
||||
std::vector<common_peg_parser> tagged_messages;
|
||||
|
||||
all_delims.reserve(delims.size());
|
||||
tagged_messages.reserve(delims.size());
|
||||
for (const auto & d : delims) {
|
||||
all_delims.push_back(d.delimiter);
|
||||
}
|
||||
result.delimiters.push_back({
|
||||
common_chat_role_from_string(d.value("role", std::string())),
|
||||
d.value("delimiter", std::string()),
|
||||
});
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void common_chat_msg_delimiters::tokenize(const llama_vocab * vocab) {
|
||||
for (auto & d : delimiters) {
|
||||
d.tokens = common_tokenize(vocab, d.delimiter, false, true);
|
||||
}
|
||||
}
|
||||
|
||||
common_chat_msg_spans common_chat_msg_delimiters::split(const llama_tokens & tokens, const std::map<size_t, size_t> & skips) const {
|
||||
std::vector<std::pair<common_chat_role, size_t>> matches;
|
||||
|
||||
auto skip = skips.begin();
|
||||
for (size_t i = 0; i < tokens.size();) {
|
||||
if (skip != skips.end() && i == skip->first) {
|
||||
i += skip->second;
|
||||
++skip;
|
||||
continue;
|
||||
auto any_delim = p.until_one_of(all_delims);
|
||||
for (const auto & d : delims) {
|
||||
tagged_messages.push_back(p.tag(d.role, p.literal(d.delimiter) + any_delim));
|
||||
}
|
||||
for (const auto & d : delimiters) {
|
||||
if (i + d.tokens.size() > tokens.size()) {
|
||||
continue;
|
||||
}
|
||||
if (std::equal(d.tokens.begin(), d.tokens.end(), tokens.begin() + i)) {
|
||||
matches.emplace_back(d.role, i);
|
||||
break;
|
||||
}
|
||||
|
||||
return any_delim + p.zero_or_more(p.choice(tagged_messages)) + p.end();
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx(prompt);
|
||||
const auto result = parser.parse(ctx);
|
||||
if (!result.success()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<common_chat_msg_span> spans;
|
||||
ctx.ast.visit(result, [&](const common_peg_ast_node & node) {
|
||||
if (!node.tag.empty()) {
|
||||
spans.push_back({ node.tag, node.start, node.end - node.start });
|
||||
}
|
||||
i++;
|
||||
}
|
||||
|
||||
matches.emplace_back(COMMON_CHAT_ROLE_UNKNOWN, tokens.size());
|
||||
|
||||
common_chat_msg_spans spans;
|
||||
for (size_t i = 0; i + 1 < matches.size(); i++) {
|
||||
const auto & curr = matches[i];
|
||||
const auto & next = matches[i + 1];
|
||||
spans.add(curr.first, curr.second, next.second - curr.second);
|
||||
}
|
||||
});
|
||||
|
||||
return spans;
|
||||
}
|
||||
@@ -1133,13 +1081,13 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
|
||||
data.prompt = prompt;
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||
data.message_delimiters = {
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, "<|start|>assistant" },
|
||||
{ COMMON_CHAT_ROLE_USER, "<|start|>user" },
|
||||
{ COMMON_CHAT_ROLE_SYSTEM, "<|start|>developer" },
|
||||
{ COMMON_CHAT_ROLE_SYSTEM, "<|start|>system" },
|
||||
{ COMMON_CHAT_ROLE_TOOL, "<|start|>functions" },
|
||||
};
|
||||
data.message_spans = common_chat_split_by_role(prompt, {
|
||||
{ "assistant", "<|start|>assistant" },
|
||||
{ "user", "<|start|>user" },
|
||||
{ "system", "<|start|>developer" },
|
||||
{ "system", "<|start|>system" },
|
||||
{ "tool", "<|start|>functions" },
|
||||
});
|
||||
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
@@ -1280,10 +1228,10 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
data.message_delimiters = {
|
||||
{ COMMON_CHAT_ROLE_USER, "<|turn>user" },
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, "<|turn>model" },
|
||||
};
|
||||
data.message_spans = common_chat_split_by_role(data.prompt, {
|
||||
{ "user", "<|turn>user\n" },
|
||||
{ "assistant", "<|turn>model\n" },
|
||||
});
|
||||
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4;
|
||||
data.supports_thinking = true;
|
||||
@@ -2082,15 +2030,15 @@ static common_chat_params common_chat_params_init_cohere2moe(const common_chat_t
|
||||
RESULT_START, RESULT_END,
|
||||
};
|
||||
|
||||
// Declare per-role message delimiters. Tool results are rendered with the
|
||||
// Split the rendered prompt into per-role message spans. Tool results are rendered with the
|
||||
// system token followed by <|START_TOOL_RESULT|>, so the "tool" delimiter must be listed before
|
||||
// the plain "system" one (it is a strict superset, and the role split tries delimiters in order).
|
||||
data.message_delimiters = {
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, GEN_PREFIX },
|
||||
{ COMMON_CHAT_ROLE_USER, TURN_START + USER },
|
||||
{ COMMON_CHAT_ROLE_TOOL, TURN_START + SYSTEM + RESULT_START },
|
||||
{ COMMON_CHAT_ROLE_SYSTEM, TURN_START + SYSTEM },
|
||||
};
|
||||
data.message_spans = common_chat_split_by_role(data.prompt, {
|
||||
{ "assistant", GEN_PREFIX },
|
||||
{ "user", TURN_START + USER },
|
||||
{ "tool", TURN_START + SYSTEM + RESULT_START },
|
||||
{ "system", TURN_START + SYSTEM },
|
||||
});
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
@@ -2578,15 +2526,17 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
autoparser.analyze_template(tmpl);
|
||||
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
|
||||
|
||||
common_chat_msg_delimiters delimiters;
|
||||
std::vector<common_chat_msg_delimiter> delimiters;
|
||||
if (!autoparser.assistant_start.empty()) {
|
||||
delimiters.add(COMMON_CHAT_ROLE_ASSISTANT, autoparser.assistant_start);
|
||||
delimiters.push_back({ "assistant", autoparser.assistant_start });
|
||||
}
|
||||
if (!autoparser.user_start.empty()) {
|
||||
delimiters.add(COMMON_CHAT_ROLE_USER, autoparser.user_start);
|
||||
delimiters.push_back({ "user", autoparser.user_start });
|
||||
}
|
||||
|
||||
auto_params.message_delimiters = std::move(delimiters);
|
||||
if (!delimiters.empty()) {
|
||||
auto_params.message_spans = common_chat_split_by_role(auto_params.prompt, delimiters);
|
||||
}
|
||||
|
||||
auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
|
||||
if (auto_params.supports_thinking) {
|
||||
|
||||
+6
-65
@@ -143,75 +143,15 @@ struct common_chat_msg_diff {
|
||||
}
|
||||
};
|
||||
|
||||
enum common_chat_role {
|
||||
COMMON_CHAT_ROLE_UNKNOWN,
|
||||
COMMON_CHAT_ROLE_SYSTEM,
|
||||
COMMON_CHAT_ROLE_ASSISTANT,
|
||||
COMMON_CHAT_ROLE_USER,
|
||||
COMMON_CHAT_ROLE_TOOL
|
||||
};
|
||||
|
||||
common_chat_role common_chat_role_from_string(const std::string & role);
|
||||
const char * common_chat_role_to_string(common_chat_role role);
|
||||
|
||||
struct common_chat_msg_span {
|
||||
common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN;
|
||||
std::string role;
|
||||
std::size_t pos = 0;
|
||||
std::size_t len = 0;
|
||||
|
||||
bool valid() const {
|
||||
return role != COMMON_CHAT_ROLE_UNKNOWN;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_chat_msg_spans {
|
||||
std::vector<common_chat_msg_span> spans;
|
||||
|
||||
void add(common_chat_role role, size_t pos, size_t len) {
|
||||
spans.push_back({ role, pos, len });
|
||||
}
|
||||
|
||||
bool is_user_start(int32_t pos) const {
|
||||
for (auto it = spans.begin(); it != spans.end(); ++it) {
|
||||
if (it->role == COMMON_CHAT_ROLE_USER && pos == (int32_t) it->pos) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t last_user_message_pos() const {
|
||||
for (auto it = spans.rbegin(); it != spans.rend(); ++it) {
|
||||
if (it->role == COMMON_CHAT_ROLE_USER) {
|
||||
return (int32_t) it->pos;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_chat_msg_delimiter {
|
||||
common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN;
|
||||
std::string delimiter;
|
||||
llama_tokens tokens = {};
|
||||
};
|
||||
|
||||
struct common_chat_msg_delimiters {
|
||||
std::vector<common_chat_msg_delimiter> delimiters;
|
||||
|
||||
common_chat_msg_delimiters() = default;
|
||||
common_chat_msg_delimiters(std::initializer_list<common_chat_msg_delimiter> delims) : delimiters(delims) {}
|
||||
|
||||
void add(common_chat_role role, const std::string & delimiter) {
|
||||
delimiters.push_back({ role, delimiter });
|
||||
}
|
||||
|
||||
void tokenize(const llama_vocab * vocab);
|
||||
|
||||
// split tokens into message spans. skips maps a start index to a length of a region to jump over without matching
|
||||
common_chat_msg_spans split(const llama_tokens & tokens, const std::map<size_t, size_t> & skips = {}) const;
|
||||
|
||||
nlohmann::ordered_json to_json() const;
|
||||
std::string role;
|
||||
std::string delimiter;
|
||||
};
|
||||
|
||||
struct common_chat_tool {
|
||||
@@ -279,7 +219,7 @@ struct common_chat_params {
|
||||
std::vector<std::string> preserved_tokens;
|
||||
std::vector<std::string> additional_stops;
|
||||
std::string parser;
|
||||
common_chat_msg_delimiters message_delimiters;
|
||||
std::vector<common_chat_msg_span> message_spans;
|
||||
};
|
||||
|
||||
// per-message parsing syntax
|
||||
@@ -385,4 +325,5 @@ struct common_chat_prompt_preset {
|
||||
|
||||
common_chat_prompt_preset common_chat_get_asr_prompt(const common_chat_templates * chat_templates);
|
||||
|
||||
common_chat_msg_delimiters common_chat_msg_delimiters_parse(const nlohmann::ordered_json & delimiters);
|
||||
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims);
|
||||
|
||||
|
||||
+3
-4
@@ -471,6 +471,8 @@ struct common_params {
|
||||
common_cpu_params cpuparams;
|
||||
common_cpu_params cpuparams_batch;
|
||||
|
||||
int sampling_n_threads = -1; // number of threads for sampling, used by server
|
||||
|
||||
ggml_backend_sched_eval_callback cb_eval = nullptr;
|
||||
void * cb_eval_user_data = nullptr;
|
||||
|
||||
@@ -609,7 +611,7 @@ struct common_params {
|
||||
bool cache_prompt = true; // whether to enable prompt caching
|
||||
bool cache_idle_slots = true; // save and clear idle slots upon starting a new task
|
||||
int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot
|
||||
int32_t checkpoint_min_step = 8192; // minimum spacing between context checkpoints
|
||||
int32_t checkpoint_min_step = 256; // minimum spacing between context checkpoints
|
||||
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
||||
|
||||
std::string hostname = "127.0.0.1";
|
||||
@@ -631,9 +633,6 @@ struct common_params {
|
||||
|
||||
std::map<std::string, std::string> default_template_kwargs;
|
||||
|
||||
// CLI params
|
||||
std::string server_base; // if set, connect to this server instead of starting a new one
|
||||
|
||||
// UI configs
|
||||
bool ui = true;
|
||||
bool ui_mcp_proxy = false;
|
||||
|
||||
+1
-3
@@ -799,7 +799,6 @@ common_download_model_result common_download_model(const common_params_model &
|
||||
|
||||
bool download_mmproj = opts.download_mmproj;
|
||||
bool download_mtp = opts.download_mtp;
|
||||
bool preset_only = opts.preset_only;
|
||||
bool is_hf = !model.hf_repo.empty();
|
||||
|
||||
if (is_hf) {
|
||||
@@ -807,8 +806,7 @@ common_download_model_result common_download_model(const common_params_model &
|
||||
if (!hf.preset.path.empty()) {
|
||||
// if preset.ini exists, only download that file alone
|
||||
tasks.push_back({hf.preset.url, hf.preset.local_path});
|
||||
} else if (!preset_only) {
|
||||
// only add other files if we're NOT in preset-only mode (normal run, non-router)
|
||||
} else {
|
||||
for (const auto & f : hf.model_files) {
|
||||
tasks.push_back({f.url, f.local_path});
|
||||
}
|
||||
|
||||
@@ -55,7 +55,6 @@ struct common_download_opts {
|
||||
bool skip_download = false; // if true, only validation is performed, common_skip_download_exception may be thrown if the file is missing or invalid
|
||||
bool download_mmproj = false;
|
||||
bool download_mtp = false;
|
||||
bool preset_only = false; // if true, only check & download remote preset (for router mode)
|
||||
common_download_callback * callback = nullptr;
|
||||
};
|
||||
|
||||
|
||||
@@ -2,16 +2,6 @@
|
||||
|
||||
#include <cpp-httplib/httplib.h>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <winsock2.h>
|
||||
#include <windows.h>
|
||||
#else
|
||||
#include <sys/socket.h>
|
||||
#include <netinet/in.h>
|
||||
#include <arpa/inet.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
struct common_http_url {
|
||||
std::string scheme;
|
||||
std::string user;
|
||||
@@ -107,63 +97,3 @@ static std::pair<httplib::Client, common_http_url> common_http_client(const std:
|
||||
static std::string common_http_show_masked_url(const common_http_url & parts) {
|
||||
return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + parts.host + parts.path;
|
||||
}
|
||||
|
||||
static int common_http_get_free_port() {
|
||||
#ifdef _WIN32
|
||||
WSADATA wsaData;
|
||||
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
|
||||
return -1;
|
||||
}
|
||||
typedef SOCKET native_socket_t;
|
||||
#define INVALID_SOCKET_VAL INVALID_SOCKET
|
||||
#define CLOSE_SOCKET(s) closesocket(s)
|
||||
#else
|
||||
typedef int native_socket_t;
|
||||
#define INVALID_SOCKET_VAL -1
|
||||
#define CLOSE_SOCKET(s) close(s)
|
||||
#endif
|
||||
|
||||
native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sock == INVALID_SOCKET_VAL) {
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
return -1;
|
||||
}
|
||||
|
||||
struct sockaddr_in serv_addr;
|
||||
std::memset(&serv_addr, 0, sizeof(serv_addr));
|
||||
serv_addr.sin_family = AF_INET;
|
||||
serv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
|
||||
serv_addr.sin_port = htons(0);
|
||||
|
||||
if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) {
|
||||
CLOSE_SOCKET(sock);
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
return -1;
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
int namelen = sizeof(serv_addr);
|
||||
#else
|
||||
socklen_t namelen = sizeof(serv_addr);
|
||||
#endif
|
||||
if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) {
|
||||
CLOSE_SOCKET(sock);
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
return -1;
|
||||
}
|
||||
|
||||
int port = ntohs(serv_addr.sin_port);
|
||||
|
||||
CLOSE_SOCKET(sock);
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
|
||||
return port;
|
||||
}
|
||||
|
||||
@@ -96,7 +96,6 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
||||
"GraniteMoeHybridForCausalLM": "granite",
|
||||
"GraniteMoeSharedForCausalLM": "granite",
|
||||
"GraniteSpeechForConditionalGeneration": "granite",
|
||||
"GraniteSpeechPlusForConditionalGeneration": "granite",
|
||||
"Grok1ForCausalLM": "grok",
|
||||
"GrokForCausalLM": "grok",
|
||||
"GroveMoeForCausalLM": "grovemoe",
|
||||
@@ -262,7 +261,6 @@ MMPROJ_MODEL_MAP: dict[str, str] = {
|
||||
"GlmasrModel": "ultravox",
|
||||
"Granite4VisionForConditionalGeneration": "granite",
|
||||
"GraniteSpeechForConditionalGeneration": "granite",
|
||||
"GraniteSpeechPlusForConditionalGeneration": "granite",
|
||||
"HunYuanVLForConditionalGeneration": "hunyuan",
|
||||
"Idefics3ForConditionalGeneration": "smolvlm",
|
||||
"InternVisionModel": "internvl",
|
||||
|
||||
@@ -348,34 +348,6 @@ class GraniteSpeechMmprojModel(MmprojModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("GraniteSpeechPlusForConditionalGeneration")
|
||||
class GraniteSpeechPlusMmprojModel(GraniteSpeechMmprojModel):
|
||||
"""Conversion for GraniteSpeechPlus - extends GraniteSpeech with feature layer concatenation"""
|
||||
has_vision_encoder = False
|
||||
has_audio_encoder = True
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
assert self.hparams_audio is not None
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# Add feature_layer if present in encoder config
|
||||
if feature_layers := self.hparams_audio.get("cat_hidden_layers"):
|
||||
self.gguf_writer.add_audio_feature_layers(feature_layers)
|
||||
logger.info(f"gguf: audio feature_layers = {feature_layers}")
|
||||
|
||||
# Validate projector dimension matches concatenated encoder output
|
||||
hidden_dim = self.hparams_audio["hidden_dim"]
|
||||
expected_dim = hidden_dim * (len(feature_layers) + 1)
|
||||
projector_dim = self.global_config["projector_config"]["encoder_hidden_size"]
|
||||
|
||||
if projector_dim != expected_dim:
|
||||
raise ValueError(
|
||||
f"Projector encoder_hidden_size ({projector_dim}) does not match "
|
||||
f"expected concatenated dimension ({expected_dim}). "
|
||||
f"Expected: hidden_dim ({hidden_dim}) * (len(feature_layers) + 1) = {expected_dim}"
|
||||
)
|
||||
|
||||
|
||||
@ModelBase.register("Granite4VisionForConditionalGeneration")
|
||||
class Granite4VisionMmprojModel(MmprojModel):
|
||||
has_vision_encoder = True
|
||||
|
||||
@@ -174,7 +174,7 @@ __kernel void kernel_gemv_noshuffle_q8_0_f32(
|
||||
regA.s6 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
|
||||
regA.s7 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
|
||||
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, convert_float(regS), regB);
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, regS, regB);
|
||||
}
|
||||
|
||||
// reduction in local memory, assumes #wave=4
|
||||
|
||||
@@ -108,9 +108,6 @@ if (Vulkan_FOUND)
|
||||
|
||||
if (GGML_VULKAN_CHECK_RESULTS)
|
||||
add_compile_definitions(GGML_VULKAN_CHECK_RESULTS)
|
||||
# the result-checking path computes a CPU reference graph via
|
||||
# ggml_graph_compute_with_ctx(), which is defined in ggml-cpu
|
||||
target_link_libraries(ggml-vulkan PRIVATE ggml-cpu)
|
||||
endif()
|
||||
|
||||
if (GGML_VULKAN_DEBUG)
|
||||
@@ -132,8 +129,6 @@ if (Vulkan_FOUND)
|
||||
|
||||
if (GGML_VULKAN_RUN_TESTS)
|
||||
add_compile_definitions(GGML_VULKAN_RUN_TESTS)
|
||||
# the test path also calls ggml_graph_compute_with_ctx() (ggml-cpu)
|
||||
target_link_libraries(ggml-vulkan PRIVATE ggml-cpu)
|
||||
endif()
|
||||
|
||||
# Set up toolchain for host compilation whether cross-compiling or not
|
||||
|
||||
@@ -905,12 +905,11 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key {
|
||||
ggml_type src0_type;
|
||||
ggml_type src1_type;
|
||||
int vectorized;
|
||||
uint32_t num_cols;
|
||||
bool use_mmvq;
|
||||
|
||||
bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const {
|
||||
return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&
|
||||
num_cols == other.num_cols && use_mmvq == other.use_mmvq;
|
||||
use_mmvq == other.use_mmvq;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -920,7 +919,6 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {
|
||||
ggml_webgpu_hash_combine(seed, key.src0_type);
|
||||
ggml_webgpu_hash_combine(seed, key.src1_type);
|
||||
ggml_webgpu_hash_combine(seed, key.vectorized);
|
||||
ggml_webgpu_hash_combine(seed, key.num_cols);
|
||||
ggml_webgpu_hash_combine(seed, key.use_mmvq);
|
||||
return seed;
|
||||
}
|
||||
@@ -995,12 +993,11 @@ struct ggml_webgpu_mul_mat_id_pipeline_key {
|
||||
ggml_type src0_type;
|
||||
ggml_type src1_type;
|
||||
uint32_t n_experts;
|
||||
uint32_t num_cols;
|
||||
int vectorized;
|
||||
|
||||
bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const {
|
||||
return src0_type == other.src0_type && src1_type == other.src1_type && n_experts == other.n_experts &&
|
||||
num_cols == other.num_cols && vectorized == other.vectorized;
|
||||
vectorized == other.vectorized;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1010,7 +1007,6 @@ struct ggml_webgpu_mul_mat_id_pipeline_key_hash {
|
||||
ggml_webgpu_hash_combine(seed, key.src0_type);
|
||||
ggml_webgpu_hash_combine(seed, key.src1_type);
|
||||
ggml_webgpu_hash_combine(seed, key.n_experts);
|
||||
ggml_webgpu_hash_combine(seed, key.num_cols);
|
||||
ggml_webgpu_hash_combine(seed, key.vectorized);
|
||||
return seed;
|
||||
}
|
||||
@@ -1111,7 +1107,7 @@ inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0,
|
||||
const ggml_tensor * src1,
|
||||
bool supports_dot_product,
|
||||
const std::string & vendor) {
|
||||
if (src1->ne[1] <= 4) {
|
||||
if (src1->ne[1] == 1) {
|
||||
bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia";
|
||||
if (supports_dp4a && supports_dot_product) {
|
||||
switch (src1->type) {
|
||||
@@ -1893,7 +1889,6 @@ class ggml_webgpu_shader_lib {
|
||||
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
||||
1 :
|
||||
0;
|
||||
key.num_cols = context.dst->ne[1];
|
||||
key.use_mmvq =
|
||||
ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor);
|
||||
|
||||
@@ -2009,7 +2004,6 @@ class ggml_webgpu_shader_lib {
|
||||
if (key.vectorized) {
|
||||
variant += "_vectorized";
|
||||
}
|
||||
defines.push_back(std::string("NUM_COLS=") + std::to_string(key.num_cols));
|
||||
|
||||
auto processed = preprocessor.preprocess(shader_src, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
|
||||
@@ -2427,7 +2421,6 @@ class ggml_webgpu_shader_lib {
|
||||
if (key.vectorized) {
|
||||
variant += "_vectorized";
|
||||
}
|
||||
defines.push_back(std::string("NUM_COLS=1"));
|
||||
|
||||
defines.push_back(std::string("N_EXPERTS=") + std::to_string(key.n_experts));
|
||||
|
||||
|
||||
@@ -1418,17 +1418,15 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context &
|
||||
const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
|
||||
const size_t q8_src1_align_offset = ROUNDUP_POW2(
|
||||
dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||
const size_t q8_src1_binding_size = ROUNDUP_POW2(
|
||||
src1->ne[3] * src1->ne[2] * src1->ne[1] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)),
|
||||
WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
const size_t q8_src1_binding_size =
|
||||
ROUNDUP_POW2(src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)),
|
||||
WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
|
||||
std::vector<uint32_t> q8_params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
|
||||
(uint32_t) src1->ne[0],
|
||||
(uint32_t) src1->ne[1],
|
||||
(uint32_t) src1->ne[2],
|
||||
(uint32_t) src1->ne[3],
|
||||
};
|
||||
@@ -1444,7 +1442,7 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context &
|
||||
uint32_t q8_wg_x = 1;
|
||||
uint32_t q8_wg_y = 1;
|
||||
const uint32_t wg_per_vec = (src0->ne[0] / 4 + (q8_wg_size - 1)) / q8_wg_size;
|
||||
const uint32_t q8_total_wg = src1->ne[1] * src1->ne[2] * src1->ne[3] * wg_per_vec;
|
||||
const uint32_t q8_total_wg = src1->ne[2] * src1->ne[3] * wg_per_vec;
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
compute_2d_workgroups(q8_total_wg, max_wg_per_dim, q8_wg_x, q8_wg_y);
|
||||
|
||||
@@ -1458,7 +1456,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
// Determine if this is a mat-vec operation
|
||||
bool use_mat_vec = (dst->ne[1] <= 4);
|
||||
bool is_vec = (dst->ne[1] == 1);
|
||||
|
||||
// use MMVQ path for mat-vec
|
||||
bool use_mmvq = ggml_webgpu_can_use_mmvq(src0, src1, ctx->global_ctx->capabilities.supports_dot_product,
|
||||
@@ -1484,7 +1482,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
webgpu_pipeline pipeline;
|
||||
std::vector<webgpu_dispatch_desc> dispatches;
|
||||
|
||||
if (use_mat_vec) {
|
||||
if (is_vec) {
|
||||
if (use_mmvq) {
|
||||
ggml_webgpu_quantize_q8_dispatch(ctx, src0, src1, dst, dispatches);
|
||||
}
|
||||
@@ -1531,7 +1529,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
uint32_t wg_y = 1;
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
|
||||
if (use_mat_vec) {
|
||||
if (is_vec) {
|
||||
auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
uint32_t batches = dst->ne[2] * dst->ne[3];
|
||||
@@ -3693,8 +3691,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
|
||||
ggml_webgpu_can_use_mmvq(src0, src1, ctx->webgpu_global_ctx->capabilities.supports_dot_product,
|
||||
ctx->webgpu_global_ctx->vendor);
|
||||
if (use_mmvq) {
|
||||
const size_t q8_src1_size = src1->ne[3] * src1->ne[2] * src1->ne[1] *
|
||||
(36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32));
|
||||
const size_t q8_src1_size =
|
||||
src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32));
|
||||
res = ROUNDUP_POW2(res + q8_src1_size +
|
||||
ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
|
||||
WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
|
||||
@@ -103,7 +103,7 @@ fn main(
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let subgroup_total = subgroupAdd(acc[0][row]);
|
||||
let subgroup_total = subgroupAdd(acc[row]);
|
||||
if (subgroup_invocation_id == 0u) {
|
||||
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
|
||||
}
|
||||
@@ -126,7 +126,7 @@ fn main(
|
||||
|
||||
#ifdef USE_WORKGROUP_REDUCTION
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
partial_sums[partial_index(row, thread_id)] = acc[0][row];
|
||||
partial_sums[partial_index(row, thread_id)] = acc[row];
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
@@ -91,67 +91,61 @@ fn main(
|
||||
let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base;
|
||||
|
||||
#ifdef MMVQ
|
||||
let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * params.n * (params.k / 32u);
|
||||
let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * (params.k / 32u);
|
||||
let acc = accumulate_vec_q_dot(thread_id, row_base, src0_batch_offset, src1q_idx_base);
|
||||
#else
|
||||
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
|
||||
let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base);
|
||||
#endif
|
||||
|
||||
for (var col = 0u;col < NUM_COLS;col += 1) {
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let subgroup_total = subgroupAdd(acc[col][row]);
|
||||
if (subgroup_invocation_id == 0u) {
|
||||
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) {
|
||||
let output_row = row_base + row;
|
||||
var row_acc = 0.0f;
|
||||
for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) {
|
||||
row_acc += partial_sums[partial_index(row, k)];
|
||||
}
|
||||
let row_total = subgroupAdd(row_acc);
|
||||
if (subgroup_invocation_id == 0) {
|
||||
dst[dst_idx_base + col * params.m + row] = row_total;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_WORKGROUP_REDUCTION
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
partial_sums[partial_index(row, thread_id)] = acc[col][row];
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
var stride = WG_SIZE / 2u;
|
||||
|
||||
while (stride > 0) {
|
||||
if (thread_id < stride) {
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)];
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
stride = stride / 2;
|
||||
}
|
||||
|
||||
if (thread_id < OUTPUTS_PER_WG) {
|
||||
let output_row = row_base + thread_id;
|
||||
if (output_row < params.m) {
|
||||
dst[dst_idx_base + col * params.m + thread_id] = partial_sums[partial_index(thread_id, 0)];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let subgroup_total = subgroupAdd(acc[row]);
|
||||
if (subgroup_invocation_id == 0u) {
|
||||
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) {
|
||||
let output_row = row_base + row;
|
||||
var row_acc = 0.0f;
|
||||
for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) {
|
||||
row_acc += partial_sums[partial_index(row, k)];
|
||||
}
|
||||
let row_total = subgroupAdd(row_acc);
|
||||
if (subgroup_invocation_id == 0) {
|
||||
dst[dst_idx_base + row] = row_total;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_WORKGROUP_REDUCTION
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
partial_sums[partial_index(row, thread_id)] = acc[row];
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
var stride = WG_SIZE / 2u;
|
||||
|
||||
while (stride > 0) {
|
||||
if (thread_id < stride) {
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)];
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
stride = stride / 2;
|
||||
}
|
||||
|
||||
if (thread_id < OUTPUTS_PER_WG) {
|
||||
let output_row = row_base + thread_id;
|
||||
if (output_row < params.m) {
|
||||
dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -51,7 +51,10 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE {
|
||||
fn get_dm(block_byte_base: u32) -> f32 {
|
||||
return f32(load_f16_at_src0(block_byte_base));
|
||||
}
|
||||
#endif // MUL_ACC_Q4_0
|
||||
fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 {
|
||||
return f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_Q4_1
|
||||
#define BLOCK_SIZE_BYTES 20
|
||||
@@ -82,7 +85,10 @@ fn get_dm(block_byte_base: u32) -> vec2<f32> {
|
||||
f32(load_f16_at_src0(block_byte_base + 2u))
|
||||
);
|
||||
}
|
||||
#endif // MUL_ACC_Q4_1
|
||||
fn mul_q8_1(row_sum: i32, dma: vec2<f32>, b_ds: B_DS_TYPE) -> f32 {
|
||||
return f32(row_sum) * (dma.x * b_ds.x) + dma.y * b_ds.y / THREADS_PER_BLOCK;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_Q8_0
|
||||
#define BLOCK_SIZE_BYTES 34
|
||||
@@ -105,48 +111,46 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE {
|
||||
fn get_dm(block_byte_base: u32) -> f32 {
|
||||
return f32(load_f16_at_src0(block_byte_base));
|
||||
}
|
||||
#endif // MUL_ACC_Q8_0
|
||||
fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 {
|
||||
return f32(row_sum) * (da * b_ds);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(LEGACY_QUANTS)
|
||||
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<array<f32, OUTPUTS_PER_WG>, NUM_COLS> {
|
||||
var acc: array<array<f32, OUTPUTS_PER_WG>, NUM_COLS>;
|
||||
#ifdef LEGACY_QUANTS
|
||||
fn mmvq_dot_product(a_byte_base: u32, b_inner_id: u32, b_repacked: vec2<u32>, b_ds: B_DS_TYPE) -> f32 {
|
||||
var row_sum = 0;
|
||||
let a_repacked = repack_a(a_byte_base, b_inner_id);
|
||||
|
||||
row_sum += dot4I8Packed(a_repacked[0], b_repacked[0]);
|
||||
row_sum += dot4I8Packed(a_repacked[1], b_repacked[1]);
|
||||
|
||||
return mul_q8_1(row_sum, get_dm(a_byte_base), b_ds);
|
||||
}
|
||||
|
||||
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
|
||||
var acc: array<f32, OUTPUTS_PER_WG>;
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
|
||||
for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
|
||||
let inner_id = thread_id % THREADS_PER_BLOCK;
|
||||
let b_inner_id = thread_id % THREADS_PER_BLOCK;
|
||||
let b_block_idx = src1q_idx_base + block;
|
||||
|
||||
let b_repacked = repack_b_qs(b_block_idx, b_inner_id);
|
||||
let b_ds = repack_b_dm(b_block_idx);
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
let a_repacked = repack_a(block_byte_base, inner_id);
|
||||
let da = get_dm(block_byte_base);
|
||||
for (var col = 0u;col < NUM_COLS;col += 1) {
|
||||
let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + block;
|
||||
let b_repacked = repack_b_qs(src1q_idx, inner_id);
|
||||
let b_ds = repack_b_dm(src1q_idx);
|
||||
|
||||
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]);
|
||||
|
||||
#if defined(MUL_ACC_Q4_0)
|
||||
acc[col][row] += f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK;
|
||||
#endif // MUL_ACC_Q4_0
|
||||
|
||||
#if defined(MUL_ACC_Q4_1)
|
||||
acc[col][row] += f32(row_sum) * (da.x * b_ds.x) + da.y * b_ds.y / THREADS_PER_BLOCK;
|
||||
#endif // MUL_ACC_Q4_1
|
||||
|
||||
#if defined(MUL_ACC_Q8_0)
|
||||
acc[col][row] += f32(row_sum) * (da * b_ds);
|
||||
#endif // MUL_ACC_Q8_0
|
||||
}
|
||||
acc[row] += mmvq_dot_product(block_byte_base, b_inner_id, b_repacked, b_ds);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return acc;
|
||||
}
|
||||
#endif // LEGACY_QUANTS
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_Q2_K
|
||||
#define BLOCK_SIZE_BYTES 84
|
||||
@@ -187,7 +191,22 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> {
|
||||
let scale = byte_of(load_u32_at_src0_aligned(scale_byte), scale_byte & 3u);
|
||||
return vec2<f32>(f32(scale & 0xFu), f32(scale >> 4u));
|
||||
}
|
||||
#endif // MUL_ACC_Q2_K
|
||||
fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 {
|
||||
let a_repacked = repack_a(a_byte_base, tid);
|
||||
let dm = get_dm(a_byte_base);
|
||||
let scale_min = get_scale_min(a_byte_base, tid);
|
||||
|
||||
let scale_q = i32(scale_min.x);
|
||||
let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u;
|
||||
|
||||
let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1])
|
||||
+ dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q;
|
||||
let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4)
|
||||
+ dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4);
|
||||
|
||||
return b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m));
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_Q4_K
|
||||
#define BLOCK_SIZE_BYTES 144
|
||||
@@ -246,52 +265,39 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> {
|
||||
|
||||
return vec2<f32>(scale, min_val);
|
||||
}
|
||||
#endif // MUL_ACC_Q4_K
|
||||
fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 {
|
||||
let a_repacked = repack_a(a_byte_base, tid);
|
||||
let dm = get_dm(a_byte_base);
|
||||
let scale_min = get_scale_min(a_byte_base, tid);
|
||||
|
||||
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1])
|
||||
+ dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]);
|
||||
|
||||
// Each thread covers half of the Q8_1 block, so add only b_ds.y/2.
|
||||
return b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD));
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef K_QUANTS
|
||||
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<array<f32, OUTPUTS_PER_WG>, NUM_COLS> {
|
||||
var acc: array<array<f32, OUTPUTS_PER_WG>, NUM_COLS>;
|
||||
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
|
||||
var acc: array<f32, OUTPUTS_PER_WG>;
|
||||
|
||||
let tid = thread_id % THREADS_PER_BLOCK;
|
||||
|
||||
for (var block = thread_id / THREADS_PER_BLOCK; block < params.k / BLOCK_SIZE; block += WG_SIZE / THREADS_PER_BLOCK) {
|
||||
let src1q_idx = src1q_idx_base + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE;
|
||||
let b_repacked = repack_b_qs(src1q_idx, tid);
|
||||
let b_ds = repack_b_dm(src1q_idx);
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
let a_repacked = repack_a(block_byte_base, tid);
|
||||
let dm = get_dm(block_byte_base);
|
||||
let scale_min = get_scale_min(block_byte_base, tid);
|
||||
for (var col = 0u;col < NUM_COLS;col += 1) {
|
||||
let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE;
|
||||
let b_repacked = repack_b_qs(src1q_idx, tid);
|
||||
let b_ds = repack_b_dm(src1q_idx);
|
||||
|
||||
#if defined(MUL_ACC_Q2_K)
|
||||
let scale_q = i32(scale_min.x);
|
||||
let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u;
|
||||
|
||||
let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1])
|
||||
+ dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q;
|
||||
let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4)
|
||||
+ dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4);
|
||||
|
||||
acc[col][row] += b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m));
|
||||
#endif // MUL_ACC_Q2_K
|
||||
|
||||
#if defined(MUL_ACC_Q4_K)
|
||||
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1])
|
||||
+ dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]);
|
||||
|
||||
// Each thread covers half of the Q8_1 block, so add only b_ds.y/2.
|
||||
acc[col][row] += b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD));
|
||||
#endif // MUL_ACC_Q4_K
|
||||
|
||||
}
|
||||
acc[row] += mmvq_dot_product(block_byte_base, tid, b_repacked, b_ds);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return acc;
|
||||
}
|
||||
#endif // K_QUANTS
|
||||
#endif
|
||||
|
||||
@@ -9,11 +9,9 @@ requires packed_4x8_integer_dot_product;
|
||||
|
||||
struct Params {
|
||||
offset_src1: u32,
|
||||
stride_11: u32,
|
||||
stride_12: u32,
|
||||
stride_13: u32,
|
||||
ne0: u32,
|
||||
ne1: u32,
|
||||
ne2: u32,
|
||||
ne3: u32,
|
||||
};
|
||||
@@ -59,28 +57,25 @@ fn main(
|
||||
@builtin(num_workgroups) num_wg: vec3<u32>
|
||||
) {
|
||||
let thread_id = local_id.x;
|
||||
let ne0_vec4 = params.ne0 / 4u;
|
||||
let num_vec4 = params.ne0 / 4u;
|
||||
|
||||
let wg_per_vec = (ne0_vec4 + (WG_SIZE - 1u)) / WG_SIZE;
|
||||
let total_batches = wg_per_vec * params.ne1 * params.ne2 * params.ne3;
|
||||
let wg_per_vec = (num_vec4 + (WG_SIZE - 1u)) / WG_SIZE;
|
||||
let total_batches = wg_per_vec * params.ne2 * params.ne3;
|
||||
|
||||
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
|
||||
if (wg_linear >= total_batches) {
|
||||
return;
|
||||
}
|
||||
|
||||
let vec_idx = wg_linear / wg_per_vec;
|
||||
let src13_idx = vec_idx / (params.ne2 * params.ne1);
|
||||
let vec_ne12_num = vec_idx % (params.ne2 * params.ne1);
|
||||
let src12_idx = vec_ne12_num / params.ne1;
|
||||
let src11_idx = vec_ne12_num % params.ne1;
|
||||
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + src11_idx * params.stride_11;
|
||||
let src13_idx = wg_linear / (params.ne2 * wg_per_vec);
|
||||
let src12_idx = (wg_linear - src13_idx * (params.ne2 * wg_per_vec)) / wg_per_vec;
|
||||
let src11_wg_idx = wg_linear % wg_per_vec;
|
||||
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
|
||||
let src1_idx_vec4_base = src1_idx_base / 4u;
|
||||
|
||||
let blocks_per_row = params.ne0 / 32u;
|
||||
let blocks_per_wg = (WG_SIZE * 4u) / 32u;
|
||||
let src1q_idx_base = ((src13_idx * params.ne2 + src12_idx) * params.ne1 + src11_idx) * blocks_per_row;
|
||||
let src11_wg_idx = wg_linear % wg_per_vec;
|
||||
let src1q_idx_base = (src13_idx * params.ne2 + src12_idx) * blocks_per_row;
|
||||
let src1q_idx = src1q_idx_base + src11_wg_idx * blocks_per_wg + thread_id / 8u;
|
||||
let qs_idx = thread_id % 8u;
|
||||
|
||||
@@ -90,7 +85,7 @@ fn main(
|
||||
var thread_amax = 0.0;
|
||||
|
||||
let src11_vec4_idx = src11_wg_idx * WG_SIZE + thread_id;
|
||||
let is_valid = src11_vec4_idx < ne0_vec4;
|
||||
let is_valid = src11_vec4_idx < num_vec4;
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
|
||||
|
||||
@@ -359,7 +359,6 @@ class Keys:
|
||||
CHUNK_SIZE = "clip.audio.chunk_size"
|
||||
CONV_KERNEL_SIZE = "clip.audio.conv_kernel_size"
|
||||
MAX_POS_EMB = "clip.audio.max_pos_emb"
|
||||
FEATURE_LAYERS = "clip.audio.feature_layer" # Granite Speech Plus
|
||||
|
||||
class Attention:
|
||||
HEAD_COUNT = "clip.audio.attention.head_count"
|
||||
|
||||
@@ -1310,9 +1310,6 @@ class GGUFWriter:
|
||||
def add_audio_max_pos_emb(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.MAX_POS_EMB, value)
|
||||
|
||||
def add_audio_feature_layers(self, layers: Sequence[int]) -> None:
|
||||
self.add_array(Keys.ClipAudio.FEATURE_LAYERS, layers)
|
||||
|
||||
def add_audio_projector_window_size(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.Projector.WINDOW_SIZE, value)
|
||||
|
||||
|
||||
@@ -8433,7 +8433,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {1, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 4, k, {3, 2}, {2, 2}));
|
||||
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {2, 1}));
|
||||
@@ -8450,7 +8449,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
|
||||
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 4, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 2, 1, 3}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
|
||||
|
||||
+24
-99
@@ -1562,112 +1562,37 @@ static void test_msgs_oaicompat_json_conversion() {
|
||||
}
|
||||
}
|
||||
|
||||
static void test_msg_token_delimiters_split() {
|
||||
static void test_split_by_role() {
|
||||
LOG_DBG("%s\n", __func__);
|
||||
|
||||
// Delimiters that share a leading token, distinguished by the second token,
|
||||
// to exercise the per-position token matching.
|
||||
const common_chat_msg_delimiters delims = {
|
||||
{ { COMMON_CHAT_ROLE_USER, "", { 10, 11 } },
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, "", { 10, 12 } } }
|
||||
};
|
||||
|
||||
// Empty inputs
|
||||
assert_equals<size_t>(0, common_chat_msg_delimiters{}.split({}).spans.size());
|
||||
assert_equals<size_t>(0, common_chat_msg_delimiters{}.split({ 10, 11 }).spans.size());
|
||||
assert_equals<size_t>(0, delims.split({}).spans.size());
|
||||
assert_equals<size_t>(0, common_chat_split_by_role("", {}).size());
|
||||
assert_equals<size_t>(0, common_chat_split_by_role("hello", {}).size());
|
||||
assert_equals<size_t>(0, common_chat_split_by_role("", { { "user", "<|user|>" } }).size());
|
||||
|
||||
// No delimiters match -> no spans
|
||||
assert_equals<size_t>(0, delims.split({ 100, 101, 102 }).spans.size());
|
||||
|
||||
// Multi-role conversation: <user>Hi<assistant>Hello<user>Bye
|
||||
// Multi-role conversation, no leading/trailing content
|
||||
{
|
||||
const llama_tokens tokens = {
|
||||
10, 11, // <user>
|
||||
100, 101, // Hi
|
||||
10, 12, // <assistant>
|
||||
200, 201, 202, // Hello
|
||||
10, 11, // <user>
|
||||
300, 301, // Bye
|
||||
};
|
||||
const std::string prompt = "<|user|>Hi<|assistant|>Hello<|user|>Bye";
|
||||
const auto splits = common_chat_split_by_role(prompt, {
|
||||
{ "user", "<|user|>" },
|
||||
{ "assistant", "<|assistant|>" },
|
||||
});
|
||||
assert_equals<size_t>(3, splits.size());
|
||||
|
||||
const auto result = delims.split(tokens);
|
||||
const auto & spans = result.spans;
|
||||
assert_equals<size_t>(3, spans.size());
|
||||
assert_equals<std::string>("user", splits[0].role);
|
||||
assert_equals<size_t>(0, splits[0].pos);
|
||||
assert_equals<size_t>(10, splits[0].len);
|
||||
assert_equals<std::string>("<|user|>Hi", prompt.substr(splits[0].pos, splits[0].len));
|
||||
|
||||
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
|
||||
assert_equals<size_t>(0, spans[0].pos);
|
||||
assert_equals<size_t>(4, spans[0].len);
|
||||
assert_equals<std::string>("assistant", splits[1].role);
|
||||
assert_equals<size_t>(10, splits[1].pos);
|
||||
assert_equals<size_t>(18, splits[1].len);
|
||||
assert_equals<std::string>("<|assistant|>Hello", prompt.substr(splits[1].pos, splits[1].len));
|
||||
|
||||
assert_equals(COMMON_CHAT_ROLE_ASSISTANT, spans[1].role);
|
||||
assert_equals<size_t>(4, spans[1].pos);
|
||||
assert_equals<size_t>(5, spans[1].len);
|
||||
|
||||
assert_equals(COMMON_CHAT_ROLE_USER, spans[2].role);
|
||||
assert_equals<size_t>(9, spans[2].pos);
|
||||
assert_equals<size_t>(4, spans[2].len);
|
||||
|
||||
// is_user_start() is true at the token position where a user span begins
|
||||
assert_equals(true, result.is_user_start(0));
|
||||
assert_equals(false, result.is_user_start(4)); // assistant span
|
||||
assert_equals(true, result.is_user_start(9));
|
||||
}
|
||||
|
||||
// Content before the first delimiter is not captured as a span
|
||||
{
|
||||
const llama_tokens tokens = {
|
||||
500, 501, // leading content (dropped)
|
||||
10, 11, // <user>
|
||||
100, // Hi
|
||||
};
|
||||
|
||||
const auto spans = delims.split(tokens).spans;
|
||||
assert_equals<size_t>(1, spans.size());
|
||||
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
|
||||
assert_equals<size_t>(2, spans[0].pos);
|
||||
assert_equals<size_t>(3, spans[0].len);
|
||||
}
|
||||
|
||||
// Skipped regions (media chunks) are jumped over but still count as span content
|
||||
{
|
||||
const llama_tokens tokens = {
|
||||
10, 11, // <user>
|
||||
LLAMA_TOKEN_NULL, // media chunk (3 tokens)
|
||||
LLAMA_TOKEN_NULL,
|
||||
LLAMA_TOKEN_NULL,
|
||||
100, // Hi
|
||||
10, 12, // <assistant>
|
||||
};
|
||||
|
||||
const std::map<size_t, size_t> skips = { { 2, 3 } };
|
||||
|
||||
const auto spans = delims.split(tokens, skips).spans;
|
||||
assert_equals<size_t>(2, spans.size());
|
||||
|
||||
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
|
||||
assert_equals<size_t>(0, spans[0].pos);
|
||||
assert_equals<size_t>(6, spans[0].len);
|
||||
|
||||
assert_equals(COMMON_CHAT_ROLE_ASSISTANT, spans[1].role);
|
||||
assert_equals<size_t>(6, spans[1].pos);
|
||||
assert_equals<size_t>(2, spans[1].len);
|
||||
}
|
||||
|
||||
// A delimiter sequence inside a skipped region is not matched
|
||||
{
|
||||
const llama_tokens tokens = {
|
||||
10, 11, // <user>
|
||||
10, 12, // skipped region that happens to contain delimiter tokens
|
||||
100, // Hi
|
||||
};
|
||||
|
||||
const std::map<size_t, size_t> skips = { { 2, 2 } };
|
||||
|
||||
const auto spans = delims.split(tokens, skips).spans;
|
||||
assert_equals<size_t>(1, spans.size());
|
||||
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
|
||||
assert_equals<size_t>(0, spans[0].pos);
|
||||
assert_equals<size_t>(5, spans[0].len);
|
||||
assert_equals<std::string>("user", splits[2].role);
|
||||
assert_equals<size_t>(28, splits[2].pos);
|
||||
assert_equals<size_t>(11, splits[2].len);
|
||||
assert_equals<std::string>("<|user|>Bye", prompt.substr(splits[2].pos, splits[2].len));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5932,7 +5857,7 @@ int main(int argc, char ** argv) {
|
||||
{
|
||||
test_msg_diffs_compute();
|
||||
test_msgs_oaicompat_json_conversion();
|
||||
test_msg_token_delimiters_split();
|
||||
test_split_by_role();
|
||||
test_tools_oaicompat_json_conversion();
|
||||
test_convert_responses_to_chatcmpl();
|
||||
test_developer_role_to_system_workaround();
|
||||
|
||||
@@ -2,13 +2,11 @@
|
||||
|
||||
set(TARGET llama-cli-impl)
|
||||
|
||||
add_library(${TARGET} cli.cpp
|
||||
cli-client.cpp
|
||||
cli-context.cpp)
|
||||
add_library(${TARGET} cli.cpp)
|
||||
set_target_properties(${TARGET} PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS ON)
|
||||
|
||||
target_include_directories(${TARGET} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} ../server)
|
||||
target_link_libraries(${TARGET} PUBLIC llama-server-impl llama-common ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PUBLIC server-context llama-common ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
if(LLAMA_TOOLS_INSTALL)
|
||||
install(TARGETS ${TARGET} LIBRARY)
|
||||
|
||||
@@ -1,164 +0,0 @@
|
||||
#include "cli-client.h"
|
||||
|
||||
#include "http.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
|
||||
// generation can stall for a long time during prompt processing, so the
|
||||
// read timeout must be generous
|
||||
static constexpr time_t CLI_HTTP_READ_TIMEOUT_SEC = 3600;
|
||||
|
||||
// upper bound for the accumulated response body kept for error reporting
|
||||
static constexpr size_t CLI_HTTP_MAX_ERROR_BODY = 1024 * 1024;
|
||||
|
||||
// returns the path with the base url's path prefix prepended (if any)
|
||||
static std::string join_path(const common_http_url & parts, const std::string & path) {
|
||||
if (parts.path.empty() || parts.path == "/") {
|
||||
return path;
|
||||
}
|
||||
std::string prefix = parts.path;
|
||||
if (prefix.back() == '/') {
|
||||
prefix.pop_back();
|
||||
}
|
||||
return prefix + path;
|
||||
}
|
||||
|
||||
json cli_client::get(const std::string & path) {
|
||||
auto [cli, parts] = common_http_client(server_base);
|
||||
cli.set_read_timeout(CLI_HTTP_READ_TIMEOUT_SEC, 0);
|
||||
auto path_with_model = path + (model.empty() ? "" : ("?model=" + model));
|
||||
auto res = cli.Get(join_path(parts, path_with_model));
|
||||
if (!res) {
|
||||
throw std::runtime_error("failed to connect to " + server_base + ": " + httplib::to_string(res.error()));
|
||||
}
|
||||
if (res->status < 200 || res->status >= 300) {
|
||||
throw std::runtime_error("GET " + path + " failed with status " + std::to_string(res->status) + ": " + res->body);
|
||||
}
|
||||
json result = json::parse(res->body, nullptr, false);
|
||||
if (result.is_discarded()) {
|
||||
throw std::runtime_error("GET " + path + " returned invalid JSON");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
json cli_client::post(const std::string & path, const json & body) {
|
||||
auto [cli, parts] = common_http_client(server_base);
|
||||
cli.set_read_timeout(CLI_HTTP_READ_TIMEOUT_SEC, 0);
|
||||
auto body_with_model = body;
|
||||
if (!model.empty()) {
|
||||
body_with_model["model"] = model;
|
||||
}
|
||||
auto res = cli.Post(join_path(parts, path), body_with_model.dump(), "application/json");
|
||||
if (!res) {
|
||||
throw std::runtime_error("failed to connect to " + server_base + ": " + httplib::to_string(res.error()));
|
||||
}
|
||||
if (res->status < 200 || res->status >= 300) {
|
||||
throw std::runtime_error("POST " + path + " failed with status " + std::to_string(res->status) + ": " + res->body);
|
||||
}
|
||||
json result = json::parse(res->body, nullptr, false);
|
||||
if (result.is_discarded()) {
|
||||
throw std::runtime_error("POST " + path + " returned invalid JSON");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
json cli_client::post_sse(const std::string & path,
|
||||
const json & body,
|
||||
const std::function<bool()> & should_stop,
|
||||
const std::function<void(const json &)> & on_data) {
|
||||
auto [cli, parts] = common_http_client(server_base);
|
||||
cli.set_read_timeout(CLI_HTTP_READ_TIMEOUT_SEC, 0);
|
||||
|
||||
std::string pending; // buffer for incomplete SSE lines
|
||||
std::string raw_body; // accumulated body, used only for error reporting
|
||||
|
||||
auto receiver = [&](const char * data, size_t len) -> bool {
|
||||
if (should_stop()) {
|
||||
return false; // aborts the request
|
||||
}
|
||||
if (raw_body.size() < CLI_HTTP_MAX_ERROR_BODY) {
|
||||
raw_body.append(data, std::min(len, CLI_HTTP_MAX_ERROR_BODY - raw_body.size()));
|
||||
}
|
||||
pending.append(data, len);
|
||||
size_t pos;
|
||||
while ((pos = pending.find('\n')) != std::string::npos) {
|
||||
std::string line = pending.substr(0, pos);
|
||||
pending.erase(0, pos + 1);
|
||||
if (!line.empty() && line.back() == '\r') {
|
||||
line.pop_back();
|
||||
}
|
||||
if (line.rfind("data: ", 0) != 0) {
|
||||
continue;
|
||||
}
|
||||
std::string payload = line.substr(6);
|
||||
if (payload == "[DONE]") {
|
||||
continue;
|
||||
}
|
||||
json event = json::parse(payload, nullptr, false);
|
||||
if (!event.is_discarded()) {
|
||||
on_data(event);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
httplib::Headers headers = {{"Accept", "text/event-stream"}};
|
||||
auto body_with_model = body;
|
||||
if (!model.empty()) {
|
||||
body_with_model["model"] = model;
|
||||
}
|
||||
auto res = cli.Post(join_path(parts, path), headers, body_with_model.dump(), "application/json", receiver);
|
||||
|
||||
if (!res) {
|
||||
if (res.error() == httplib::Error::Canceled && should_stop()) {
|
||||
return json(); // cancelled by the user
|
||||
}
|
||||
return json {{"error", {{"message", "failed to connect to " + server_base + ": " + httplib::to_string(res.error())}}}};
|
||||
}
|
||||
if (res->status < 200 || res->status >= 300) {
|
||||
json error_body = json::parse(raw_body, nullptr, false);
|
||||
if (!error_body.is_discarded() && error_body.contains("error")) {
|
||||
return error_body;
|
||||
}
|
||||
return json {{"error", {{"message", "request failed with status " + std::to_string(res->status)}}}};
|
||||
}
|
||||
return json();
|
||||
}
|
||||
|
||||
bool cli_client::wait_health(const std::function<bool()> & is_aborted) {
|
||||
int connect_attempts = 0;
|
||||
while (!is_aborted()) {
|
||||
auto [cli, parts] = common_http_client(server_base);
|
||||
cli.set_connection_timeout(1, 0);
|
||||
auto res = cli.Get(join_path(parts, "/health"));
|
||||
if (res) {
|
||||
if (res->status == 200) {
|
||||
return true;
|
||||
}
|
||||
// any other status means the server is up but not ready yet
|
||||
// (e.g. 503 while the model is still loading)
|
||||
} else if (++connect_attempts >= 10) {
|
||||
last_error = "failed to connect to " + server_base + ": " + httplib::to_string(res.error());
|
||||
return false;
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(300));
|
||||
}
|
||||
last_error = "aborted while waiting for the server to become ready";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<std::string> cli_client::list_models() {
|
||||
json resp = get("/v1/models");
|
||||
if (!resp.contains("data") || !resp.at("data").is_array()) {
|
||||
throw std::runtime_error("invalid response from /v1/models");
|
||||
}
|
||||
std::vector<std::string> models;
|
||||
for (const auto & m : resp.at("data")) {
|
||||
if (m.contains("id") && m.at("id").is_string()) {
|
||||
models.push_back(m.at("id").get<std::string>());
|
||||
}
|
||||
}
|
||||
return models;
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <functional>
|
||||
#include <string>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
// openai-like client for CLI
|
||||
struct cli_client {
|
||||
std::string server_base; // base url, for example "http://127.0.0.1:8080"
|
||||
std::string last_error; // set when wait_health() fails
|
||||
|
||||
std::string model; // optional, set when the server has multiple models (router mode)
|
||||
|
||||
// simple GET request, returns the response json
|
||||
// throws std::runtime_error on transport error or non-2xx status
|
||||
json get(const std::string & path);
|
||||
|
||||
// simple POST request, returns the response json
|
||||
// throws std::runtime_error on transport error or non-2xx status
|
||||
json post(const std::string & path, const json & body);
|
||||
|
||||
// POST request with an SSE streaming response; on_data is invoked once
|
||||
// per "data:" event; the function returns after the stream is finished:
|
||||
// a null json on graceful exit (incl. cancellation via should_stop),
|
||||
// the error response json otherwise
|
||||
json post_sse(const std::string & path,
|
||||
const json & body,
|
||||
const std::function<bool()> & should_stop,
|
||||
const std::function<void(const json &)> & on_data);
|
||||
|
||||
// poll /health until the server is ready to accept requests
|
||||
// returns false if is_aborted returned true or the server is unreachable
|
||||
bool wait_health(const std::function<bool()> & is_aborted);
|
||||
|
||||
//
|
||||
// higher-level wrappers
|
||||
//
|
||||
|
||||
json create_chat_completion(const json & request,
|
||||
const std::function<bool()> & should_stop,
|
||||
const std::function<void(const json &)> & on_data) {
|
||||
return post_sse("/v1/chat/completions", request, should_stop, on_data);
|
||||
}
|
||||
|
||||
json get_props() {
|
||||
return get("/props");
|
||||
}
|
||||
|
||||
std::vector<std::string> list_models();
|
||||
};
|
||||
@@ -1,559 +0,0 @@
|
||||
#include "cli-context.h"
|
||||
#include "cli-view.h"
|
||||
|
||||
#include "arg.h"
|
||||
#include "base64.hpp"
|
||||
#include "log.h"
|
||||
#include "console.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <set>
|
||||
|
||||
std::atomic<bool> g_cli_interrupted = false;
|
||||
|
||||
static bool should_stop() {
|
||||
return g_cli_interrupted.load();
|
||||
}
|
||||
|
||||
static constexpr size_t FILE_GLOB_MAX_RESULTS = 100;
|
||||
|
||||
const char * LLAMA_ASCII_LOGO = R"(
|
||||
▄▄ ▄▄
|
||||
██ ██
|
||||
██ ██ ▀▀█▄ ███▄███▄ ▀▀█▄ ▄████ ████▄ ████▄
|
||||
██ ██ ▄█▀██ ██ ██ ██ ▄█▀██ ██ ██ ██ ██ ██
|
||||
██ ██ ▀█▄██ ██ ██ ██ ▀█▄██ ██ ▀████ ████▀ ████▀
|
||||
██ ██
|
||||
▀▀ ▀▀
|
||||
)";
|
||||
|
||||
// number of values an arg consumes on the command line
|
||||
static int arg_num_values(const common_arg & opt) {
|
||||
if (opt.value_hint_2 != nullptr) {
|
||||
return 2;
|
||||
}
|
||||
if (opt.value_hint != nullptr) {
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static std::string format_error_message(const json & err) {
|
||||
if (err.contains("error") && err.at("error").is_object()) {
|
||||
const auto & e = err.at("error");
|
||||
if (e.contains("message") && e.at("message").is_string()) {
|
||||
return e.at("message").get<std::string>();
|
||||
}
|
||||
}
|
||||
return err.dump();
|
||||
}
|
||||
|
||||
static std::string media_type_from_ext(const std::string & fname) {
|
||||
std::string ext = std::filesystem::path(fname).extension().string();
|
||||
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
|
||||
if (ext == ".wav" || ext == ".mp3") {
|
||||
return "audio";
|
||||
}
|
||||
if (ext == ".mp4" || ext == ".avi" || ext == ".mkv" || ext == ".mov" || ext == ".webm") {
|
||||
return "video";
|
||||
}
|
||||
return "image";
|
||||
}
|
||||
|
||||
bool cli_context::init() {
|
||||
view::init(params);
|
||||
|
||||
std::optional<view::spinner> spinner;
|
||||
|
||||
bool use_external_server = !params.server_base.empty();
|
||||
if (use_external_server) {
|
||||
std::string base = params.server_base;
|
||||
while (!base.empty() && base.back() == '/') {
|
||||
base.pop_back();
|
||||
}
|
||||
client.server_base = base;
|
||||
|
||||
spinner.emplace("Connecting to server at " + base);
|
||||
} else {
|
||||
if (params.model.path.empty() && params.model.url.empty() &&
|
||||
params.model.hf_repo.empty() && params.model.docker_repo.empty()) {
|
||||
view::show_error(
|
||||
"no model specified",
|
||||
"use -m <file.gguf> or -hf <user/repo> to run a local model,\n"
|
||||
"or --server-base <url> to connect to a running llama-server"
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
spinner.emplace("\n\nLoading model...");
|
||||
|
||||
server.emplace();
|
||||
if (!server->start(params)) {
|
||||
view::show_error("server start failed");
|
||||
return false;
|
||||
}
|
||||
if (!server->wait_ready(should_stop)) {
|
||||
if (!should_stop()) {
|
||||
view::show_error("the server exited before becoming ready");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
client.server_base = server->address();
|
||||
}
|
||||
|
||||
// for --server-base this is the main availability check; for a spawned
|
||||
// server it is a cheap sanity check on top of the ready signal
|
||||
auto is_aborted = [this]() {
|
||||
return should_stop() || (server && !server->alive());
|
||||
};
|
||||
bool healthy = false;
|
||||
try {
|
||||
healthy = client.wait_health(is_aborted);
|
||||
} catch (const std::exception & e) {
|
||||
client.last_error = e.what();
|
||||
}
|
||||
if (!healthy) {
|
||||
if (!should_stop()) {
|
||||
view::show_error(client.last_error);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (use_external_server) {
|
||||
spinner.reset();
|
||||
if (!list_and_ask_models()) {
|
||||
return false;
|
||||
}
|
||||
// restore the spinner for the next step
|
||||
spinner.emplace("Waiting for server...");
|
||||
}
|
||||
|
||||
fetch_server_props();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void cli_context::fetch_server_props() {
|
||||
try {
|
||||
json props = client.get_props();
|
||||
model_name = props.value("model_alias", "");
|
||||
if (model_name.empty()) {
|
||||
const std::string path = props.value("model_path", "");
|
||||
if (!path.empty()) {
|
||||
model_name = std::filesystem::path(path).filename().string();
|
||||
}
|
||||
}
|
||||
build_info = props.value("build_info", "");
|
||||
if (props.contains("modalities") && props.at("modalities").is_object()) {
|
||||
const auto & modalities = props.at("modalities");
|
||||
has_vision = modalities.value("vision", false);
|
||||
has_audio = modalities.value("audio", false);
|
||||
has_video = modalities.value("video", false);
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
// /props can be disabled on remote servers; not fatal
|
||||
LOG_DBG("failed to fetch /props: %s\n", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
bool cli_context::list_and_ask_models() {
|
||||
auto models = client.list_models();
|
||||
|
||||
// only one model: use it without asking
|
||||
if (models.size() == 1) {
|
||||
model_name = models[0];
|
||||
client.model = model_name;
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string message = "\nAvailable models:";
|
||||
if (!models.empty()) {
|
||||
for (size_t i = 0; i < models.size(); ++i) {
|
||||
message += "\n " + std::to_string(i + 1) + ". " + models[i];
|
||||
}
|
||||
}
|
||||
message += "\n";
|
||||
view::show_message(message);
|
||||
std::string selection;
|
||||
while (selection.empty()) {
|
||||
if (should_stop()) {
|
||||
return false;
|
||||
}
|
||||
view::user_turn user_turn;
|
||||
selection = user_turn.read_input(false, "Select model by number: ");
|
||||
if (selection.empty()) {
|
||||
continue;
|
||||
}
|
||||
try {
|
||||
size_t idx = std::stoul(selection);
|
||||
if (idx > 0 && idx <= models.size()) {
|
||||
model_name = models[idx - 1];
|
||||
client.model = model_name;
|
||||
view::show_message("Selected model: " + model_name);
|
||||
break;
|
||||
}
|
||||
} catch (...) {
|
||||
// ignore
|
||||
}
|
||||
view::show_error("Invalid selection. Please enter a valid number.");
|
||||
selection.clear();
|
||||
continue;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void cli_context::add_system_prompt() {
|
||||
if (!params.system_prompt.empty()) {
|
||||
messages.push_back({
|
||||
{"role", "system"},
|
||||
{"content", params.system_prompt}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void cli_context::push_user_message(const std::string & text) {
|
||||
json content;
|
||||
if (pending_media.empty()) {
|
||||
content = text;
|
||||
} else {
|
||||
// multimodal message: media parts first, then the text
|
||||
content = pending_media;
|
||||
content.push_back({
|
||||
{"type", "text"},
|
||||
{"text", text}
|
||||
});
|
||||
pending_media = json::array();
|
||||
}
|
||||
messages.push_back({
|
||||
{"role", "user"},
|
||||
{"content", content}
|
||||
});
|
||||
}
|
||||
|
||||
bool cli_context::stage_media_file(const std::string & fname, const std::string & type) {
|
||||
std::ifstream file(fname, std::ios::binary);
|
||||
if (!file) {
|
||||
return false;
|
||||
}
|
||||
std::string data((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||
std::string encoded = base64::encode(data);
|
||||
|
||||
if (type == "audio") {
|
||||
std::string ext = std::filesystem::path(fname).extension().string();
|
||||
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
|
||||
pending_media.push_back({
|
||||
{"type", "input_audio"},
|
||||
{"input_audio", {
|
||||
{"data", encoded},
|
||||
{"format", ext == ".mp3" ? "mp3" : "wav"}
|
||||
}}
|
||||
});
|
||||
} else if (type == "video") {
|
||||
pending_media.push_back({
|
||||
{"type", "input_video"},
|
||||
{"input_video", {
|
||||
{"data", encoded}
|
||||
}}
|
||||
});
|
||||
} else {
|
||||
// the server detects the actual image type from the data
|
||||
pending_media.push_back({
|
||||
{"type", "image_url"},
|
||||
{"image_url", {
|
||||
{"url", "data:image/unknown;base64," + encoded}
|
||||
}}
|
||||
});
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool cli_context::generate_completion(std::string & assistant_content, cli_timings & timings) {
|
||||
json body = {
|
||||
{"messages", messages},
|
||||
{"stream", true},
|
||||
// in order to get timings even when we cancel mid-way
|
||||
{"timings_per_token", true},
|
||||
};
|
||||
|
||||
bool stream_error = false;
|
||||
|
||||
view::assistant_turn a;
|
||||
|
||||
json err = client.create_chat_completion(body, should_stop, [&](const json & chunk) {
|
||||
if (chunk.contains("error")) {
|
||||
stream_error = true;
|
||||
view::show_error(format_error_message(chunk));
|
||||
return;
|
||||
}
|
||||
if (chunk.contains("timings")) {
|
||||
const auto & t = chunk.at("timings");
|
||||
timings.prompt_per_second = t.value("prompt_per_second", 0.0);
|
||||
timings.predicted_per_second = t.value("predicted_per_second", 0.0);
|
||||
}
|
||||
if (!chunk.contains("choices") || !chunk.at("choices").is_array() || chunk.at("choices").empty()) {
|
||||
return;
|
||||
}
|
||||
const auto & choice = chunk.at("choices").at(0);
|
||||
if (!choice.contains("delta")) {
|
||||
return;
|
||||
}
|
||||
const auto & delta = choice.at("delta");
|
||||
if (delta.contains("reasoning_content") && delta.at("reasoning_content").is_string()) {
|
||||
const std::string text = delta.at("reasoning_content").get<std::string>();
|
||||
if (!text.empty()) {
|
||||
a.push(view::ASSISTANT_DISPLAY_MODE_REASONING, text);
|
||||
}
|
||||
}
|
||||
if (delta.contains("content") && delta.at("content").is_string()) {
|
||||
const std::string text = delta.at("content").get<std::string>();
|
||||
if (!text.empty()) {
|
||||
assistant_content += text;
|
||||
a.push(view::ASSISTANT_DISPLAY_MODE_CONTENT, text);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
g_cli_interrupted.store(false);
|
||||
|
||||
if (!err.is_null()) {
|
||||
view::show_error(format_error_message(err));
|
||||
return false;
|
||||
}
|
||||
return !stream_error;
|
||||
}
|
||||
|
||||
int cli_context::run() {
|
||||
add_system_prompt();
|
||||
|
||||
std::string modalities = "text";
|
||||
if (has_vision) {
|
||||
modalities += ", vision";
|
||||
}
|
||||
if (has_audio) {
|
||||
modalities += ", audio";
|
||||
}
|
||||
if (has_video) {
|
||||
modalities += ", video";
|
||||
}
|
||||
|
||||
std::string banner;
|
||||
banner += "\n";
|
||||
banner += LLAMA_ASCII_LOGO;
|
||||
banner += "\n";
|
||||
banner += "build : " + build_info + "\n";
|
||||
banner += "model : " + model_name + "\n";
|
||||
banner += "modalities : " + modalities + "\n";
|
||||
if (!params.system_prompt.empty()) {
|
||||
banner += "using custom system prompt\n";
|
||||
}
|
||||
banner += "\n";
|
||||
banner += "available commands:\n";
|
||||
banner += " /exit or Ctrl+C stop or exit\n";
|
||||
banner += " /regen regenerate the last response\n";
|
||||
banner += " /clear clear the chat history\n";
|
||||
banner += " /read <file> add a text file\n";
|
||||
banner += " /glob <pattern> add text files using globbing pattern\n";
|
||||
if (has_vision) {
|
||||
banner += " /image <file> add an image file\n";
|
||||
}
|
||||
if (has_audio) {
|
||||
banner += " /audio <file> add an audio file\n";
|
||||
}
|
||||
if (has_video) {
|
||||
banner += " /video <file> add a video file\n";
|
||||
}
|
||||
banner += "\n";
|
||||
|
||||
view::show_message(banner);
|
||||
|
||||
// interactive loop
|
||||
std::string cur_msg;
|
||||
|
||||
auto add_text_file = [&](const std::string & fname) -> bool {
|
||||
std::ifstream file(fname, std::ios::binary);
|
||||
if (!file) {
|
||||
view::show_error(string_format("file does not exist or cannot be opened: '%s'", fname.c_str()));
|
||||
return false;
|
||||
}
|
||||
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||
cur_msg += "--- File: ";
|
||||
cur_msg += fname;
|
||||
cur_msg += " ---\n";
|
||||
cur_msg += content;
|
||||
view::show_message(string_format("Loaded text from '%s'", fname.c_str()));
|
||||
return true;
|
||||
};
|
||||
|
||||
while (true) {
|
||||
std::string buffer;
|
||||
{
|
||||
view::user_turn user_turn;
|
||||
|
||||
if (params.prompt.empty()) {
|
||||
buffer = user_turn.read_input(params.multiline_input);
|
||||
} else {
|
||||
// process input prompt from args
|
||||
for (auto & fname : params.image) {
|
||||
if (!stage_media_file(fname, media_type_from_ext(fname))) {
|
||||
view::show_error(string_format("file does not exist or cannot be opened: '%s'", fname.c_str()));
|
||||
break;
|
||||
}
|
||||
view::show_message(string_format("Loaded media from '%s'", fname.c_str()));
|
||||
}
|
||||
buffer = params.prompt;
|
||||
user_turn.echo(buffer);
|
||||
params.prompt.clear(); // only use it once
|
||||
}
|
||||
}
|
||||
|
||||
if (should_stop()) {
|
||||
g_cli_interrupted.store(false);
|
||||
break;
|
||||
}
|
||||
|
||||
// remove trailing newline
|
||||
if (!buffer.empty() && buffer.back() == '\n') {
|
||||
buffer.pop_back();
|
||||
}
|
||||
|
||||
// skip empty messages
|
||||
if (buffer.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool add_user_msg = true;
|
||||
|
||||
// process commands
|
||||
if (string_starts_with(buffer, "/exit")) {
|
||||
break;
|
||||
} else if (string_starts_with(buffer, "/regen")) {
|
||||
if (messages.size() >= 2) {
|
||||
size_t last_idx = messages.size() - 1;
|
||||
messages.erase(last_idx);
|
||||
add_user_msg = false;
|
||||
} else {
|
||||
view::show_error("No message to regenerate.");
|
||||
continue;
|
||||
}
|
||||
} else if (string_starts_with(buffer, "/clear")) {
|
||||
messages.clear();
|
||||
add_system_prompt();
|
||||
|
||||
pending_media = json::array();
|
||||
view::show_message("Chat history cleared.");
|
||||
continue;
|
||||
} else if (
|
||||
(string_starts_with(buffer, "/image ") && has_vision) ||
|
||||
(string_starts_with(buffer, "/audio ") && has_audio) ||
|
||||
(string_starts_with(buffer, "/video ") && has_video)) {
|
||||
std::string type = buffer.substr(1, 5);
|
||||
// just in case (bad copy-paste for example), we strip all trailing/leading spaces
|
||||
std::string fname = string_strip(buffer.substr(7));
|
||||
if (!stage_media_file(fname, type)) {
|
||||
view::show_error(string_format("file does not exist or cannot be opened: '%s'", fname.c_str()));
|
||||
continue;
|
||||
}
|
||||
view::show_message(string_format("Loaded media from '%s'", fname.c_str()));
|
||||
continue;
|
||||
} else if (string_starts_with(buffer, "/read ")) {
|
||||
std::string fname = string_strip(buffer.substr(6));
|
||||
add_text_file(fname);
|
||||
continue;
|
||||
} else if (string_starts_with(buffer, "/glob ")) {
|
||||
std::error_code ec;
|
||||
size_t count = 0;
|
||||
auto curdir = std::filesystem::current_path();
|
||||
std::string pattern = string_strip(buffer.substr(6));
|
||||
std::filesystem::path rel_path;
|
||||
|
||||
auto startglob = pattern.find_first_of("![*?");
|
||||
if (startglob != std::string::npos && startglob != 0) {
|
||||
auto endpath = pattern.substr(0, startglob).find_last_of('/');
|
||||
if (endpath != std::string::npos) {
|
||||
std::string rel_pattern = pattern.substr(0, endpath);
|
||||
#if !defined(_WIN32)
|
||||
if (string_starts_with(rel_pattern, '~')) {
|
||||
const char * home = std::getenv("HOME");
|
||||
if (home && home[0]) {
|
||||
rel_pattern = home + rel_pattern.substr(1);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
rel_path = rel_pattern;
|
||||
pattern.erase(0, endpath + 1);
|
||||
curdir /= rel_path;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto & entry : std::filesystem::recursive_directory_iterator(curdir,
|
||||
std::filesystem::directory_options::skip_permission_denied, ec)) {
|
||||
if (!entry.is_regular_file()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string rel = std::filesystem::relative(entry.path(), curdir, ec).string();
|
||||
if (ec) {
|
||||
ec.clear();
|
||||
continue;
|
||||
}
|
||||
std::replace(rel.begin(), rel.end(), '\\', '/');
|
||||
|
||||
if (!glob_match(pattern, rel)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!add_text_file((rel_path / rel).string())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (++count >= FILE_GLOB_MAX_RESULTS) {
|
||||
view::show_error(string_format("Maximum number of globbed files allowed (%zu) reached.", FILE_GLOB_MAX_RESULTS));
|
||||
break;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
} else {
|
||||
// not a command
|
||||
cur_msg += buffer;
|
||||
}
|
||||
|
||||
// generate response
|
||||
if (add_user_msg) {
|
||||
push_user_message(cur_msg);
|
||||
cur_msg.clear();
|
||||
}
|
||||
cli_timings timings;
|
||||
std::string assistant_content;
|
||||
generate_completion(assistant_content, timings);
|
||||
messages.push_back({
|
||||
{"role", "assistant"},
|
||||
{"content", assistant_content}
|
||||
});
|
||||
|
||||
if (params.show_timings) {
|
||||
view::show_info(string_format(
|
||||
"\n[ Prompt: %.1f t/s | Generation: %.1f t/s ]",
|
||||
timings.prompt_per_second,
|
||||
timings.predicted_per_second
|
||||
));
|
||||
}
|
||||
|
||||
if (params.single_turn) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
view::show_message("\n\nExiting...");
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void cli_context::shutdown() {
|
||||
if (server) {
|
||||
server->stop();
|
||||
server.reset();
|
||||
}
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
|
||||
#include "cli-client.h"
|
||||
#include "cli-server.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
struct cli_timings {
|
||||
double prompt_per_second = 0.0;
|
||||
double predicted_per_second = 0.0;
|
||||
};
|
||||
|
||||
// set by the SIGINT handler; cleared once the interrupt has been handled
|
||||
extern std::atomic<bool> g_cli_interrupted;
|
||||
|
||||
struct cli_context {
|
||||
common_params params;
|
||||
|
||||
cli_client client; // always initialized
|
||||
std::optional<cli_server> server; // only set when no --server-base is given
|
||||
|
||||
json messages = json::array();
|
||||
json pending_media = json::array(); // staged multimodal content parts
|
||||
|
||||
// properties of the connected server
|
||||
// will be populated by fetch_server_props()
|
||||
std::string model_name;
|
||||
std::string build_info;
|
||||
bool has_vision = false;
|
||||
bool has_audio = false;
|
||||
bool has_video = false;
|
||||
|
||||
cli_context(const common_params & params) : params(params) {}
|
||||
~cli_context() {
|
||||
shutdown();
|
||||
}
|
||||
|
||||
// connect to --server-base or spawn a local llama-server child;
|
||||
// argc/argv are needed to forward the server-relevant args to the child
|
||||
bool init();
|
||||
|
||||
// run the interactive chat loop, returns the process exit code
|
||||
int run();
|
||||
|
||||
// stop the local server child (if any)
|
||||
void shutdown();
|
||||
|
||||
private:
|
||||
bool generate_completion(std::string & assistant_content, cli_timings & timings);
|
||||
void fetch_server_props();
|
||||
void add_system_prompt();
|
||||
void push_user_message(const std::string & text);
|
||||
|
||||
// check if server have multiple models (router mode)
|
||||
// if yes, list them then ask; do nothing otherwise
|
||||
bool list_and_ask_models();
|
||||
|
||||
// read a file and stage it as a multimodal content part; type is one of
|
||||
// "image", "audio", "video"; returns false if the file cannot be read
|
||||
bool stage_media_file(const std::string & fname, const std::string & type);
|
||||
};
|
||||
@@ -1,83 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <thread>
|
||||
|
||||
#include "http.h"
|
||||
|
||||
// llama_server will be available as a dynamic library symbol
|
||||
int llama_server(common_params & params, int argc, char ** argv);
|
||||
void llama_server_terminate();
|
||||
|
||||
struct cli_server {
|
||||
std::thread th;
|
||||
int port = -1;
|
||||
std::atomic<bool> is_alive = false;
|
||||
std::atomic<bool> is_stopping = false;
|
||||
|
||||
~cli_server() {
|
||||
stop();
|
||||
}
|
||||
|
||||
void stop() {
|
||||
if (alive() && !is_stopping.exchange(true)) {
|
||||
llama_server_terminate();
|
||||
th.join();
|
||||
}
|
||||
}
|
||||
|
||||
// spawn llama-server in a thread and interact with it via a random port
|
||||
bool start(common_params & params) {
|
||||
port = common_http_get_free_port();
|
||||
if (port <= 0) {
|
||||
fprintf(stderr, "failed to get a free port\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
is_alive.store(true, std::memory_order_release);
|
||||
|
||||
th = std::thread([&]() {
|
||||
common_params server_params = params; // copy
|
||||
server_params.port = port;
|
||||
// argc / argv are only used in router mode, we can skip them for now
|
||||
int res = llama_server(server_params, 0, nullptr);
|
||||
if (res != 0) {
|
||||
fprintf(stderr, "llama_server exited with code %d\n", res);
|
||||
}
|
||||
is_alive.store(false, std::memory_order_release);
|
||||
});
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string address() const {
|
||||
return "http://127.0.0.1:" + std::to_string(port);
|
||||
}
|
||||
|
||||
bool wait_ready(std::function<bool()> should_stop) {
|
||||
if (!alive()) {
|
||||
return false;
|
||||
}
|
||||
while (!should_stop()) {
|
||||
auto [cli, parts] = common_http_client(address());
|
||||
cli.set_connection_timeout(1, 0);
|
||||
auto res = cli.Get("/health");
|
||||
if (res) {
|
||||
if (res->status == 200) {
|
||||
return true;
|
||||
}
|
||||
// any other status means the server is up but not ready yet
|
||||
// (e.g. 503 while the model is still loading)
|
||||
}
|
||||
if (!alive()) {
|
||||
// in case server die permanently
|
||||
return false;
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(200));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool alive() const {
|
||||
return is_alive.load(std::memory_order_acquire);
|
||||
}
|
||||
};
|
||||
@@ -1,250 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "console.h"
|
||||
|
||||
#include <array>
|
||||
#include <algorithm>
|
||||
#include <filesystem>
|
||||
#include <string_view>
|
||||
|
||||
// TODO?: Make this reusable, enums, docs
|
||||
static const std::array<std::string_view, 8> cmds = {
|
||||
"/audio ",
|
||||
"/clear",
|
||||
"/exit",
|
||||
"/glob ",
|
||||
"/image ",
|
||||
"/read ",
|
||||
"/regen",
|
||||
"/video ",
|
||||
};
|
||||
|
||||
static std::vector<std::pair<std::string, size_t>> auto_completion_callback(std::string_view line, size_t cursor_byte_pos) {
|
||||
std::vector<std::pair<std::string, size_t>> matches;
|
||||
std::string cmd;
|
||||
|
||||
if (line.length() > 1 && line.front() == '/' && !std::any_of(cmds.begin(), cmds.end(), [line](std::string_view prefix) {
|
||||
return string_starts_with(line, prefix);
|
||||
})) {
|
||||
auto it = cmds.begin();
|
||||
|
||||
while ((it = std::find_if(it, cmds.end(), [line](std::string_view cmd_line) {
|
||||
return string_starts_with(cmd_line, line);
|
||||
})) != cmds.end()) {
|
||||
matches.emplace_back(*it, it->length());
|
||||
++it;
|
||||
}
|
||||
} else {
|
||||
auto it = std::find_if(cmds.begin(), cmds.end(), [line](std::string_view prefix) {
|
||||
return prefix.back() == ' ' && string_starts_with(line, prefix);
|
||||
});
|
||||
|
||||
if (it != cmds.end()) {
|
||||
cmd = *it;
|
||||
}
|
||||
}
|
||||
|
||||
if (!cmd.empty() && cmd != "/glob " && line.length() >= cmd.length() && cursor_byte_pos >= cmd.length()) {
|
||||
const std::string path_prefix = std::string(line.substr(cmd.length(), cursor_byte_pos - cmd.length()));
|
||||
const std::string path_postfix = std::string(line.substr(cursor_byte_pos));
|
||||
auto cur_dir = std::filesystem::current_path();
|
||||
std::string cur_dir_str = cur_dir.string();
|
||||
std::string expanded_prefix = path_prefix;
|
||||
|
||||
#if !defined(_WIN32)
|
||||
if (string_starts_with(path_prefix, '~')) {
|
||||
const char * home = std::getenv("HOME");
|
||||
if (home && home[0]) {
|
||||
expanded_prefix = home + path_prefix.substr(1);
|
||||
}
|
||||
}
|
||||
if (string_starts_with(expanded_prefix, '/')) {
|
||||
#else
|
||||
if (std::isalpha(expanded_prefix[0]) && expanded_prefix.find(':') == 1) {
|
||||
#endif
|
||||
cur_dir = std::filesystem::path(expanded_prefix).parent_path();
|
||||
cur_dir_str.clear();
|
||||
} else if (!path_prefix.empty()) {
|
||||
cur_dir /= std::filesystem::path(path_prefix).parent_path();
|
||||
}
|
||||
|
||||
std::error_code ec;
|
||||
for (const auto & entry : std::filesystem::directory_iterator(cur_dir, ec)) {
|
||||
if (ec) {
|
||||
break;
|
||||
}
|
||||
if (!entry.exists(ec)) {
|
||||
ec.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
const std::string path_full = entry.path().string();
|
||||
std::string path_entry = !cur_dir_str.empty() && string_starts_with(path_full, cur_dir_str) ? path_full.substr(cur_dir_str.length() + 1) : path_full;
|
||||
|
||||
if (entry.is_directory(ec)) {
|
||||
path_entry.push_back(std::filesystem::path::preferred_separator);
|
||||
}
|
||||
|
||||
if (expanded_prefix.empty() || string_starts_with(path_entry, expanded_prefix)) {
|
||||
const std::string updated_line = cmd + path_entry;
|
||||
matches.emplace_back(updated_line + path_postfix, updated_line.length());
|
||||
}
|
||||
|
||||
if (ec) {
|
||||
ec.clear();
|
||||
}
|
||||
}
|
||||
|
||||
if (matches.empty()) {
|
||||
const std::string updated_line = cmd + path_prefix;
|
||||
matches.emplace_back(updated_line + path_postfix, updated_line.length());
|
||||
}
|
||||
|
||||
// Add the longest common prefix
|
||||
if (!expanded_prefix.empty() && matches.size() > 1) {
|
||||
const std::string_view match0(matches[0].first);
|
||||
const std::string_view match1(matches[1].first);
|
||||
auto it = std::mismatch(match0.begin(), match0.end(), match1.begin(), match1.end());
|
||||
size_t len = it.first - match0.begin();
|
||||
|
||||
for (size_t i = 2; i < matches.size(); ++i) {
|
||||
const std::string_view matchi(matches[i].first);
|
||||
auto cmp = std::mismatch(match0.begin(), match0.end(), matchi.begin(), matchi.end());
|
||||
len = std::min(len, static_cast<size_t>(cmp.first - match0.begin()));
|
||||
}
|
||||
|
||||
const std::string updated_line = std::string(match0.substr(0, len));
|
||||
matches.emplace_back(updated_line + path_postfix, updated_line.length());
|
||||
}
|
||||
|
||||
std::sort(matches.begin(), matches.end(), [](const auto & a, const auto & b) {
|
||||
return a.first.compare(0, a.second, b.first, 0, b.second) < 0;
|
||||
});
|
||||
}
|
||||
|
||||
return matches;
|
||||
}
|
||||
|
||||
// note: make this view implementation generic, so that we can move to TUI in the future if we want to
|
||||
namespace view {
|
||||
static void init(const common_params & params) {
|
||||
// TODO: avoid using atexit() here by making `console` a singleton
|
||||
console::init(params.simple_io, params.use_color);
|
||||
atexit([]() { console::cleanup(); });
|
||||
|
||||
console::set_completion_callback(auto_completion_callback);
|
||||
}
|
||||
|
||||
struct spinner {
|
||||
spinner(const std::string & message) {
|
||||
if (!message.empty()) {
|
||||
console::log("%s ", message.c_str());
|
||||
}
|
||||
console::spinner::start();
|
||||
}
|
||||
~spinner() {
|
||||
console::spinner::stop();
|
||||
}
|
||||
};
|
||||
|
||||
struct user_turn {
|
||||
user_turn() {
|
||||
console::set_display(DISPLAY_TYPE_USER_INPUT);
|
||||
}
|
||||
~user_turn() {
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
}
|
||||
void echo(const std::string & buffer) {
|
||||
if (buffer.size() > 500) {
|
||||
console::log("\n> %s ... (truncated)\n", buffer.substr(0, 500).c_str());
|
||||
} else {
|
||||
console::log("\n> %s\n", buffer.c_str());
|
||||
}
|
||||
}
|
||||
std::string read_input(bool multiline_input, const char * prompt = nullptr) {
|
||||
if (prompt) {
|
||||
console::log("%s", prompt);
|
||||
} else {
|
||||
console::log("\n> ");
|
||||
}
|
||||
std::string buffer;
|
||||
std::string line;
|
||||
bool another_line = true;
|
||||
do {
|
||||
another_line = console::readline(line, multiline_input);
|
||||
buffer += line;
|
||||
} while (another_line);
|
||||
return buffer;
|
||||
}
|
||||
};
|
||||
|
||||
enum assistant_display_mode {
|
||||
ASSISTANT_DISPLAY_MODE_REASONING,
|
||||
ASSISTANT_DISPLAY_MODE_CONTENT,
|
||||
};
|
||||
struct assistant_turn {
|
||||
assistant_display_mode mode = ASSISTANT_DISPLAY_MODE_CONTENT;
|
||||
bool trailing_newline = true;
|
||||
bool is_inside_reasoning = false;
|
||||
assistant_turn() {
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
}
|
||||
~assistant_turn() {
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
add_newline_if_needed();
|
||||
}
|
||||
void push(assistant_display_mode m, const std::string & buffer) {
|
||||
if (m != mode) {
|
||||
add_newline_if_needed();
|
||||
switch (m) {
|
||||
case ASSISTANT_DISPLAY_MODE_CONTENT:
|
||||
{
|
||||
if (is_inside_reasoning) {
|
||||
console::log("[End thinking]\n\n");
|
||||
is_inside_reasoning = false;
|
||||
}
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
} break;
|
||||
case ASSISTANT_DISPLAY_MODE_REASONING:
|
||||
{
|
||||
console::set_display(DISPLAY_TYPE_REASONING);
|
||||
is_inside_reasoning = true;
|
||||
console::log("\n[Start thinking]\n\n");
|
||||
} break;
|
||||
}
|
||||
}
|
||||
mode = m;
|
||||
if (buffer.empty()) {
|
||||
return;
|
||||
}
|
||||
trailing_newline = buffer.back() == '\n';
|
||||
console::log("%s", buffer.c_str());
|
||||
console::flush();
|
||||
}
|
||||
void add_newline_if_needed() {
|
||||
if (!trailing_newline) {
|
||||
console::log("\n");
|
||||
console::flush();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static void show_error(const std::string & title, const std::string & message = "") {
|
||||
console::spinner::stop();
|
||||
console::error("Error: %s\n", title.c_str());
|
||||
if (!message.empty()) {
|
||||
console::log("%s\n", message.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
static void show_message(const std::string & message) {
|
||||
console::log("%s\n", message.c_str());
|
||||
}
|
||||
|
||||
static void show_info(const std::string & message) {
|
||||
console::set_display(DISPLAY_TYPE_INFO);
|
||||
console::log("%s\n", message.c_str());
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
}
|
||||
}
|
||||
+624
-10
@@ -1,10 +1,20 @@
|
||||
#include "arg.h"
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "arg.h"
|
||||
#include "console.h"
|
||||
#include "fit.h"
|
||||
// #include "log.h"
|
||||
|
||||
#include "cli-context.h"
|
||||
#include "cli-view.h"
|
||||
#include "server-common.h"
|
||||
#include "server-context.h"
|
||||
#include "server-task.h"
|
||||
|
||||
#include <array>
|
||||
#include <atomic>
|
||||
#include <algorithm>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <thread>
|
||||
#include <signal.h>
|
||||
|
||||
#if defined(_WIN32)
|
||||
@@ -15,19 +25,342 @@
|
||||
#include <windows.h>
|
||||
#endif
|
||||
|
||||
const char * LLAMA_ASCII_LOGO = R"(
|
||||
▄▄ ▄▄
|
||||
██ ██
|
||||
██ ██ ▀▀█▄ ███▄███▄ ▀▀█▄ ▄████ ████▄ ████▄
|
||||
██ ██ ▄█▀██ ██ ██ ██ ▄█▀██ ██ ██ ██ ██ ██
|
||||
██ ██ ▀█▄██ ██ ██ ██ ▀█▄██ ██ ▀████ ████▀ ████▀
|
||||
██ ██
|
||||
▀▀ ▀▀
|
||||
)";
|
||||
|
||||
static std::atomic<bool> g_is_interrupted = false;
|
||||
static bool should_stop() {
|
||||
return g_is_interrupted.load();
|
||||
}
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
||||
static void signal_handler(int) {
|
||||
if (g_cli_interrupted.load()) {
|
||||
if (g_is_interrupted.load()) {
|
||||
// second Ctrl+C - exit immediately
|
||||
// make sure to clear colors before exiting (not using LOG or console.cpp here to avoid deadlock)
|
||||
fprintf(stdout, "\033[0m\n");
|
||||
fflush(stdout);
|
||||
std::exit(130);
|
||||
}
|
||||
g_cli_interrupted.store(true);
|
||||
g_is_interrupted.store(true);
|
||||
}
|
||||
#endif
|
||||
|
||||
struct cli_context {
|
||||
server_context ctx_server;
|
||||
json messages = json::array();
|
||||
std::vector<raw_buffer> input_files;
|
||||
task_params defaults;
|
||||
bool verbose_prompt;
|
||||
|
||||
// thread for showing "loading" animation
|
||||
std::atomic<bool> loading_show;
|
||||
|
||||
cli_context(const common_params & params) {
|
||||
defaults.sampling = params.sampling;
|
||||
defaults.speculative = params.speculative;
|
||||
defaults.n_keep = params.n_keep;
|
||||
defaults.n_predict = params.n_predict;
|
||||
defaults.antiprompt = params.antiprompt;
|
||||
|
||||
defaults.stream = true; // make sure we always use streaming mode
|
||||
defaults.timings_per_token = true; // in order to get timings even when we cancel mid-way
|
||||
// defaults.return_progress = true; // TODO: show progress
|
||||
|
||||
verbose_prompt = params.verbose_prompt;
|
||||
}
|
||||
|
||||
std::string generate_completion(result_timings & out_timings) {
|
||||
server_response_reader rd = ctx_server.get_response_reader();
|
||||
auto chat_params = format_chat();
|
||||
{
|
||||
// TODO: reduce some copies here in the future
|
||||
server_task task = server_task(SERVER_TASK_TYPE_COMPLETION);
|
||||
task.id = rd.get_new_id();
|
||||
task.index = 0;
|
||||
task.params = defaults; // copy
|
||||
task.cli_prompt = chat_params.prompt; // copy
|
||||
task.cli_files = input_files; // copy
|
||||
task.cli = true;
|
||||
|
||||
// chat template settings
|
||||
task.params.chat_parser_params = common_chat_parser_params(chat_params);
|
||||
task.params.chat_parser_params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
if (!chat_params.parser.empty()) {
|
||||
task.params.chat_parser_params.parser.load(chat_params.parser);
|
||||
}
|
||||
|
||||
// Copy the preserved tokens into the sampling params
|
||||
const llama_vocab * vocab = llama_model_get_vocab(
|
||||
llama_get_model(ctx_server.get_llama_context()));
|
||||
for (const auto & token : chat_params.preserved_tokens) {
|
||||
auto ids = common_tokenize(vocab, token, false, true);
|
||||
if (ids.size() == 1) {
|
||||
task.params.sampling.preserved_tokens.insert(ids[0]);
|
||||
}
|
||||
}
|
||||
|
||||
// reasoning budget sampler
|
||||
if (!chat_params.thinking_end_tag.empty()) {
|
||||
task.params.sampling.reasoning_budget_tokens = defaults.sampling.reasoning_budget_tokens;
|
||||
task.params.sampling.generation_prompt = chat_params.generation_prompt;
|
||||
|
||||
if (!chat_params.thinking_start_tag.empty()) {
|
||||
task.params.sampling.reasoning_budget_start =
|
||||
common_tokenize(vocab, chat_params.thinking_start_tag, false, true);
|
||||
}
|
||||
task.params.sampling.reasoning_budget_end =
|
||||
common_tokenize(vocab, chat_params.thinking_end_tag, false, true);
|
||||
task.params.sampling.reasoning_budget_forced =
|
||||
common_tokenize(vocab, defaults.sampling.reasoning_budget_message + chat_params.thinking_end_tag, false, true);
|
||||
}
|
||||
|
||||
rd.post_task({std::move(task)});
|
||||
}
|
||||
|
||||
if (verbose_prompt) {
|
||||
console::set_display(DISPLAY_TYPE_PROMPT);
|
||||
console::log("%s\n\n", chat_params.prompt.c_str());
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
}
|
||||
|
||||
// wait for first result
|
||||
console::spinner::start();
|
||||
server_task_result_ptr result = rd.next(should_stop);
|
||||
|
||||
while (true) {
|
||||
auto res_partial = dynamic_cast<server_task_result_cmpl_partial *>(result.get());
|
||||
if (res_partial && res_partial->is_begin) {
|
||||
// this is the "send 200 status to client" signal in streaming mode
|
||||
// skip, do not stop the spinner
|
||||
result = rd.next(should_stop);
|
||||
} else {
|
||||
console::spinner::stop();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::string curr_content;
|
||||
bool is_thinking = false;
|
||||
|
||||
while (result) {
|
||||
if (should_stop()) {
|
||||
break;
|
||||
}
|
||||
if (result->is_error()) {
|
||||
json err_data = result->to_json();
|
||||
if (err_data.contains("message")) {
|
||||
console::error("Error: %s\n", err_data["message"].get<std::string>().c_str());
|
||||
} else {
|
||||
console::error("Error: %s\n", err_data.dump().c_str());
|
||||
}
|
||||
return curr_content;
|
||||
}
|
||||
auto res_partial = dynamic_cast<server_task_result_cmpl_partial *>(result.get());
|
||||
if (res_partial) {
|
||||
out_timings = std::move(res_partial->timings);
|
||||
for (const auto & diff : res_partial->oaicompat_msg_diffs) {
|
||||
if (!diff.content_delta.empty()) {
|
||||
if (is_thinking) {
|
||||
console::log("\n[End thinking]\n\n");
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
is_thinking = false;
|
||||
}
|
||||
curr_content += diff.content_delta;
|
||||
console::log("%s", diff.content_delta.c_str());
|
||||
console::flush();
|
||||
}
|
||||
if (!diff.reasoning_content_delta.empty()) {
|
||||
console::set_display(DISPLAY_TYPE_REASONING);
|
||||
if (!is_thinking) {
|
||||
console::log("[Start thinking]\n");
|
||||
}
|
||||
is_thinking = true;
|
||||
console::log("%s", diff.reasoning_content_delta.c_str());
|
||||
console::flush();
|
||||
}
|
||||
}
|
||||
}
|
||||
auto res_final = dynamic_cast<server_task_result_cmpl_final *>(result.get());
|
||||
if (res_final) {
|
||||
out_timings = std::move(res_final->timings);
|
||||
break;
|
||||
}
|
||||
result = rd.next(should_stop);
|
||||
}
|
||||
g_is_interrupted.store(false);
|
||||
// server_response_reader automatically cancels pending tasks upon destruction
|
||||
return curr_content;
|
||||
}
|
||||
|
||||
// TODO: support remote files in the future (http, https, etc)
|
||||
std::string load_input_file(const std::string & fname, bool is_media) {
|
||||
std::ifstream file = fs_open_ifstream(fname, std::ios::binary);
|
||||
if (!file) {
|
||||
return "";
|
||||
}
|
||||
if (is_media) {
|
||||
raw_buffer buf;
|
||||
buf.assign((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||
input_files.push_back(std::move(buf));
|
||||
return get_media_marker();
|
||||
} else {
|
||||
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||
return content;
|
||||
}
|
||||
}
|
||||
|
||||
common_chat_params format_chat() {
|
||||
auto meta = ctx_server.get_meta();
|
||||
auto & chat_params = meta.chat_params;
|
||||
|
||||
auto caps = common_chat_templates_get_caps(chat_params.tmpls.get());
|
||||
|
||||
common_chat_templates_inputs inputs;
|
||||
inputs.messages = common_chat_msgs_parse_oaicompat(messages);
|
||||
inputs.tools = {}; // TODO
|
||||
inputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
inputs.json_schema = ""; // TODO
|
||||
inputs.grammar = ""; // TODO
|
||||
inputs.use_jinja = chat_params.use_jinja;
|
||||
inputs.parallel_tool_calls = caps["supports_parallel_tool_calls"];
|
||||
inputs.add_generation_prompt = true;
|
||||
inputs.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
inputs.force_pure_content = chat_params.force_pure_content;
|
||||
inputs.enable_thinking = chat_params.enable_thinking ? common_chat_templates_support_enable_thinking(chat_params.tmpls.get()) : false;
|
||||
|
||||
// Apply chat template to the list of messages
|
||||
return common_chat_templates_apply(chat_params.tmpls.get(), inputs);
|
||||
}
|
||||
};
|
||||
|
||||
// TODO?: Make this reusable, enums, docs
|
||||
static const std::array<std::string_view, 8> cmds = {
|
||||
"/audio ",
|
||||
"/clear",
|
||||
"/exit",
|
||||
"/glob ",
|
||||
"/image ",
|
||||
"/read ",
|
||||
"/regen",
|
||||
"/video ",
|
||||
};
|
||||
|
||||
static std::vector<std::pair<std::string, size_t>> auto_completion_callback(std::string_view line, size_t cursor_byte_pos) {
|
||||
std::vector<std::pair<std::string, size_t>> matches;
|
||||
std::string cmd;
|
||||
|
||||
if (line.length() > 1 && line.front() == '/' && !std::any_of(cmds.begin(), cmds.end(), [line](std::string_view prefix) {
|
||||
return string_starts_with(line, prefix);
|
||||
})) {
|
||||
auto it = cmds.begin();
|
||||
|
||||
while ((it = std::find_if(it, cmds.end(), [line](std::string_view cmd_line) {
|
||||
return string_starts_with(cmd_line, line);
|
||||
})) != cmds.end()) {
|
||||
matches.emplace_back(*it, it->length());
|
||||
++it;
|
||||
}
|
||||
} else {
|
||||
auto it = std::find_if(cmds.begin(), cmds.end(), [line](std::string_view prefix) {
|
||||
return prefix.back() == ' ' && string_starts_with(line, prefix);
|
||||
});
|
||||
|
||||
if (it != cmds.end()) {
|
||||
cmd = *it;
|
||||
}
|
||||
}
|
||||
|
||||
if (!cmd.empty() && cmd != "/glob " && line.length() >= cmd.length() && cursor_byte_pos >= cmd.length()) {
|
||||
const std::string path_prefix = std::string(line.substr(cmd.length(), cursor_byte_pos - cmd.length()));
|
||||
const std::string path_postfix = std::string(line.substr(cursor_byte_pos));
|
||||
auto cur_dir = std::filesystem::current_path();
|
||||
std::string cur_dir_str = cur_dir.string();
|
||||
std::string expanded_prefix = path_prefix;
|
||||
|
||||
#if !defined(_WIN32)
|
||||
if (string_starts_with(path_prefix, '~')) {
|
||||
const char * home = std::getenv("HOME");
|
||||
if (home && home[0]) {
|
||||
expanded_prefix = home + path_prefix.substr(1);
|
||||
}
|
||||
}
|
||||
if (string_starts_with(expanded_prefix, '/')) {
|
||||
#else
|
||||
if (std::isalpha(expanded_prefix[0]) && expanded_prefix.find(':') == 1) {
|
||||
#endif
|
||||
cur_dir = std::filesystem::path(expanded_prefix).parent_path();
|
||||
cur_dir_str.clear();
|
||||
} else if (!path_prefix.empty()) {
|
||||
cur_dir /= std::filesystem::path(path_prefix).parent_path();
|
||||
}
|
||||
|
||||
std::error_code ec;
|
||||
for (const auto & entry : std::filesystem::directory_iterator(cur_dir, ec)) {
|
||||
if (ec) {
|
||||
break;
|
||||
}
|
||||
if (!entry.exists(ec)) {
|
||||
ec.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
const std::string path_full = entry.path().string();
|
||||
std::string path_entry = !cur_dir_str.empty() && string_starts_with(path_full, cur_dir_str) ? path_full.substr(cur_dir_str.length() + 1) : path_full;
|
||||
|
||||
if (entry.is_directory(ec)) {
|
||||
path_entry.push_back(std::filesystem::path::preferred_separator);
|
||||
}
|
||||
|
||||
if (expanded_prefix.empty() || string_starts_with(path_entry, expanded_prefix)) {
|
||||
const std::string updated_line = cmd + path_entry;
|
||||
matches.emplace_back(updated_line + path_postfix, updated_line.length());
|
||||
}
|
||||
|
||||
if (ec) {
|
||||
ec.clear();
|
||||
}
|
||||
}
|
||||
|
||||
if (matches.empty()) {
|
||||
const std::string updated_line = cmd + path_prefix;
|
||||
matches.emplace_back(updated_line + path_postfix, updated_line.length());
|
||||
}
|
||||
|
||||
// Add the longest common prefix
|
||||
if (!expanded_prefix.empty() && matches.size() > 1) {
|
||||
const std::string_view match0(matches[0].first);
|
||||
const std::string_view match1(matches[1].first);
|
||||
auto it = std::mismatch(match0.begin(), match0.end(), match1.begin(), match1.end());
|
||||
size_t len = it.first - match0.begin();
|
||||
|
||||
for (size_t i = 2; i < matches.size(); ++i) {
|
||||
const std::string_view matchi(matches[i].first);
|
||||
auto cmp = std::mismatch(match0.begin(), match0.end(), matchi.begin(), matchi.end());
|
||||
len = std::min(len, static_cast<size_t>(cmp.first - match0.begin()));
|
||||
}
|
||||
|
||||
const std::string updated_line = std::string(match0.substr(0, len));
|
||||
matches.emplace_back(updated_line + path_postfix, updated_line.length());
|
||||
}
|
||||
|
||||
std::sort(matches.begin(), matches.end(), [](const auto & a, const auto & b) {
|
||||
return a.first.compare(0, a.second, b.first, 0, b.second) < 0;
|
||||
});
|
||||
}
|
||||
|
||||
return matches;
|
||||
}
|
||||
|
||||
static constexpr size_t FILE_GLOB_MAX_RESULTS = 100;
|
||||
|
||||
// satisfies -Wmissing-declarations
|
||||
int llama_cli(int argc, char ** argv);
|
||||
|
||||
@@ -42,6 +375,25 @@ int llama_cli(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// TODO: maybe support it later?
|
||||
if (params.conversation_mode == COMMON_CONVERSATION_MODE_DISABLED) {
|
||||
console::error("--no-conversation is not supported by llama-cli\n");
|
||||
console::error("please use llama-completion instead\n");
|
||||
}
|
||||
|
||||
// struct that contains llama context and inference
|
||||
cli_context ctx_cli(params);
|
||||
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
// TODO: avoid using atexit() here by making `console` a singleton
|
||||
console::init(params.simple_io, params.use_color);
|
||||
atexit([]() { console::cleanup(); });
|
||||
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
console::set_completion_callback(auto_completion_callback);
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
struct sigaction sigint_action;
|
||||
sigint_action.sa_handler = signal_handler;
|
||||
@@ -56,11 +408,273 @@ int llama_cli(int argc, char ** argv) {
|
||||
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||
#endif
|
||||
|
||||
cli_context ctx_cli(params);
|
||||
|
||||
if (!ctx_cli.init()) {
|
||||
console::log("\nLoading model... "); // followed by loading animation
|
||||
console::spinner::start();
|
||||
if (!ctx_cli.ctx_server.load_model(params)) {
|
||||
console::spinner::stop();
|
||||
console::error("\nFailed to load the model\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
return ctx_cli.run();
|
||||
ctx_cli.defaults.sampling = params.sampling;
|
||||
|
||||
console::spinner::stop();
|
||||
console::log("\n");
|
||||
|
||||
std::thread inference_thread([&ctx_cli]() {
|
||||
ctx_cli.ctx_server.start_loop();
|
||||
});
|
||||
|
||||
auto inf = ctx_cli.ctx_server.get_meta();
|
||||
std::string modalities = "text";
|
||||
if (inf.has_inp_image) {
|
||||
modalities += ", vision";
|
||||
}
|
||||
if (inf.has_inp_audio) {
|
||||
modalities += ", audio";
|
||||
}
|
||||
|
||||
auto add_system_prompt = [&]() {
|
||||
if (!params.system_prompt.empty()) {
|
||||
ctx_cli.messages.push_back({
|
||||
{"role", "system"},
|
||||
{"content", params.system_prompt}
|
||||
});
|
||||
}
|
||||
};
|
||||
add_system_prompt();
|
||||
|
||||
console::log("\n");
|
||||
console::log("%s\n", LLAMA_ASCII_LOGO);
|
||||
console::log("build : %s\n", inf.build_info.c_str());
|
||||
console::log("model : %s\n", inf.model_name.c_str());
|
||||
console::log("modalities : %s\n", modalities.c_str());
|
||||
if (!params.system_prompt.empty()) {
|
||||
console::log("using custom system prompt\n");
|
||||
}
|
||||
console::log("\n");
|
||||
console::log("available commands:\n");
|
||||
console::log(" /exit or Ctrl+C stop or exit\n");
|
||||
console::log(" /regen regenerate the last response\n");
|
||||
console::log(" /clear clear the chat history\n");
|
||||
console::log(" /read <file> add a text file\n");
|
||||
console::log(" /glob <pattern> add text files using globbing pattern\n");
|
||||
if (inf.has_inp_image) {
|
||||
console::log(" /image <file> add an image file\n");
|
||||
}
|
||||
if (inf.has_inp_audio) {
|
||||
console::log(" /audio <file> add an audio file\n");
|
||||
}
|
||||
if (inf.has_inp_video) {
|
||||
console::log(" /video <file> add a video file\n");
|
||||
}
|
||||
console::log("\n");
|
||||
|
||||
// interactive loop
|
||||
std::string cur_msg;
|
||||
|
||||
auto add_text_file = [&](const std::string & fname) -> bool {
|
||||
std::string marker = ctx_cli.load_input_file(fname, false);
|
||||
if (marker.empty()) {
|
||||
console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
|
||||
return false;
|
||||
}
|
||||
if (inf.fim_sep_token != LLAMA_TOKEN_NULL) {
|
||||
cur_msg += common_token_to_piece(ctx_cli.ctx_server.get_llama_context(), inf.fim_sep_token, true);
|
||||
cur_msg += fname;
|
||||
cur_msg.push_back('\n');
|
||||
} else {
|
||||
cur_msg += "--- File: ";
|
||||
cur_msg += fname;
|
||||
cur_msg += " ---\n";
|
||||
}
|
||||
cur_msg += marker;
|
||||
console::log("Loaded text from '%s'\n", fname.c_str());
|
||||
return true;
|
||||
};
|
||||
|
||||
while (true) {
|
||||
std::string buffer;
|
||||
console::set_display(DISPLAY_TYPE_USER_INPUT);
|
||||
if (params.prompt.empty()) {
|
||||
console::log("\n> ");
|
||||
std::string line;
|
||||
bool another_line = true;
|
||||
do {
|
||||
another_line = console::readline(line, params.multiline_input);
|
||||
buffer += line;
|
||||
} while (another_line);
|
||||
} else {
|
||||
// process input prompt from args
|
||||
for (auto & fname : params.image) {
|
||||
std::string marker = ctx_cli.load_input_file(fname, true);
|
||||
if (marker.empty()) {
|
||||
console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
|
||||
break;
|
||||
}
|
||||
console::log("Loaded media from '%s'\n", fname.c_str());
|
||||
cur_msg += marker;
|
||||
}
|
||||
buffer = params.prompt;
|
||||
if (buffer.size() > 500) {
|
||||
console::log("\n> %s ... (truncated)\n", buffer.substr(0, 500).c_str());
|
||||
} else {
|
||||
console::log("\n> %s\n", buffer.c_str());
|
||||
}
|
||||
params.prompt.clear(); // only use it once
|
||||
}
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
console::log("\n");
|
||||
|
||||
if (should_stop()) {
|
||||
g_is_interrupted.store(false);
|
||||
break;
|
||||
}
|
||||
|
||||
// remove trailing newline
|
||||
if (!buffer.empty() &&buffer.back() == '\n') {
|
||||
buffer.pop_back();
|
||||
}
|
||||
|
||||
// skip empty messages
|
||||
if (buffer.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool add_user_msg = true;
|
||||
|
||||
// process commands
|
||||
if (string_starts_with(buffer, "/exit")) {
|
||||
break;
|
||||
} else if (string_starts_with(buffer, "/regen")) {
|
||||
if (ctx_cli.messages.size() >= 2) {
|
||||
size_t last_idx = ctx_cli.messages.size() - 1;
|
||||
ctx_cli.messages.erase(last_idx);
|
||||
add_user_msg = false;
|
||||
} else {
|
||||
console::error("No message to regenerate.\n");
|
||||
continue;
|
||||
}
|
||||
} else if (string_starts_with(buffer, "/clear")) {
|
||||
ctx_cli.messages.clear();
|
||||
add_system_prompt();
|
||||
|
||||
ctx_cli.input_files.clear();
|
||||
console::log("Chat history cleared.\n");
|
||||
continue;
|
||||
} else if (
|
||||
(string_starts_with(buffer, "/image ") && inf.has_inp_image) ||
|
||||
(string_starts_with(buffer, "/audio ") && inf.has_inp_audio) ||
|
||||
(string_starts_with(buffer, "/video ") && inf.has_inp_video)) {
|
||||
// just in case (bad copy-paste for example), we strip all trailing/leading spaces
|
||||
std::string fname = string_strip(buffer.substr(7));
|
||||
std::string marker = ctx_cli.load_input_file(fname, true);
|
||||
if (marker.empty()) {
|
||||
console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
|
||||
continue;
|
||||
}
|
||||
cur_msg += marker;
|
||||
console::log("Loaded media from '%s'\n", fname.c_str());
|
||||
continue;
|
||||
} else if (string_starts_with(buffer, "/read ")) {
|
||||
std::string fname = string_strip(buffer.substr(6));
|
||||
add_text_file(fname);
|
||||
continue;
|
||||
} else if (string_starts_with(buffer, "/glob ")) {
|
||||
std::error_code ec;
|
||||
size_t count = 0;
|
||||
auto curdir = std::filesystem::current_path();
|
||||
std::string pattern = string_strip(buffer.substr(6));
|
||||
std::filesystem::path rel_path;
|
||||
|
||||
auto startglob = pattern.find_first_of("![*?");
|
||||
if (startglob != std::string::npos && startglob != 0) {
|
||||
auto endpath = pattern.substr(0, startglob).find_last_of('/');
|
||||
if (endpath != std::string::npos) {
|
||||
std::string rel_pattern = pattern.substr(0, endpath);
|
||||
#if !defined(_WIN32)
|
||||
if (string_starts_with(rel_pattern, '~')) {
|
||||
const char * home = std::getenv("HOME");
|
||||
if (home && home[0]) {
|
||||
rel_pattern = home + rel_pattern.substr(1);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
rel_path = rel_pattern;
|
||||
pattern.erase(0, endpath + 1);
|
||||
curdir /= rel_path;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto & entry : std::filesystem::recursive_directory_iterator(curdir,
|
||||
std::filesystem::directory_options::skip_permission_denied, ec)) {
|
||||
if (!entry.is_regular_file()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string rel = std::filesystem::relative(entry.path(), curdir, ec).string();
|
||||
if (ec) {
|
||||
ec.clear();
|
||||
continue;
|
||||
}
|
||||
std::replace(rel.begin(), rel.end(), '\\', '/');
|
||||
|
||||
if (!glob_match(pattern, rel)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!add_text_file((rel_path / rel).string())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (++count >= FILE_GLOB_MAX_RESULTS) {
|
||||
console::error("Maximum number of globbed files allowed (%zu) reached.\n", FILE_GLOB_MAX_RESULTS);
|
||||
break;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
} else {
|
||||
// not a command
|
||||
cur_msg += buffer;
|
||||
}
|
||||
|
||||
// generate response
|
||||
if (add_user_msg) {
|
||||
ctx_cli.messages.push_back({
|
||||
{"role", "user"},
|
||||
{"content", cur_msg}
|
||||
});
|
||||
cur_msg.clear();
|
||||
}
|
||||
result_timings timings;
|
||||
std::string assistant_content = ctx_cli.generate_completion(timings);
|
||||
ctx_cli.messages.push_back({
|
||||
{"role", "assistant"},
|
||||
{"content", assistant_content}
|
||||
});
|
||||
console::log("\n");
|
||||
|
||||
if (params.show_timings) {
|
||||
console::set_display(DISPLAY_TYPE_INFO);
|
||||
console::log("\n");
|
||||
console::log("[ Prompt: %.1f t/s | Generation: %.1f t/s ]\n", timings.prompt_per_second, timings.predicted_per_second);
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
}
|
||||
|
||||
if (params.single_turn) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
|
||||
console::log("\nExiting...\n");
|
||||
ctx_cli.ctx_server.terminate();
|
||||
inference_thread.join();
|
||||
|
||||
// bump the log level to display timings
|
||||
common_log_set_verbosity_thold(LOG_LEVEL_INFO);
|
||||
common_memory_breakdown_print(ctx_cli.ctx_server.get_llama_context());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -42,7 +42,6 @@
|
||||
#define KEY_N_HEAD "clip.%s.attention.head_count"
|
||||
#define KEY_N_HEAD_KV "clip.%s.attention.head_count_kv"
|
||||
#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon"
|
||||
#define KEY_FEATURE_LAYERS "clip.%s.feature_layer"
|
||||
|
||||
// vision-specific
|
||||
#define KEY_VISION_PROJ_TYPE "clip.vision.projector_type" // for models with mixed modalities
|
||||
@@ -55,6 +54,7 @@
|
||||
#define KEY_PATCH_SIZE "clip.vision.patch_size"
|
||||
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
|
||||
#define KEY_IMAGE_STD "clip.vision.image_std"
|
||||
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
|
||||
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
|
||||
#define KEY_PROJ_SAMPLE_QUERY_SIDE "clip.vision.projector.query_side"
|
||||
#define KEY_PROJ_SAMPLE_WINDOW_SIDE "clip.vision.projector.window_side"
|
||||
|
||||
@@ -91,7 +91,7 @@ struct clip_hparams {
|
||||
|
||||
float eps = 1e-6;
|
||||
float rope_theta = 0.0;
|
||||
std::vector<int32_t> feature_layers;
|
||||
std::vector<int32_t> vision_feature_layer;
|
||||
int32_t attn_window_size = 0;
|
||||
int32_t n_wa_pattern = 0;
|
||||
std::unordered_set<int32_t> wa_layer_indexes; // explicit layer indexes that use full attention (for irregular patterns like YoutuVL)
|
||||
@@ -165,8 +165,8 @@ struct clip_hparams {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_feature_layer(int32_t layer) const {
|
||||
return std::find(feature_layers.begin(), feature_layers.end(), layer) != feature_layers.end();
|
||||
bool is_vision_feature_layer(int32_t layer) const {
|
||||
return std::find(vision_feature_layer.begin(), vision_feature_layer.end(), layer) != vision_feature_layer.end();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
+10
-9
@@ -1264,10 +1264,12 @@ struct clip_model_loader {
|
||||
}
|
||||
}
|
||||
|
||||
// Load the vision/audio feature layer indices if they are explicitly provided
|
||||
// Load the vision feature layer indices if they are explicitly provided;
|
||||
// if multiple vision feature layers are present, the values will be concatenated
|
||||
// to form the final visual features.
|
||||
// NOTE: gguf conversions should standardize the values of the vision feature layer to
|
||||
// be non-negative, since we use -1 to mark values as unset here.
|
||||
get_arr_int(string_format(KEY_FEATURE_LAYERS, prefix), hparams.feature_layers, false);
|
||||
get_arr_int(KEY_FEATURE_LAYER, hparams.vision_feature_layer, false);
|
||||
|
||||
// model-specific params
|
||||
switch (model.proj_type) {
|
||||
@@ -1649,7 +1651,6 @@ struct clip_model_loader {
|
||||
get_u32(KEY_A_PROJ_WINDOW_SIZE, hparams.audio_proj_window_size);
|
||||
get_u32(KEY_A_PROJ_DOWNSAMPLE_RATE, hparams.audio_proj_downsample_rate);
|
||||
get_u32(KEY_A_PROJ_HEAD_COUNT, hparams.audio_proj_head_count);
|
||||
// NOTE: feature layers loaded above in common path
|
||||
} break;
|
||||
case PROJECTOR_TYPE_JANUS_PRO:
|
||||
{
|
||||
@@ -1662,11 +1663,11 @@ struct clip_model_loader {
|
||||
hparams.image_resize_algo = RESIZE_ALGO_BICUBIC_PILLOW;
|
||||
hparams.image_resize_pad = PAD_CEIL;
|
||||
|
||||
// NOTE: feature_layers loaded in common path as optional
|
||||
get_arr_int(KEY_FEATURE_LAYER, hparams.vision_feature_layer);
|
||||
get_arr_int(KEY_PROJ_SPATIAL_OFFSETS, hparams.proj_spatial_offsets);
|
||||
if (hparams.feature_layers.size() != hparams.proj_spatial_offsets.size()) {
|
||||
throw std::runtime_error(string_format("%s: feature_layers.size() %d != proj_spatial_offsets.size() %d",
|
||||
hparams.feature_layers.size(), hparams.proj_spatial_offsets.size()));
|
||||
if (hparams.vision_feature_layer.size() != hparams.proj_spatial_offsets.size()) {
|
||||
throw std::runtime_error(string_format("%s: vision_feature_layer.size() %d != proj_spatial_offsets.size() %d",
|
||||
hparams.vision_feature_layer.size(), hparams.proj_spatial_offsets.size()));
|
||||
}
|
||||
|
||||
get_u32(KEY_PROJ_SAMPLE_QUERY_SIDE, hparams.downsample_query_side);
|
||||
@@ -2739,7 +2740,7 @@ struct clip_model_loader {
|
||||
model.image_newline = get_tensor(TN_IMAGE_NEWLINE);
|
||||
|
||||
// Load separate layerwise and spatial projector tensors
|
||||
const auto projector_count = hparams.feature_layers.size();
|
||||
const auto projector_count = hparams.vision_feature_layer.size();
|
||||
model.qf_proj_blocks.resize(projector_count);
|
||||
for (size_t bid = 0; bid < projector_count; ++bid) {
|
||||
auto & b = model.qf_proj_blocks[bid];
|
||||
@@ -4387,7 +4388,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, int n_threads, const clip_image_f32
|
||||
|
||||
// Stage 1b only uses block 0's permutations; future stages
|
||||
// will upload all blocks.
|
||||
for (size_t bid = 0; bid < hparams.feature_layers.size(); ++bid) {
|
||||
for (size_t bid = 0; bid < hparams.vision_feature_layer.size(); ++bid) {
|
||||
const std::string prefix = "g4v_blk" + std::to_string(bid) + "_";
|
||||
upload(prefix + "win_idx", make_win_idx(image_side, window_side));
|
||||
upload(prefix + "qwin_idx", make_win_idx(new_side, query_side));
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
#include "models.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
ggml_cgraph * clip_graph_granite_speech::build() {
|
||||
const int n_frames = img.nx();
|
||||
const int context_size = hparams.audio_chunk_size;
|
||||
@@ -13,10 +11,6 @@ ggml_cgraph * clip_graph_granite_speech::build() {
|
||||
const int padded_len = num_blocks * context_size;
|
||||
const int remainder = n_frames % context_size;
|
||||
|
||||
// Calculate projector input dimension based on feature layers
|
||||
const int proj_input_dim = n_embd * (hparams.feature_layers.size() + 1);
|
||||
const bool use_feature_concat = !hparams.feature_layers.empty();
|
||||
|
||||
ggml_tensor * attn_dists = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, context_size * context_size);
|
||||
ggml_set_name(attn_dists, "attn_dists");
|
||||
ggml_set_input(attn_dists);
|
||||
@@ -37,15 +31,6 @@ ggml_cgraph * clip_graph_granite_speech::build() {
|
||||
cur = ggml_add(ctx0, cur, model.inp_proj_b);
|
||||
cb(cur, "inp_linear", -1);
|
||||
|
||||
// Capture layer 0 if requested (after input_linear)
|
||||
ggml_tensor * concat_result = nullptr;
|
||||
if (use_feature_concat) {
|
||||
if (std::find(hparams.feature_layers.begin(), hparams.feature_layers.end(), 0) != hparams.feature_layers.end()) {
|
||||
concat_result = cur;
|
||||
cb(concat_result, "feature_layer_0", -1);
|
||||
}
|
||||
}
|
||||
|
||||
for (int il = 0; il < n_layer; il++) {
|
||||
const auto & layer = model.layers[il];
|
||||
auto * residual = cur;
|
||||
@@ -183,18 +168,6 @@ ggml_cgraph * clip_graph_granite_speech::build() {
|
||||
NORM_TYPE_NORMAL, eps, il);
|
||||
cb(cur, "layer_out", il);
|
||||
|
||||
// Capture intermediate layer (il + 1) if requested
|
||||
if (use_feature_concat) {
|
||||
if (hparams.is_feature_layer(il + 1)) {
|
||||
if (concat_result == nullptr) {
|
||||
concat_result = cur;
|
||||
} else {
|
||||
concat_result = ggml_concat(ctx0, concat_result, cur, 0);
|
||||
}
|
||||
cb(concat_result, string_format("feature_layer_%d", il + 1).c_str(), il);
|
||||
}
|
||||
}
|
||||
|
||||
// CTC branch
|
||||
if (il + 1 == ctc_layer) {
|
||||
auto * mid = build_mm(model.ctc_out_w, cur);
|
||||
@@ -207,13 +180,6 @@ ggml_cgraph * clip_graph_granite_speech::build() {
|
||||
}
|
||||
}
|
||||
|
||||
// Append final output to concatenated features if using feature concatenation
|
||||
if (use_feature_concat && concat_result != nullptr) {
|
||||
concat_result = ggml_concat(ctx0, concat_result, cur, 0);
|
||||
cb(concat_result, "concat_final", -1);
|
||||
cur = concat_result;
|
||||
}
|
||||
|
||||
cb(cur, "encoder_out", -1);
|
||||
|
||||
// QFormer projector
|
||||
@@ -231,7 +197,7 @@ ggml_cgraph * clip_graph_granite_speech::build() {
|
||||
cur = ggml_pad(ctx0, cur, 0, padded_proj - n_frames, 0, 0);
|
||||
}
|
||||
|
||||
ggml_tensor * enc_windows = ggml_reshape_3d(ctx0, cur, proj_input_dim, window_size, nblocks_proj);
|
||||
ggml_tensor * enc_windows = ggml_reshape_3d(ctx0, cur, n_embd, window_size, nblocks_proj);
|
||||
|
||||
ggml_tensor * queries = build_norm(model.qf_proj_blocks[0].qf_proj_query,
|
||||
model.qf_proj_blocks[0].qf_proj_norm_w, model.qf_proj_blocks[0].qf_proj_norm_b,
|
||||
|
||||
@@ -304,14 +304,14 @@ ggml_cgraph * clip_graph_granite4_vision::build() {
|
||||
}
|
||||
|
||||
// --- Stage 1b/1c: WindowQFormer blocks ---
|
||||
const int projector_count = hparams.feature_layers.size();
|
||||
const int projector_count = hparams.vision_feature_layer.size();
|
||||
const float qformer_eps = 1e-12f;
|
||||
|
||||
ggml_tensor * mmproj = nullptr;
|
||||
for (int bid = 0; bid < projector_count; ++bid) {
|
||||
const auto & blk = model.qf_proj_blocks[bid];
|
||||
|
||||
int vlayer = hparams.feature_layers[bid];
|
||||
int vlayer = hparams.vision_feature_layer[bid];
|
||||
GGML_ASSERT(vlayer >= 0 && vlayer < n_layer);
|
||||
ggml_tensor * h = layer_outs[vlayer];
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ ggml_cgraph * clip_graph_llava::build() {
|
||||
|
||||
// If we set explicit vision feature layers, only go up to the deepest one
|
||||
// NOTE: only used by granite-vision models for now
|
||||
for (const auto & feature_layer : hparams.feature_layers) {
|
||||
for (const auto & feature_layer : hparams.vision_feature_layer) {
|
||||
if (feature_layer > deepest_feature_layer) {
|
||||
deepest_feature_layer = feature_layer;
|
||||
}
|
||||
@@ -59,7 +59,7 @@ ggml_cgraph * clip_graph_llava::build() {
|
||||
|
||||
// If this is an embedding feature layer, save the output.
|
||||
// NOTE: 0 index here refers to the input to the encoder.
|
||||
if (hparams.is_feature_layer(il)) {
|
||||
if (hparams.is_vision_feature_layer(il)) {
|
||||
embedding_stack.push_back(cur);
|
||||
}
|
||||
|
||||
@@ -134,7 +134,7 @@ ggml_cgraph * clip_graph_llava::build() {
|
||||
// process vision feature layers (used by granite)
|
||||
{
|
||||
// final layer is a vision feature layer
|
||||
if (hparams.is_feature_layer(max_feature_layer)) {
|
||||
if (hparams.is_vision_feature_layer(max_feature_layer)) {
|
||||
embedding_stack.push_back(inpL);
|
||||
}
|
||||
|
||||
|
||||
@@ -518,14 +518,6 @@ size_t server_tokens::get_common_prefix(const server_tokens & b) const {
|
||||
return max_idx; // all tokens are equal
|
||||
}
|
||||
|
||||
common_chat_msg_spans server_tokens::find_message_spans(const common_chat_msg_delimiters & delims) const {
|
||||
std::map<size_t, size_t> skips;
|
||||
for (const auto & it : map_idx_to_media) {
|
||||
skips[it.first] = mtmd_input_chunk_get_n_tokens(it.second.get());
|
||||
}
|
||||
return delims.split(tokens, skips);
|
||||
}
|
||||
|
||||
bool server_tokens::validate(const struct llama_context * ctx) const {
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
@@ -1112,7 +1104,15 @@ json oaicompat_chat_params_parse(
|
||||
llama_params["chat_parser"] = chat_params.parser;
|
||||
}
|
||||
|
||||
llama_params["message_delimiters"] = chat_params.message_delimiters.to_json();
|
||||
llama_params["message_spans"] = json::array();
|
||||
|
||||
for (const auto & span : chat_params.message_spans) {
|
||||
llama_params["message_spans"].push_back({
|
||||
{ "role", span.role },
|
||||
{ "pos", span.pos },
|
||||
{ "len", span.len },
|
||||
});
|
||||
}
|
||||
|
||||
// Reasoning budget: pass parameters through to sampling layer
|
||||
{
|
||||
@@ -1583,3 +1583,82 @@ server_tokens format_prompt_rerank(
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
//
|
||||
// threadpool
|
||||
//
|
||||
|
||||
server_threadpool::~server_threadpool() {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
stop = true;
|
||||
}
|
||||
cv.notify_all();
|
||||
for (auto & t : threads) t.join();
|
||||
}
|
||||
|
||||
void server_threadpool::init(int n) {
|
||||
// the caller (main thread) participates as a worker, so spawn n-1 threads
|
||||
const int n_workers = std::max(1, n) - 1;
|
||||
for (int i = 0; i < n_workers; i++) {
|
||||
threads.emplace_back([this]() { run_worker(); });
|
||||
}
|
||||
}
|
||||
|
||||
void server_threadpool::run_worker() {
|
||||
while (true) {
|
||||
std::function<void()> task;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mtx);
|
||||
cv.wait(lock, [this]() { return stop || !tasks.empty(); });
|
||||
if (stop && tasks.empty()) return;
|
||||
task = std::move(tasks.front());
|
||||
tasks.pop();
|
||||
}
|
||||
task();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
pending--;
|
||||
}
|
||||
cv_done.notify_all();
|
||||
}
|
||||
}
|
||||
|
||||
void server_threadpool::enqueue(std::function<void()> fn) {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
GGML_ASSERT(!stop);
|
||||
tasks.push(std::move(fn));
|
||||
pending++;
|
||||
}
|
||||
cv.notify_one();
|
||||
}
|
||||
|
||||
void server_threadpool::wait_all() {
|
||||
// the calling thread helps drain the queue until no tasks remain pending
|
||||
while (true) {
|
||||
std::function<void()> task;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
if (pending == 0) {
|
||||
return;
|
||||
}
|
||||
if (!tasks.empty()) {
|
||||
task = std::move(tasks.front());
|
||||
tasks.pop();
|
||||
}
|
||||
}
|
||||
if (task) {
|
||||
task();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
pending--;
|
||||
}
|
||||
cv_done.notify_all();
|
||||
} else {
|
||||
// no task available right now, but some are still pending (being run by workers)
|
||||
std::unique_lock<std::mutex> lock(mtx);
|
||||
cv_done.wait(lock, [this]() { return pending == 0 || !tasks.empty(); });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,11 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cinttypes>
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
#include <queue>
|
||||
#include <functional>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
@@ -218,9 +223,6 @@ public:
|
||||
|
||||
size_t get_common_prefix(const server_tokens & b) const;
|
||||
|
||||
// split the tokens into message spans, skipping over media chunks
|
||||
common_chat_msg_spans find_message_spans(const common_chat_msg_delimiters & delims) const;
|
||||
|
||||
// make sure all text tokens are within the vocab range
|
||||
bool validate(const struct llama_context * ctx) const;
|
||||
|
||||
@@ -373,3 +375,39 @@ server_tokens format_prompt_rerank(
|
||||
mtmd_context * mctx,
|
||||
const std::string & query,
|
||||
const std::string & doc);
|
||||
|
||||
//
|
||||
// threadpool utils
|
||||
// to be used for multi-threaded sampling
|
||||
//
|
||||
|
||||
// the main thread participates as one of the pool's workers, so init(n)
|
||||
// only spawns n-1 background threads (the caller is the nth)
|
||||
struct server_threadpool {
|
||||
std::vector<std::thread> threads;
|
||||
std::queue<std::function<void()>> tasks;
|
||||
std::mutex mtx;
|
||||
std::condition_variable cv;
|
||||
std::condition_variable cv_done;
|
||||
int pending = 0;
|
||||
bool stop = false;
|
||||
|
||||
~server_threadpool();
|
||||
void init(int n);
|
||||
|
||||
template<typename T>
|
||||
void run_all(std::vector<T> & tasks, std::function<void(T&)> handler) {
|
||||
for (auto & item : tasks) {
|
||||
enqueue([&handler, &item]() {
|
||||
handler(item);
|
||||
});
|
||||
}
|
||||
// the calling thread runs tasks too, until all are done
|
||||
wait_all();
|
||||
}
|
||||
|
||||
private:
|
||||
void enqueue(std::function<void()> fn);
|
||||
void wait_all();
|
||||
void run_worker();
|
||||
};
|
||||
|
||||
+142
-25
@@ -866,6 +866,9 @@ public:
|
||||
// note: chat_params must not be refreshed upon existing sleeping state
|
||||
server_chat_params chat_params;
|
||||
|
||||
// threadpool for parallel sampling
|
||||
server_threadpool threadpool;
|
||||
|
||||
server_state_callback_t callback_state = [](server_state, json) -> void {};
|
||||
|
||||
server_context_impl() {
|
||||
@@ -1457,6 +1460,16 @@ private:
|
||||
|
||||
metrics.init();
|
||||
|
||||
// initialize threadpool
|
||||
{
|
||||
int threadpool_size = params_base.sampling_n_threads;
|
||||
if (threadpool_size <= 0) {
|
||||
threadpool_size = params_base.cpuparams.n_threads;
|
||||
}
|
||||
SRV_DBG("%s: initializing threadpool, size = %d\n", __func__, threadpool_size);
|
||||
threadpool.init(threadpool_size);
|
||||
}
|
||||
|
||||
if (params_base.cache_idle_slots) {
|
||||
if (params_base.cache_ram_mib == 0) {
|
||||
SRV_WRN("%s", "--cache-idle-slots requires --cache-ram, disabling\n");
|
||||
@@ -3436,8 +3449,8 @@ private:
|
||||
has_mtmd = true;
|
||||
}
|
||||
|
||||
const auto & spans = slot.task->params.message_spans;
|
||||
const auto last_user_pos = spans.last_user_message_pos();
|
||||
const int32_t n_before_user = slot.task->params.n_before_user;
|
||||
const bool n_before_user_known = n_before_user > 0;
|
||||
|
||||
// add prompt tokens for processing in the current batch
|
||||
while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.size() < n_batch) {
|
||||
@@ -3466,8 +3479,10 @@ private:
|
||||
|
||||
slot.n_prompt_tokens_processed++;
|
||||
|
||||
// stop the prompt batch exactly before a user message
|
||||
if (spans.is_user_start(slot.prompt.n_tokens())) {
|
||||
// stop the prompt batch exactly before the latest user input, so a checkpoint
|
||||
// can be created after the previous messages
|
||||
if (n_before_user_known &&
|
||||
slot.prompt.n_tokens() == n_before_user) {
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -3496,13 +3511,8 @@ private:
|
||||
// the number of tokens added to the batch for the current slot
|
||||
const auto n_tokens_cur = batch.size() - n_tokens_prev;
|
||||
|
||||
const auto n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur;
|
||||
|
||||
const bool near_prompt_end = slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch;
|
||||
|
||||
const bool is_user_start = spans.is_user_start(n_tokens_start);
|
||||
const bool is_last_user_message = n_tokens_start == last_user_pos;
|
||||
|
||||
// entire prompt has been processed
|
||||
if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
|
||||
slot.state = SLOT_STATE_DONE_PROMPT;
|
||||
@@ -3517,9 +3527,8 @@ private:
|
||||
|
||||
slot.init_sampler();
|
||||
} else {
|
||||
// skip ordinary mid-prompt checkpoints, unless the batch starts a user
|
||||
// message or we are near the end of the prompt
|
||||
if (!is_user_start && !near_prompt_end) {
|
||||
// skip ordinary mid-prompt checkpoints
|
||||
if (!n_before_user_known && !near_prompt_end) {
|
||||
do_checkpoint = false;
|
||||
}
|
||||
}
|
||||
@@ -3527,6 +3536,29 @@ private:
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id);
|
||||
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id);
|
||||
|
||||
// checkpoints are created before the current batch is decoded, so
|
||||
// their token position is the batch start rather than the prompt end
|
||||
const int32_t n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur;
|
||||
|
||||
{
|
||||
const bool is_on_user =
|
||||
n_before_user_known &&
|
||||
n_tokens_start == n_before_user;
|
||||
|
||||
const bool is_after_user =
|
||||
n_before_user_known &&
|
||||
n_tokens_start > n_before_user;
|
||||
|
||||
const bool is_allowed =
|
||||
!n_before_user_known ||
|
||||
is_on_user ||
|
||||
(is_after_user && near_prompt_end);
|
||||
|
||||
if (do_checkpoint && !is_allowed) {
|
||||
do_checkpoint = false;
|
||||
}
|
||||
}
|
||||
|
||||
// nothing to checkpoint yet
|
||||
// TODO: is this check needed?
|
||||
if (do_checkpoint && pos_min < 0) {
|
||||
@@ -3536,8 +3568,8 @@ private:
|
||||
// do not checkpoint after mtmd chunks
|
||||
do_checkpoint = do_checkpoint && !has_mtmd;
|
||||
|
||||
// no need to create checkpoints that are too close together, unless it's the last user message
|
||||
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || is_last_user_message || n_tokens_start > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step);
|
||||
// no need to create checkpoints that are too close together
|
||||
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || n_tokens_start > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step);
|
||||
SLT_DBG(slot, "main/do_checkpoint = %s, pos_min = %d, pos_max = %d\n", do_checkpoint ? "yes" : "no", pos_min, pos_max);
|
||||
|
||||
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
|
||||
@@ -3680,6 +3712,12 @@ private:
|
||||
return true;
|
||||
}
|
||||
|
||||
struct sampling_task {
|
||||
server_slot * slot = nullptr;
|
||||
int32_t tok_idx = 0;
|
||||
llama_token sampled_id = LLAMA_TOKEN_NULL; // result
|
||||
};
|
||||
|
||||
void post_decode(int32_t n_batch_tokens, int32_t off, llama_batch & batch_view) {
|
||||
// for checking if a given batch index is inside batch_view
|
||||
auto is_inside_view = [&](int32_t idx) {
|
||||
@@ -3701,7 +3739,13 @@ private:
|
||||
slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end();
|
||||
};
|
||||
|
||||
std::vector<sampling_task> smpl_tasks;
|
||||
smpl_tasks.resize(slots.size());
|
||||
bool need_sampling = false;
|
||||
|
||||
iterate(slots, [&](server_slot & slot) {
|
||||
auto & smpl_task = smpl_tasks[slot.id];
|
||||
|
||||
// optionally send prompt processing progress
|
||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
|
||||
if (slot.task->params.stream && slot.task->params.return_progress) {
|
||||
@@ -3746,15 +3790,35 @@ private:
|
||||
return; // sample using speculative decoding
|
||||
}
|
||||
|
||||
// shifted according to the current sub-batch
|
||||
const int tok_idx = slot.i_batch - off;
|
||||
// otherwise, we must sample the next token
|
||||
// also shift batch idx according to the current sub-batch
|
||||
smpl_task.slot = &slot;
|
||||
smpl_task.tok_idx = slot.i_batch - off;
|
||||
need_sampling = true;
|
||||
});
|
||||
|
||||
llama_token id;
|
||||
{
|
||||
scoped_timer timer(t_sampl, n_sampl);
|
||||
id = common_sampler_sample(slot.smpl.get(), slot.ctx_tgt, tok_idx);
|
||||
// run multiple sampling tasks in parallel
|
||||
GGML_ASSERT(smpl_tasks.size() == slots.size());
|
||||
if (need_sampling) {
|
||||
llama_synchronize(ctx_tgt);
|
||||
threadpool.run_all<sampling_task>(smpl_tasks, [](sampling_task & task) {
|
||||
if (task.slot) {
|
||||
task.sampled_id = common_sampler_sample(task.slot->smpl.get(),
|
||||
task.slot->ctx_tgt, task.tok_idx);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
iterate(slots, [&](server_slot & slot) {
|
||||
auto & smpl_task = smpl_tasks[slot.id];
|
||||
|
||||
if (!smpl_task.slot) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto tok_idx = smpl_task.tok_idx;
|
||||
auto id = smpl_task.sampled_id;
|
||||
|
||||
slot.i_batch = -1;
|
||||
|
||||
common_sampler_accept(slot.smpl.get(), id, true);
|
||||
@@ -4036,6 +4100,54 @@ void server_context::set_state_callback(server_state_callback_t callback) {
|
||||
});
|
||||
}
|
||||
|
||||
// compute the number of tokens before the last user message in the prompt
|
||||
static int32_t prompt_get_n_before_user(
|
||||
const json & message_spans,
|
||||
const std::string & prompt,
|
||||
const std::vector<raw_buffer> & files,
|
||||
const llama_vocab * vocab,
|
||||
mtmd_context * mctx) {
|
||||
int32_t result = -1;
|
||||
int32_t byte_pos = -1;
|
||||
|
||||
for (const auto & span : message_spans) {
|
||||
const std::string role = json_value(span, "role", std::string());
|
||||
|
||||
if (role == "user") {
|
||||
byte_pos = json_value(span, "pos", -1);
|
||||
}
|
||||
}
|
||||
|
||||
if (byte_pos >= 0) {
|
||||
GGML_ASSERT((size_t) byte_pos <= prompt.size());
|
||||
|
||||
const std::string prefix = prompt.substr(0, (size_t) byte_pos);
|
||||
|
||||
const std::string marker = get_media_marker();
|
||||
size_t n_prefix_media = 0;
|
||||
for (size_t pos = 0; (pos = prefix.find(marker, pos)) != std::string::npos; pos += marker.size()) {
|
||||
n_prefix_media++;
|
||||
}
|
||||
|
||||
GGML_ASSERT(n_prefix_media <= files.size());
|
||||
|
||||
if (mctx != nullptr && n_prefix_media > 0) {
|
||||
// TODO: this makes a copy - avoid it
|
||||
std::vector<raw_buffer> prefix_files(files.begin(), files.begin() + n_prefix_media);
|
||||
|
||||
result = (int32_t) process_mtmd_prompt(mctx, prefix, prefix_files).size();
|
||||
} else {
|
||||
result = (int32_t) tokenize_input_prompts(vocab, nullptr, prefix, true, true)[0].size();
|
||||
}
|
||||
|
||||
SRV_TRC("message_spans: last user message: byte_pos=%d, media=%zu, n_before_user=%d\n",
|
||||
byte_pos, n_prefix_media, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// server_routes
|
||||
//
|
||||
@@ -4083,10 +4195,6 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
|
||||
|
||||
// tasks.reserve(inputs.size()); // TODO: this is inaccurate due to child tasks
|
||||
|
||||
// message delimiters for checkpointing
|
||||
auto delimiters = common_chat_msg_delimiters_parse(json_value(data, "message_delimiters", json::array()));
|
||||
delimiters.tokenize(ctx_server.vocab);
|
||||
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
server_task task = server_task(type);
|
||||
|
||||
@@ -4100,7 +4208,16 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
|
||||
meta->logit_bias_eog,
|
||||
data);
|
||||
|
||||
task.params.message_spans = task.tokens.find_message_spans(delimiters);
|
||||
const auto message_spans = json_value(data, "message_spans", json::array());
|
||||
if (prompt.is_string() && message_spans.is_array()) {
|
||||
task.params.n_before_user =
|
||||
prompt_get_n_before_user(
|
||||
message_spans,
|
||||
prompt.get<std::string>(),
|
||||
files,
|
||||
ctx_server.vocab,
|
||||
ctx_server.mctx);
|
||||
}
|
||||
|
||||
task.id_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
#include "build-info.h"
|
||||
#include "preset.h"
|
||||
#include "download.h"
|
||||
#include "http.h"
|
||||
|
||||
#include <cpp-httplib/httplib.h> // TODO: remove this once we use HTTP client from download.h
|
||||
#include <sheredom/subprocess.h>
|
||||
@@ -26,7 +25,14 @@
|
||||
#include <sstream>
|
||||
#include <cstring>
|
||||
|
||||
#ifndef _WIN32
|
||||
#ifdef _WIN32
|
||||
#include <winsock2.h>
|
||||
#include <windows.h>
|
||||
#else
|
||||
#include <sys/socket.h>
|
||||
#include <netinet/in.h>
|
||||
#include <arpa/inet.h>
|
||||
#include <unistd.h>
|
||||
extern char **environ;
|
||||
#endif
|
||||
|
||||
@@ -218,7 +224,7 @@ void server_model_meta::update_caps() {
|
||||
});
|
||||
params.offline = true;
|
||||
// params.skip_download = true; // TODO: ideally, we should validate the model here, but it takes too much time
|
||||
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, {});
|
||||
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER);
|
||||
if (params.mmproj.path.empty()) {
|
||||
multimodal = { false, false };
|
||||
} else {
|
||||
@@ -698,6 +704,66 @@ std::optional<server_model_meta> server_models::get_meta(const std::string & nam
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static int get_free_port() {
|
||||
#ifdef _WIN32
|
||||
WSADATA wsaData;
|
||||
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
|
||||
return -1;
|
||||
}
|
||||
typedef SOCKET native_socket_t;
|
||||
#define INVALID_SOCKET_VAL INVALID_SOCKET
|
||||
#define CLOSE_SOCKET(s) closesocket(s)
|
||||
#else
|
||||
typedef int native_socket_t;
|
||||
#define INVALID_SOCKET_VAL -1
|
||||
#define CLOSE_SOCKET(s) close(s)
|
||||
#endif
|
||||
|
||||
native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sock == INVALID_SOCKET_VAL) {
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
return -1;
|
||||
}
|
||||
|
||||
struct sockaddr_in serv_addr;
|
||||
std::memset(&serv_addr, 0, sizeof(serv_addr));
|
||||
serv_addr.sin_family = AF_INET;
|
||||
serv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
|
||||
serv_addr.sin_port = htons(0);
|
||||
|
||||
if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) {
|
||||
CLOSE_SOCKET(sock);
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
return -1;
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
int namelen = sizeof(serv_addr);
|
||||
#else
|
||||
socklen_t namelen = sizeof(serv_addr);
|
||||
#endif
|
||||
if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) {
|
||||
CLOSE_SOCKET(sock);
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
return -1;
|
||||
}
|
||||
|
||||
int port = ntohs(serv_addr.sin_port);
|
||||
|
||||
CLOSE_SOCKET(sock);
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
|
||||
return port;
|
||||
}
|
||||
|
||||
// helper to convert vector<string> to char **
|
||||
// pointers are only valid as long as the original vector is valid
|
||||
static std::vector<char *> to_char_ptr_array(const std::vector<std::string> & vec) {
|
||||
@@ -801,7 +867,7 @@ void server_models::load(const std::string & name, const load_options & opts) {
|
||||
// prepare new instance info
|
||||
instance_t inst;
|
||||
inst.meta = meta;
|
||||
inst.meta.port = common_http_get_free_port();
|
||||
inst.meta.port = get_free_port();
|
||||
inst.meta.status = SERVER_MODEL_STATUS_LOADING;
|
||||
inst.meta.loaded_info = json{};
|
||||
inst.meta.last_used = ggml_time_ms();
|
||||
@@ -1327,9 +1393,7 @@ struct server_download_state : public common_download_callback {
|
||||
|
||||
bool run(common_params & params) {
|
||||
try {
|
||||
common_params_handle_models_params p;
|
||||
p.callback = this;
|
||||
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, p);
|
||||
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, this);
|
||||
is_ok = true;
|
||||
} catch (const std::exception & e) {
|
||||
auto model_name = params.model.get_name();
|
||||
|
||||
@@ -591,11 +591,10 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp() {
|
||||
|
||||
for (const common_chat_tool_call & tool_call : oaicompat_msg.tool_calls) {
|
||||
output.push_back(json {
|
||||
{"id", "fc_" + tool_call.id},
|
||||
{"type", "function_call"},
|
||||
{"status", "completed"},
|
||||
{"arguments", tool_call.arguments},
|
||||
{"call_id", "call_" + tool_call.id},
|
||||
{"call_id", "fc_" + tool_call.id},
|
||||
{"name", tool_call.name},
|
||||
});
|
||||
}
|
||||
@@ -691,11 +690,10 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp_stream() {
|
||||
|
||||
for (const common_chat_tool_call & tool_call : oaicompat_msg.tool_calls) {
|
||||
const json output_item = {
|
||||
{"id", "fc_" + tool_call.id},
|
||||
{"type", "function_call"},
|
||||
{"status", "completed"},
|
||||
{"arguments", tool_call.arguments},
|
||||
{"call_id", "call_" + tool_call.id},
|
||||
{"call_id", "fc_" + tool_call.id},
|
||||
{"name", tool_call.name}
|
||||
};
|
||||
server_sent_events.push_back(json {
|
||||
@@ -1279,9 +1277,8 @@ json server_task_result_cmpl_partial::to_json_oaicompat_resp() {
|
||||
{"data", json {
|
||||
{"type", "response.output_item.added"},
|
||||
{"item", json {
|
||||
{"id", "fc_" + diff.tool_call_delta.id},
|
||||
{"arguments", ""},
|
||||
{"call_id", "call_" + diff.tool_call_delta.id},
|
||||
{"call_id", "fc_" + diff.tool_call_delta.id},
|
||||
{"name", diff.tool_call_delta.name},
|
||||
{"type", "function_call"},
|
||||
{"status", "in_progress"},
|
||||
|
||||
@@ -62,6 +62,9 @@ struct task_params {
|
||||
|
||||
int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled)
|
||||
|
||||
// number of prompt tokens before the latest user message
|
||||
int32_t n_before_user = -1;
|
||||
|
||||
int64_t t_max_prompt_ms = -1; // TODO: implement
|
||||
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
||||
|
||||
@@ -89,9 +92,6 @@ struct task_params {
|
||||
// per-request parameters for chat parsing
|
||||
common_chat_parser_params chat_parser_params;
|
||||
|
||||
// message spans for checkpointing
|
||||
common_chat_msg_spans message_spans;
|
||||
|
||||
// Embeddings
|
||||
int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
|
||||
|
||||
|
||||
+16
-46
@@ -35,19 +35,6 @@ static inline void signal_handler(int signal) {
|
||||
shutdown_handler(signal);
|
||||
}
|
||||
|
||||
// satisfies -Wmissing-declarations (used by llama command)
|
||||
int llama_server(int argc, char ** argv);
|
||||
|
||||
// to be used via CLI (argc / argv are used by router mode only)
|
||||
int llama_server(common_params & params, int argc, char ** argv);
|
||||
void llama_server_terminate();
|
||||
void llama_server_terminate() {
|
||||
if (shutdown_handler) {
|
||||
shutdown_handler(0);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// wrapper function that handles exceptions and logs errors
|
||||
// this is to make sure handler_t never throws exceptions; instead, it returns an error response
|
||||
static server_http_context::handler_t ex_wrapper(server_http_context::handler_t func) {
|
||||
@@ -84,6 +71,9 @@ static server_http_context::handler_t ex_wrapper(server_http_context::handler_t
|
||||
};
|
||||
}
|
||||
|
||||
// satisfies -Wmissing-declarations
|
||||
int llama_server(int argc, char ** argv);
|
||||
|
||||
int llama_server(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
@@ -99,23 +89,6 @@ int llama_server(int argc, char ** argv) {
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
return llama_server(params, argc, argv);
|
||||
}
|
||||
|
||||
int llama_server(common_params & params, int argc, char ** argv) {
|
||||
bool is_run_by_cli = (argv == nullptr);
|
||||
|
||||
// note: router mode also accepts -hf remote-preset, so we need to check that first
|
||||
if (!is_run_by_cli && !params.model.hf_repo.empty()) {
|
||||
try {
|
||||
common_params_handle_models_params handle_params;
|
||||
handle_params.preset_only = true;
|
||||
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, handle_params);
|
||||
} catch (const std::exception & e) {
|
||||
// ignored for now
|
||||
}
|
||||
}
|
||||
|
||||
// router server never loads a model and must not touch the GPU
|
||||
const bool is_router_server = params.model.path.empty()
|
||||
&& params.model.hf_repo.empty();
|
||||
@@ -288,10 +261,9 @@ int llama_server(common_params & params, int argc, char ** argv) {
|
||||
|
||||
if (child.is_child() && child.get_mode() == SERVER_CHILD_MODE_DOWNLOAD) {
|
||||
return child.run_download(params);
|
||||
} else if (!is_router_server && !is_run_by_cli) {
|
||||
} else if (!is_router_server) {
|
||||
// single-model mode (NOT spawned by router)
|
||||
// if this is invoked by CLI, model downloading should be already handled
|
||||
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, {});
|
||||
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER);
|
||||
}
|
||||
|
||||
//
|
||||
@@ -373,22 +345,20 @@ int llama_server(common_params & params, int argc, char ** argv) {
|
||||
};
|
||||
}
|
||||
|
||||
// register signal handler if not running by CLI
|
||||
if (!is_run_by_cli) {
|
||||
// TODO: refactor in common/console
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
struct sigaction sigint_action;
|
||||
sigint_action.sa_handler = signal_handler;
|
||||
sigemptyset (&sigint_action.sa_mask);
|
||||
sigint_action.sa_flags = 0;
|
||||
sigaction(SIGINT, &sigint_action, NULL);
|
||||
sigaction(SIGTERM, &sigint_action, NULL);
|
||||
struct sigaction sigint_action;
|
||||
sigint_action.sa_handler = signal_handler;
|
||||
sigemptyset (&sigint_action.sa_mask);
|
||||
sigint_action.sa_flags = 0;
|
||||
sigaction(SIGINT, &sigint_action, NULL);
|
||||
sigaction(SIGTERM, &sigint_action, NULL);
|
||||
#elif defined (_WIN32)
|
||||
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
|
||||
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
|
||||
};
|
||||
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
|
||||
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
|
||||
};
|
||||
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||
#endif
|
||||
}
|
||||
|
||||
if (is_router_server) {
|
||||
SRV_INF("router server is listening on %s\n", ctx_http.listening_address.c_str());
|
||||
|
||||
@@ -256,25 +256,6 @@ def test_router_reload_models():
|
||||
os.remove(preset_path)
|
||||
|
||||
|
||||
def test_router_remote_preset():
|
||||
global server
|
||||
server.model_hf_repo = "ggml-org/test-preset-ci"
|
||||
server.model_hf_file = None
|
||||
server.offline = False
|
||||
server.start()
|
||||
|
||||
# Should see preset models in GET /models
|
||||
res = server.make_request("GET", "/models")
|
||||
assert res.status_code == 200
|
||||
ids = {item["id"] for item in res.body.get("data", [])}
|
||||
assert "tinygemma3-preset" in ids
|
||||
assert "stories260K-test" in ids
|
||||
|
||||
# Should be able to load a preset model
|
||||
model_id = "tinygemma3-preset"
|
||||
_load_model_and_wait(model_id)
|
||||
|
||||
|
||||
MODEL_DOWNLOAD_ID = "ggml-org/test-model-router-download:F16"
|
||||
MODEL_DOWNLOAD_TIMEOUT = 30
|
||||
|
||||
|
||||
@@ -545,8 +545,7 @@ class ModelsStore {
|
||||
* 1. Model from active conversation's last assistant response (if loaded)
|
||||
* 2. Model from active conversation's last assistant response (if not loaded)
|
||||
* 3. First loaded model (not from active conversation)
|
||||
* 4. A favorite model
|
||||
* 5. First available model
|
||||
* 4. First available model
|
||||
*/
|
||||
async ensureFirstModelSelected(): Promise<void> {
|
||||
if (this.selectedModelName) return;
|
||||
@@ -575,13 +574,6 @@ class ModelsStore {
|
||||
return;
|
||||
}
|
||||
|
||||
// Try loading a favorite model
|
||||
const favorite = this.favoriteModelIds.values().next()?.value
|
||||
if (favorite) {
|
||||
await this.selectModelById(favorite);
|
||||
return;
|
||||
}
|
||||
|
||||
// Fall back to the first available model
|
||||
await this.selectModelById(availableModels[0].id);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user