Compare commits

...

8 Commits

Author SHA1 Message Date
Xuan-Son Nguyen 6ee0f65793 server: refactor/generalize input file schema (#24299)
* server: refactor/generalize input file schema

* wire up input_video, accept raw base64

* nits

* nits (2)

* fix windows
2026-06-22 16:42:47 +02:00
Pascal 099b579acb ui: model status and load progress via /models/sse feed (#24878)
* ui: model status and load progress via /models/sse feed

* ui: centralize SSE wire-format delimiters into shared constants for the chat and /models/sse parsers

* ui: type /models/sse event names as a ServerModelsSseEventType enum

Address review from allozaur
2026-06-22 15:55:30 +02:00
Neo Zhang f8cc15f163 [SYCL] support bf16 on bin_bcast OP and unary OPs (#24838)
* support bf16 on bin_bcast OP and unary OPs

* support the older Intel compiler than 2026.0
2026-06-22 14:09:02 +03:00
Tim Neumann 37957e8531 sampling : remove unconditional softmax+sort in top-n-sigma sampler (#22645) 2026-06-22 14:08:32 +03:00
Pascal d0f9d2e5ac server: fix edit_file crash on append at end of file (line_start -1) (#24893)
line_start -1 normalized to n+1, so append inserted at lines.begin() + n + 1,
one past end() -> heap-buffer-overflow in vector::_M_range_insert.

Normalize -1 to n (insert at end()), restrict -1 to append mode and reject it
for replace/delete instead of silently clobbering the last line. Parenthesize
the insert offset so empty-file append computes the position as int first,
avoiding a transient begin() - 1 on a null vector data pointer.
2026-06-22 10:55:28 +02:00
aafsmarak 0ef6f06d55 docs/android.md: Add dependency libandroid-spawn for building in termux (#21812)
Fixes https://github.com/ggml-org/llama.cpp/issues/18615
2026-06-22 05:48:31 +02:00
Aldehir Rojas 52b3df0023 common/peg : implement ac parser for stricter grammar generation (#24869)
* common/peg : implement ac parser

* cont : extract functions

* cont : tidy up

* cont : remove a test

* cont : move ac() def
2026-06-21 16:20:58 -05:00
Xuan-Son Nguyen 7c082bc417 server: fix report progress for loading spec models, add "stages" list (#24870)
* server: fix report progress for loading spec models, add "stages" list

* improve

* nits

* nits 2
2026-06-21 17:36:52 +02:00
30 changed files with 927 additions and 213 deletions
+5 -4
View File
@@ -395,10 +395,11 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
arguments.name_suffix) +
arguments.value_prefix +
(schema_info.resolves_to_string(param_schema) ?
p.tool_arg_string_value(until_suffix) :
p.tool_arg_json_value(p.schema(
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false))) +
p.tool_arg_close(p.literal(arguments.value_suffix)));
p.ac(p.tool_arg_string_value(until_suffix) +
p.tool_arg_close(p.literal(arguments.value_suffix)), arguments.value_suffix) :
(p.tool_arg_json_value(p.schema(
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false)) +
p.tool_arg_close(p.literal(arguments.value_suffix)))));
auto named_arg = p.rule("tool-" + name + "-arg-" + param_name, arg);
if (is_required) {
+102 -29
View File
@@ -921,6 +921,10 @@ struct parser_executor {
common_peg_parse_result operator()(const common_peg_gbnf_parser & p) {
return arena.parse(p.child, ctx, start_pos);
}
common_peg_parse_result operator()(const common_peg_ac_parser & p) {
return arena.parse(p.child, ctx, start_pos);
}
};
common_peg_parse_result common_peg_arena::parse(common_peg_parse_context & ctx, size_t start) const {
@@ -989,7 +993,8 @@ void common_peg_arena::resolve_refs() {
std::is_same_v<T, common_peg_not_parser> ||
std::is_same_v<T, common_peg_tag_parser> ||
std::is_same_v<T, common_peg_atomic_parser> ||
std::is_same_v<T, common_peg_gbnf_parser>) {
std::is_same_v<T, common_peg_gbnf_parser> ||
std::is_same_v<T, common_peg_ac_parser>) {
p.child = resolve_ref(p.child);
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
p.child = resolve_ref(p.child);
@@ -1070,6 +1075,8 @@ std::string common_peg_arena::dump_impl(common_peg_parser_id
return "Atomic(" + dump_impl(p.child, visited) + ")";
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
return "Gbnf(" + p.grammar + ", " + dump_impl(p.child, visited) + ")";
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
return "Ac(" + string_join(p.delimiters, " | ") + ", " + dump_impl(p.child, visited) + ")";
} else if constexpr (std::is_same_v<T, common_peg_any_parser>) {
return "Any";
} else if constexpr (std::is_same_v<T, common_peg_space_parser>) {
@@ -1479,6 +1486,13 @@ common_peg_parser common_peg_parser_builder::json_member(const std::string & key
});
}
common_peg_parser common_peg_parser_builder::ac(const common_peg_parser & p, const std::vector<std::string> & delimiters) {
if (delimiters.empty()) {
throw std::runtime_error("ac parser requires at least one delimiter");
}
return add(common_peg_ac_parser{p, delimiters});
}
static std::string gbnf_escape_char_class(uint32_t c) {
if (c == '-' || c == ']' || c == '[' || c == '\\') {
return "\\" + std::string(1, (char) c);
@@ -1529,14 +1543,22 @@ static std::string gbnf_escape_char_class(uint32_t c) {
return std::string(buf);
}
// GBNF grammar matching strings that contain no string in `strings` as a
// substring. Emits the complement of an Aho-Corasick automaton DFA and returns
// the start state rule name.
//
// ref: https://github.com/ggml-org/llama.cpp/pull/24839
static std::string gbnf_excluding_grammar(const common_grammar_builder & builder,
const std::string & prefix,
const std::vector<std::string> & strings) {
static std::string gbnf_char_class(const std::vector<uint32_t> & chars, bool negate) {
std::string s = negate ? "[^" : "[";
for (uint32_t ch : chars) {
s += gbnf_escape_char_class(ch);
}
return s + "]";
}
static std::string gbnf_ac_grammar(
const common_grammar_builder & builder,
const std::string & prefix,
const std::vector<std::string> & strings,
const std::function<std::string(const std::vector<uint32_t> &,
const std::map<size_t, std::vector<uint32_t>> &,
const std::vector<uint32_t> &,
const std::function<std::string(size_t)> &)> & build_rule) {
aho_corasick ac(strings);
auto state_name = [&](size_t s) -> std::string {
@@ -1548,42 +1570,30 @@ static std::string gbnf_excluding_grammar(const common_grammar_builder & builder
return prefix + "-" + num;
};
auto char_class = [](const std::vector<uint32_t> & chars, bool negate) {
std::string s = negate ? "[^" : "[";
for (uint32_t ch : chars) {
s += gbnf_escape_char_class(ch);
}
return s + "]";
};
for (size_t q = 0; q < ac.num_states(); q++) {
if (ac.is_terminal(q)) {
continue; // match states are dropped
continue; // match states
}
std::map<size_t, std::vector<uint32_t>> buckets;
std::vector<uint32_t> excluded;
std::vector<uint32_t> completing; // chars that complete a delimiter
std::vector<uint32_t> specific; // chars with an explicit transition
for (uint32_t c : ac.alphabet) {
size_t d = ac.next(q, c);
if (ac.is_terminal(d)) {
excluded.push_back(c); // completes a forbidden string -> omit
completing.push_back(c);
specific.push_back(c);
} else if (d != 0) {
buckets[d].push_back(c); // specific non-root destination
excluded.push_back(c);
specific.push_back(c);
}
}
std::string rhs = "|"; // every state is accepting
for (const auto & [d, chars] : buckets) {
rhs += " " + char_class(chars, false) + " " + state_name(d) + " |";
}
rhs += " " + char_class(excluded, true) + " " + state_name(0);
builder.add_rule(state_name(q), rhs);
builder.add_rule(state_name(q), build_rule(completing, buckets, specific, state_name));
}
// An empty delimiter makes the start state terminal. Emit an entry rule
// that matches nothing so the returned reference stays valid.
// that matches the empty string so the returned reference stays valid.
if (ac.is_terminal(0)) {
builder.add_rule(prefix, "|");
}
@@ -1591,6 +1601,54 @@ static std::string gbnf_excluding_grammar(const common_grammar_builder & builder
return state_name(0);
}
// GBNF grammar matching strings that contain no string in `strings` as a
// substring. Emits the complement of an Aho-Corasick automaton DFA and returns
// the start state rule name.
//
// ref: https://github.com/ggml-org/llama.cpp/pull/24839
static std::string gbnf_excluding_grammar(const common_grammar_builder & builder,
const std::string & prefix,
const std::vector<std::string> & strings) {
return gbnf_ac_grammar(builder, prefix, strings,
[](const std::vector<uint32_t> & /*completing*/,
const std::map<size_t, std::vector<uint32_t>> & buckets,
const std::vector<uint32_t> & specific,
const std::function<std::string(size_t)> & state_name) {
// every state is accepting and completing chars get no
// alternative, so a forbidden string can never be matched
std::string rhs = "|";
for (const auto & [d, chars] : buckets) {
rhs += " " + gbnf_char_class(chars, false) + " " + state_name(d) + " |";
}
rhs += " " + gbnf_char_class(specific, true) + " " + state_name(0);
return rhs;
});
}
// GBNF grammar matching everything up to and including the first occurrence of
// any string in `strings`. Emits the Aho-Corasick automaton DFA and returns
// the start state rule name.
static std::string gbnf_including_grammar(const common_grammar_builder & builder,
const std::string & prefix,
const std::vector<std::string> & strings) {
return gbnf_ac_grammar(builder, prefix, strings,
[](const std::vector<uint32_t> & completing,
const std::map<size_t, std::vector<uint32_t>> & buckets,
const std::vector<uint32_t> & specific,
const std::function<std::string(size_t)> & state_name) {
std::vector<std::string> alts;
if (!completing.empty()) {
alts.push_back(gbnf_char_class(completing, false)); // terminate on match
}
for (const auto & [d, chars] : buckets) {
alts.push_back(gbnf_char_class(chars, false) + " " + state_name(d));
}
// every other character keeps scanning from the start state
alts.push_back(gbnf_char_class(specific, true) + " " + state_name(0));
return string_join(alts, " | ");
});
}
static std::set<std::string> collect_reachable_rules(
const common_peg_arena & arena,
const common_peg_parser_id & rule
@@ -1628,6 +1686,7 @@ static std::set<std::string> collect_reachable_rules(
std::is_same_v<T, common_peg_tag_parser> ||
std::is_same_v<T, common_peg_atomic_parser> ||
std::is_same_v<T, common_peg_gbnf_parser> ||
std::is_same_v<T, common_peg_ac_parser> ||
std::is_same_v<T, common_peg_schema_parser>) {
visit(p.child);
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
@@ -1822,6 +1881,8 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
return to_gbnf(p.child);
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
return p.grammar;
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
return gbnf_including_grammar(builder, "ac-" + std::to_string(id), p.delimiters);
} else {
static_assert(is_always_false_v<T>);
}
@@ -1958,6 +2019,8 @@ static nlohmann::json serialize_parser_variant(const common_peg_parser_variant &
};
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
return json{{"type", "gbnf"}, {"child", p.child}, {"grammar", p.grammar}};
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
return json{{"type", "ac"}, {"child", p.child}, {"delimiters", p.delimiters}};
}
}, variant);
}
@@ -2130,6 +2193,16 @@ static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json
};
}
if (type == "ac") {
if (!j.contains("child") || !j.contains("delimiters") || !j["delimiters"].is_array() || j["delimiters"].empty()) {
throw std::runtime_error("ac parser requires 'child' and a non-empty 'delimiters' array");
}
return common_peg_ac_parser{
j["child"].get<common_peg_parser_id>(),
j["delimiters"].get<std::vector<std::string>>(),
};
}
throw std::runtime_error("Unknown parser type: " + type);
}
+14 -1
View File
@@ -275,6 +275,11 @@ struct common_peg_gbnf_parser {
std::string grammar;
};
struct common_peg_ac_parser {
common_peg_parser_id child;
std::vector<std::string> delimiters;
};
// Variant holding all parser types
using common_peg_parser_variant = std::variant<
common_peg_epsilon_parser,
@@ -296,7 +301,8 @@ using common_peg_parser_variant = std::variant<
common_peg_ref_parser,
common_peg_atomic_parser,
common_peg_tag_parser,
common_peg_gbnf_parser
common_peg_gbnf_parser,
common_peg_ac_parser
>;
class common_peg_arena {
@@ -514,6 +520,13 @@ class common_peg_parser_builder {
// the child's grammar. Parsing delegates entirely to the child.
common_peg_parser gbnf(const common_peg_parser & p, const std::string & grammar) { return add(common_peg_gbnf_parser{p, grammar}); }
// Wraps a child parser but emits a GBNF grammar built from the Aho-Corasick
// automaton of `delimiters`, matching everything up to and including the
// first delimiter. Parsing delegates entirely to the child, which is
// responsible for consuming the delimiter (e.g. until(D) + literal(D)).
common_peg_parser ac(const common_peg_parser & p, const std::vector<std::string> & delimiters);
common_peg_parser ac(const common_peg_parser & p, const std::string & delimiter) { return ac(p, std::vector<std::string>{delimiter}); }
void set_root(const common_peg_parser & p);
common_peg_arena build();
+1 -1
View File
@@ -29,7 +29,7 @@ With Termux, you can install and run `llama.cpp` as if the environment were Linu
```
$ apt update && apt upgrade -y
$ apt install git cmake
$ apt install git cmake libandroid-spawn
```
Then, follow the [build instructions](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md), specifically for CMake.
+5
View File
@@ -293,6 +293,11 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t
(sycl::ext::oneapi::bfloat16 *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2,
ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, ggml_is_contiguous(src0),
ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_BF16) {
op()((const sycl::ext::oneapi::bfloat16 *) src0->data, (const float *) src1->data,
(sycl::ext::oneapi::bfloat16 *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2,
ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, ggml_is_contiguous(src0),
ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream);
#endif
} else {
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type),
+155 -53
View File
@@ -43,14 +43,44 @@ static __dpct_inline__ T op_sgn(T x) {
return x > static_cast<T>(0.f) ? static_cast<T>(1.f) : ((x < static_cast<T>(0.f) ? static_cast<T>(-1.f) : static_cast<T>(0.f)));
}
template<typename T>
static __dpct_inline__ T op_abs(T x) {
return sycl::fabs(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::fabs(x); // or experimental namespace if needed
} else {
return sycl::fabs(x);
}
}
template<typename T>
static __dpct_inline__ T op_expm1(T x) {
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return static_cast<sycl::ext::oneapi::bfloat16>(
sycl::expm1(static_cast<float>(x))
);
} else {
return sycl::expm1(x);
}
}
template<typename T>
static __dpct_inline__ T op_elu(T x) {
return (x > static_cast<T>(0.f)) ? x : sycl::expm1(x);
return (x > static_cast<T>(0.f)) ? x : op_expm1(x);
}
template<typename T>
static __dpct_inline__ T op_tanh(T x) {
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
constexpr int ver = __INTEL_LLVM_COMPILER;
#if defined(__INTEL_LLVM_COMPILER) && (__INTEL_LLVM_COMPILER >= 20260000)
return sycl::ext::oneapi::experimental::tanh(x);
#else
return static_cast<T>(sycl::tanh(static_cast<float>(x)));
#endif
} else {
return sycl::tanh(x);
}
}
template<typename T>
@@ -59,74 +89,106 @@ static __dpct_inline__ T op_gelu(T x) {
const T SQRT_2_OVER_PI = static_cast<T>(0.79788456080286535587989211986876f);
return static_cast<T>(0.5f) * x *
(static_cast<T>(1.0f) +
sycl::tanh(SQRT_2_OVER_PI * x * (static_cast<T>(1.0f) + GELU_COEF_A * x * x)));
op_tanh(SQRT_2_OVER_PI * x * (static_cast<T>(1.0f) + GELU_COEF_A * x * x)));
}
template<typename T>
static __dpct_inline__ T op_exp(T x) {
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::exp(x);
} else {
return sycl::exp(x);
}
}
template<typename T>
static __dpct_inline__ T op_silu(T x) {
return x / (static_cast<T>(1.0f) + sycl::native::exp(-x));
return x / (static_cast<T>(1.0f) + op_exp(-x));
}
template<typename T>
static __dpct_inline__ T op_gelu_quick(T x) {
const T GELU_QUICK_COEF_LOCAL = static_cast<T>(-1.702f);
return x * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF_LOCAL * x)));
static __dpct_inline__ T op_erf(T x) {
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return static_cast<sycl::ext::oneapi::bfloat16>(
sycl::erf(static_cast<float>(x))
);
} else {
return sycl::erf(x);
}
}
template<typename T>
static __dpct_inline__ T op_gelu_erf(T x) {
const T SQRT_2_INV = static_cast<T>(0.70710678118654752440084436210484f);
return static_cast<T>(0.5f) * x * (static_cast<T>(1.0f) + sycl::erf(x * SQRT_2_INV));
return static_cast<T>(0.5f) * x * (static_cast<T>(1.0f) + op_erf(x * SQRT_2_INV));
}
template<typename T>
static __dpct_inline__ T op_tanh(T x) {
return sycl::tanh(x);
static __dpct_inline__ T op_gelu_quick(T x) {
const T GELU_QUICK_COEF_LOCAL = static_cast<T>(-1.702f);
return x * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + op_exp(GELU_QUICK_COEF_LOCAL * x)));
}
template<typename T>
static __dpct_inline__ T op_relu(T x) {
return sycl::fmax(x, static_cast<T>(0));
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::fmax(x, static_cast<T>(0));
} else {
return sycl::fmax(x, static_cast<T>(0));
}
}
template<typename T>
static __dpct_inline__ T op_sigmoid(T x) {
return static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(-x));
return static_cast<T>(1.0f) / (static_cast<T>(1.0f) + op_exp(-x));
}
template<typename T>
static __dpct_inline__ T op_sqrt(T x) {
return sycl::sqrt(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::sqrt(x);
} else {
return sycl::sqrt(x);
}
}
template<typename T>
static __dpct_inline__ T op_sin(T x) {
return sycl::sin(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::sin(x);
} else {
return sycl::sin(x);
}
}
template<typename T>
static __dpct_inline__ T op_cos(T x) {
return sycl::cos(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::cos(x);
} else {
return sycl::cos(x);
}
}
template<typename T>
static __dpct_inline__ T op_hardsigmoid(T x) {
return sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::fmin(
static_cast<T>(1.0f), sycl::ext::oneapi::experimental::fmax(
static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
} else {
return sycl::fmin(static_cast<T>(1.0f),
sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
}
}
template<typename T>
static __dpct_inline__ T op_hardswish(T x) {
return x * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
}
template<typename T>
static __dpct_inline__ T op_exp(T x) {
return sycl::exp(x);
}
template<typename T>
static __dpct_inline__ T op_expm1(T x) {
return sycl::expm1(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return x * sycl::ext::oneapi::experimental::fmin(static_cast<T>(1.0f), sycl::ext::oneapi::experimental::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
} else {
return x * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
}
}
template<typename T>
@@ -134,13 +196,17 @@ static __dpct_inline__ T op_log(T x) {
if (x <= static_cast<T>(0)) {
return neg_infinity<T>();
}
return sycl::log(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::log(x);
} else {
return sycl::log(x);
}
}
template<typename T>
static __dpct_inline__ T op_softplus(T x) {
const float xf = (float) x;
const float ax = sycl::fabs(xf);
const float ax = op_abs(xf);
const float m = sycl::fmax(xf, 0.0f);
const float y = m + sycl::log1p(sycl::exp(-ax));
return (T) y;
@@ -159,8 +225,14 @@ static __dpct_inline__ T op_step(T x) {
template<typename T>
static __dpct_inline__ T op_leaky_relu(T x, float negative_slope) {
T neg_slope_T = static_cast<T>(negative_slope);
return sycl::fmax(x, static_cast<T>(0)) +
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::fmax(x, static_cast<T>(0)) +
sycl::ext::oneapi::experimental::fmin(x, static_cast<T>(0.0f)) * neg_slope_T;
} else {
return sycl::fmax(x, static_cast<T>(0)) +
sycl::fmin(x, static_cast<T>(0.0f)) * neg_slope_T;
}
}
template<typename T>
@@ -175,22 +247,40 @@ static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) {
template<typename T>
static __dpct_inline__ T op_floor(T x) {
return sycl::floor(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::floor(x);
} else {
return sycl::floor(x);
}
}
template<typename T>
static __dpct_inline__ T op_ceil(T x) {
return sycl::ceil(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::ceil(x);
} else {
return sycl::ceil(x);
}
}
template<typename T>
static __dpct_inline__ T op_round(T x) {
return sycl::round(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return static_cast<sycl::ext::oneapi::bfloat16>(
sycl::round(static_cast<float>(x))
);
} else {
return sycl::round(x);
}
}
template<typename T>
static __dpct_inline__ T op_trunc(T x) {
return sycl::trunc(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::trunc(x);
} else {
return sycl::trunc(x);
}
}
template<typename T, typename F>
@@ -339,7 +429,7 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst,
const int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
[=](sycl::nd_item<3> /*item_ct1*/) {
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset);
});
}
@@ -354,8 +444,8 @@ static void arange_kernel(T * dst, const int k, T start, T step,
template<typename KernelInvoker, typename... Args>
static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16 || dst->src[0]->type == GGML_TYPE_BF16);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_BF16);
GGML_ASSERT(dst->src[0]->type == dst->type);
dpct::queue_ptr main_stream = ctx.stream();
@@ -367,6 +457,14 @@ static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx,
kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
break;
}
#ifdef GGML_SYCL_HAS_BF16
case GGML_TYPE_BF16:
{
auto data_pts = cast_data<sycl::ext::oneapi::bfloat16>(dst);
kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
break;
}
#endif
case GGML_TYPE_F32:
{
auto data_pts = cast_data<float>(dst);
@@ -480,7 +578,7 @@ static inline void ggml_sycl_op_unary(
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
sycl::range<1>(256)),
[=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
unary_op_generic_kernel(
src, dst_ptr, k_elements,
ne0, ne1, ne2, ne3,
@@ -508,7 +606,7 @@ static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_ten
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE),
sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
arange_kernel(dst_ptr, k, start, step, item_ct1);
});
}
@@ -602,7 +700,7 @@ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
unary_op_log_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
@@ -640,7 +738,7 @@ static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tenso
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQRT_BLOCK_SIZE),
sycl::range<1>(SYCL_SQRT_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
unary_op_sqrt_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
@@ -653,7 +751,7 @@ static inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
unary_op_sin_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
@@ -666,7 +764,7 @@ static inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
unary_op_cos_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
@@ -681,7 +779,7 @@ static inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
unary_op_leaky_relu_kernel(src, dst_ptr, k_elements, slope, item_ct1);
});
}, negative_slope);
@@ -694,7 +792,7 @@ static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQR_BLOCK_SIZE),
sycl::range<1>(SYCL_SQR_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
unary_op_sqr_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
@@ -711,7 +809,7 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE),
sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
clamp(src, dst_ptr, min_arg, max_arg, k_elements, item_ct1);
});
}, min_val, max_val);
@@ -774,7 +872,8 @@ static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tens
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
main_stream->parallel_for(
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
gated_op_fused_geglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
});
});
@@ -785,7 +884,8 @@ static inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tens
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_RELU_BLOCK_SIZE); // Using RELU block size for reglu
main_stream->parallel_for(
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
gated_op_fused_reglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
});
});
@@ -796,7 +896,8 @@ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_ten
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_SILU_BLOCK_SIZE); // Using SILU block size for swiglu
main_stream->parallel_for(
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)),
sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
gated_op_fused_swiglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
});
});
@@ -811,7 +912,6 @@ __dpct_inline__ float ggml_sycl_op_swiglu_oai_single(float x, float g, float alp
return out_glu;
}
template <typename T>
static void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k,
const int64_t n, const int64_t o0, const int64_t o1,
@@ -845,7 +945,7 @@ static void swiglu_oai_sycl(const T * x,
const int64_t num_blocks = (k + SYCL_GLU_BLOCK_SIZE - 1) / SYCL_GLU_BLOCK_SIZE;
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
swiglu_oai_kernel(x, g, dst, k, n, o0, o1, alpha, limit, item_ct1);
});
}
@@ -899,7 +999,8 @@ static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
main_stream->parallel_for(
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
});
});
@@ -910,7 +1011,8 @@ static inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggm
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
main_stream->parallel_for(
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
});
});
-2
View File
@@ -2813,8 +2813,6 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
cur_p->data[i].logit = -INFINITY;
}
}
llama_sampler_softmax_impl(cur_p, true);
}
static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) {
+69
View File
@@ -212,6 +212,75 @@ void test_gbnf_generation(testing &t) {
)""", gbnf);
});
t.test("ac grammar", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.ac(p.until("</tag>") + p.literal("</tag>"), "</tag>");
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
ac-3 ::= [<] ac-3-01 | [^<] ac-3
ac-3-01 ::= [<] ac-3-01 | [/] ac-3-02 | [^/<] ac-3
ac-3-02 ::= [<] ac-3-01 | [t] ac-3-03 | [^<t] ac-3
ac-3-03 ::= [<] ac-3-01 | [a] ac-3-04 | [^<a] ac-3
ac-3-04 ::= [<] ac-3-01 | [g] ac-3-05 | [^<g] ac-3
ac-3-05 ::= [>] | [<] ac-3-01 | [^<>] ac-3
root ::= ac-3
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""", gbnf);
});
t.test("ac grammar terminates at first delimiter", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.ac(p.until("\n</parameter>\n") + p.literal("\n</parameter>\n"), "\n</parameter>\n");
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
ac-3 ::= [\n] ac-3-01 | [^\n] ac-3
ac-3-01 ::= [\n] ac-3-01 | [<] ac-3-02 | [^\n<] ac-3
ac-3-02 ::= [\n] ac-3-01 | [/] ac-3-03 | [^\n/] ac-3
ac-3-03 ::= [\n] ac-3-01 | [p] ac-3-04 | [^\np] ac-3
ac-3-04 ::= [\n] ac-3-01 | [a] ac-3-05 | [^\na] ac-3
ac-3-05 ::= [\n] ac-3-01 | [r] ac-3-06 | [^\nr] ac-3
ac-3-06 ::= [\n] ac-3-01 | [a] ac-3-07 | [^\na] ac-3
ac-3-07 ::= [\n] ac-3-01 | [m] ac-3-08 | [^\nm] ac-3
ac-3-08 ::= [\n] ac-3-01 | [e] ac-3-09 | [^\ne] ac-3
ac-3-09 ::= [\n] ac-3-01 | [t] ac-3-10 | [^\nt] ac-3
ac-3-10 ::= [\n] ac-3-01 | [e] ac-3-11 | [^\ne] ac-3
ac-3-11 ::= [\n] ac-3-01 | [r] ac-3-12 | [^\nr] ac-3
ac-3-12 ::= [\n] ac-3-01 | [>] ac-3-13 | [^\n>] ac-3
ac-3-13 ::= [\n] | [^\n] ac-3
root ::= ac-3
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""", gbnf);
});
t.test("ac grammar multiple delimiters", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.ac(p.eps(), std::vector<std::string>{"ab", "cd", "ef"});
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
ac-1 ::= [a] ac-1-01 | [c] ac-1-03 | [e] ac-1-05 | [^ace] ac-1
ac-1-01 ::= [b] | [a] ac-1-01 | [c] ac-1-03 | [e] ac-1-05 | [^abce] ac-1
ac-1-03 ::= [d] | [a] ac-1-01 | [c] ac-1-03 | [e] ac-1-05 | [^acde] ac-1
ac-1-05 ::= [f] | [a] ac-1-01 | [c] ac-1-03 | [e] ac-1-05 | [^acef] ac-1
root ::= ac-1
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""", gbnf);
});
t.test("complex expressions with parentheses", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.one_or_more(p.literal("a") | p.literal("b"));
+2 -2
View File
@@ -360,9 +360,9 @@ int main(void) {
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.032727f, 0.241818f, 0.241818f}, 2.0f, 1.1f, 2, 5, {});
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {});
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f, 0.0f, 0.0f}, 1.00f);
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.0f, 0.0f, 0.428571f, 0.571429f}, 1.00f);
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, 0.00f); // top_n_sigma == 0 now represents a no-op rather than greedy decoding as of PR#13345
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 3.00f);
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, 3.00f);
test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);
+18 -7
View File
@@ -1230,8 +1230,6 @@ print(completion.choices[0].text)
Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only models with a [supported chat template](https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) can be used optimally with this endpoint. By default, the ChatML template will be used.
If model supports multimodal, you can input the media file via `image_url` content part. We support both base64 and remote URL as input. See OAI documentation for more.
*Options:*
See [OpenAI Chat Completions API documentation](https://platform.openai.com/docs/api-reference/chat). llama.cpp `/completion`-specific features such as `mirostat` are also supported.
@@ -1250,9 +1248,18 @@ The `response_format` parameter supports both plain JSON output (e.g. `{"type":
`parallel_tool_calls` : Whether to enable parallel/multiple tool calls (only supported on some models, verification is based on jinja template).
For multimodal input:
- Content type `image_url` and `input_audio` are the same as OAI schema
- Content type `input_video` is an extension from OAI schema. For now, it only accepts base64 input
For multimodal input (typed content, `messages[i].content[j]`):
- If `type == "image_url"`:
- `image_url.url` can be a remote URL, base64 (raw or URI-encoded via `data:image/...;base64`) or path to local file
- Accepts formats supported by `stb_image` (jpeg, png, tga, bmp, gif, ...)
- If `type == "input_audio"`:
- Either `input_audio.data` or `input_audio.url` can be specified, can be a remote URL, raw base64 or path to local file
- Accepts formats supported by `miniaudio` (mp3, wav, flac)
- `input_audio.format` will be ignored, the file format will be determined automatically
- If `type == "input_video"`:
- Either `input_video.data` or `input_video.url` can be specified, can be a remote URL, raw base64 or path to local file
- Accepts formats supported by `ffmpeg`
- Note: for local file, make sure to set `--media-path`. File path must be prefixed by `file://`
*Examples:*
@@ -1863,11 +1870,15 @@ Example events:
"data": {
"status": "loading",
"progress": {
"stage": "fit_params",
"value": 0.5 // from 0.0 to 1.0 ; note: not all stages have this "value"
"stages": ["text_model", "spec_model", "mmproj_model"],
"current": "text_model",
"value": 0.5
}
}
}
// note for "loading" status:
// - subsequent events will follow the same order of "stages" list
// - mmap is may report incorrect progress on some platforms; if you need exact progress, use --no-mmap
{
"model": "...",
+37 -25
View File
@@ -817,12 +817,21 @@ json oaicompat_completion_params_parse(const json & body) {
return llama_params;
}
// media_path always end with '/', see arg.cpp
// url can be
// - http(s):// for remote files
// - file:// for local files (only allowed if media_path is set)
// - data: for base64 encoded data with uri scheme (e.g. data:image/png;base64,...)
// - raw base64 encoded data
static void handle_media(
std::vector<raw_buffer> & out_files,
json & media_obj,
const std::string & media_path) {
std::string url = json_value(media_obj, "url", std::string());
const std::string & url,
const std::string & media_path,
bool accept_base64_uri) {
if (!media_path.empty()) {
// should already be enforced by arg.cpp, but checking just in case
GGML_ASSERT(media_path.back() == DIRECTORY_SEPARATOR);
}
if (string_starts_with(url, "http")) {
// download remote image
// TODO @ngxson : maybe make these params configurable
@@ -858,20 +867,28 @@ static void handle_media(
data.assign((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
out_files.push_back(data);
} else {
} else if (accept_base64_uri && string_starts_with(url, "data:")) {
// try to decode base64 image
std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
if (parts.size() != 2) {
throw std::runtime_error("Invalid url value");
throw std::runtime_error("Invalid uri-encoded base64 value");
} else if (!string_starts_with(parts[0], "data:image/")) {
throw std::runtime_error("Invalid url format: " + parts[0]);
throw std::runtime_error("Invalid uri format: " + parts[0]);
} else if (!string_ends_with(parts[0], "base64")) {
throw std::runtime_error("url must be base64 encoded");
throw std::runtime_error("uri must be base64 encoded");
} else {
auto base64_data = parts[1];
auto decoded_data = base64_decode(base64_data);
out_files.push_back(decoded_data);
}
} else {
// try as raw base64 string
auto decoded_data = base64_decode(url);
if (decoded_data.empty()) {
throw std::runtime_error("Invalid base64 value");
}
out_files.push_back(decoded_data);
}
}
@@ -957,14 +974,15 @@ json oaicompat_chat_params_parse(
}
for (auto & p : content) {
std::string type = json_value(p, "type", std::string());
std::string type = json_value(p, "type", std::string());
if (type == "image_url") {
if (!opt.allow_image) {
throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
}
json image_url = json_value(p, "image_url", json::object());
handle_media(out_files, image_url, opt.media_path);
std::string url = json_value(image_url, "url", std::string());
handle_media(out_files, url, opt.media_path, true);
p["type"] = "media_marker";
p["text"] = get_media_marker();
@@ -975,17 +993,11 @@ json oaicompat_chat_params_parse(
throw std::runtime_error("audio input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
}
json input_audio = json_value(p, "input_audio", json::object());
std::string data = json_value(input_audio, "data", std::string());
std::string format = json_value(input_audio, "format", std::string());
// while we also support flac, we don't allow it here so we matches the OAI spec
if (format != "wav" && format != "mp3") {
throw std::invalid_argument("input_audio.format must be either 'wav' or 'mp3'");
}
auto decoded_data = base64_decode(data); // expected to be base64 encoded
out_files.push_back(decoded_data);
// TODO: add audio_url support by reusing handle_media()
// note: don't need to validate "format", it's redundant
json input_audio = json_value(p, "input_audio", json::object());
std::string url = json_value(input_audio, "data",
json_value(input_audio, "url", std::string()));
handle_media(out_files, url, opt.media_path, false);
p["type"] = "media_marker";
p["text"] = get_media_marker();
@@ -996,10 +1008,10 @@ json oaicompat_chat_params_parse(
throw std::runtime_error("video input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
}
json input_video = json_value(p, "input_video", json::object());
std::string data = json_value(input_video, "data", std::string());
auto decoded_data = base64_decode(data); // expected to be base64 encoded
out_files.push_back(decoded_data);
json input_video = json_value(p, "input_video", json::object());
std::string url = json_value(input_video, "data",
json_value(input_video, "url", std::string()));
handle_media(out_files, url, opt.media_path, false);
p["type"] = "media_marker";
p["text"] = get_media_marker();
+44 -27
View File
@@ -962,6 +962,7 @@ private:
struct load_progress_data {
server_context_impl * ctx;
std::string stage;
std::vector<std::string> stages;
int64_t t_last_load_progress_ms = 0;
load_progress_data(server_context_impl * ctx, const std::string & stage) : ctx(ctx), stage(stage) {}
};
@@ -982,7 +983,8 @@ private:
}
if (d->ctx->callback_state) {
d->ctx->callback_state(SERVER_STATE_LOADING, {
{"stage", d->stage},
{"stages", d->stages},
{"current", d->stage},
{"value", progress},
});
}
@@ -992,18 +994,42 @@ private:
// load the model and initialize llama_context
// this may also be called to resume from sleeping state
bool load_model(common_params & params) {
load_progress_data load_progress_text(this, "text_model");
load_progress_data load_progress_text (this, "text_model");
load_progress_data load_progress_mmproj(this, "mmproj_model");
load_progress_data load_progress_spec (this, "spec_model");
bool is_resume = sleeping;
SRV_INF("loading model '%s'\n", params.model.path.c_str());
const bool is_resume = sleeping;
params_base = params;
params_base.n_outputs_max = server_n_outputs_max(params_base);
const bool has_mmproj = !params.mmproj.path.empty();
const bool has_draft = params.speculative.has_dft();
const bool spec_mtp = std::find(params_base.speculative.types.begin(),
params_base.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end();
const bool has_spec = has_draft || spec_mtp;
if (callback_state) {
std::vector<std::string> stages = {"text_model"};
if (has_spec) {
stages.push_back("spec_model");
}
if (has_mmproj) {
stages.push_back("mmproj_model");
}
load_progress_text.stages = stages;
load_progress_mmproj.stages = stages;
load_progress_spec.stages = stages;
// trigger 0% progress
load_progress_callback(0.0f, &load_progress_text);
}
SRV_INF("loading model '%s'\n", params.model.path.c_str());
std::string & mmproj_path = params_base.mmproj.path;
bool has_mmproj = !mmproj_path.empty();
mtmd_context_params mparams = mtmd_context_params_default();
if (has_mmproj) {
mparams.use_gpu = params_base.mmproj_use_gpu;
@@ -1050,16 +1076,7 @@ private:
// optionally reserve VRAM for the draft / MTP context before fitting the target model
if (params_base.fit_params) {
if (callback_state) {
callback_state(SERVER_STATE_LOADING, {{"stage", "fit_params"}});
}
const bool spec_mtp = std::find(params_base.speculative.types.begin(),
params_base.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end();
const bool has_draft = params_base.speculative.has_dft();
if (has_draft || spec_mtp) {
if (has_spec) {
common_params params_dft = params_base;
bool measure_model_bytes = true;
@@ -1151,11 +1168,7 @@ private:
add_bos_token = llama_vocab_get_add_bos(vocab);
if (params_base.speculative.has_dft()) {
if (callback_state) {
callback_state(SERVER_STATE_LOADING, {{"stage", "spec_model"}});
}
if (has_draft) {
// TODO speculative: move to common/speculative.cpp?
const auto & params_spec = params_base.speculative.draft;
@@ -1178,6 +1191,10 @@ private:
auto mparams_dft = common_model_params_to_llama(params_dft);
// progress callback
mparams_dft.progress_callback = load_progress_callback;
mparams_dft.progress_callback_user_data = &load_progress_spec;
model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft));
if (model_dft == nullptr) {
SRV_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str());
@@ -1186,10 +1203,6 @@ private:
auto cparams = common_context_params_to_llama(params_dft);
const bool spec_mtp = std::find(params_base.speculative.types.begin(),
params_base.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end();
if (spec_mtp) {
cparams.ctx_type = LLAMA_CONTEXT_TYPE_MTP;
}
@@ -1203,8 +1216,10 @@ private:
params_base.speculative.draft.ctx_tgt = ctx_tgt;
params_base.speculative.draft.ctx_dft = ctx_dft.get();
} else if (std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end()) {
} else if (spec_mtp) {
// no new model load, so we simply report 0.0 and 1.0 progress
load_progress_callback(0.0f, &load_progress_spec);
SRV_INF("creating MTP draft context against the target model '%s'\n",
params_base.model.path.c_str());
@@ -1224,6 +1239,8 @@ private:
params_base.speculative.draft.ctx_tgt = ctx_tgt;
params_base.speculative.draft.ctx_dft = ctx_dft.get();
load_progress_callback(1.0f, &load_progress_spec);
}
if (has_mmproj) {
+9 -5
View File
@@ -569,9 +569,13 @@ struct server_tool_edit_file : server_tool {
}
int n = (int) lines.size();
if (e.line_start == -1) {
// -1 means end of file; line_end is ignored — normalize to point past last line
e.line_start = n + 1;
e.line_end = n + 1;
// -1 targets end of file -> valid for append only; line_end is ignored
if (e.mode != "append") {
return {{"error", "line_start -1 (end of file) is only valid for append mode"}};
}
// append at end of file: insert position is the current line count
e.line_start = n;
e.line_end = n;
} else {
if (e.line_start < 1 || e.line_end < e.line_start) {
return {{"error", string_format("invalid line range [%d, %d]", e.line_start, e.line_end)}};
@@ -612,8 +616,8 @@ struct server_tool_edit_file : server_tool {
} else if (e.mode == "delete") {
lines.erase(lines.begin() + idx_start, lines.begin() + idx_end + 1);
} else { // append
// idx_end + 1 may equal lines.size() when line_start == -1 (end of file)
lines.insert(lines.begin() + idx_end + 1, new_lines.begin(), new_lines.end());
// insert after idx_end; idx_end + 1 == lines.size() for end-of-file append
lines.insert(lines.begin() + (idx_end + 1), new_lines.begin(), new_lines.end());
}
}
+10
View File
@@ -19,6 +19,10 @@ import type {
ApiErrorResponse,
ApiLlamaCppServerProps,
ApiModelDataEntry,
ApiModelLoadStage,
ApiModelsSseProgress,
ApiModelsSseData,
ApiModelsSseEvent,
ApiModelListResponse,
ApiProcessingState,
ApiRouterModelMeta,
@@ -52,6 +56,7 @@ import type {
// Model types
ModelModalities,
ModelOption,
ModelLoadProgress,
// Settings types
SettingsChatServiceOptions,
SettingsConfigValue,
@@ -83,6 +88,10 @@ declare global {
ApiErrorResponse,
ApiLlamaCppServerProps,
ApiModelDataEntry,
ApiModelLoadStage,
ApiModelsSseProgress,
ApiModelsSseData,
ApiModelsSseEvent,
ApiModelListResponse,
ApiProcessingState,
ApiRouterModelMeta,
@@ -120,6 +129,7 @@ declare global {
// Model types
ModelModalities,
ModelOption,
ModelLoadProgress,
// Settings types
SettingsChatServiceOptions,
SettingsConfigValue,
@@ -10,7 +10,7 @@
import { getMessageEditContext } from '$lib/contexts';
import { useProcessingState } from '$lib/hooks/use-processing-state.svelte';
import { isLoading, isChatStreaming } from '$lib/stores/chat.svelte';
import { copyToClipboard, deriveAgenticSections } from '$lib/utils';
import { copyToClipboard, deriveAgenticSections, modelLoadProgressText } from '$lib/utils';
import { AgenticSectionType } from '$lib/enums';
import { REASONING_TAGS } from '$lib/constants/agentic';
import { tick } from 'svelte';
@@ -185,6 +185,13 @@
let hasNoContent = $derived(!message?.content?.trim());
let isActivelyProcessing = $derived(isCurrentlyLoading || isStreaming);
// during a router auto-load the message has no model yet, so target the selected one
let loadTargetModel = $derived(message.model ?? modelsStore.selectedModelName);
let modelLoadProgress = $derived(
isRouter && loadTargetModel ? modelsStore.getLoadProgress(loadTargetModel) : null
);
let modelLoadingText = $derived(modelLoadProgressText(modelLoadProgress));
let showProcessingInfoTop = $derived(
message?.role === MessageRole.ASSISTANT &&
isActivelyProcessing &&
@@ -220,7 +227,8 @@
<div class="mt-6 w-full max-w-[48rem]" in:fade>
<div class="processing-container">
<span class="processing-text">
{processingState.getPromptProgressText() ??
{modelLoadingText ??
processingState.getPromptProgressText() ??
processingState.getProcessingMessage() ??
'Processing...'}
</span>
@@ -252,7 +260,8 @@
<div class="mt-4 w-full max-w-[48rem]" in:fade>
<div class="processing-container">
<span class="processing-text">
{processingState.getPromptProgressText() ??
{modelLoadingText ??
processingState.getPromptProgressText() ??
processingState.getProcessingMessage() ??
'Processing...'}
</span>
@@ -13,6 +13,7 @@
import type { ModelOption } from '$lib/types/models';
import { ServerModelStatus } from '$lib/enums';
import { modelsStore, routerModels } from '$lib/stores/models.svelte';
import { modelLoadFraction, modelLoadProgressText } from '$lib/utils';
interface Props {
option: ModelOption;
@@ -50,11 +51,15 @@
(serverStatus === ServerModelStatus.LOADED || isSleeping) && !isOperationInProgress
);
let isLoading = $derived(serverStatus === ServerModelStatus.LOADING || isOperationInProgress);
let loadProgress = $derived(isLoading ? modelsStore.getLoadProgress(option.model) : null);
let loadPercent = $derived(Math.round(modelLoadFraction(loadProgress) * 100));
let loadTitle = $derived(modelLoadProgressText(loadProgress));
</script>
<div
class={[
'group flex w-full items-center gap-2 rounded-sm p-2 text-left text-sm transition focus:outline-none',
'group relative flex w-full items-center gap-2 rounded-sm p-2 text-left text-sm transition focus:outline-none',
'cursor-pointer hover:bg-muted focus:bg-muted',
(isSelected || isHighlighted) && 'bg-accent text-accent-foreground',
!(isSelected || isHighlighted) && 'hover:bg-accent hover:text-accent-foreground',
@@ -62,6 +67,7 @@
]}
role="option"
aria-selected={isSelected || isHighlighted}
title={loadTitle}
tabindex="0"
onclick={() => onSelect(option.id)}
onmouseenter={onMouseEnter}
@@ -188,4 +194,15 @@
</div>
{/if}
</div>
{#if isLoading}
<div
class="pointer-events-none absolute inset-x-0 bottom-0 h-0.5 overflow-hidden rounded-b-sm bg-muted"
>
<div
class="h-full bg-primary transition-[width] duration-200 ease-out"
style="width: {loadPercent}%"
></div>
</div>
{/if}
</div>
+2 -1
View File
@@ -1,7 +1,8 @@
export const API_MODELS = {
LIST: '/v1/models',
LOAD: '/models/load',
UNLOAD: '/models/unload'
UNLOAD: '/models/unload',
SSE: '/models/sse'
};
// chat completion routes, the control route drives realtime inference (e.g. end reasoning)
+2
View File
@@ -37,6 +37,8 @@ export * from './mcp-form';
export * from './mcp-resource';
export * from './message-export';
export * from './model-id';
export * from './model-loading';
export * from './sse';
export * from './precision';
export * from './processing-info';
export * from './pwa';
@@ -0,0 +1,14 @@
/**
* Labels shown while a model loads, keyed by the stage reported on /models/sse.
*/
export const MODEL_LOAD_STAGE_LABELS: Record<ApiModelLoadStage, string> = {
text_model: 'Loading weights',
spec_model: 'Loading draft',
mmproj_model: 'Loading projector'
};
/**
* Share of the bar reserved for each load phase after text_model.
* text_model fills the rest, so a plain model reaches 100% on its own.
*/
export const MODEL_LOAD_TAIL_SHARE = 0.1;
+16
View File
@@ -0,0 +1,16 @@
/**
* Server-sent events wire format, shared by the chat stream and the
* /models/sse status feed (text/event-stream).
*/
// blank line between two events
export const SSE_RECORD_SEPARATOR = '\n\n';
// line break inside an event
export const SSE_LINE_SEPARATOR = '\n';
// data field prefix, the value follows after an optional space
export const SSE_DATA_PREFIX = 'data:';
// end-of-stream marker on the chat completion stream
export const SSE_DONE_MARKER = '[DONE]';
+1 -1
View File
@@ -54,7 +54,7 @@ export {
export { ModelModality } from './model.enums';
export { ServerRole, ServerModelStatus } from './server.enums';
export { ServerRole, ServerModelStatus, ServerModelsSseEventType } from './server.enums';
export { ParameterSource, SyncableParameterType, SettingsFieldType } from './settings.enums';
+14
View File
@@ -19,3 +19,17 @@ export enum ServerModelStatus {
SLEEPING = 'sleeping',
FAILED = 'failed'
}
/**
* /models/sse event type enum - discriminates the records broadcast on the
* model status feed in ROUTER mode. Matches the event names emitted by
* tools/server/server-models.cpp from the C++ server.
*/
export enum ServerModelsSseEventType {
STATUS_CHANGE = 'status_change',
MODEL_STATUS = 'model_status',
STATUS_UPDATE = 'status_update',
MODELS_RELOAD = 'models_reload',
MODEL_REMOVE = 'model_remove',
DOWNLOAD_PROGRESS = 'download_progress'
}
+9 -7
View File
@@ -10,7 +10,10 @@ import {
SETTINGS_KEYS,
API_CHAT,
API_SLOTS,
CONTROL_ACTION
CONTROL_ACTION,
SSE_LINE_SEPARATOR,
SSE_DATA_PREFIX,
SSE_DONE_MARKER
} from '$lib/constants';
import {
AttachmentType,
@@ -18,8 +21,7 @@ import {
FileTypeAudio,
MessageRole,
MimeTypeAudio,
ReasoningFormat,
UrlProtocol
ReasoningFormat
} from '$lib/enums';
import type {
ApiChatMessageContentPart,
@@ -642,15 +644,15 @@ export class ChatService {
if (abortSignal?.aborted) break;
chunk += decoder.decode(value, { stream: true });
const lines = chunk.split('\n');
const lines = chunk.split(SSE_LINE_SEPARATOR);
chunk = lines.pop() || '';
for (const line of lines) {
if (abortSignal?.aborted) break;
if (line.startsWith(UrlProtocol.DATA)) {
const data = line.slice(6);
if (data === '[DONE]') {
if (line.startsWith(SSE_DATA_PREFIX)) {
const data = line.slice(SSE_DATA_PREFIX.length).trim();
if (data === SSE_DONE_MARKER) {
streamFinished = true;
continue;
+239 -41
View File
@@ -1,6 +1,7 @@
import { base } from '$app/paths';
import { SvelteMap, SvelteSet } from 'svelte/reactivity';
import { toast } from 'svelte-sonner';
import { ServerModelStatus, ModelModality } from '$lib/enums';
import { ServerModelStatus, ServerModelsSseEventType, ModelModality } from '$lib/enums';
import { ModelsService } from '$lib/services/models.service';
import { PropsService } from '$lib/services/props.service';
import { serverStore, isRouterMode } from '$lib/stores/server.svelte';
@@ -8,11 +9,15 @@ import {
detectThinkingSupport,
detectThinkingSupportWithReason
} from '$lib/utils/chat-template-thinking-detector';
import { TTLCache } from '$lib/utils';
import { TTLCache, getAuthHeaders } from '$lib/utils';
import {
MODEL_PROPS_CACHE_TTL_MS,
MODEL_PROPS_CACHE_MAX_ENTRIES,
FAVORITE_MODELS_LOCALSTORAGE_KEY
FAVORITE_MODELS_LOCALSTORAGE_KEY,
API_MODELS,
SSE_RECORD_SEPARATOR,
SSE_LINE_SEPARATOR,
SSE_DATA_PREFIX
} from '$lib/constants';
import { conversationsStore } from '$lib/stores/conversations.svelte';
@@ -55,6 +60,15 @@ class ModelsStore {
private modelUsage = $state<Map<string, SvelteSet<string>>>(new Map());
private modelLoadingStates = new SvelteMap<string, boolean>();
// /models/sse feed state, the single source of truth for status and load progress
private statusAbort: AbortController | null = null;
private statusReaderActive = false;
private loadProgress = new SvelteMap<string, ModelLoadProgress>();
private statusWaiters = new Map<
string,
{ target: ServerModelStatus; resolve: () => void; reject: (e: Error) => void }
>();
favoriteModelIds = $state<Set<string>>(this.loadFavoritesFromStorage());
/**
@@ -626,49 +640,218 @@ class ModelsStore {
*
*/
/**
* WORKAROUND: Polling for model status after load/unload operations.
*
* Currently, `/models/load` and `/models/unload` return success before
* the operation actually completes on the server.
*
* TODO: Remove polling once llama-server properly waits for the operation
* to complete before returning success.
*/
private static readonly STATUS_POLL_INTERVAL = 500;
// reconnect delay after the feed drops or the server is not ready yet
private static readonly SSE_RECONNECT_MS = 1000;
/**
* Poll for expected model status after load/unload operation.
* Keeps polling until the model reaches the expected status or fails.
* Open the /models/sse feed and keep it live with auto reconnect.
* Idempotent and router mode only. The feed drives status and progress,
* so it replaces any post-operation polling.
*/
private async pollForModelStatus(
modelId: string,
expectedStatus: ServerModelStatus
): Promise<void> {
let attempt = 0;
while (true) {
await this.fetchRouterModels();
subscribeStatus(): void {
if (this.statusReaderActive) return;
if (!isRouterMode()) return;
const currentStatus = this.getModelStatus(modelId);
if (currentStatus === expectedStatus) return;
this.statusReaderActive = true;
this.statusAbort = new AbortController();
void this.runStatusReader(this.statusAbort.signal);
}
if (currentStatus === ServerModelStatus.FAILED) {
throw new Error(
`Model failed to ${expectedStatus === ServerModelStatus.LOADED ? 'load' : 'unload'}`
);
/**
* Close the /models/sse feed and drop transient progress.
*/
unsubscribeStatus(): void {
this.statusReaderActive = false;
this.statusAbort?.abort();
this.statusAbort = null;
this.loadProgress.clear();
}
/**
* Current load progress for a model, or null when not loading.
*/
getLoadProgress(modelId: string): ModelLoadProgress | null {
return this.loadProgress.get(modelId) ?? null;
}
/**
* Read the feed and reconnect until unsubscribed. Splits the byte stream
* into SSE records on the blank line boundary.
*/
private async runStatusReader(signal: AbortSignal): Promise<void> {
const decoder = new TextDecoder();
while (!signal.aborted) {
try {
const response = await fetch(`${base}${API_MODELS.SSE}`, {
headers: getAuthHeaders(),
signal
});
if (response.ok && response.body) {
const reader = response.body.getReader();
let buffer = '';
while (!signal.aborted) {
const { value, done } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
let boundary = buffer.indexOf(SSE_RECORD_SEPARATOR);
while (boundary !== -1) {
this.handleStatusRecord(buffer.slice(0, boundary));
buffer = buffer.slice(boundary + SSE_RECORD_SEPARATOR.length);
boundary = buffer.indexOf(SSE_RECORD_SEPARATOR);
}
}
}
} catch {
// network drop or abort falls through to the reconnect delay
}
if (
expectedStatus === ServerModelStatus.LOADED &&
currentStatus === ServerModelStatus.UNLOADED &&
attempt > 2
) {
throw new Error('Model was unloaded unexpectedly during loading');
}
if (signal.aborted) return;
attempt++;
await new Promise((resolve) => setTimeout(resolve, ModelsStore.STATUS_POLL_INTERVAL));
await new Promise((resolve) => setTimeout(resolve, ModelsStore.SSE_RECONNECT_MS));
}
}
/**
* Parse one SSE record. The payload rides in the data lines as a JSON
* envelope that carries its own model, event and data fields.
*/
private handleStatusRecord(record: string): void {
const payload = record
.split(SSE_LINE_SEPARATOR)
.filter((line) => line.startsWith(SSE_DATA_PREFIX))
.map((line) => line.slice(SSE_DATA_PREFIX.length).trim())
.join(SSE_LINE_SEPARATOR);
if (payload.length === 0) return;
let envelope: ApiModelsSseEvent;
try {
envelope = JSON.parse(payload);
} catch {
return;
}
this.applyStatusEvent(envelope);
}
/**
* Route one feed record by event kind. Only the status_* events carry a
* status payload, models_reload triggers a list refresh, model_remove drops
* the row, download_* belong to the download surface, not here.
*/
private applyStatusEvent(event: ApiModelsSseEvent): void {
switch (event.event) {
case ServerModelsSseEventType.STATUS_CHANGE:
case ServerModelsSseEventType.MODEL_STATUS:
case ServerModelsSseEventType.STATUS_UPDATE:
this.applyModelStatus(event);
break;
case ServerModelsSseEventType.MODELS_RELOAD:
void this.fetchRouterModels();
break;
case ServerModelsSseEventType.MODEL_REMOVE:
this.removeRouterModel(event.model);
break;
case ServerModelsSseEventType.DOWNLOAD_PROGRESS:
break;
}
}
/**
* Apply a status envelope: update the model row, track or clear progress,
* settle any pending load or unload awaiter.
*/
private applyModelStatus(event: ApiModelsSseEvent): void {
const model = event.model;
const data = event.data;
if (!model || !data?.status) return;
const status = data.status;
this.setRouterModelStatus(model, status);
if (status === ServerModelStatus.LOADING) {
if (data.progress) this.loadProgress.set(model, data.progress);
} else {
this.loadProgress.delete(model);
}
if (status === ServerModelStatus.LOADED) {
void this.updateModelModalities(model);
}
const failed =
status === ServerModelStatus.FAILED ||
(status === ServerModelStatus.UNLOADED && (data.exit_code ?? 0) !== 0);
if (failed) {
this.rejectStatus(model, new Error(`Model failed: ${this.toDisplayName(model)}`));
return;
}
this.settleStatus(model, status);
}
/**
* Drop a model row reported gone by the feed and settle its awaiters.
*/
private removeRouterModel(modelId: string): void {
if (this.routerModels.findIndex((m) => m.id === modelId) === -1) return;
this.routerModels = this.routerModels.filter((m) => m.id !== modelId);
this.loadProgress.delete(modelId);
this.rejectStatus(modelId, new Error(`Model removed: ${this.toDisplayName(modelId)}`));
}
/**
* Update one model row status in place, reassigning to trigger reactivity.
*/
private setRouterModelStatus(modelId: string, status: ServerModelStatus): void {
const idx = this.routerModels.findIndex((m) => m.id === modelId);
if (idx === -1) return;
const current = this.routerModels[idx];
if (current.status.value === status) return;
const next = [...this.routerModels];
next[idx] = { ...current, status: { ...current.status, value: status } };
this.routerModels = next;
}
/**
* Register an awaiter that resolves when the feed reports target status.
* One operation runs per model at a time, so one awaiter per model is kept.
*/
private waitForStatus(modelId: string, target: ServerModelStatus): Promise<void> {
return new Promise((resolve, reject) => {
this.statusWaiters.set(modelId, { target, resolve, reject });
});
}
/**
* Resolve and drop the awaiter when the model reaches its target status.
*/
private settleStatus(modelId: string, status: ServerModelStatus): void {
const waiter = this.statusWaiters.get(modelId);
if (waiter && waiter.target === status) {
this.statusWaiters.delete(modelId);
waiter.resolve();
}
}
/**
* Reject and drop the awaiter for a model.
*/
private rejectStatus(modelId: string, error: Error): void {
const waiter = this.statusWaiters.get(modelId);
if (waiter) {
this.statusWaiters.delete(modelId);
waiter.reject(error);
}
}
@@ -679,12 +862,18 @@ class ModelsStore {
this.modelLoadingStates.set(modelId, true);
this.error = null;
// the feed drives completion, so it must be live before the request
this.subscribeStatus();
const reachedLoaded = this.waitForStatus(modelId, ServerModelStatus.LOADED);
reachedLoaded.catch(() => {});
try {
await ModelsService.load(modelId);
await this.pollForModelStatus(modelId, ServerModelStatus.LOADED);
await this.updateModelModalities(modelId);
await reachedLoaded;
toast.success(`Model loaded: ${this.toDisplayName(modelId)}`);
} catch (error) {
this.rejectStatus(modelId, error instanceof Error ? error : new Error('load failed'));
this.error = error instanceof Error ? error.message : 'Failed to load model';
toast.error(`Failed to load model: ${this.toDisplayName(modelId)}`);
throw error;
@@ -700,11 +889,17 @@ class ModelsStore {
this.modelLoadingStates.set(modelId, true);
this.error = null;
this.subscribeStatus();
const reachedUnloaded = this.waitForStatus(modelId, ServerModelStatus.UNLOADED);
reachedUnloaded.catch(() => {});
try {
await ModelsService.unload(modelId);
await this.pollForModelStatus(modelId, ServerModelStatus.UNLOADED);
await reachedUnloaded;
toast.info(`Model unloaded: ${this.toDisplayName(modelId)}`);
} catch (error) {
this.rejectStatus(modelId, error instanceof Error ? error : new Error('unload failed'));
this.error = error instanceof Error ? error.message : 'Failed to unload model';
toast.error(`Failed to unload model: ${this.toDisplayName(modelId)}`);
throw error;
@@ -783,6 +978,9 @@ class ModelsStore {
}
clear(): void {
this.unsubscribeStatus();
this.statusWaiters.forEach((waiter) => waiter.reject(new Error('Models store cleared')));
this.statusWaiters.clear();
this.models = [];
this.routerModels = [];
this.loading = false;
+47 -1
View File
@@ -1,4 +1,10 @@
import type { ContentPartType, FileTypeAudio, ServerModelStatus, ServerRole } from '$lib/enums';
import type {
ContentPartType,
FileTypeAudio,
ServerModelStatus,
ServerModelsSseEventType,
ServerRole
} from '$lib/enums';
import type { ChatMessagePromptProgress, ChatRole } from './chat';
export type AudioInputFormat = FileTypeAudio.WAV | FileTypeAudio.MP3;
@@ -96,6 +102,46 @@ export interface ApiModelDataEntry {
meta?: Record<string, unknown> | null;
}
/**
* Load stage reported by the /models/sse feed, in load order.
*/
export type ApiModelLoadStage = 'text_model' | 'spec_model' | 'mmproj_model';
/**
* Load progress snapshot: the full ordered stage plan, the active stage,
* and its fractional value (0.0 -> 1.0).
*/
export interface ApiModelsSseProgress {
stages: ApiModelLoadStage[];
current: ApiModelLoadStage;
value: number;
}
/**
* Status payload carried by a /models/sse envelope.
* exit_code appears on unload.
*/
export interface ApiModelsSseData {
status: ServerModelStatus;
progress?: ApiModelsSseProgress;
exit_code?: number;
}
/**
* Event kind multiplexed on the /models/sse feed.
* Only the status_* events carry a status payload, models_reload signals a
* full list refresh, model_remove drops a row, download_* drive download UI.
*/
/**
* One /models/sse record. event discriminates the kind, model names the
* target instance, data carries the status payload when present.
*/
export interface ApiModelsSseEvent {
model: string;
event: ServerModelsSseEventType;
data: ApiModelsSseData;
}
export interface ApiModelDetails {
name: string;
model: string;
+10 -1
View File
@@ -11,6 +11,10 @@ export type {
ApiChatMessageData,
ApiModelStatus,
ApiModelDataEntry,
ApiModelLoadStage,
ApiModelsSseProgress,
ApiModelsSseData,
ApiModelsSseEvent,
ApiModelDetails,
ApiModelListResponse,
ApiLlamaCppServerProps,
@@ -70,7 +74,12 @@ export type {
} from './database';
// Model types
export type { ModelModalities, ModelOption, ModalityCapabilities } from './models';
export type {
ModelModalities,
ModelOption,
ModelLoadProgress,
ModalityCapabilities
} from './models';
// Settings types
export type {
+12 -1
View File
@@ -1,4 +1,4 @@
import type { ApiModelDataEntry, ApiModelDetails } from '$lib/types/api';
import type { ApiModelDataEntry, ApiModelDetails, ApiModelLoadStage } from '$lib/types/api';
export interface ModelModalities {
vision: boolean;
@@ -20,6 +20,17 @@ export interface ModelOption {
tags?: string[];
}
/**
* Ephemeral UI-only load progress for one model instance.
* Lives only while a load runs, driven by the /models/sse feed.
* stage is absent until the feed reports its first stage.
*/
export interface ModelLoadProgress {
stages: ApiModelLoadStage[];
current: ApiModelLoadStage;
value: number;
}
export interface ParsedModelId {
raw: string;
orgName: string | null;
+3
View File
@@ -44,6 +44,9 @@ export { buildProxiedUrl, buildProxiedHeaders } from './cors-proxy';
// URL utilities
export { extractRootDomain, sanitizeExternalUrl } from './url';
// Progress helpers
export { modelLoadFraction, modelLoadProgressText } from './progress';
// Conversation utilities
export { createMessageCountMap, getMessageCount } from './conversation-utils';
+43
View File
@@ -0,0 +1,43 @@
/**
* Model load progress helpers for the /models/sse surfaces
* (selector row and chat message).
*/
import { MODEL_LOAD_STAGE_LABELS, MODEL_LOAD_TAIL_SHARE } from '$lib/constants';
/**
* Human label for a model load stage.
*/
export function modelLoadStageLabel(stage: ApiModelLoadStage): string {
return MODEL_LOAD_STAGE_LABELS[stage];
}
/**
* Overall load fraction (0.0 -> 1.0) across the declared stage plan.
* text_model fills [0, 1 - tail], each later phase owns one tail slice.
*/
export function modelLoadFraction(progress: ModelLoadProgress | null): number {
if (!progress) return 0;
const { stages, current, value } = progress;
const tailCount = Math.max(stages.length - 1, 0);
const textCeiling = 1 - tailCount * MODEL_LOAD_TAIL_SHARE;
const idx = stages.indexOf(current);
if (idx <= 0) {
return value * textCeiling;
}
return textCeiling + (idx - 1 + value) * MODEL_LOAD_TAIL_SHARE;
}
/**
* Single line describing load progress: active stage label and overall percent.
* Returns null when there is no progress to show.
*/
export function modelLoadProgressText(progress: ModelLoadProgress | null): string | null {
if (!progress) return null;
const label = modelLoadStageLabel(progress.current);
return `${label} ${Math.round(modelLoadFraction(progress) * 100)}%`;
}
+14
View File
@@ -230,6 +230,20 @@
}
});
// Live model status and load progress via the /models/sse feed (router mode)
$effect(() => {
if (!browser) return;
if (!isRouterMode()) return;
untrack(() => {
modelsStore.subscribeStatus();
});
return () => {
modelsStore.unsubscribeStatus();
};
});
// Background MCP server health checks on app load
// Fetch enabled servers from settings and run health checks in background
$effect(() => {