mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-23 12:09:46 +02:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a3900a6694 | |||
| 7c908502ea | |||
| 035cd8f9a6 | |||
| 73618f27a8 | |||
| 23ee8797e1 | |||
| dec5ca5577 | |||
| 9c0ac887f3 |
+1
-1
@@ -10,7 +10,7 @@
|
||||
# ggml-org/ggml-rpc : rgerganov
|
||||
# ggml-org/ggml-sycl : arthw
|
||||
# ggml-org/ggml-vulkan : 0cc4m, jeffbolznv
|
||||
# ggml-org/ggml-webgpu : reeselevine
|
||||
# ggml-org/ggml-webgpu : reeselevine, yomaytk
|
||||
# ggml-org/ggml-zdnn : taronaeo
|
||||
# ggml-org/llama-common : ggerganov, aldehir, angt, danbev, ngxson, pwilkin
|
||||
# ggml-org/llama-mtmd : ngxson
|
||||
|
||||
+103
-53
@@ -90,41 +90,93 @@ std::string common_chat_msg::render_content(const std::string & delimiter) const
|
||||
return text;
|
||||
}
|
||||
|
||||
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims) {
|
||||
if (delims.empty() || prompt.empty()) {
|
||||
return {};
|
||||
common_chat_role common_chat_role_from_string(const std::string & role) {
|
||||
if (role == "system") { return COMMON_CHAT_ROLE_SYSTEM; }
|
||||
if (role == "assistant") { return COMMON_CHAT_ROLE_ASSISTANT; }
|
||||
if (role == "user") { return COMMON_CHAT_ROLE_USER; }
|
||||
if (role == "tool") { return COMMON_CHAT_ROLE_TOOL; }
|
||||
return COMMON_CHAT_ROLE_UNKNOWN;
|
||||
}
|
||||
|
||||
const char * common_chat_role_to_string(common_chat_role role) {
|
||||
switch (role) {
|
||||
case COMMON_CHAT_ROLE_SYSTEM: return "system";
|
||||
case COMMON_CHAT_ROLE_ASSISTANT: return "assistant";
|
||||
case COMMON_CHAT_ROLE_USER: return "user";
|
||||
case COMMON_CHAT_ROLE_TOOL: return "tool";
|
||||
case COMMON_CHAT_ROLE_UNKNOWN: return "";
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
json common_chat_msg_delimiters::to_json() const {
|
||||
json result = json::array();
|
||||
for (const auto & d : delimiters) {
|
||||
result.push_back({
|
||||
{ "role", common_chat_role_to_string(d.role) },
|
||||
{ "delimiter", d.delimiter },
|
||||
});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
common_chat_msg_delimiters common_chat_msg_delimiters_parse(const json & delimiters) {
|
||||
common_chat_msg_delimiters result;
|
||||
|
||||
if (!delimiters.is_array()) {
|
||||
return result;
|
||||
}
|
||||
|
||||
auto parser = build_peg_parser([&](common_peg_parser_builder & p) {
|
||||
std::vector<std::string> all_delims;
|
||||
std::vector<common_peg_parser> tagged_messages;
|
||||
|
||||
all_delims.reserve(delims.size());
|
||||
tagged_messages.reserve(delims.size());
|
||||
for (const auto & d : delims) {
|
||||
all_delims.push_back(d.delimiter);
|
||||
result.delimiters.reserve(delimiters.size());
|
||||
for (const auto & d : delimiters) {
|
||||
if (!d.is_object()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto any_delim = p.until_one_of(all_delims);
|
||||
for (const auto & d : delims) {
|
||||
tagged_messages.push_back(p.tag(d.role, p.literal(d.delimiter) + any_delim));
|
||||
}
|
||||
|
||||
return any_delim + p.zero_or_more(p.choice(tagged_messages)) + p.end();
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx(prompt);
|
||||
const auto result = parser.parse(ctx);
|
||||
if (!result.success()) {
|
||||
return {};
|
||||
result.delimiters.push_back({
|
||||
common_chat_role_from_string(d.value("role", std::string())),
|
||||
d.value("delimiter", std::string()),
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<common_chat_msg_span> spans;
|
||||
ctx.ast.visit(result, [&](const common_peg_ast_node & node) {
|
||||
if (!node.tag.empty()) {
|
||||
spans.push_back({ node.tag, node.start, node.end - node.start });
|
||||
return result;
|
||||
}
|
||||
|
||||
void common_chat_msg_delimiters::tokenize(const llama_vocab * vocab) {
|
||||
for (auto & d : delimiters) {
|
||||
d.tokens = common_tokenize(vocab, d.delimiter, false, true);
|
||||
}
|
||||
}
|
||||
|
||||
common_chat_msg_spans common_chat_msg_delimiters::split(const llama_tokens & tokens, const std::map<size_t, size_t> & skips) const {
|
||||
std::vector<std::pair<common_chat_role, size_t>> matches;
|
||||
|
||||
auto skip = skips.begin();
|
||||
for (size_t i = 0; i < tokens.size();) {
|
||||
if (skip != skips.end() && i == skip->first) {
|
||||
i += skip->second;
|
||||
++skip;
|
||||
continue;
|
||||
}
|
||||
});
|
||||
for (const auto & d : delimiters) {
|
||||
if (i + d.tokens.size() > tokens.size()) {
|
||||
continue;
|
||||
}
|
||||
if (std::equal(d.tokens.begin(), d.tokens.end(), tokens.begin() + i)) {
|
||||
matches.emplace_back(d.role, i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
i++;
|
||||
}
|
||||
|
||||
matches.emplace_back(COMMON_CHAT_ROLE_UNKNOWN, tokens.size());
|
||||
|
||||
common_chat_msg_spans spans;
|
||||
for (size_t i = 0; i + 1 < matches.size(); i++) {
|
||||
const auto & curr = matches[i];
|
||||
const auto & next = matches[i + 1];
|
||||
spans.add(curr.first, curr.second, next.second - curr.second);
|
||||
}
|
||||
|
||||
return spans;
|
||||
}
|
||||
@@ -1081,13 +1133,13 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
|
||||
data.prompt = prompt;
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||
data.message_spans = common_chat_split_by_role(prompt, {
|
||||
{ "assistant", "<|start|>assistant" },
|
||||
{ "user", "<|start|>user" },
|
||||
{ "system", "<|start|>developer" },
|
||||
{ "system", "<|start|>system" },
|
||||
{ "tool", "<|start|>functions" },
|
||||
});
|
||||
data.message_delimiters = {
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, "<|start|>assistant" },
|
||||
{ COMMON_CHAT_ROLE_USER, "<|start|>user" },
|
||||
{ COMMON_CHAT_ROLE_SYSTEM, "<|start|>developer" },
|
||||
{ COMMON_CHAT_ROLE_SYSTEM, "<|start|>system" },
|
||||
{ COMMON_CHAT_ROLE_TOOL, "<|start|>functions" },
|
||||
};
|
||||
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
@@ -1228,10 +1280,10 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
data.message_spans = common_chat_split_by_role(data.prompt, {
|
||||
{ "user", "<|turn>user\n" },
|
||||
{ "assistant", "<|turn>model\n" },
|
||||
});
|
||||
data.message_delimiters = {
|
||||
{ COMMON_CHAT_ROLE_USER, "<|turn>user" },
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, "<|turn>model" },
|
||||
};
|
||||
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4;
|
||||
data.supports_thinking = true;
|
||||
@@ -2030,15 +2082,15 @@ static common_chat_params common_chat_params_init_cohere2moe(const common_chat_t
|
||||
RESULT_START, RESULT_END,
|
||||
};
|
||||
|
||||
// Split the rendered prompt into per-role message spans. Tool results are rendered with the
|
||||
// Declare per-role message delimiters. Tool results are rendered with the
|
||||
// system token followed by <|START_TOOL_RESULT|>, so the "tool" delimiter must be listed before
|
||||
// the plain "system" one (it is a strict superset, and the role split tries delimiters in order).
|
||||
data.message_spans = common_chat_split_by_role(data.prompt, {
|
||||
{ "assistant", GEN_PREFIX },
|
||||
{ "user", TURN_START + USER },
|
||||
{ "tool", TURN_START + SYSTEM + RESULT_START },
|
||||
{ "system", TURN_START + SYSTEM },
|
||||
});
|
||||
data.message_delimiters = {
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, GEN_PREFIX },
|
||||
{ COMMON_CHAT_ROLE_USER, TURN_START + USER },
|
||||
{ COMMON_CHAT_ROLE_TOOL, TURN_START + SYSTEM + RESULT_START },
|
||||
{ COMMON_CHAT_ROLE_SYSTEM, TURN_START + SYSTEM },
|
||||
};
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
@@ -2526,17 +2578,15 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
autoparser.analyze_template(tmpl);
|
||||
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
|
||||
|
||||
std::vector<common_chat_msg_delimiter> delimiters;
|
||||
common_chat_msg_delimiters delimiters;
|
||||
if (!autoparser.assistant_start.empty()) {
|
||||
delimiters.push_back({ "assistant", autoparser.assistant_start });
|
||||
delimiters.add(COMMON_CHAT_ROLE_ASSISTANT, autoparser.assistant_start);
|
||||
}
|
||||
if (!autoparser.user_start.empty()) {
|
||||
delimiters.push_back({ "user", autoparser.user_start });
|
||||
delimiters.add(COMMON_CHAT_ROLE_USER, autoparser.user_start);
|
||||
}
|
||||
|
||||
if (!delimiters.empty()) {
|
||||
auto_params.message_spans = common_chat_split_by_role(auto_params.prompt, delimiters);
|
||||
}
|
||||
auto_params.message_delimiters = std::move(delimiters);
|
||||
|
||||
auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
|
||||
if (auto_params.supports_thinking) {
|
||||
|
||||
+65
-6
@@ -143,15 +143,75 @@ struct common_chat_msg_diff {
|
||||
}
|
||||
};
|
||||
|
||||
enum common_chat_role {
|
||||
COMMON_CHAT_ROLE_UNKNOWN,
|
||||
COMMON_CHAT_ROLE_SYSTEM,
|
||||
COMMON_CHAT_ROLE_ASSISTANT,
|
||||
COMMON_CHAT_ROLE_USER,
|
||||
COMMON_CHAT_ROLE_TOOL
|
||||
};
|
||||
|
||||
common_chat_role common_chat_role_from_string(const std::string & role);
|
||||
const char * common_chat_role_to_string(common_chat_role role);
|
||||
|
||||
struct common_chat_msg_span {
|
||||
std::string role;
|
||||
common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN;
|
||||
std::size_t pos = 0;
|
||||
std::size_t len = 0;
|
||||
|
||||
bool valid() const {
|
||||
return role != COMMON_CHAT_ROLE_UNKNOWN;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_chat_msg_spans {
|
||||
std::vector<common_chat_msg_span> spans;
|
||||
|
||||
void add(common_chat_role role, size_t pos, size_t len) {
|
||||
spans.push_back({ role, pos, len });
|
||||
}
|
||||
|
||||
bool is_user_start(int32_t pos) const {
|
||||
for (auto it = spans.begin(); it != spans.end(); ++it) {
|
||||
if (it->role == COMMON_CHAT_ROLE_USER && pos == (int32_t) it->pos) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t last_user_message_pos() const {
|
||||
for (auto it = spans.rbegin(); it != spans.rend(); ++it) {
|
||||
if (it->role == COMMON_CHAT_ROLE_USER) {
|
||||
return (int32_t) it->pos;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_chat_msg_delimiter {
|
||||
std::string role;
|
||||
std::string delimiter;
|
||||
common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN;
|
||||
std::string delimiter;
|
||||
llama_tokens tokens = {};
|
||||
};
|
||||
|
||||
struct common_chat_msg_delimiters {
|
||||
std::vector<common_chat_msg_delimiter> delimiters;
|
||||
|
||||
common_chat_msg_delimiters() = default;
|
||||
common_chat_msg_delimiters(std::initializer_list<common_chat_msg_delimiter> delims) : delimiters(delims) {}
|
||||
|
||||
void add(common_chat_role role, const std::string & delimiter) {
|
||||
delimiters.push_back({ role, delimiter });
|
||||
}
|
||||
|
||||
void tokenize(const llama_vocab * vocab);
|
||||
|
||||
// split tokens into message spans. skips maps a start index to a length of a region to jump over without matching
|
||||
common_chat_msg_spans split(const llama_tokens & tokens, const std::map<size_t, size_t> & skips = {}) const;
|
||||
|
||||
nlohmann::ordered_json to_json() const;
|
||||
};
|
||||
|
||||
struct common_chat_tool {
|
||||
@@ -219,7 +279,7 @@ struct common_chat_params {
|
||||
std::vector<std::string> preserved_tokens;
|
||||
std::vector<std::string> additional_stops;
|
||||
std::string parser;
|
||||
std::vector<common_chat_msg_span> message_spans;
|
||||
common_chat_msg_delimiters message_delimiters;
|
||||
};
|
||||
|
||||
// per-message parsing syntax
|
||||
@@ -325,5 +385,4 @@ struct common_chat_prompt_preset {
|
||||
|
||||
common_chat_prompt_preset common_chat_get_asr_prompt(const common_chat_templates * chat_templates);
|
||||
|
||||
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims);
|
||||
|
||||
common_chat_msg_delimiters common_chat_msg_delimiters_parse(const nlohmann::ordered_json & delimiters);
|
||||
|
||||
+1
-1
@@ -609,7 +609,7 @@ struct common_params {
|
||||
bool cache_prompt = true; // whether to enable prompt caching
|
||||
bool cache_idle_slots = true; // save and clear idle slots upon starting a new task
|
||||
int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot
|
||||
int32_t checkpoint_min_step = 256; // minimum spacing between context checkpoints
|
||||
int32_t checkpoint_min_step = 8192; // minimum spacing between context checkpoints
|
||||
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
||||
|
||||
std::string hostname = "127.0.0.1";
|
||||
|
||||
@@ -96,6 +96,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
||||
"GraniteMoeHybridForCausalLM": "granite",
|
||||
"GraniteMoeSharedForCausalLM": "granite",
|
||||
"GraniteSpeechForConditionalGeneration": "granite",
|
||||
"GraniteSpeechPlusForConditionalGeneration": "granite",
|
||||
"Grok1ForCausalLM": "grok",
|
||||
"GrokForCausalLM": "grok",
|
||||
"GroveMoeForCausalLM": "grovemoe",
|
||||
@@ -261,6 +262,7 @@ MMPROJ_MODEL_MAP: dict[str, str] = {
|
||||
"GlmasrModel": "ultravox",
|
||||
"Granite4VisionForConditionalGeneration": "granite",
|
||||
"GraniteSpeechForConditionalGeneration": "granite",
|
||||
"GraniteSpeechPlusForConditionalGeneration": "granite",
|
||||
"HunYuanVLForConditionalGeneration": "hunyuan",
|
||||
"Idefics3ForConditionalGeneration": "smolvlm",
|
||||
"InternVisionModel": "internvl",
|
||||
|
||||
@@ -348,6 +348,34 @@ class GraniteSpeechMmprojModel(MmprojModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("GraniteSpeechPlusForConditionalGeneration")
|
||||
class GraniteSpeechPlusMmprojModel(GraniteSpeechMmprojModel):
|
||||
"""Conversion for GraniteSpeechPlus - extends GraniteSpeech with feature layer concatenation"""
|
||||
has_vision_encoder = False
|
||||
has_audio_encoder = True
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
assert self.hparams_audio is not None
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# Add feature_layer if present in encoder config
|
||||
if feature_layers := self.hparams_audio.get("cat_hidden_layers"):
|
||||
self.gguf_writer.add_audio_feature_layers(feature_layers)
|
||||
logger.info(f"gguf: audio feature_layers = {feature_layers}")
|
||||
|
||||
# Validate projector dimension matches concatenated encoder output
|
||||
hidden_dim = self.hparams_audio["hidden_dim"]
|
||||
expected_dim = hidden_dim * (len(feature_layers) + 1)
|
||||
projector_dim = self.global_config["projector_config"]["encoder_hidden_size"]
|
||||
|
||||
if projector_dim != expected_dim:
|
||||
raise ValueError(
|
||||
f"Projector encoder_hidden_size ({projector_dim}) does not match "
|
||||
f"expected concatenated dimension ({expected_dim}). "
|
||||
f"Expected: hidden_dim ({hidden_dim}) * (len(feature_layers) + 1) = {expected_dim}"
|
||||
)
|
||||
|
||||
|
||||
@ModelBase.register("Granite4VisionForConditionalGeneration")
|
||||
class Granite4VisionMmprojModel(MmprojModel):
|
||||
has_vision_encoder = True
|
||||
|
||||
@@ -174,7 +174,7 @@ __kernel void kernel_gemv_noshuffle_q8_0_f32(
|
||||
regA.s6 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
|
||||
regA.s7 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
|
||||
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, regS, regB);
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, convert_float(regS), regB);
|
||||
}
|
||||
|
||||
// reduction in local memory, assumes #wave=4
|
||||
|
||||
@@ -905,11 +905,12 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key {
|
||||
ggml_type src0_type;
|
||||
ggml_type src1_type;
|
||||
int vectorized;
|
||||
uint32_t num_cols;
|
||||
bool use_mmvq;
|
||||
|
||||
bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const {
|
||||
return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&
|
||||
use_mmvq == other.use_mmvq;
|
||||
num_cols == other.num_cols && use_mmvq == other.use_mmvq;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -919,6 +920,7 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {
|
||||
ggml_webgpu_hash_combine(seed, key.src0_type);
|
||||
ggml_webgpu_hash_combine(seed, key.src1_type);
|
||||
ggml_webgpu_hash_combine(seed, key.vectorized);
|
||||
ggml_webgpu_hash_combine(seed, key.num_cols);
|
||||
ggml_webgpu_hash_combine(seed, key.use_mmvq);
|
||||
return seed;
|
||||
}
|
||||
@@ -993,11 +995,12 @@ struct ggml_webgpu_mul_mat_id_pipeline_key {
|
||||
ggml_type src0_type;
|
||||
ggml_type src1_type;
|
||||
uint32_t n_experts;
|
||||
uint32_t num_cols;
|
||||
int vectorized;
|
||||
|
||||
bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const {
|
||||
return src0_type == other.src0_type && src1_type == other.src1_type && n_experts == other.n_experts &&
|
||||
vectorized == other.vectorized;
|
||||
num_cols == other.num_cols && vectorized == other.vectorized;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1007,6 +1010,7 @@ struct ggml_webgpu_mul_mat_id_pipeline_key_hash {
|
||||
ggml_webgpu_hash_combine(seed, key.src0_type);
|
||||
ggml_webgpu_hash_combine(seed, key.src1_type);
|
||||
ggml_webgpu_hash_combine(seed, key.n_experts);
|
||||
ggml_webgpu_hash_combine(seed, key.num_cols);
|
||||
ggml_webgpu_hash_combine(seed, key.vectorized);
|
||||
return seed;
|
||||
}
|
||||
@@ -1107,7 +1111,7 @@ inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0,
|
||||
const ggml_tensor * src1,
|
||||
bool supports_dot_product,
|
||||
const std::string & vendor) {
|
||||
if (src1->ne[1] == 1) {
|
||||
if (src1->ne[1] <= 4) {
|
||||
bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia";
|
||||
if (supports_dp4a && supports_dot_product) {
|
||||
switch (src1->type) {
|
||||
@@ -1889,6 +1893,7 @@ class ggml_webgpu_shader_lib {
|
||||
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
||||
1 :
|
||||
0;
|
||||
key.num_cols = context.dst->ne[1];
|
||||
key.use_mmvq =
|
||||
ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor);
|
||||
|
||||
@@ -2004,6 +2009,7 @@ class ggml_webgpu_shader_lib {
|
||||
if (key.vectorized) {
|
||||
variant += "_vectorized";
|
||||
}
|
||||
defines.push_back(std::string("NUM_COLS=") + std::to_string(key.num_cols));
|
||||
|
||||
auto processed = preprocessor.preprocess(shader_src, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
|
||||
@@ -2421,6 +2427,7 @@ class ggml_webgpu_shader_lib {
|
||||
if (key.vectorized) {
|
||||
variant += "_vectorized";
|
||||
}
|
||||
defines.push_back(std::string("NUM_COLS=1"));
|
||||
|
||||
defines.push_back(std::string("N_EXPERTS=") + std::to_string(key.n_experts));
|
||||
|
||||
|
||||
@@ -1418,15 +1418,17 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context &
|
||||
const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
|
||||
const size_t q8_src1_align_offset = ROUNDUP_POW2(
|
||||
dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||
const size_t q8_src1_binding_size =
|
||||
ROUNDUP_POW2(src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)),
|
||||
WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
const size_t q8_src1_binding_size = ROUNDUP_POW2(
|
||||
src1->ne[3] * src1->ne[2] * src1->ne[1] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)),
|
||||
WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
|
||||
std::vector<uint32_t> q8_params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
|
||||
(uint32_t) src1->ne[0],
|
||||
(uint32_t) src1->ne[1],
|
||||
(uint32_t) src1->ne[2],
|
||||
(uint32_t) src1->ne[3],
|
||||
};
|
||||
@@ -1442,7 +1444,7 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context &
|
||||
uint32_t q8_wg_x = 1;
|
||||
uint32_t q8_wg_y = 1;
|
||||
const uint32_t wg_per_vec = (src0->ne[0] / 4 + (q8_wg_size - 1)) / q8_wg_size;
|
||||
const uint32_t q8_total_wg = src1->ne[2] * src1->ne[3] * wg_per_vec;
|
||||
const uint32_t q8_total_wg = src1->ne[1] * src1->ne[2] * src1->ne[3] * wg_per_vec;
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
compute_2d_workgroups(q8_total_wg, max_wg_per_dim, q8_wg_x, q8_wg_y);
|
||||
|
||||
@@ -1456,7 +1458,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
// Determine if this is a mat-vec operation
|
||||
bool is_vec = (dst->ne[1] == 1);
|
||||
bool use_mat_vec = (dst->ne[1] <= 4);
|
||||
|
||||
// use MMVQ path for mat-vec
|
||||
bool use_mmvq = ggml_webgpu_can_use_mmvq(src0, src1, ctx->global_ctx->capabilities.supports_dot_product,
|
||||
@@ -1482,7 +1484,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
webgpu_pipeline pipeline;
|
||||
std::vector<webgpu_dispatch_desc> dispatches;
|
||||
|
||||
if (is_vec) {
|
||||
if (use_mat_vec) {
|
||||
if (use_mmvq) {
|
||||
ggml_webgpu_quantize_q8_dispatch(ctx, src0, src1, dst, dispatches);
|
||||
}
|
||||
@@ -1529,7 +1531,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
uint32_t wg_y = 1;
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
|
||||
if (is_vec) {
|
||||
if (use_mat_vec) {
|
||||
auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
uint32_t batches = dst->ne[2] * dst->ne[3];
|
||||
@@ -3691,8 +3693,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
|
||||
ggml_webgpu_can_use_mmvq(src0, src1, ctx->webgpu_global_ctx->capabilities.supports_dot_product,
|
||||
ctx->webgpu_global_ctx->vendor);
|
||||
if (use_mmvq) {
|
||||
const size_t q8_src1_size =
|
||||
src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32));
|
||||
const size_t q8_src1_size = src1->ne[3] * src1->ne[2] * src1->ne[1] *
|
||||
(36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32));
|
||||
res = ROUNDUP_POW2(res + q8_src1_size +
|
||||
ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
|
||||
WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
|
||||
@@ -103,7 +103,7 @@ fn main(
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let subgroup_total = subgroupAdd(acc[row]);
|
||||
let subgroup_total = subgroupAdd(acc[0][row]);
|
||||
if (subgroup_invocation_id == 0u) {
|
||||
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
|
||||
}
|
||||
@@ -126,7 +126,7 @@ fn main(
|
||||
|
||||
#ifdef USE_WORKGROUP_REDUCTION
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
partial_sums[partial_index(row, thread_id)] = acc[row];
|
||||
partial_sums[partial_index(row, thread_id)] = acc[0][row];
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
@@ -91,61 +91,67 @@ fn main(
|
||||
let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base;
|
||||
|
||||
#ifdef MMVQ
|
||||
let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * (params.k / 32u);
|
||||
let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * params.n * (params.k / 32u);
|
||||
let acc = accumulate_vec_q_dot(thread_id, row_base, src0_batch_offset, src1q_idx_base);
|
||||
#else
|
||||
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
|
||||
let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base);
|
||||
#endif
|
||||
|
||||
for (var col = 0u;col < NUM_COLS;col += 1) {
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let subgroup_total = subgroupAdd(acc[row]);
|
||||
if (subgroup_invocation_id == 0u) {
|
||||
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
|
||||
}
|
||||
}
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let subgroup_total = subgroupAdd(acc[col][row]);
|
||||
if (subgroup_invocation_id == 0u) {
|
||||
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
workgroupBarrier();
|
||||
|
||||
for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) {
|
||||
let output_row = row_base + row;
|
||||
var row_acc = 0.0f;
|
||||
for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) {
|
||||
row_acc += partial_sums[partial_index(row, k)];
|
||||
}
|
||||
let row_total = subgroupAdd(row_acc);
|
||||
if (subgroup_invocation_id == 0) {
|
||||
dst[dst_idx_base + row] = row_total;
|
||||
}
|
||||
}
|
||||
for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) {
|
||||
let output_row = row_base + row;
|
||||
var row_acc = 0.0f;
|
||||
for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) {
|
||||
row_acc += partial_sums[partial_index(row, k)];
|
||||
}
|
||||
let row_total = subgroupAdd(row_acc);
|
||||
if (subgroup_invocation_id == 0) {
|
||||
dst[dst_idx_base + col * params.m + row] = row_total;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_WORKGROUP_REDUCTION
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
partial_sums[partial_index(row, thread_id)] = acc[row];
|
||||
}
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
partial_sums[partial_index(row, thread_id)] = acc[col][row];
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
var stride = WG_SIZE / 2u;
|
||||
|
||||
while (stride > 0) {
|
||||
if (thread_id < stride) {
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)];
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
stride = stride / 2;
|
||||
}
|
||||
|
||||
if (thread_id < OUTPUTS_PER_WG) {
|
||||
let output_row = row_base + thread_id;
|
||||
if (output_row < params.m) {
|
||||
dst[dst_idx_base + col * params.m + thread_id] = partial_sums[partial_index(thread_id, 0)];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
var stride = WG_SIZE / 2u;
|
||||
|
||||
while (stride > 0) {
|
||||
if (thread_id < stride) {
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)];
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
stride = stride / 2;
|
||||
}
|
||||
|
||||
if (thread_id < OUTPUTS_PER_WG) {
|
||||
let output_row = row_base + thread_id;
|
||||
if (output_row < params.m) {
|
||||
dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -51,10 +51,7 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE {
|
||||
fn get_dm(block_byte_base: u32) -> f32 {
|
||||
return f32(load_f16_at_src0(block_byte_base));
|
||||
}
|
||||
fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 {
|
||||
return f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK;
|
||||
}
|
||||
#endif
|
||||
#endif // MUL_ACC_Q4_0
|
||||
|
||||
#ifdef MUL_ACC_Q4_1
|
||||
#define BLOCK_SIZE_BYTES 20
|
||||
@@ -85,10 +82,7 @@ fn get_dm(block_byte_base: u32) -> vec2<f32> {
|
||||
f32(load_f16_at_src0(block_byte_base + 2u))
|
||||
);
|
||||
}
|
||||
fn mul_q8_1(row_sum: i32, dma: vec2<f32>, b_ds: B_DS_TYPE) -> f32 {
|
||||
return f32(row_sum) * (dma.x * b_ds.x) + dma.y * b_ds.y / THREADS_PER_BLOCK;
|
||||
}
|
||||
#endif
|
||||
#endif // MUL_ACC_Q4_1
|
||||
|
||||
#ifdef MUL_ACC_Q8_0
|
||||
#define BLOCK_SIZE_BYTES 34
|
||||
@@ -111,46 +105,48 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE {
|
||||
fn get_dm(block_byte_base: u32) -> f32 {
|
||||
return f32(load_f16_at_src0(block_byte_base));
|
||||
}
|
||||
fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 {
|
||||
return f32(row_sum) * (da * b_ds);
|
||||
}
|
||||
#endif
|
||||
#endif // MUL_ACC_Q8_0
|
||||
|
||||
#ifdef LEGACY_QUANTS
|
||||
fn mmvq_dot_product(a_byte_base: u32, b_inner_id: u32, b_repacked: vec2<u32>, b_ds: B_DS_TYPE) -> f32 {
|
||||
var row_sum = 0;
|
||||
let a_repacked = repack_a(a_byte_base, b_inner_id);
|
||||
|
||||
row_sum += dot4I8Packed(a_repacked[0], b_repacked[0]);
|
||||
row_sum += dot4I8Packed(a_repacked[1], b_repacked[1]);
|
||||
|
||||
return mul_q8_1(row_sum, get_dm(a_byte_base), b_ds);
|
||||
}
|
||||
|
||||
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
|
||||
var acc: array<f32, OUTPUTS_PER_WG>;
|
||||
#if defined(LEGACY_QUANTS)
|
||||
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<array<f32, OUTPUTS_PER_WG>, NUM_COLS> {
|
||||
var acc: array<array<f32, OUTPUTS_PER_WG>, NUM_COLS>;
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
|
||||
for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
|
||||
let b_inner_id = thread_id % THREADS_PER_BLOCK;
|
||||
let b_block_idx = src1q_idx_base + block;
|
||||
|
||||
let b_repacked = repack_b_qs(b_block_idx, b_inner_id);
|
||||
let b_ds = repack_b_dm(b_block_idx);
|
||||
|
||||
let inner_id = thread_id % THREADS_PER_BLOCK;
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
acc[row] += mmvq_dot_product(block_byte_base, b_inner_id, b_repacked, b_ds);
|
||||
let a_repacked = repack_a(block_byte_base, inner_id);
|
||||
let da = get_dm(block_byte_base);
|
||||
for (var col = 0u;col < NUM_COLS;col += 1) {
|
||||
let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + block;
|
||||
let b_repacked = repack_b_qs(src1q_idx, inner_id);
|
||||
let b_ds = repack_b_dm(src1q_idx);
|
||||
|
||||
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]);
|
||||
|
||||
#if defined(MUL_ACC_Q4_0)
|
||||
acc[col][row] += f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK;
|
||||
#endif // MUL_ACC_Q4_0
|
||||
|
||||
#if defined(MUL_ACC_Q4_1)
|
||||
acc[col][row] += f32(row_sum) * (da.x * b_ds.x) + da.y * b_ds.y / THREADS_PER_BLOCK;
|
||||
#endif // MUL_ACC_Q4_1
|
||||
|
||||
#if defined(MUL_ACC_Q8_0)
|
||||
acc[col][row] += f32(row_sum) * (da * b_ds);
|
||||
#endif // MUL_ACC_Q8_0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return acc;
|
||||
}
|
||||
#endif
|
||||
#endif // LEGACY_QUANTS
|
||||
|
||||
#ifdef MUL_ACC_Q2_K
|
||||
#define BLOCK_SIZE_BYTES 84
|
||||
@@ -191,22 +187,7 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> {
|
||||
let scale = byte_of(load_u32_at_src0_aligned(scale_byte), scale_byte & 3u);
|
||||
return vec2<f32>(f32(scale & 0xFu), f32(scale >> 4u));
|
||||
}
|
||||
fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 {
|
||||
let a_repacked = repack_a(a_byte_base, tid);
|
||||
let dm = get_dm(a_byte_base);
|
||||
let scale_min = get_scale_min(a_byte_base, tid);
|
||||
|
||||
let scale_q = i32(scale_min.x);
|
||||
let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u;
|
||||
|
||||
let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1])
|
||||
+ dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q;
|
||||
let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4)
|
||||
+ dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4);
|
||||
|
||||
return b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m));
|
||||
}
|
||||
#endif
|
||||
#endif // MUL_ACC_Q2_K
|
||||
|
||||
#ifdef MUL_ACC_Q4_K
|
||||
#define BLOCK_SIZE_BYTES 144
|
||||
@@ -265,39 +246,52 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> {
|
||||
|
||||
return vec2<f32>(scale, min_val);
|
||||
}
|
||||
fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 {
|
||||
let a_repacked = repack_a(a_byte_base, tid);
|
||||
let dm = get_dm(a_byte_base);
|
||||
let scale_min = get_scale_min(a_byte_base, tid);
|
||||
|
||||
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1])
|
||||
+ dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]);
|
||||
|
||||
// Each thread covers half of the Q8_1 block, so add only b_ds.y/2.
|
||||
return b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD));
|
||||
}
|
||||
#endif
|
||||
#endif // MUL_ACC_Q4_K
|
||||
|
||||
#ifdef K_QUANTS
|
||||
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
|
||||
var acc: array<f32, OUTPUTS_PER_WG>;
|
||||
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<array<f32, OUTPUTS_PER_WG>, NUM_COLS> {
|
||||
var acc: array<array<f32, OUTPUTS_PER_WG>, NUM_COLS>;
|
||||
|
||||
let tid = thread_id % THREADS_PER_BLOCK;
|
||||
|
||||
for (var block = thread_id / THREADS_PER_BLOCK; block < params.k / BLOCK_SIZE; block += WG_SIZE / THREADS_PER_BLOCK) {
|
||||
let src1q_idx = src1q_idx_base + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE;
|
||||
let b_repacked = repack_b_qs(src1q_idx, tid);
|
||||
let b_ds = repack_b_dm(src1q_idx);
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
acc[row] += mmvq_dot_product(block_byte_base, tid, b_repacked, b_ds);
|
||||
let a_repacked = repack_a(block_byte_base, tid);
|
||||
let dm = get_dm(block_byte_base);
|
||||
let scale_min = get_scale_min(block_byte_base, tid);
|
||||
for (var col = 0u;col < NUM_COLS;col += 1) {
|
||||
let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE;
|
||||
let b_repacked = repack_b_qs(src1q_idx, tid);
|
||||
let b_ds = repack_b_dm(src1q_idx);
|
||||
|
||||
#if defined(MUL_ACC_Q2_K)
|
||||
let scale_q = i32(scale_min.x);
|
||||
let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u;
|
||||
|
||||
let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1])
|
||||
+ dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q;
|
||||
let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4)
|
||||
+ dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4);
|
||||
|
||||
acc[col][row] += b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m));
|
||||
#endif // MUL_ACC_Q2_K
|
||||
|
||||
#if defined(MUL_ACC_Q4_K)
|
||||
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1])
|
||||
+ dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]);
|
||||
|
||||
// Each thread covers half of the Q8_1 block, so add only b_ds.y/2.
|
||||
acc[col][row] += b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD));
|
||||
#endif // MUL_ACC_Q4_K
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return acc;
|
||||
}
|
||||
#endif
|
||||
#endif // K_QUANTS
|
||||
|
||||
@@ -9,9 +9,11 @@ requires packed_4x8_integer_dot_product;
|
||||
|
||||
struct Params {
|
||||
offset_src1: u32,
|
||||
stride_11: u32,
|
||||
stride_12: u32,
|
||||
stride_13: u32,
|
||||
ne0: u32,
|
||||
ne1: u32,
|
||||
ne2: u32,
|
||||
ne3: u32,
|
||||
};
|
||||
@@ -57,25 +59,28 @@ fn main(
|
||||
@builtin(num_workgroups) num_wg: vec3<u32>
|
||||
) {
|
||||
let thread_id = local_id.x;
|
||||
let num_vec4 = params.ne0 / 4u;
|
||||
let ne0_vec4 = params.ne0 / 4u;
|
||||
|
||||
let wg_per_vec = (num_vec4 + (WG_SIZE - 1u)) / WG_SIZE;
|
||||
let total_batches = wg_per_vec * params.ne2 * params.ne3;
|
||||
let wg_per_vec = (ne0_vec4 + (WG_SIZE - 1u)) / WG_SIZE;
|
||||
let total_batches = wg_per_vec * params.ne1 * params.ne2 * params.ne3;
|
||||
|
||||
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
|
||||
if (wg_linear >= total_batches) {
|
||||
return;
|
||||
}
|
||||
|
||||
let src13_idx = wg_linear / (params.ne2 * wg_per_vec);
|
||||
let src12_idx = (wg_linear - src13_idx * (params.ne2 * wg_per_vec)) / wg_per_vec;
|
||||
let src11_wg_idx = wg_linear % wg_per_vec;
|
||||
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
|
||||
let vec_idx = wg_linear / wg_per_vec;
|
||||
let src13_idx = vec_idx / (params.ne2 * params.ne1);
|
||||
let vec_ne12_num = vec_idx % (params.ne2 * params.ne1);
|
||||
let src12_idx = vec_ne12_num / params.ne1;
|
||||
let src11_idx = vec_ne12_num % params.ne1;
|
||||
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + src11_idx * params.stride_11;
|
||||
let src1_idx_vec4_base = src1_idx_base / 4u;
|
||||
|
||||
let blocks_per_row = params.ne0 / 32u;
|
||||
let blocks_per_wg = (WG_SIZE * 4u) / 32u;
|
||||
let src1q_idx_base = (src13_idx * params.ne2 + src12_idx) * blocks_per_row;
|
||||
let src1q_idx_base = ((src13_idx * params.ne2 + src12_idx) * params.ne1 + src11_idx) * blocks_per_row;
|
||||
let src11_wg_idx = wg_linear % wg_per_vec;
|
||||
let src1q_idx = src1q_idx_base + src11_wg_idx * blocks_per_wg + thread_id / 8u;
|
||||
let qs_idx = thread_id % 8u;
|
||||
|
||||
@@ -85,7 +90,7 @@ fn main(
|
||||
var thread_amax = 0.0;
|
||||
|
||||
let src11_vec4_idx = src11_wg_idx * WG_SIZE + thread_id;
|
||||
let is_valid = src11_vec4_idx < num_vec4;
|
||||
let is_valid = src11_vec4_idx < ne0_vec4;
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
|
||||
|
||||
@@ -359,6 +359,7 @@ class Keys:
|
||||
CHUNK_SIZE = "clip.audio.chunk_size"
|
||||
CONV_KERNEL_SIZE = "clip.audio.conv_kernel_size"
|
||||
MAX_POS_EMB = "clip.audio.max_pos_emb"
|
||||
FEATURE_LAYERS = "clip.audio.feature_layer" # Granite Speech Plus
|
||||
|
||||
class Attention:
|
||||
HEAD_COUNT = "clip.audio.attention.head_count"
|
||||
|
||||
@@ -1310,6 +1310,9 @@ class GGUFWriter:
|
||||
def add_audio_max_pos_emb(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.MAX_POS_EMB, value)
|
||||
|
||||
def add_audio_feature_layers(self, layers: Sequence[int]) -> None:
|
||||
self.add_array(Keys.ClipAudio.FEATURE_LAYERS, layers)
|
||||
|
||||
def add_audio_projector_window_size(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.Projector.WINDOW_SIZE, value)
|
||||
|
||||
|
||||
@@ -8433,6 +8433,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {1, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 4, k, {3, 2}, {2, 2}));
|
||||
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {2, 1}));
|
||||
@@ -8449,6 +8450,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
|
||||
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 4, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 2, 1, 3}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
|
||||
|
||||
+99
-24
@@ -1562,37 +1562,112 @@ static void test_msgs_oaicompat_json_conversion() {
|
||||
}
|
||||
}
|
||||
|
||||
static void test_split_by_role() {
|
||||
static void test_msg_token_delimiters_split() {
|
||||
LOG_DBG("%s\n", __func__);
|
||||
|
||||
// Delimiters that share a leading token, distinguished by the second token,
|
||||
// to exercise the per-position token matching.
|
||||
const common_chat_msg_delimiters delims = {
|
||||
{ { COMMON_CHAT_ROLE_USER, "", { 10, 11 } },
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, "", { 10, 12 } } }
|
||||
};
|
||||
|
||||
// Empty inputs
|
||||
assert_equals<size_t>(0, common_chat_split_by_role("", {}).size());
|
||||
assert_equals<size_t>(0, common_chat_split_by_role("hello", {}).size());
|
||||
assert_equals<size_t>(0, common_chat_split_by_role("", { { "user", "<|user|>" } }).size());
|
||||
assert_equals<size_t>(0, common_chat_msg_delimiters{}.split({}).spans.size());
|
||||
assert_equals<size_t>(0, common_chat_msg_delimiters{}.split({ 10, 11 }).spans.size());
|
||||
assert_equals<size_t>(0, delims.split({}).spans.size());
|
||||
|
||||
// Multi-role conversation, no leading/trailing content
|
||||
// No delimiters match -> no spans
|
||||
assert_equals<size_t>(0, delims.split({ 100, 101, 102 }).spans.size());
|
||||
|
||||
// Multi-role conversation: <user>Hi<assistant>Hello<user>Bye
|
||||
{
|
||||
const std::string prompt = "<|user|>Hi<|assistant|>Hello<|user|>Bye";
|
||||
const auto splits = common_chat_split_by_role(prompt, {
|
||||
{ "user", "<|user|>" },
|
||||
{ "assistant", "<|assistant|>" },
|
||||
});
|
||||
assert_equals<size_t>(3, splits.size());
|
||||
const llama_tokens tokens = {
|
||||
10, 11, // <user>
|
||||
100, 101, // Hi
|
||||
10, 12, // <assistant>
|
||||
200, 201, 202, // Hello
|
||||
10, 11, // <user>
|
||||
300, 301, // Bye
|
||||
};
|
||||
|
||||
assert_equals<std::string>("user", splits[0].role);
|
||||
assert_equals<size_t>(0, splits[0].pos);
|
||||
assert_equals<size_t>(10, splits[0].len);
|
||||
assert_equals<std::string>("<|user|>Hi", prompt.substr(splits[0].pos, splits[0].len));
|
||||
const auto result = delims.split(tokens);
|
||||
const auto & spans = result.spans;
|
||||
assert_equals<size_t>(3, spans.size());
|
||||
|
||||
assert_equals<std::string>("assistant", splits[1].role);
|
||||
assert_equals<size_t>(10, splits[1].pos);
|
||||
assert_equals<size_t>(18, splits[1].len);
|
||||
assert_equals<std::string>("<|assistant|>Hello", prompt.substr(splits[1].pos, splits[1].len));
|
||||
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
|
||||
assert_equals<size_t>(0, spans[0].pos);
|
||||
assert_equals<size_t>(4, spans[0].len);
|
||||
|
||||
assert_equals<std::string>("user", splits[2].role);
|
||||
assert_equals<size_t>(28, splits[2].pos);
|
||||
assert_equals<size_t>(11, splits[2].len);
|
||||
assert_equals<std::string>("<|user|>Bye", prompt.substr(splits[2].pos, splits[2].len));
|
||||
assert_equals(COMMON_CHAT_ROLE_ASSISTANT, spans[1].role);
|
||||
assert_equals<size_t>(4, spans[1].pos);
|
||||
assert_equals<size_t>(5, spans[1].len);
|
||||
|
||||
assert_equals(COMMON_CHAT_ROLE_USER, spans[2].role);
|
||||
assert_equals<size_t>(9, spans[2].pos);
|
||||
assert_equals<size_t>(4, spans[2].len);
|
||||
|
||||
// is_user_start() is true at the token position where a user span begins
|
||||
assert_equals(true, result.is_user_start(0));
|
||||
assert_equals(false, result.is_user_start(4)); // assistant span
|
||||
assert_equals(true, result.is_user_start(9));
|
||||
}
|
||||
|
||||
// Content before the first delimiter is not captured as a span
|
||||
{
|
||||
const llama_tokens tokens = {
|
||||
500, 501, // leading content (dropped)
|
||||
10, 11, // <user>
|
||||
100, // Hi
|
||||
};
|
||||
|
||||
const auto spans = delims.split(tokens).spans;
|
||||
assert_equals<size_t>(1, spans.size());
|
||||
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
|
||||
assert_equals<size_t>(2, spans[0].pos);
|
||||
assert_equals<size_t>(3, spans[0].len);
|
||||
}
|
||||
|
||||
// Skipped regions (media chunks) are jumped over but still count as span content
|
||||
{
|
||||
const llama_tokens tokens = {
|
||||
10, 11, // <user>
|
||||
LLAMA_TOKEN_NULL, // media chunk (3 tokens)
|
||||
LLAMA_TOKEN_NULL,
|
||||
LLAMA_TOKEN_NULL,
|
||||
100, // Hi
|
||||
10, 12, // <assistant>
|
||||
};
|
||||
|
||||
const std::map<size_t, size_t> skips = { { 2, 3 } };
|
||||
|
||||
const auto spans = delims.split(tokens, skips).spans;
|
||||
assert_equals<size_t>(2, spans.size());
|
||||
|
||||
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
|
||||
assert_equals<size_t>(0, spans[0].pos);
|
||||
assert_equals<size_t>(6, spans[0].len);
|
||||
|
||||
assert_equals(COMMON_CHAT_ROLE_ASSISTANT, spans[1].role);
|
||||
assert_equals<size_t>(6, spans[1].pos);
|
||||
assert_equals<size_t>(2, spans[1].len);
|
||||
}
|
||||
|
||||
// A delimiter sequence inside a skipped region is not matched
|
||||
{
|
||||
const llama_tokens tokens = {
|
||||
10, 11, // <user>
|
||||
10, 12, // skipped region that happens to contain delimiter tokens
|
||||
100, // Hi
|
||||
};
|
||||
|
||||
const std::map<size_t, size_t> skips = { { 2, 2 } };
|
||||
|
||||
const auto spans = delims.split(tokens, skips).spans;
|
||||
assert_equals<size_t>(1, spans.size());
|
||||
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
|
||||
assert_equals<size_t>(0, spans[0].pos);
|
||||
assert_equals<size_t>(5, spans[0].len);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5857,7 +5932,7 @@ int main(int argc, char ** argv) {
|
||||
{
|
||||
test_msg_diffs_compute();
|
||||
test_msgs_oaicompat_json_conversion();
|
||||
test_split_by_role();
|
||||
test_msg_token_delimiters_split();
|
||||
test_tools_oaicompat_json_conversion();
|
||||
test_convert_responses_to_chatcmpl();
|
||||
test_developer_role_to_system_workaround();
|
||||
|
||||
@@ -42,6 +42,7 @@
|
||||
#define KEY_N_HEAD "clip.%s.attention.head_count"
|
||||
#define KEY_N_HEAD_KV "clip.%s.attention.head_count_kv"
|
||||
#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon"
|
||||
#define KEY_FEATURE_LAYERS "clip.%s.feature_layer"
|
||||
|
||||
// vision-specific
|
||||
#define KEY_VISION_PROJ_TYPE "clip.vision.projector_type" // for models with mixed modalities
|
||||
@@ -54,7 +55,6 @@
|
||||
#define KEY_PATCH_SIZE "clip.vision.patch_size"
|
||||
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
|
||||
#define KEY_IMAGE_STD "clip.vision.image_std"
|
||||
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
|
||||
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
|
||||
#define KEY_PROJ_SAMPLE_QUERY_SIDE "clip.vision.projector.query_side"
|
||||
#define KEY_PROJ_SAMPLE_WINDOW_SIDE "clip.vision.projector.window_side"
|
||||
|
||||
@@ -91,7 +91,7 @@ struct clip_hparams {
|
||||
|
||||
float eps = 1e-6;
|
||||
float rope_theta = 0.0;
|
||||
std::vector<int32_t> vision_feature_layer;
|
||||
std::vector<int32_t> feature_layers;
|
||||
int32_t attn_window_size = 0;
|
||||
int32_t n_wa_pattern = 0;
|
||||
std::unordered_set<int32_t> wa_layer_indexes; // explicit layer indexes that use full attention (for irregular patterns like YoutuVL)
|
||||
@@ -165,8 +165,8 @@ struct clip_hparams {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_vision_feature_layer(int32_t layer) const {
|
||||
return std::find(vision_feature_layer.begin(), vision_feature_layer.end(), layer) != vision_feature_layer.end();
|
||||
bool is_feature_layer(int32_t layer) const {
|
||||
return std::find(feature_layers.begin(), feature_layers.end(), layer) != feature_layers.end();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
+9
-10
@@ -1264,12 +1264,10 @@ struct clip_model_loader {
|
||||
}
|
||||
}
|
||||
|
||||
// Load the vision feature layer indices if they are explicitly provided;
|
||||
// if multiple vision feature layers are present, the values will be concatenated
|
||||
// to form the final visual features.
|
||||
// Load the vision/audio feature layer indices if they are explicitly provided
|
||||
// NOTE: gguf conversions should standardize the values of the vision feature layer to
|
||||
// be non-negative, since we use -1 to mark values as unset here.
|
||||
get_arr_int(KEY_FEATURE_LAYER, hparams.vision_feature_layer, false);
|
||||
get_arr_int(string_format(KEY_FEATURE_LAYERS, prefix), hparams.feature_layers, false);
|
||||
|
||||
// model-specific params
|
||||
switch (model.proj_type) {
|
||||
@@ -1651,6 +1649,7 @@ struct clip_model_loader {
|
||||
get_u32(KEY_A_PROJ_WINDOW_SIZE, hparams.audio_proj_window_size);
|
||||
get_u32(KEY_A_PROJ_DOWNSAMPLE_RATE, hparams.audio_proj_downsample_rate);
|
||||
get_u32(KEY_A_PROJ_HEAD_COUNT, hparams.audio_proj_head_count);
|
||||
// NOTE: feature layers loaded above in common path
|
||||
} break;
|
||||
case PROJECTOR_TYPE_JANUS_PRO:
|
||||
{
|
||||
@@ -1663,11 +1662,11 @@ struct clip_model_loader {
|
||||
hparams.image_resize_algo = RESIZE_ALGO_BICUBIC_PILLOW;
|
||||
hparams.image_resize_pad = PAD_CEIL;
|
||||
|
||||
get_arr_int(KEY_FEATURE_LAYER, hparams.vision_feature_layer);
|
||||
// NOTE: feature_layers loaded in common path as optional
|
||||
get_arr_int(KEY_PROJ_SPATIAL_OFFSETS, hparams.proj_spatial_offsets);
|
||||
if (hparams.vision_feature_layer.size() != hparams.proj_spatial_offsets.size()) {
|
||||
throw std::runtime_error(string_format("%s: vision_feature_layer.size() %d != proj_spatial_offsets.size() %d",
|
||||
hparams.vision_feature_layer.size(), hparams.proj_spatial_offsets.size()));
|
||||
if (hparams.feature_layers.size() != hparams.proj_spatial_offsets.size()) {
|
||||
throw std::runtime_error(string_format("%s: feature_layers.size() %d != proj_spatial_offsets.size() %d",
|
||||
hparams.feature_layers.size(), hparams.proj_spatial_offsets.size()));
|
||||
}
|
||||
|
||||
get_u32(KEY_PROJ_SAMPLE_QUERY_SIDE, hparams.downsample_query_side);
|
||||
@@ -2740,7 +2739,7 @@ struct clip_model_loader {
|
||||
model.image_newline = get_tensor(TN_IMAGE_NEWLINE);
|
||||
|
||||
// Load separate layerwise and spatial projector tensors
|
||||
const auto projector_count = hparams.vision_feature_layer.size();
|
||||
const auto projector_count = hparams.feature_layers.size();
|
||||
model.qf_proj_blocks.resize(projector_count);
|
||||
for (size_t bid = 0; bid < projector_count; ++bid) {
|
||||
auto & b = model.qf_proj_blocks[bid];
|
||||
@@ -4388,7 +4387,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, int n_threads, const clip_image_f32
|
||||
|
||||
// Stage 1b only uses block 0's permutations; future stages
|
||||
// will upload all blocks.
|
||||
for (size_t bid = 0; bid < hparams.vision_feature_layer.size(); ++bid) {
|
||||
for (size_t bid = 0; bid < hparams.feature_layers.size(); ++bid) {
|
||||
const std::string prefix = "g4v_blk" + std::to_string(bid) + "_";
|
||||
upload(prefix + "win_idx", make_win_idx(image_side, window_side));
|
||||
upload(prefix + "qwin_idx", make_win_idx(new_side, query_side));
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#include "models.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
ggml_cgraph * clip_graph_granite_speech::build() {
|
||||
const int n_frames = img.nx();
|
||||
const int context_size = hparams.audio_chunk_size;
|
||||
@@ -11,6 +13,10 @@ ggml_cgraph * clip_graph_granite_speech::build() {
|
||||
const int padded_len = num_blocks * context_size;
|
||||
const int remainder = n_frames % context_size;
|
||||
|
||||
// Calculate projector input dimension based on feature layers
|
||||
const int proj_input_dim = n_embd * (hparams.feature_layers.size() + 1);
|
||||
const bool use_feature_concat = !hparams.feature_layers.empty();
|
||||
|
||||
ggml_tensor * attn_dists = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, context_size * context_size);
|
||||
ggml_set_name(attn_dists, "attn_dists");
|
||||
ggml_set_input(attn_dists);
|
||||
@@ -31,6 +37,15 @@ ggml_cgraph * clip_graph_granite_speech::build() {
|
||||
cur = ggml_add(ctx0, cur, model.inp_proj_b);
|
||||
cb(cur, "inp_linear", -1);
|
||||
|
||||
// Capture layer 0 if requested (after input_linear)
|
||||
ggml_tensor * concat_result = nullptr;
|
||||
if (use_feature_concat) {
|
||||
if (std::find(hparams.feature_layers.begin(), hparams.feature_layers.end(), 0) != hparams.feature_layers.end()) {
|
||||
concat_result = cur;
|
||||
cb(concat_result, "feature_layer_0", -1);
|
||||
}
|
||||
}
|
||||
|
||||
for (int il = 0; il < n_layer; il++) {
|
||||
const auto & layer = model.layers[il];
|
||||
auto * residual = cur;
|
||||
@@ -168,6 +183,18 @@ ggml_cgraph * clip_graph_granite_speech::build() {
|
||||
NORM_TYPE_NORMAL, eps, il);
|
||||
cb(cur, "layer_out", il);
|
||||
|
||||
// Capture intermediate layer (il + 1) if requested
|
||||
if (use_feature_concat) {
|
||||
if (hparams.is_feature_layer(il + 1)) {
|
||||
if (concat_result == nullptr) {
|
||||
concat_result = cur;
|
||||
} else {
|
||||
concat_result = ggml_concat(ctx0, concat_result, cur, 0);
|
||||
}
|
||||
cb(concat_result, string_format("feature_layer_%d", il + 1).c_str(), il);
|
||||
}
|
||||
}
|
||||
|
||||
// CTC branch
|
||||
if (il + 1 == ctc_layer) {
|
||||
auto * mid = build_mm(model.ctc_out_w, cur);
|
||||
@@ -180,6 +207,13 @@ ggml_cgraph * clip_graph_granite_speech::build() {
|
||||
}
|
||||
}
|
||||
|
||||
// Append final output to concatenated features if using feature concatenation
|
||||
if (use_feature_concat && concat_result != nullptr) {
|
||||
concat_result = ggml_concat(ctx0, concat_result, cur, 0);
|
||||
cb(concat_result, "concat_final", -1);
|
||||
cur = concat_result;
|
||||
}
|
||||
|
||||
cb(cur, "encoder_out", -1);
|
||||
|
||||
// QFormer projector
|
||||
@@ -197,7 +231,7 @@ ggml_cgraph * clip_graph_granite_speech::build() {
|
||||
cur = ggml_pad(ctx0, cur, 0, padded_proj - n_frames, 0, 0);
|
||||
}
|
||||
|
||||
ggml_tensor * enc_windows = ggml_reshape_3d(ctx0, cur, n_embd, window_size, nblocks_proj);
|
||||
ggml_tensor * enc_windows = ggml_reshape_3d(ctx0, cur, proj_input_dim, window_size, nblocks_proj);
|
||||
|
||||
ggml_tensor * queries = build_norm(model.qf_proj_blocks[0].qf_proj_query,
|
||||
model.qf_proj_blocks[0].qf_proj_norm_w, model.qf_proj_blocks[0].qf_proj_norm_b,
|
||||
|
||||
@@ -304,14 +304,14 @@ ggml_cgraph * clip_graph_granite4_vision::build() {
|
||||
}
|
||||
|
||||
// --- Stage 1b/1c: WindowQFormer blocks ---
|
||||
const int projector_count = hparams.vision_feature_layer.size();
|
||||
const int projector_count = hparams.feature_layers.size();
|
||||
const float qformer_eps = 1e-12f;
|
||||
|
||||
ggml_tensor * mmproj = nullptr;
|
||||
for (int bid = 0; bid < projector_count; ++bid) {
|
||||
const auto & blk = model.qf_proj_blocks[bid];
|
||||
|
||||
int vlayer = hparams.vision_feature_layer[bid];
|
||||
int vlayer = hparams.feature_layers[bid];
|
||||
GGML_ASSERT(vlayer >= 0 && vlayer < n_layer);
|
||||
ggml_tensor * h = layer_outs[vlayer];
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ ggml_cgraph * clip_graph_llava::build() {
|
||||
|
||||
// If we set explicit vision feature layers, only go up to the deepest one
|
||||
// NOTE: only used by granite-vision models for now
|
||||
for (const auto & feature_layer : hparams.vision_feature_layer) {
|
||||
for (const auto & feature_layer : hparams.feature_layers) {
|
||||
if (feature_layer > deepest_feature_layer) {
|
||||
deepest_feature_layer = feature_layer;
|
||||
}
|
||||
@@ -59,7 +59,7 @@ ggml_cgraph * clip_graph_llava::build() {
|
||||
|
||||
// If this is an embedding feature layer, save the output.
|
||||
// NOTE: 0 index here refers to the input to the encoder.
|
||||
if (hparams.is_vision_feature_layer(il)) {
|
||||
if (hparams.is_feature_layer(il)) {
|
||||
embedding_stack.push_back(cur);
|
||||
}
|
||||
|
||||
@@ -134,7 +134,7 @@ ggml_cgraph * clip_graph_llava::build() {
|
||||
// process vision feature layers (used by granite)
|
||||
{
|
||||
// final layer is a vision feature layer
|
||||
if (hparams.is_vision_feature_layer(max_feature_layer)) {
|
||||
if (hparams.is_feature_layer(max_feature_layer)) {
|
||||
embedding_stack.push_back(inpL);
|
||||
}
|
||||
|
||||
|
||||
@@ -518,6 +518,14 @@ size_t server_tokens::get_common_prefix(const server_tokens & b) const {
|
||||
return max_idx; // all tokens are equal
|
||||
}
|
||||
|
||||
common_chat_msg_spans server_tokens::find_message_spans(const common_chat_msg_delimiters & delims) const {
|
||||
std::map<size_t, size_t> skips;
|
||||
for (const auto & it : map_idx_to_media) {
|
||||
skips[it.first] = mtmd_input_chunk_get_n_tokens(it.second.get());
|
||||
}
|
||||
return delims.split(tokens, skips);
|
||||
}
|
||||
|
||||
bool server_tokens::validate(const struct llama_context * ctx) const {
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
@@ -1104,15 +1112,7 @@ json oaicompat_chat_params_parse(
|
||||
llama_params["chat_parser"] = chat_params.parser;
|
||||
}
|
||||
|
||||
llama_params["message_spans"] = json::array();
|
||||
|
||||
for (const auto & span : chat_params.message_spans) {
|
||||
llama_params["message_spans"].push_back({
|
||||
{ "role", span.role },
|
||||
{ "pos", span.pos },
|
||||
{ "len", span.len },
|
||||
});
|
||||
}
|
||||
llama_params["message_delimiters"] = chat_params.message_delimiters.to_json();
|
||||
|
||||
// Reasoning budget: pass parameters through to sampling layer
|
||||
{
|
||||
|
||||
@@ -218,6 +218,9 @@ public:
|
||||
|
||||
size_t get_common_prefix(const server_tokens & b) const;
|
||||
|
||||
// split the tokens into message spans, skipping over media chunks
|
||||
common_chat_msg_spans find_message_spans(const common_chat_msg_delimiters & delims) const;
|
||||
|
||||
// make sure all text tokens are within the vocab range
|
||||
bool validate(const struct llama_context * ctx) const;
|
||||
|
||||
|
||||
@@ -3436,8 +3436,8 @@ private:
|
||||
has_mtmd = true;
|
||||
}
|
||||
|
||||
const int32_t n_before_user = slot.task->params.n_before_user;
|
||||
const bool n_before_user_known = n_before_user > 0;
|
||||
const auto & spans = slot.task->params.message_spans;
|
||||
const auto last_user_pos = spans.last_user_message_pos();
|
||||
|
||||
// add prompt tokens for processing in the current batch
|
||||
while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.size() < n_batch) {
|
||||
@@ -3466,10 +3466,8 @@ private:
|
||||
|
||||
slot.n_prompt_tokens_processed++;
|
||||
|
||||
// stop the prompt batch exactly before the latest user input, so a checkpoint
|
||||
// can be created after the previous messages
|
||||
if (n_before_user_known &&
|
||||
slot.prompt.n_tokens() == n_before_user) {
|
||||
// stop the prompt batch exactly before a user message
|
||||
if (spans.is_user_start(slot.prompt.n_tokens())) {
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -3498,8 +3496,13 @@ private:
|
||||
// the number of tokens added to the batch for the current slot
|
||||
const auto n_tokens_cur = batch.size() - n_tokens_prev;
|
||||
|
||||
const auto n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur;
|
||||
|
||||
const bool near_prompt_end = slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch;
|
||||
|
||||
const bool is_user_start = spans.is_user_start(n_tokens_start);
|
||||
const bool is_last_user_message = n_tokens_start == last_user_pos;
|
||||
|
||||
// entire prompt has been processed
|
||||
if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
|
||||
slot.state = SLOT_STATE_DONE_PROMPT;
|
||||
@@ -3514,8 +3517,9 @@ private:
|
||||
|
||||
slot.init_sampler();
|
||||
} else {
|
||||
// skip ordinary mid-prompt checkpoints
|
||||
if (!n_before_user_known && !near_prompt_end) {
|
||||
// skip ordinary mid-prompt checkpoints, unless the batch starts a user
|
||||
// message or we are near the end of the prompt
|
||||
if (!is_user_start && !near_prompt_end) {
|
||||
do_checkpoint = false;
|
||||
}
|
||||
}
|
||||
@@ -3523,29 +3527,6 @@ private:
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id);
|
||||
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id);
|
||||
|
||||
// checkpoints are created before the current batch is decoded, so
|
||||
// their token position is the batch start rather than the prompt end
|
||||
const int32_t n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur;
|
||||
|
||||
{
|
||||
const bool is_on_user =
|
||||
n_before_user_known &&
|
||||
n_tokens_start == n_before_user;
|
||||
|
||||
const bool is_after_user =
|
||||
n_before_user_known &&
|
||||
n_tokens_start > n_before_user;
|
||||
|
||||
const bool is_allowed =
|
||||
!n_before_user_known ||
|
||||
is_on_user ||
|
||||
(is_after_user && near_prompt_end);
|
||||
|
||||
if (do_checkpoint && !is_allowed) {
|
||||
do_checkpoint = false;
|
||||
}
|
||||
}
|
||||
|
||||
// nothing to checkpoint yet
|
||||
// TODO: is this check needed?
|
||||
if (do_checkpoint && pos_min < 0) {
|
||||
@@ -3555,8 +3536,8 @@ private:
|
||||
// do not checkpoint after mtmd chunks
|
||||
do_checkpoint = do_checkpoint && !has_mtmd;
|
||||
|
||||
// no need to create checkpoints that are too close together
|
||||
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || n_tokens_start > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step);
|
||||
// no need to create checkpoints that are too close together, unless it's the last user message
|
||||
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || is_last_user_message || n_tokens_start > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step);
|
||||
SLT_DBG(slot, "main/do_checkpoint = %s, pos_min = %d, pos_max = %d\n", do_checkpoint ? "yes" : "no", pos_min, pos_max);
|
||||
|
||||
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
|
||||
@@ -4055,54 +4036,6 @@ void server_context::set_state_callback(server_state_callback_t callback) {
|
||||
});
|
||||
}
|
||||
|
||||
// compute the number of tokens before the last user message in the prompt
|
||||
static int32_t prompt_get_n_before_user(
|
||||
const json & message_spans,
|
||||
const std::string & prompt,
|
||||
const std::vector<raw_buffer> & files,
|
||||
const llama_vocab * vocab,
|
||||
mtmd_context * mctx) {
|
||||
int32_t result = -1;
|
||||
int32_t byte_pos = -1;
|
||||
|
||||
for (const auto & span : message_spans) {
|
||||
const std::string role = json_value(span, "role", std::string());
|
||||
|
||||
if (role == "user") {
|
||||
byte_pos = json_value(span, "pos", -1);
|
||||
}
|
||||
}
|
||||
|
||||
if (byte_pos >= 0) {
|
||||
GGML_ASSERT((size_t) byte_pos <= prompt.size());
|
||||
|
||||
const std::string prefix = prompt.substr(0, (size_t) byte_pos);
|
||||
|
||||
const std::string marker = get_media_marker();
|
||||
size_t n_prefix_media = 0;
|
||||
for (size_t pos = 0; (pos = prefix.find(marker, pos)) != std::string::npos; pos += marker.size()) {
|
||||
n_prefix_media++;
|
||||
}
|
||||
|
||||
GGML_ASSERT(n_prefix_media <= files.size());
|
||||
|
||||
if (mctx != nullptr && n_prefix_media > 0) {
|
||||
// TODO: this makes a copy - avoid it
|
||||
std::vector<raw_buffer> prefix_files(files.begin(), files.begin() + n_prefix_media);
|
||||
|
||||
result = (int32_t) process_mtmd_prompt(mctx, prefix, prefix_files).size();
|
||||
} else {
|
||||
result = (int32_t) tokenize_input_prompts(vocab, nullptr, prefix, true, true)[0].size();
|
||||
}
|
||||
|
||||
SRV_TRC("message_spans: last user message: byte_pos=%d, media=%zu, n_before_user=%d\n",
|
||||
byte_pos, n_prefix_media, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// server_routes
|
||||
//
|
||||
@@ -4150,6 +4083,10 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
|
||||
|
||||
// tasks.reserve(inputs.size()); // TODO: this is inaccurate due to child tasks
|
||||
|
||||
// message delimiters for checkpointing
|
||||
auto delimiters = common_chat_msg_delimiters_parse(json_value(data, "message_delimiters", json::array()));
|
||||
delimiters.tokenize(ctx_server.vocab);
|
||||
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
server_task task = server_task(type);
|
||||
|
||||
@@ -4163,16 +4100,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
|
||||
meta->logit_bias_eog,
|
||||
data);
|
||||
|
||||
const auto message_spans = json_value(data, "message_spans", json::array());
|
||||
if (prompt.is_string() && message_spans.is_array()) {
|
||||
task.params.n_before_user =
|
||||
prompt_get_n_before_user(
|
||||
message_spans,
|
||||
prompt.get<std::string>(),
|
||||
files,
|
||||
ctx_server.vocab,
|
||||
ctx_server.mctx);
|
||||
}
|
||||
task.params.message_spans = task.tokens.find_message_spans(delimiters);
|
||||
|
||||
task.id_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
|
||||
@@ -591,10 +591,11 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp() {
|
||||
|
||||
for (const common_chat_tool_call & tool_call : oaicompat_msg.tool_calls) {
|
||||
output.push_back(json {
|
||||
{"id", "fc_" + tool_call.id},
|
||||
{"type", "function_call"},
|
||||
{"status", "completed"},
|
||||
{"arguments", tool_call.arguments},
|
||||
{"call_id", "fc_" + tool_call.id},
|
||||
{"call_id", "call_" + tool_call.id},
|
||||
{"name", tool_call.name},
|
||||
});
|
||||
}
|
||||
@@ -690,10 +691,11 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp_stream() {
|
||||
|
||||
for (const common_chat_tool_call & tool_call : oaicompat_msg.tool_calls) {
|
||||
const json output_item = {
|
||||
{"id", "fc_" + tool_call.id},
|
||||
{"type", "function_call"},
|
||||
{"status", "completed"},
|
||||
{"arguments", tool_call.arguments},
|
||||
{"call_id", "fc_" + tool_call.id},
|
||||
{"call_id", "call_" + tool_call.id},
|
||||
{"name", tool_call.name}
|
||||
};
|
||||
server_sent_events.push_back(json {
|
||||
@@ -1277,8 +1279,9 @@ json server_task_result_cmpl_partial::to_json_oaicompat_resp() {
|
||||
{"data", json {
|
||||
{"type", "response.output_item.added"},
|
||||
{"item", json {
|
||||
{"id", "fc_" + diff.tool_call_delta.id},
|
||||
{"arguments", ""},
|
||||
{"call_id", "fc_" + diff.tool_call_delta.id},
|
||||
{"call_id", "call_" + diff.tool_call_delta.id},
|
||||
{"name", diff.tool_call_delta.name},
|
||||
{"type", "function_call"},
|
||||
{"status", "in_progress"},
|
||||
|
||||
@@ -62,9 +62,6 @@ struct task_params {
|
||||
|
||||
int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled)
|
||||
|
||||
// number of prompt tokens before the latest user message
|
||||
int32_t n_before_user = -1;
|
||||
|
||||
int64_t t_max_prompt_ms = -1; // TODO: implement
|
||||
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
||||
|
||||
@@ -92,6 +89,9 @@ struct task_params {
|
||||
// per-request parameters for chat parsing
|
||||
common_chat_parser_params chat_parser_params;
|
||||
|
||||
// message spans for checkpointing
|
||||
common_chat_msg_spans message_spans;
|
||||
|
||||
// Embeddings
|
||||
int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
|
||||
|
||||
|
||||
@@ -545,7 +545,8 @@ class ModelsStore {
|
||||
* 1. Model from active conversation's last assistant response (if loaded)
|
||||
* 2. Model from active conversation's last assistant response (if not loaded)
|
||||
* 3. First loaded model (not from active conversation)
|
||||
* 4. First available model
|
||||
* 4. A favorite model
|
||||
* 5. First available model
|
||||
*/
|
||||
async ensureFirstModelSelected(): Promise<void> {
|
||||
if (this.selectedModelName) return;
|
||||
@@ -574,6 +575,13 @@ class ModelsStore {
|
||||
return;
|
||||
}
|
||||
|
||||
// Try loading a favorite model
|
||||
const favorite = this.favoriteModelIds.values().next()?.value
|
||||
if (favorite) {
|
||||
await this.selectModelById(favorite);
|
||||
return;
|
||||
}
|
||||
|
||||
// Fall back to the first available model
|
||||
await this.selectModelById(availableModels[0].id);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user