Compare commits

..

4 Commits

Author SHA1 Message Date
Xuan Son Nguyen 095058ca19 add arg --threads-sampling 2026-06-22 20:03:49 +02:00
Xuan Son Nguyen c62fdd5fd0 working 2026-06-22 19:38:25 +02:00
Xuan Son Nguyen 41ed530be2 wip 2026-06-22 19:30:11 +02:00
Xuan Son Nguyen fe03cce8db server: run sampling in a threadpool 2026-06-22 19:05:39 +02:00
152 changed files with 3517 additions and 4527 deletions
+1 -1
View File
@@ -10,7 +10,7 @@
# ggml-org/ggml-rpc : rgerganov
# ggml-org/ggml-sycl : arthw
# ggml-org/ggml-vulkan : 0cc4m, jeffbolznv
# ggml-org/ggml-webgpu : reeselevine, yomaytk
# ggml-org/ggml-webgpu : reeselevine
# ggml-org/ggml-zdnn : taronaeo
# ggml-org/llama-common : ggerganov, aldehir, angt, danbev, ngxson, pwilkin
# ggml-org/llama-mtmd : ngxson
+1 -3
View File
@@ -142,9 +142,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct)
- [X] [Trillion-7B-preview](https://huggingface.co/trillionlabs/Trillion-7B-preview)
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
- [x] [Liquid LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2)
- [x] [Liquid LFM2.5 models](https://huggingface.co/collections/LiquidAI/lfm25)
- [x] [Liquid Nanos](https://huggingface.co/collections/LiquidAI/liquid-nanos)
- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38)
- [x] [Hunyuan models](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7)
- [x] [BailingMoeV2 (Ring/Ling 2.0) models](https://huggingface.co/collections/inclusionAI/ling-v2-68bf1dd2fc34c306c1fa6f86)
- [x] [Mellum models](https://huggingface.co/JetBrains/models?search=mellum)
+14 -7
View File
@@ -301,8 +301,6 @@ static handle_model_result common_params_handle_model(struct common_params_model
const common_download_opts & opts) {
handle_model_result result;
// TODO @ngxson : refactor this into a new common_model_download_context
if (!model.docker_repo.empty()) {
model.path = common_docker_resolve_model(model.docker_repo);
} else if (!model.hf_repo.empty()) {
@@ -398,7 +396,7 @@ static bool parse_bool_value(const std::string & value) {
// CLI argument parsing functions
//
bool common_params_handle_models(common_params & params, llama_example curr_ex, const common_params_handle_models_params & handle_params) {
bool common_params_handle_models(common_params & params, llama_example curr_ex, common_download_callback * callback) {
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
params.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();
@@ -409,10 +407,9 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex,
opts.skip_download = params.skip_download;
opts.download_mtp = spec_type_draft_mtp;
opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty();
opts.preset_only = handle_params.preset_only;
if (handle_params.callback) {
opts.callback = handle_params.callback;
if (callback) {
opts.callback = callback;
}
// sub-models (draft, mmproj, vocoder) are explicitly specified by the user,
@@ -599,7 +596,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
if (!skip_model_download) {
// handle model and download
common_params_handle_models(params, ctx_arg.ex, {});
common_params_handle_models(params, ctx_arg.ex);
// model is required (except for server)
// TODO @ngxson : maybe show a list of available models in CLI in this case
@@ -1171,6 +1168,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
));
add_opt(common_arg(
{"--threads-sampling"}, "N",
"number of threads to use during sampling (default: same as --threads)",
[](common_params & params, int value) {
params.sampling_n_threads = value;
if (params.sampling_n_threads <= 0) {
params.sampling_n_threads = std::thread::hardware_concurrency();
}
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-C", "--cpu-mask"}, "M",
"CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: \"\")",
+1 -6
View File
@@ -130,11 +130,6 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
// see: https://github.com/ggml-org/llama.cpp/issues/18163
void common_params_add_preset_options(std::vector<common_arg> & args);
struct common_params_handle_models_params {
common_download_callback * callback = nullptr;
bool preset_only = false; // if true, only check & download remote preset (for router mode)
};
// populate model paths (main model, mmproj, etc) from -hf if necessary
// return true if the model is ready to use
// throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc)
@@ -142,7 +137,7 @@ struct common_params_handle_models_params {
bool common_params_handle_models(
common_params & params,
llama_example curr_ex,
const common_params_handle_models_params & handle_params);
common_download_callback * callback = nullptr);
// initialize argument parser context - used by test-arg-parser and preset
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
+53 -103
View File
@@ -90,93 +90,41 @@ std::string common_chat_msg::render_content(const std::string & delimiter) const
return text;
}
common_chat_role common_chat_role_from_string(const std::string & role) {
if (role == "system") { return COMMON_CHAT_ROLE_SYSTEM; }
if (role == "assistant") { return COMMON_CHAT_ROLE_ASSISTANT; }
if (role == "user") { return COMMON_CHAT_ROLE_USER; }
if (role == "tool") { return COMMON_CHAT_ROLE_TOOL; }
return COMMON_CHAT_ROLE_UNKNOWN;
}
const char * common_chat_role_to_string(common_chat_role role) {
switch (role) {
case COMMON_CHAT_ROLE_SYSTEM: return "system";
case COMMON_CHAT_ROLE_ASSISTANT: return "assistant";
case COMMON_CHAT_ROLE_USER: return "user";
case COMMON_CHAT_ROLE_TOOL: return "tool";
case COMMON_CHAT_ROLE_UNKNOWN: return "";
}
return "";
}
json common_chat_msg_delimiters::to_json() const {
json result = json::array();
for (const auto & d : delimiters) {
result.push_back({
{ "role", common_chat_role_to_string(d.role) },
{ "delimiter", d.delimiter },
});
}
return result;
}
common_chat_msg_delimiters common_chat_msg_delimiters_parse(const json & delimiters) {
common_chat_msg_delimiters result;
if (!delimiters.is_array()) {
return result;
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims) {
if (delims.empty() || prompt.empty()) {
return {};
}
result.delimiters.reserve(delimiters.size());
for (const auto & d : delimiters) {
if (!d.is_object()) {
continue;
auto parser = build_peg_parser([&](common_peg_parser_builder & p) {
std::vector<std::string> all_delims;
std::vector<common_peg_parser> tagged_messages;
all_delims.reserve(delims.size());
tagged_messages.reserve(delims.size());
for (const auto & d : delims) {
all_delims.push_back(d.delimiter);
}
result.delimiters.push_back({
common_chat_role_from_string(d.value("role", std::string())),
d.value("delimiter", std::string()),
});
}
return result;
}
void common_chat_msg_delimiters::tokenize(const llama_vocab * vocab) {
for (auto & d : delimiters) {
d.tokens = common_tokenize(vocab, d.delimiter, false, true);
}
}
common_chat_msg_spans common_chat_msg_delimiters::split(const llama_tokens & tokens, const std::map<size_t, size_t> & skips) const {
std::vector<std::pair<common_chat_role, size_t>> matches;
auto skip = skips.begin();
for (size_t i = 0; i < tokens.size();) {
if (skip != skips.end() && i == skip->first) {
i += skip->second;
++skip;
continue;
auto any_delim = p.until_one_of(all_delims);
for (const auto & d : delims) {
tagged_messages.push_back(p.tag(d.role, p.literal(d.delimiter) + any_delim));
}
for (const auto & d : delimiters) {
if (i + d.tokens.size() > tokens.size()) {
continue;
}
if (std::equal(d.tokens.begin(), d.tokens.end(), tokens.begin() + i)) {
matches.emplace_back(d.role, i);
break;
}
return any_delim + p.zero_or_more(p.choice(tagged_messages)) + p.end();
});
common_peg_parse_context ctx(prompt);
const auto result = parser.parse(ctx);
if (!result.success()) {
return {};
}
std::vector<common_chat_msg_span> spans;
ctx.ast.visit(result, [&](const common_peg_ast_node & node) {
if (!node.tag.empty()) {
spans.push_back({ node.tag, node.start, node.end - node.start });
}
i++;
}
matches.emplace_back(COMMON_CHAT_ROLE_UNKNOWN, tokens.size());
common_chat_msg_spans spans;
for (size_t i = 0; i + 1 < matches.size(); i++) {
const auto & curr = matches[i];
const auto & next = matches[i + 1];
spans.add(curr.first, curr.second, next.second - curr.second);
}
});
return spans;
}
@@ -1133,13 +1081,13 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
data.prompt = prompt;
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override= */ adjusted_messages);
data.message_delimiters = {
{ COMMON_CHAT_ROLE_ASSISTANT, "<|start|>assistant" },
{ COMMON_CHAT_ROLE_USER, "<|start|>user" },
{ COMMON_CHAT_ROLE_SYSTEM, "<|start|>developer" },
{ COMMON_CHAT_ROLE_SYSTEM, "<|start|>system" },
{ COMMON_CHAT_ROLE_TOOL, "<|start|>functions" },
};
data.message_spans = common_chat_split_by_role(prompt, {
{ "assistant", "<|start|>assistant" },
{ "user", "<|start|>user" },
{ "system", "<|start|>developer" },
{ "system", "<|start|>system" },
{ "tool", "<|start|>functions" },
});
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;
@@ -1280,10 +1228,10 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
data.prompt += data.generation_prompt;
}
data.message_delimiters = {
{ COMMON_CHAT_ROLE_USER, "<|turn>user" },
{ COMMON_CHAT_ROLE_ASSISTANT, "<|turn>model" },
};
data.message_spans = common_chat_split_by_role(data.prompt, {
{ "user", "<|turn>user\n" },
{ "assistant", "<|turn>model\n" },
});
data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4;
data.supports_thinking = true;
@@ -2082,15 +2030,15 @@ static common_chat_params common_chat_params_init_cohere2moe(const common_chat_t
RESULT_START, RESULT_END,
};
// Declare per-role message delimiters. Tool results are rendered with the
// Split the rendered prompt into per-role message spans. Tool results are rendered with the
// system token followed by <|START_TOOL_RESULT|>, so the "tool" delimiter must be listed before
// the plain "system" one (it is a strict superset, and the role split tries delimiters in order).
data.message_delimiters = {
{ COMMON_CHAT_ROLE_ASSISTANT, GEN_PREFIX },
{ COMMON_CHAT_ROLE_USER, TURN_START + USER },
{ COMMON_CHAT_ROLE_TOOL, TURN_START + SYSTEM + RESULT_START },
{ COMMON_CHAT_ROLE_SYSTEM, TURN_START + SYSTEM },
};
data.message_spans = common_chat_split_by_role(data.prompt, {
{ "assistant", GEN_PREFIX },
{ "user", TURN_START + USER },
{ "tool", TURN_START + SYSTEM + RESULT_START },
{ "system", TURN_START + SYSTEM },
});
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
@@ -2578,15 +2526,17 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
autoparser.analyze_template(tmpl);
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
common_chat_msg_delimiters delimiters;
std::vector<common_chat_msg_delimiter> delimiters;
if (!autoparser.assistant_start.empty()) {
delimiters.add(COMMON_CHAT_ROLE_ASSISTANT, autoparser.assistant_start);
delimiters.push_back({ "assistant", autoparser.assistant_start });
}
if (!autoparser.user_start.empty()) {
delimiters.add(COMMON_CHAT_ROLE_USER, autoparser.user_start);
delimiters.push_back({ "user", autoparser.user_start });
}
auto_params.message_delimiters = std::move(delimiters);
if (!delimiters.empty()) {
auto_params.message_spans = common_chat_split_by_role(auto_params.prompt, delimiters);
}
auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
if (auto_params.supports_thinking) {
+6 -65
View File
@@ -143,75 +143,15 @@ struct common_chat_msg_diff {
}
};
enum common_chat_role {
COMMON_CHAT_ROLE_UNKNOWN,
COMMON_CHAT_ROLE_SYSTEM,
COMMON_CHAT_ROLE_ASSISTANT,
COMMON_CHAT_ROLE_USER,
COMMON_CHAT_ROLE_TOOL
};
common_chat_role common_chat_role_from_string(const std::string & role);
const char * common_chat_role_to_string(common_chat_role role);
struct common_chat_msg_span {
common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN;
std::string role;
std::size_t pos = 0;
std::size_t len = 0;
bool valid() const {
return role != COMMON_CHAT_ROLE_UNKNOWN;
}
};
struct common_chat_msg_spans {
std::vector<common_chat_msg_span> spans;
void add(common_chat_role role, size_t pos, size_t len) {
spans.push_back({ role, pos, len });
}
bool is_user_start(int32_t pos) const {
for (auto it = spans.begin(); it != spans.end(); ++it) {
if (it->role == COMMON_CHAT_ROLE_USER && pos == (int32_t) it->pos) {
return true;
}
}
return false;
}
int32_t last_user_message_pos() const {
for (auto it = spans.rbegin(); it != spans.rend(); ++it) {
if (it->role == COMMON_CHAT_ROLE_USER) {
return (int32_t) it->pos;
}
}
return -1;
}
};
struct common_chat_msg_delimiter {
common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN;
std::string delimiter;
llama_tokens tokens = {};
};
struct common_chat_msg_delimiters {
std::vector<common_chat_msg_delimiter> delimiters;
common_chat_msg_delimiters() = default;
common_chat_msg_delimiters(std::initializer_list<common_chat_msg_delimiter> delims) : delimiters(delims) {}
void add(common_chat_role role, const std::string & delimiter) {
delimiters.push_back({ role, delimiter });
}
void tokenize(const llama_vocab * vocab);
// split tokens into message spans. skips maps a start index to a length of a region to jump over without matching
common_chat_msg_spans split(const llama_tokens & tokens, const std::map<size_t, size_t> & skips = {}) const;
nlohmann::ordered_json to_json() const;
std::string role;
std::string delimiter;
};
struct common_chat_tool {
@@ -279,7 +219,7 @@ struct common_chat_params {
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;
std::string parser;
common_chat_msg_delimiters message_delimiters;
std::vector<common_chat_msg_span> message_spans;
};
// per-message parsing syntax
@@ -385,4 +325,5 @@ struct common_chat_prompt_preset {
common_chat_prompt_preset common_chat_get_asr_prompt(const common_chat_templates * chat_templates);
common_chat_msg_delimiters common_chat_msg_delimiters_parse(const nlohmann::ordered_json & delimiters);
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims);
+3 -1
View File
@@ -471,6 +471,8 @@ struct common_params {
common_cpu_params cpuparams;
common_cpu_params cpuparams_batch;
int sampling_n_threads = -1; // number of threads for sampling, used by server
ggml_backend_sched_eval_callback cb_eval = nullptr;
void * cb_eval_user_data = nullptr;
@@ -609,7 +611,7 @@ struct common_params {
bool cache_prompt = true; // whether to enable prompt caching
bool cache_idle_slots = true; // save and clear idle slots upon starting a new task
int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot
int32_t checkpoint_min_step = 8192; // minimum spacing between context checkpoints
int32_t checkpoint_min_step = 256; // minimum spacing between context checkpoints
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
std::string hostname = "127.0.0.1";
+1 -3
View File
@@ -799,7 +799,6 @@ common_download_model_result common_download_model(const common_params_model &
bool download_mmproj = opts.download_mmproj;
bool download_mtp = opts.download_mtp;
bool preset_only = opts.preset_only;
bool is_hf = !model.hf_repo.empty();
if (is_hf) {
@@ -807,8 +806,7 @@ common_download_model_result common_download_model(const common_params_model &
if (!hf.preset.path.empty()) {
// if preset.ini exists, only download that file alone
tasks.push_back({hf.preset.url, hf.preset.local_path});
} else if (!preset_only) {
// only add other files if we're NOT in preset-only mode (normal run, non-router)
} else {
for (const auto & f : hf.model_files) {
tasks.push_back({f.url, f.local_path});
}
-1
View File
@@ -55,7 +55,6 @@ struct common_download_opts {
bool skip_download = false; // if true, only validation is performed, common_skip_download_exception may be thrown if the file is missing or invalid
bool download_mmproj = false;
bool download_mtp = false;
bool preset_only = false; // if true, only check & download remote preset (for router mode)
common_download_callback * callback = nullptr;
};
-3
View File
@@ -96,7 +96,6 @@ TEXT_MODEL_MAP: dict[str, str] = {
"GraniteMoeHybridForCausalLM": "granite",
"GraniteMoeSharedForCausalLM": "granite",
"GraniteSpeechForConditionalGeneration": "granite",
"GraniteSpeechPlusForConditionalGeneration": "granite",
"Grok1ForCausalLM": "grok",
"GrokForCausalLM": "grok",
"GroveMoeForCausalLM": "grovemoe",
@@ -124,7 +123,6 @@ TEXT_MODEL_MAP: dict[str, str] = {
"LLaDAModelLM": "llada",
"LLaMAForCausalLM": "llama",
"Lfm25AudioTokenizer": "lfm2",
"Lfm2BidirectionalModel": "lfm2",
"Lfm2ForCausalLM": "lfm2",
"Lfm2Model": "lfm2",
"Lfm2MoeForCausalLM": "lfm2",
@@ -263,7 +261,6 @@ MMPROJ_MODEL_MAP: dict[str, str] = {
"GlmasrModel": "ultravox",
"Granite4VisionForConditionalGeneration": "granite",
"GraniteSpeechForConditionalGeneration": "granite",
"GraniteSpeechPlusForConditionalGeneration": "granite",
"HunYuanVLForConditionalGeneration": "hunyuan",
"Idefics3ForConditionalGeneration": "smolvlm",
"InternVisionModel": "internvl",
-28
View File
@@ -348,34 +348,6 @@ class GraniteSpeechMmprojModel(MmprojModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("GraniteSpeechPlusForConditionalGeneration")
class GraniteSpeechPlusMmprojModel(GraniteSpeechMmprojModel):
"""Conversion for GraniteSpeechPlus - extends GraniteSpeech with feature layer concatenation"""
has_vision_encoder = False
has_audio_encoder = True
def set_gguf_parameters(self):
assert self.hparams_audio is not None
super().set_gguf_parameters()
# Add feature_layer if present in encoder config
if feature_layers := self.hparams_audio.get("cat_hidden_layers"):
self.gguf_writer.add_audio_feature_layers(feature_layers)
logger.info(f"gguf: audio feature_layers = {feature_layers}")
# Validate projector dimension matches concatenated encoder output
hidden_dim = self.hparams_audio["hidden_dim"]
expected_dim = hidden_dim * (len(feature_layers) + 1)
projector_dim = self.global_config["projector_config"]["encoder_hidden_size"]
if projector_dim != expected_dim:
raise ValueError(
f"Projector encoder_hidden_size ({projector_dim}) does not match "
f"expected concatenated dimension ({expected_dim}). "
f"Expected: hidden_dim ({hidden_dim}) * (len(feature_layers) + 1) = {expected_dim}"
)
@ModelBase.register("Granite4VisionForConditionalGeneration")
class Granite4VisionMmprojModel(MmprojModel):
has_vision_encoder = True
+3 -10
View File
@@ -64,17 +64,11 @@ class LFM2Model(TextModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Lfm2Model", "Lfm2BidirectionalModel")
@ModelBase.register("Lfm2Model")
class LFM2ColBertModel(LFM2Model):
model_arch = gguf.MODEL_ARCH.LFM2
dense_tensor_name = "dense_2"
def set_gguf_parameters(self):
super().set_gguf_parameters()
if self.hf_arch == "Lfm2BidirectionalModel":
self.gguf_writer.add_causal_attention(False)
self._try_set_pooling_type()
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if not name.startswith(self.dense_tensor_name):
name = "model." + name
@@ -82,11 +76,10 @@ class LFM2ColBertModel(LFM2Model):
yield from super().modify_tensors(data_torch, name, bid)
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
# optional dense tensor is stored in a separate safetensors file
# dense tensor is stored in a separate safetensors file
from safetensors.torch import load_file
tensors_file = self.dir_model / "1_Dense" / "model.safetensors"
if not tensors_file.is_file():
return
assert tensors_file.is_file()
tensor = load_file(tensors_file)["linear.weight"]
self.gguf_writer.add_embedding_length_out(tensor.shape[0])
yield f"{self.dense_tensor_name}.weight", tensor.clone()
+23 -50
View File
@@ -3688,6 +3688,8 @@ static void ggml_compute_forward_norm_f32(
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
@@ -3701,49 +3703,25 @@ static void ggml_compute_forward_norm_f32(
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
const char * x = (const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
char * y = (char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
if (nb00 == sizeof(float) && nb0 == sizeof(float)) {
const float * xf = (const float *) x;
float sum = 0.0;
ggml_vec_sum_f32(ne00, &sum, x);
float mean = sum/ne00;
float sum = 0.0;
ggml_vec_sum_f32(ne00, &sum, xf);
float mean = sum/ne00;
float * yf = (float *) y;
float variance = 0;
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
float variance = 0;
#ifdef GGML_USE_ACCELERATE
mean = -mean;
vDSP_vsadd(xf, 1, &mean, yf, 1, ne00);
vDSP_measqv(yf, 1, &variance, ne00);
mean = -mean;
vDSP_vsadd(x, 1, &mean, y, 1, ne00);
vDSP_measqv(y, 1, &variance, ne00);
#else
variance = ggml_vec_cvar_f32(ne00, yf, xf, mean);
variance = ggml_vec_cvar_f32(ne00, y, x, mean);
#endif //GGML_USE_ACCELERATE
const float scale = 1.0f/sqrtf(variance + eps);
ggml_vec_scale_f32(ne00, yf, scale);
} else {
float sum = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
sum += *(const float *) (x + i00*nb00);
}
const float mean = sum/ne00;
float variance = 0.0f;
for (int64_t i00 = 0; i00 < ne00; i00++) {
const float v = *(const float *) (x + i00*nb00) - mean;
*(float *) (y + i00*nb0) = v;
variance += v * v;
}
variance /= ne00;
const float scale = 1.0f/sqrtf(variance + eps);
for (int64_t i00 = 0; i00 < ne00; i00++) {
*(float *) (y + i00*nb0) *= scale;
}
}
const float scale = 1.0f/sqrtf(variance + eps);
ggml_vec_scale_f32(ne00, y, scale);
}
}
}
@@ -4164,6 +4142,8 @@ static void ggml_compute_forward_l2_norm_f32(
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
@@ -4178,27 +4158,20 @@ static void ggml_compute_forward_l2_norm_f32(
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
const char * x = (const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
ggml_float sum = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
const float xi = *(const float *) (x + i00*nb00);
sum += (ggml_float)(xi * xi);
sum += (ggml_float)(x[i00] * x[i00]);
}
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
memcpy(y, x, ne00 * sizeof(float));
const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
char * y = (char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
if (nb00 == sizeof(float) && nb0 == sizeof(float)) {
memcpy(y, x, ne00 * sizeof(float));
ggml_vec_scale_f32(ne00, (float *) y, scale);
} else {
for (int64_t i00 = 0; i00 < ne00; i00++) {
const float xi = *(const float *) (x + i00*nb00);
*(float *) (y + i00*nb0) = xi * scale;
}
}
ggml_vec_scale_f32(ne00, y, scale);
}
}
}
+1 -1
View File
@@ -5334,7 +5334,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_L2_NORM:
return ggml_is_contiguous_rows(op->src[0]);
return true;
case GGML_OP_RMS_NORM_BACK:
return ggml_is_contiguous(op->src[0]);
break;
@@ -174,7 +174,7 @@ __kernel void kernel_gemv_noshuffle_q8_0_f32(
regA.s6 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
regA.s7 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, convert_float(regS), regB);
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, regS, regB);
}
// reduction in local memory, assumes #wave=4
-5
View File
@@ -108,9 +108,6 @@ if (Vulkan_FOUND)
if (GGML_VULKAN_CHECK_RESULTS)
add_compile_definitions(GGML_VULKAN_CHECK_RESULTS)
# the result-checking path computes a CPU reference graph via
# ggml_graph_compute_with_ctx(), which is defined in ggml-cpu
target_link_libraries(ggml-vulkan PRIVATE ggml-cpu)
endif()
if (GGML_VULKAN_DEBUG)
@@ -132,8 +129,6 @@ if (Vulkan_FOUND)
if (GGML_VULKAN_RUN_TESTS)
add_compile_definitions(GGML_VULKAN_RUN_TESTS)
# the test path also calls ggml_graph_compute_with_ctx() (ggml-cpu)
target_link_libraries(ggml-vulkan PRIVATE ggml-cpu)
endif()
# Set up toolchain for host compilation whether cross-compiling or not
+66 -362
View File
@@ -493,20 +493,6 @@ struct vk_conv2d_pipeline_state {
}
};
struct vk_conv3d_pipeline_state {
vk_conv3d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t s2, uint32_t p0, uint32_t p1, uint32_t p2,
uint32_t d0, uint32_t d1, uint32_t d2, uint32_t KW, uint32_t KH, uint32_t KD, uint32_t aligned)
: s0(s0), s1(s1), s2(s2), p0(p0), p1(p1), p2(p2), d0(d0), d1(d1), d2(d2), KW(KW), KH(KH), KD(KD), aligned(aligned) {}
uint32_t s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD;
uint32_t aligned;
bool operator<(const vk_conv3d_pipeline_state &b) const {
return std::tie(s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD, aligned) <
std::tie(b.s0, b.s1, b.s2, b.p0, b.p1, b.p2, b.d0, b.d1, b.d2, b.KW, b.KH, b.KD, b.aligned);
}
};
struct vk_solve_tri_pipeline_state {
vk_solve_tri_pipeline_state(uint32_t N, uint32_t K)
: N(N), K(K) {}
@@ -791,7 +777,6 @@ struct vk_device_struct {
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
vk_pipeline pipeline_get_rows_back_f32;
vk_pipeline pipeline_acc_f32;
vk_pipeline pipeline_set_f32;
@@ -816,10 +801,14 @@ struct vk_device_struct {
vk_pipeline pipeline_concat_i8, pipeline_concat_i16, pipeline_concat_i32, pipeline_concat_i64;
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32, pipeline_upscale_bilinear_antialias_f32;
vk_pipeline pipeline_scale_f32;
vk_pipeline pipeline_sqr_f32;
vk_pipeline pipeline_sqrt_f32;
vk_pipeline pipeline_sin_f32;
vk_pipeline pipeline_cos_f32;
vk_pipeline pipeline_log[2];
vk_pipeline pipeline_tri[2];
vk_pipeline pipeline_diag[2];
vk_pipeline pipeline_clamp[2];
vk_pipeline pipeline_clamp_f32;
vk_pipeline pipeline_pad_f32;
vk_pipeline pipeline_roll_f32;
vk_pipeline pipeline_repeat_i32, pipeline_repeat_back_f32;
@@ -851,10 +840,6 @@ struct vk_device_struct {
vk_pipeline pipeline_gelu_quick[2];
vk_pipeline pipeline_silu[2];
vk_pipeline pipeline_relu[2];
vk_pipeline pipeline_sqr[2];
vk_pipeline pipeline_sqrt[2];
vk_pipeline pipeline_sin[2];
vk_pipeline pipeline_cos[2];
vk_pipeline pipeline_xielu[2];
vk_pipeline pipeline_neg[2];
vk_pipeline pipeline_tanh[2];
@@ -886,7 +871,7 @@ struct vk_device_struct {
vk_pipeline pipeline_geglu_erf[2];
vk_pipeline pipeline_geglu_quick[2];
vk_pipeline pipeline_leaky_relu[2];
vk_pipeline pipeline_leaky_relu_f32;
vk_pipeline pipeline_silu_back_f32;
vk_pipeline pipeline_diag_mask_inf_f32;
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
@@ -939,8 +924,6 @@ struct vk_device_struct {
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];
std::map<vk_conv3d_pipeline_state, vk_pipeline> pipeline_conv3d_f32[CONV_SHAPE_COUNT];
std::map<vk_conv3d_pipeline_state, vk_pipeline> pipeline_conv3d_f16_f32[CONV_SHAPE_COUNT];
vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
@@ -1686,41 +1669,6 @@ template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
}
struct vk_op_conv3d_push_constants {
uint32_t OC;
uint32_t IC;
uint32_t N;
uint32_t IW;
uint32_t IH;
uint32_t ID;
uint32_t OW;
uint32_t OH;
uint32_t OD;
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb1;
uint32_t nb2;
uint32_t nb3;
uint32_t OWmp; uint32_t OWL;
uint32_t OWOHmp; uint32_t OWOHL;
uint32_t OWOHODmp; uint32_t OWOHODL;
};
template <> void init_pushconst_fastdiv(vk_op_conv3d_push_constants &p) {
init_fastdiv_values(p.OW, p.OWmp, p.OWL);
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
init_fastdiv_values(p.OW*p.OH*p.OD, p.OWOHODmp, p.OWOHODL);
}
struct vk_op_conv2d_dw_push_constants {
uint32_t ne;
uint32_t batches;
@@ -4126,35 +4074,19 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
}
#endif
auto const &ggml_vk_mul_mm_spec = [](std::vector<uint32_t> spec, bool aligned) {
spec.push_back(aligned ? 1u : 0u);
return spec;
};
const int mul_mat_id_param_count = 5;
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (device->coopmat2) {
auto const &ggml_vk_mul_mm_cm2_spec = [](std::vector<uint32_t> spec, bool aligned, bool mul_mat_id) {
if (mul_mat_id && spec.size() > 5) {
spec.insert(spec.begin() + 5, aligned ? 1u : 0u);
} else {
spec.push_back(aligned ? 1u : 0u);
}
if (mul_mat_id && spec.size() == 6) {
spec.push_back(32);
}
return spec;
};
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(l_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(m_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(s_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(l_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), l_align, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(m_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), m_align, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(s_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), s_align, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, true); \
// Create 2 variants, {f16,f32} accumulator
#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
@@ -4229,17 +4161,17 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
// Create 2 variants, {f16,f32} accumulator
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
@@ -4352,32 +4284,32 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
// Selects dot2 SPIR-V variant at runtime when device->dot2_f16 is true
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
// bf16 scalar path promotes to f32, no dot2 variant
#define CREATE_MM_NODOT2(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l_int[TYPE]) { \
@@ -4542,17 +4474,17 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l_int[TYPE]) \
@@ -4947,7 +4879,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_NVFP4], "get_rows_nvfp4_f32", get_rows_nvfp4_f32_len, get_rows_nvfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_back_f32, "get_rows_back_f32", get_rows_back_f32_len, get_rows_back_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {256, 1, 1}, {}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
@@ -4972,7 +4903,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
}
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_nc_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
@@ -5092,6 +5023,11 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -5101,6 +5037,8 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
ggml_vk_create_pipeline(device, device->pipeline_diag[0], "diag_f32", diag_f32_len, diag_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_diag[1], "diag_f16", diag_f16_len, diag_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -5120,12 +5058,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
CREATE_UNARY(gelu_quick)
CREATE_UNARY(silu)
CREATE_UNARY(relu)
CREATE_UNARY(sqr)
CREATE_UNARY(sqrt)
CREATE_UNARY(sin)
CREATE_UNARY(cos)
CREATE_UNARY(clamp)
CREATE_UNARY(leaky_relu)
CREATE_UNARY(xielu)
CREATE_UNARY(neg)
CREATE_UNARY(tanh)
@@ -5165,6 +5097,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
CREATE_GLU(geglu_quick)
#undef CREATE_GLU
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
@@ -5381,7 +5314,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
// conv2d, conv_transpose_2d, conv3d
// conv2d, conv_transpose_2d
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
// smaller WG for the small-tile fallback gives more concurrent WGs per SM
uint32_t conv2d_WG_SIZE = (s == CONV_SHAPE_64x32) ? 128 : 256;
@@ -5444,8 +5377,8 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
return (conv2d_BS.K * (conv2d_BS.CRS + pad) + conv2d_BS.CRS * (conv2d_BS.NPQ + pad) + csh_elems) * elem_size;
};
// 2D, transpose-2D, and 3D conv use the same KxCRS @ CRSxNPQ shmem
// layout. cm1 needs Csh for output, so check before applying cm1 params.
// coopmat1 needs to store the output through shared memory, so check up front
// whether it'll fit and disable it before applying coopmat1 parameters.
if (conv2d_use_cm1 && device->properties.limits.maxComputeSharedMemorySize < shmem_req(conv2d_cm1_shmem_pad, true, true)) {
conv2d_use_cm1 = false;
}
@@ -5537,53 +5470,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
}
#undef CREATE_CONV
#undef CREATE_CONVS
std::vector<uint32_t> conv3d_spec_constants = { conv2d_WG_SIZE, conv2d_BS.K, conv2d_BS.CRS, conv2d_BS.NPQ, conv2d_TS_K, conv2d_SHMEM_PAD };
#define CREATE_CONV3D(type_suffix, spv_suffix) \
for (auto &c : device->pipeline_conv3d##type_suffix[s]) { \
const vk_conv3d_pipeline_state &state = c.first; \
std::vector<uint32_t> spec_constants_cpy = conv3d_spec_constants; \
spec_constants_cpy.push_back(state.s0); \
spec_constants_cpy.push_back(state.s1); \
spec_constants_cpy.push_back(state.s2); \
spec_constants_cpy.push_back(state.p0); \
spec_constants_cpy.push_back(state.p1); \
spec_constants_cpy.push_back(state.p2); \
spec_constants_cpy.push_back(state.d0); \
spec_constants_cpy.push_back(state.d1); \
spec_constants_cpy.push_back(state.d2); \
spec_constants_cpy.push_back(state.KW); \
spec_constants_cpy.push_back(state.KH); \
spec_constants_cpy.push_back(state.KD); \
spec_constants_cpy.push_back(state.aligned); \
spec_constants_cpy.push_back(conv2d_csh_store); \
spec_constants_cpy.push_back(conv2d_WM); \
spec_constants_cpy.push_back(conv2d_WN); \
ggml_vk_create_pipeline( \
device, c.second, "conv3d" #type_suffix, \
conv3d##type_suffix##spv_suffix##_len, conv3d##type_suffix##spv_suffix##_data, "main", 3, \
sizeof(vk_op_conv3d_push_constants), wg_denoms, spec_constants_cpy, 1, true, conv2d_required_subgroup_size != 0, conv2d_required_subgroup_size); \
}
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (device->coopmat2) {
CREATE_CONV3D(_f32, _cm2)
CREATE_CONV3D(_f16_f32, _cm2)
} else
#endif
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (conv2d_use_cm1) {
CREATE_CONV3D(_f32, _cm1)
CREATE_CONV3D(_f16_f32, _cm1)
} else
#endif
if (conv2d_UNROLL) {
CREATE_CONV3D(_f32, _unroll)
CREATE_CONV3D(_f16_f32, _unroll)
} else {
CREATE_CONV3D(_f32, )
CREATE_CONV3D(_f16_f32, )
}
#undef CREATE_CONV3D
}
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
@@ -10408,11 +10294,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_get_rows_f32[src0->type];
}
return nullptr;
case GGML_OP_GET_ROWS_BACK:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_get_rows_back_f32;
}
return nullptr;
case GGML_OP_ACC:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_acc_f32;
@@ -10519,27 +10400,23 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
}
return nullptr;
case GGML_OP_SQR:
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_sqr[dst->type == GGML_TYPE_F16];
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_sqr_f32;
}
return nullptr;
case GGML_OP_SQRT:
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_sqrt[dst->type == GGML_TYPE_F16];
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_sqrt_f32;
}
return nullptr;
case GGML_OP_SIN:
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_sin[dst->type == GGML_TYPE_F16];
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_sin_f32;
}
return nullptr;
case GGML_OP_COS:
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_cos[dst->type == GGML_TYPE_F16];
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_cos_f32;
}
return nullptr;
case GGML_OP_LOG:
@@ -10561,9 +10438,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
}
return nullptr;
case GGML_OP_CLAMP:
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_clamp[dst->type == GGML_TYPE_F16];
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_clamp_f32;
}
return nullptr;
case GGML_OP_PAD:
@@ -10931,9 +10807,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
}
return nullptr;
case GGML_OP_LEAKY_RELU:
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_leaky_relu[dst->type == GGML_TYPE_F16];
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_leaky_relu_f32;
}
return nullptr;
case GGML_OP_CONV_2D:
@@ -11010,61 +10885,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
}
}
return nullptr;
case GGML_OP_CONV_3D:
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
const uint32_t OC = (uint32_t)ggml_get_op_params_i32(dst, 11);
const uint32_t IC = (uint32_t)ggml_get_op_params_i32(dst, 9);
const uint32_t N = (uint32_t)ggml_get_op_params_i32(dst, 10);
const uint32_t NPQ = N * dst->ne[2] * dst->ne[1] * dst->ne[0];
const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, OC, NPQ);
const uint32_t KW = (uint32_t)src0->ne[0];
const uint32_t KH = (uint32_t)src0->ne[1];
const uint32_t KD = (uint32_t)src0->ne[2];
const uint32_t s0 = (uint32_t)ggml_get_op_params_i32(dst, 0);
const uint32_t s1 = (uint32_t)ggml_get_op_params_i32(dst, 1);
const uint32_t s2 = (uint32_t)ggml_get_op_params_i32(dst, 2);
const uint32_t p0 = (uint32_t)ggml_get_op_params_i32(dst, 3);
const uint32_t p1 = (uint32_t)ggml_get_op_params_i32(dst, 4);
const uint32_t p2 = (uint32_t)ggml_get_op_params_i32(dst, 5);
const uint32_t d0 = (uint32_t)ggml_get_op_params_i32(dst, 6);
const uint32_t d1 = (uint32_t)ggml_get_op_params_i32(dst, 7);
const uint32_t d2 = (uint32_t)ggml_get_op_params_i32(dst, 8);
const uint32_t CRS = IC * KW * KH * KD;
const uint32_t BS_K = vk_conv_block_sizes[shape].K;
const uint32_t BS_CRS = vk_conv_block_sizes[shape].CRS;
const uint32_t BS_NPQ = vk_conv_block_sizes[shape].NPQ;
const uint32_t aligned = ((OC % BS_K == 0) &&
(CRS % BS_CRS == 0) &&
(NPQ % BS_NPQ == 0)) ? 1u : 0u;
vk_conv3d_pipeline_state conv3d_pipeline_state(s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD, aligned);
std::map<vk_conv3d_pipeline_state, vk_pipeline> *pipelines = nullptr;
if (src0->type == GGML_TYPE_F32) {
pipelines = &ctx->device->pipeline_conv3d_f32[shape];
} else if (src0->type == GGML_TYPE_F16) {
pipelines = &ctx->device->pipeline_conv3d_f16_f32[shape];
} else {
return nullptr;
}
vk_pipeline pipeline = nullptr;
{
std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
auto it = pipelines->find(conv3d_pipeline_state);
if (it != pipelines->end()) {
pipeline = it->second;
} else {
(*pipelines)[conv3d_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
}
}
return pipeline;
}
return nullptr;
case GGML_OP_ADD1:
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_add1_f16_f16;
@@ -11315,10 +11135,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
break;
case GGML_OP_GET_ROWS_BACK:
elements = { (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], 1 };
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
break;
case GGML_OP_ARGSORT:
GGML_ASSERT(0);
break;
@@ -11404,21 +11220,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
GGML_ABORT("invalid push constant type for CONV_2D");
}
break;
case GGML_OP_CONV_3D:
if constexpr (std::is_same_v<PC, vk_op_conv3d_push_constants>) {
const uint32_t NPQ = pc.N * pc.OD * pc.OH * pc.OW;
const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, pc.OC, NPQ);
const uint32_t NPQ_blocks = CEIL_DIV(NPQ, vk_conv_block_sizes[shape].NPQ);
elements = { pc.OC, NPQ_blocks, 1 };
if (elements[1] > 512) {
elements[2] = CEIL_DIV(elements[1], 512);
elements[1] = 512;
}
} else {
GGML_ABORT("invalid push constant type for CONV_3D");
}
break;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_DIV:
@@ -11435,7 +11236,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_TRI:
case GGML_OP_DIAG:
case GGML_OP_CLAMP:
case GGML_OP_LEAKY_RELU:
case GGML_OP_PAD:
case GGML_OP_ROLL:
case GGML_OP_REPEAT:
@@ -11580,21 +11380,6 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
});
}
static void ggml_vk_get_rows_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GET_ROWS_BACK, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2], (uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
0.0f, 0.0f, 0,
});
}
static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
@@ -12302,10 +12087,8 @@ static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx,
static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
float * op_params = (float *)dst->op_params;
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
p.param1 = op_params[0];
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, std::move(p));
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
}
static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@@ -13335,51 +13118,6 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx,
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, std::move(p));
}
static void ggml_vk_conv_3d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
GGML_ASSERT(nb10 == sizeof(float));
GGML_ASSERT(nb0 == sizeof(float));
vk_op_conv3d_push_constants p{};
p.IC = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 9));
p.N = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 10));
p.OC = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 11));
GGML_ASSERT(src0->ne[3] == (int64_t)p.IC * p.OC);
GGML_ASSERT(src1->ne[3] == (int64_t)p.IC * p.N);
GGML_ASSERT(dst->ne[3] == (int64_t)p.OC * p.N);
p.IW = static_cast<uint32_t>(ne10);
p.IH = static_cast<uint32_t>(ne11);
p.ID = static_cast<uint32_t>(ne12);
p.OW = static_cast<uint32_t>(ne0);
p.OH = static_cast<uint32_t>(ne1);
p.OD = static_cast<uint32_t>(ne2);
// the shader clamps src addresses to p.IC * p.N * p.IW * p.IH * p.ID - 1 in uint32, so the
// total input element count must fit in a uint32.
GGML_ASSERT((uint64_t)p.IC * p.N * p.IW * p.IH * p.ID <= 0xFFFFFFFFull);
p.nb01 = static_cast<uint32_t>(nb01 / nb00);
p.nb02 = static_cast<uint32_t>(nb02 / nb00);
p.nb03 = static_cast<uint32_t>(nb03 / nb00);
p.nb11 = static_cast<uint32_t>(nb11 / nb10);
p.nb12 = static_cast<uint32_t>(nb12 / nb10);
p.nb13 = static_cast<uint32_t>(nb13 / nb10);
p.nb1 = static_cast<uint32_t>(nb1 / nb0);
p.nb2 = static_cast<uint32_t>(nb2 / nb0);
p.nb3 = static_cast<uint32_t>(nb3 / nb0);
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_3D, std::move(p));
}
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
vk_op_conv2d_dw_push_constants p{};
p.ne = ggml_nelements(dst);
@@ -13406,10 +13144,7 @@ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx
static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
const float * op_params = (const float *)dst->op_params;
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
p.param1 = op_params[0];
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, std::move(p));
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f, 0.0f, 0.0f });
}
#ifdef GGML_VULKAN_RUN_TESTS
@@ -14512,10 +14247,6 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_GET_ROWS:
ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_GET_ROWS_BACK:
ggml_vk_get_rows_back(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_ADD:
if (ctx->num_additional_fused_ops) {
@@ -14784,10 +14515,6 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_CONV_TRANSPOSE_2D:
ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_CONV_3D:
ggml_vk_conv_3d(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_CONV_2D_DW:
ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node);
@@ -17237,8 +16964,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return false;
}
}
case GGML_OP_GET_ROWS_BACK:
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SET_ROWS:
{
switch (op->type) {
@@ -17335,11 +17060,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_TRANSPOSE:
case GGML_OP_RMS_NORM:
return true;
case GGML_OP_NORM:
case GGML_OP_GROUP_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_NORM:
case GGML_OP_L2_NORM:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
return ggml_is_contiguous_rows(op->src[0]) &&
op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_MUL:
@@ -17358,9 +17084,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_CLAMP:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_LEAKY_RELU:
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
op->type == op->src[0]->type;
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
@@ -17560,13 +17285,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
ggml_is_contiguous(op->src[1]) &&
ggml_is_contiguous(op));
}
case GGML_OP_CONV_3D:
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
op->src[1]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32 &&
ggml_is_contiguous(op->src[0]) &&
ggml_is_contiguous(op->src[1]) &&
ggml_is_contiguous(op);
default:
return false;
}
@@ -18410,20 +18128,6 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
const int32_t d0 = tensor->op_params[4];
const int32_t d1 = tensor->op_params[5];
tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
} else if (tensor->op == GGML_OP_CONV_3D) {
const int32_t s0 = tensor->op_params[0];
const int32_t s1 = tensor->op_params[1];
const int32_t s2 = tensor->op_params[2];
const int32_t p0 = tensor->op_params[3];
const int32_t p1 = tensor->op_params[4];
const int32_t p2 = tensor->op_params[5];
const int32_t d0 = tensor->op_params[6];
const int32_t d1 = tensor->op_params[7];
const int32_t d2 = tensor->op_params[8];
const int32_t IC = tensor->op_params[9];
const int32_t N = tensor->op_params[10];
const int32_t OC = tensor->op_params[11];
tensor_clone = ggml_conv_3d_direct(ggml_ctx, src_clone[0], src_clone[1], s0, s1, s2, p0, p1, p2, d0, d1, d2, IC, N, OC);
} else if (tensor->op == GGML_OP_CONV_2D_DW) {
const int32_t s0 = tensor->op_params[0];
const int32_t s1 = tensor->op_params[1];
@@ -0,0 +1,17 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
}
@@ -1,431 +0,0 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#ifdef COOPMAT2
#extension GL_NV_cooperative_matrix2 : enable
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_KHR_memory_scope_semantics : enable
#endif
#ifdef COOPMAT
#extension GL_KHR_cooperative_matrix : enable
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_KHR_memory_scope_semantics : enable
#endif
#include "types.glsl"
// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
layout(binding = 0) readonly buffer A {
A_TYPE knl_data[];
}; // src0 - kernel: [KW, KH, KD, IC*OC]
layout(binding = 1) readonly buffer B {
B_TYPE src_data[];
}; // src1 - input: [IW, IH, ID, IC*N] -- channel_first format
layout(binding = 2) writeonly buffer D {
D_TYPE dst_data[];
}; // dst - result: [OW, OH, OD, OC*N]
layout(push_constant) uniform parameter {
// I/O channels, batch size
uint32_t OC;
uint32_t IC;
uint32_t N;
// Tensor spatial sizes: input, output
uint32_t IW;
uint32_t IH;
uint32_t ID;
uint32_t OW;
uint32_t OH;
uint32_t OD;
// Strides in elements
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb1;
uint32_t nb2;
uint32_t nb3;
// fastdiv helper values
uint32_t OWmp; uint32_t OWL;
uint32_t OWOHmp; uint32_t OWOHL;
uint32_t OWOHODmp; uint32_t OWOHODL;
}
p;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
// Blocktile sizes
layout(constant_id = 1) const uint BS_K = 128;
layout(constant_id = 2) const uint BS_CRS = 16;
layout(constant_id = 3) const uint BS_NPQ = 128;
// Thread-tile sizes
layout(constant_id = 4) const uint TS_K = 8;
layout(constant_id = 5) const uint SHMEM_PAD = 4;
// Stride, padding, dilation
layout(constant_id = 6) const uint s0 = 1;
layout(constant_id = 7) const uint s1 = 1;
layout(constant_id = 8) const uint s2 = 1;
layout(constant_id = 9) const uint p0 = 0;
layout(constant_id = 10) const uint p1 = 0;
layout(constant_id = 11) const uint p2 = 0;
layout(constant_id = 12) const uint d0 = 1;
layout(constant_id = 13) const uint d1 = 1;
layout(constant_id = 14) const uint d2 = 1;
// Kernel spatial sizes
layout(constant_id = 15) const uint KW = 1;
layout(constant_id = 16) const uint KH = 1;
layout(constant_id = 17) const uint KD = 1;
// when set, skip bounds checks and address clamps (K/CRS/NPQ are tile-aligned)
layout(constant_id = 18) const uint aligned = 0;
// stage cm2 result through shmem (Csh) for coalesced stores. cm1 always does this.
layout(constant_id = 19) const uint csh_store = 0;
#ifdef COOPMAT
// cm1 subgroup tile: each subgroup computes a WM x WN region as a grid of
// TM x TN x TK fragments. Requires WM%TM == WN%TN == BS_K%WM == BS_NPQ%WN ==
// BS_CRS%TK == 0, and WG_SIZE == (BS_K/WM) * (BS_NPQ/WN) * subgroup_size.
layout(constant_id = 20) const uint WM = 32;
layout(constant_id = 21) const uint WN = 32;
const uint TM = 16;
const uint TN = 16;
const uint TK = 16;
const uint cms_per_row = WM / TM;
const uint cms_per_col = WN / TN;
const uint warps_M = BS_K / WM;
const uint warps_N = BS_NPQ / WN;
#endif
// without padding, ID_idx/IH_idx/IW_idx are in bounds by construction
const bool dhw_in_bounds = (p0 == 0) && (p1 == 0) && (p2 == 0);
uint32_t tid = gl_LocalInvocationID.x;
const uint32_t WG_SIZE = gl_WorkGroupSize.x;
uint splitWork(uint work_size, uint block_size) {
return (block_size + work_size - 1) / block_size;
}
uint32_t K = p.OC;
uint32_t CRS = p.IC * KD * KH * KW;
uint32_t NPQ = p.N * p.OD * p.OH * p.OW;
// Number of blocktiles per input
uint32_t NB_CRS = splitWork(CRS, BS_CRS);
#if defined(COOPMAT2) || defined(COOPMAT)
#define SHMEM_TYPE float16_t
#else
#define SHMEM_TYPE float
#endif
const uint32_t Ash_stride = BS_CRS + SHMEM_PAD;
const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD;
const uint32_t Ash_len = BS_K * Ash_stride;
const uint32_t Bsh_len = BS_CRS * Bsh_stride;
shared SHMEM_TYPE Ash[Ash_len]; // K x CRS
shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ
#if defined(COOPMAT2) || defined(COOPMAT)
// stage matC through shmem so global stores are row-major (NPQ-contiguous)
const uint32_t Csh_stride = BS_NPQ;
#ifdef COOPMAT
const uint32_t Csh_len = BS_K * Csh_stride;
#else
const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 1;
#endif
shared SHMEM_TYPE Csh[Csh_len]; // K x NPQ
#endif
// Threadtile sizes
const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
// Number of threadtiles per blocktile
const uint32_t NT_NPQ = BS_NPQ / TS_NPQ;
/*
Compute
KxCRS @ CRSxNPQ = K x NPQ
K=OC
C=IC
D,R,S=KD,KH,KW
Z,P,Q=OD,OH,OW
*/
uint32_t B_idx_K = gl_WorkGroupID.x;
uint32_t B_idx_NPQ = gl_WorkGroupID.y + gl_WorkGroupID.z * 512;
uint32_t T_y = tid / NT_NPQ;
uint32_t T_x = tid % NT_NPQ;
uint32_t Ar = tid / BS_CRS;
uint32_t Ac = tid % BS_CRS;
const uint32_t ArpWg = WG_SIZE / BS_CRS;
uint32_t Br = tid / BS_NPQ;
uint32_t Bc = tid % BS_NPQ;
const uint32_t BrpWg = WG_SIZE / BS_NPQ;
// see init_fastdiv_values in ggml-vulkan.cpp
uint fastdiv(uint n, uint mp, uint L) {
uint msbs, lsbs;
// msbs = mulhi(n, mp)
umulExtended(n, mp, msbs, lsbs);
return (msbs + n) >> L;
}
void split_crs(uint32_t crs_idx, out uint32_t ic, out uint32_t kd, out uint32_t kh, out uint32_t kw) {
const uint32_t KHKW = KH * KW;
const uint32_t KDKHKW = KD * KHKW;
ic = crs_idx / KDKHKW;
uint32_t rem = crs_idx - ic * KDKHKW;
kd = rem / KHKW;
rem = rem - kd * KHKW;
kh = rem / KW;
kw = rem - kh * KW;
}
void split_npq(uint32_t npq_idx, out uint32_t n, out uint32_t od, out uint32_t oh, out uint32_t ow) {
const uint32_t OWOH = p.OW * p.OH;
n = fastdiv(npq_idx, p.OWOHODmp, p.OWOHODL);
uint32_t rem = npq_idx - n * p.OD * OWOH;
od = fastdiv(rem, p.OWOHmp, p.OWOHL);
rem = rem - od * OWOH;
oh = fastdiv(rem, p.OWmp, p.OWL);
ow = rem - oh * p.OW;
}
#ifdef COOPMAT2
#define ACC_TYPE float16_t
ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem)
{
uint32_t K_idx = B_idx_K * BS_K + r;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c;
uint32_t N_idx;
uint32_t OD_idx;
uint32_t OH_idx;
uint32_t OW_idx;
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
dst_data[dst_idx] = D_TYPE(elem);
}
return elem;
}
#endif
void main() {
if (B_idx_NPQ * BS_NPQ >= NPQ) {
return;
}
#ifdef COOPMAT2
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator> matC;
matC = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator>(0.0);
#elif defined(COOPMAT)
coopmat<float16_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
sums[i] = coopmat<float16_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0);
}
const uint warp_r = gl_SubgroupID / warps_N;
const uint warp_c = gl_SubgroupID % warps_N;
#else
float regC[TS_K][TS_NPQ];
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regC[T_ly][T_lx] = 0.0;
}
}
#endif
/* Advance block in CRS dim */
[[dont_unroll]] for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
uint32_t CRS_idx_a = B_idx_CRS * BS_CRS + Ac;
uint32_t IC_idx_a;
uint32_t KD_idx_a;
uint32_t KH_idx_a;
uint32_t KW_idx_a;
split_crs(CRS_idx_a, IC_idx_a, KD_idx_a, KH_idx_a, KW_idx_a);
/* Load kernel to A_block: (BS_K x BS_CRS)*/
UNROLL for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
uint32_t B_ly = r_offset + Ar;
uint32_t B_lx = Ac;
uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
uint32_t knl_idx = KW_idx_a + KH_idx_a * p.nb01 + KD_idx_a * p.nb02 + (K_idx * p.IC + IC_idx_a) * p.nb03;
if (aligned == 0) {
knl_idx = min(knl_idx, K * CRS - 1);
}
float val = knl_data[knl_idx];
if (aligned == 0 && (K_idx >= K || CRS_idx_a >= CRS)) {
val = 0.0;
}
Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val);
}
/* Load input to B_block: (BS_CRS x BS_NPQ) */
UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
uint32_t B_ly = r_offset + Br; /* Row index of B block */
uint32_t B_lx = Bc;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
uint32_t N_idx;
uint32_t OD_idx;
uint32_t OH_idx;
uint32_t OW_idx;
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
uint32_t CRS_idx_b = B_idx_CRS * BS_CRS + B_ly;
uint32_t IC_idx_b;
uint32_t KD_idx_b;
uint32_t KH_idx_b;
uint32_t KW_idx_b;
split_crs(CRS_idx_b, IC_idx_b, KD_idx_b, KH_idx_b, KW_idx_b);
uint32_t ID_idx = OD_idx * s2 + KD_idx_b * d2 - p2;
uint32_t IH_idx = OH_idx * s1 + KH_idx_b * d1 - p1;
uint32_t IW_idx = OW_idx * s0 + KW_idx_b * d0 - p0;
uint32_t src_idx = IW_idx + IH_idx * p.nb11 + ID_idx * p.nb12 + (N_idx * p.IC + IC_idx_b) * p.nb13;
// skip clamp when address can't go OOB
if (aligned == 0 || !dhw_in_bounds) {
src_idx = min(src_idx, p.IC * p.N * p.IW * p.IH * p.ID - 1);
}
float val = src_data[src_idx];
bool oob = false;
if (aligned == 0 && (CRS_idx_b >= CRS || NPQ_idx >= NPQ)) {
oob = true;
}
// also catches lower-bound underflow (idx wraps to 0x80000000+)
if (!dhw_in_bounds && (ID_idx >= p.ID || IH_idx >= p.IH || IW_idx >= p.IW)) {
oob = true;
}
if (oob) {
val = 0.0;
}
Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val);
}
barrier();
#ifdef COOPMAT2
coopmat<float16_t, gl_ScopeWorkgroup, BS_K, BS_CRS, gl_MatrixUseA> matA;
coopmat<float16_t, gl_ScopeWorkgroup, BS_CRS, BS_NPQ, gl_MatrixUseB> matB;
coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);
coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);
matC = coopMatMulAdd(matA, matB, matC);
#elif defined(COOPMAT)
// each subgroup multiplies its grid of fragments per TK-sized CRS chunk
[[unroll]] for (uint k_step = 0; k_step < BS_CRS / TK; k_step++) {
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a[cms_per_row];
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
const uint a_off = (warp_r * WM + cm_row * TM) * Ash_stride + k_step * TK;
coopMatLoad(cache_a[cm_row], Ash, a_off, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);
}
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
const uint b_off = k_step * TK * Bsh_stride + warp_c * WN + cm_col * TN;
coopMatLoad(cache_b, Bsh, b_off, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a[cm_row], cache_b, sums[cm_col * cms_per_row + cm_row]);
}
}
}
#else
if (T_y * TS_K < K) {
UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
float regA[TS_K];
float regB[TS_NPQ];
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
}
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
}
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
}
}
}
}
#endif
barrier();
}
/* Save C* */
#if defined(COOPMAT2) || defined(COOPMAT)
// stage matC into Csh, then write to dst with coalesced NPQ-contiguous stores
#ifdef COOPMAT
const bool use_staged_store = true;
#else
const bool use_staged_store = (csh_store != 0);
#endif
if (use_staged_store) {
#ifdef COOPMAT
// cm1: each subgroup stores its fragment grid into its Csh slot
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
const uint csh_off = (warp_r * WM + cm_row * TM) * Csh_stride + warp_c * WN + cm_col * TN;
coopMatStore(sums[cm_col * cms_per_row + cm_row], Csh, csh_off, Csh_stride, gl_CooperativeMatrixLayoutRowMajor);
}
}
#else
coopMatStore(matC, Csh, 0, Csh_stride, gl_CooperativeMatrixLayoutRowMajor);
#endif
barrier();
// cooperative shmem->global: WG threads spread across BS_NPQ (the
// contiguous direction of dst), each iter covers store_rows_per_iter K-rows
const uint32_t store_rows_per_iter = WG_SIZE / BS_NPQ;
const uint32_t store_iters = BS_K / store_rows_per_iter;
const uint32_t k_thread_offset = tid / BS_NPQ;
const uint32_t npq_thread = tid % BS_NPQ;
[[unroll]] for (uint32_t i = 0; i < store_iters; i++) {
uint32_t k_local = i * store_rows_per_iter + k_thread_offset;
uint32_t K_idx = B_idx_K * BS_K + k_local;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + npq_thread;
uint32_t N_idx;
uint32_t OD_idx;
uint32_t OH_idx;
uint32_t OW_idx;
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
dst_data[dst_idx] = D_TYPE(Csh[k_local * Csh_stride + npq_thread]);
}
}
}
#ifdef COOPMAT2
else {
coopMatPerElementNV(matC, matC, perElemOpStore);
}
#endif
#else
if (T_y * TS_K < K) {
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
uint32_t N_idx;
uint32_t OD_idx;
uint32_t OH_idx;
uint32_t OW_idx;
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
dst_data[dst_idx] = D_TYPE(regC[T_ly][T_lx]);
}
}
}
}
#endif
}
@@ -0,0 +1,17 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(cos(val));
}
@@ -463,7 +463,6 @@ void main() {
}
rowmaxf = max(rowmaxf, float(Sf[r][c]));
}
rowmaxf += FATTN_KQ_MAX_OFFSET;
float Moldf = Mf[r];
// M = max(rowmax, Mold)
@@ -352,7 +352,6 @@ void main() {
}
rowmaxf = max(rowmaxf, float(sfsh[r_vec + (c * cols_per_iter + col_tid) * sfshstride][r_comp]));
}
rowmaxf += FATTN_KQ_MAX_OFFSET;
float Moldf = Mf[r];
// Compute max across the row
@@ -1,25 +0,0 @@
#version 450
#include "types.glsl"
#include "generic_binary_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint col = gl_GlobalInvocationID.x;
if (col >= p.ne20) {
return;
}
for (uint row = gl_GlobalInvocationID.y; row < p.ne21; row += gl_WorkGroupSize.y * gl_NumWorkGroups.y) {
float sum = 0.0f;
for (uint i = 0; i < p.ne10; ++i) {
if (data_b[get_boffset() + i*p.nb10] == int(row)) {
sum += data_a[get_aoffset() + i*p.nb01 + col*p.nb00];
}
}
data_d[get_doffset() + row*p.nb21 + col*p.nb20] = sum;
}
}
@@ -14,13 +14,16 @@ void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
const uint a_base = get_aoffset() + src0_idx(row * p.ne00);
const uint d_base = get_doffset() + dst_idx(row * p.ne10);
const uint i3 = row / (p.ne11 * p.ne12);
const uint i3_offset = i3 * p.ne12 * p.ne11;
const uint i2 = (row - i3_offset) / p.ne11;
const uint i2_offset = i2 * p.ne11;
const uint i1 = row - i3_offset - i2_offset;
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_base + i0*p.nb00]);
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]);
sum[tid] += xi * xi;
}
@@ -36,6 +39,6 @@ void main() {
const FLOAT_TYPE scale = 1.0f / max(sqrt(sum[0]), FLOAT_TYPE(p.param1));
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
data_d[d_base + i0*p.nb10] = D_TYPE(scale * FLOAT_TYPE(data_a[a_base + i0*p.nb00]));
data_d[i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0] = D_TYPE(scale * FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]));
}
}
@@ -0,0 +1,22 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float val = float(data_a[i]);
data_d[i] = D_TYPE(max(val, 0.0f) + min(val, 0.0f) * p.param1);
}
+23 -31
View File
@@ -38,7 +38,17 @@
#define LOAD_VEC_B 1
#endif
layout (constant_id = 11) const uint ALIGNED = 0;
// Load 2 values at once without affecting index calculations through LOAD_VEC
#if (defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)) && !defined(ALIGNED)
#define LOAD_VEC_BATCH_A 2
#else
#define LOAD_VEC_BATCH_A 1
#endif
#if !defined(ALIGNED)
#define LOAD_VEC_BATCH_B 2
#else
#define LOAD_VEC_BATCH_B 1
#endif
#if !defined(TO_FLOAT_TYPE)
#define TO_FLOAT_TYPE FLOAT_TYPE
@@ -47,13 +57,6 @@ layout (constant_id = 11) const uint ALIGNED = 0;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(DATA_A_F32)
layout (binding = 0) readonly buffer A_SCALAR {float data_a_scalar[];};
#elif defined(DATA_A_F16)
layout (binding = 0) readonly buffer A_SCALAR {float16_t data_a_scalar[];};
#elif defined(DATA_A_BF16)
layout (binding = 0) readonly buffer A_SCALAR {uint16_t data_a_scalar[];};
#endif
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
@@ -62,7 +65,6 @@ layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32
#endif
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 1) readonly buffer B_SCALAR {B_TYPE_SCALAR data_b_scalar[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID
@@ -192,23 +194,13 @@ void main() {
const uint warp_r = warp_i % (BM / WM);
const uint warp_c = warp_i / (BM / WM);
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
const uint LOAD_VEC_A_EFF = (ALIGNED != 0) ? LOAD_VEC_A : 1;
const uint LOAD_VEC_BATCH_A = (ALIGNED != 0) ? 1 : 2;
#else
const uint LOAD_VEC_A_EFF = LOAD_VEC_A;
const uint LOAD_VEC_BATCH_A = 1;
#endif
const uint LOAD_VEC_B_EFF = (ALIGNED != 0) ? LOAD_VEC_B : 1;
const uint LOAD_VEC_BATCH_B = (ALIGNED != 0) ? 1 : 2;
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A_EFF / LOAD_VEC_BATCH_A);
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A_EFF / LOAD_VEC_BATCH_A);
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B_EFF / LOAD_VEC_BATCH_B);
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B_EFF / LOAD_VEC_BATCH_B);
const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A_EFF * LOAD_VEC_BATCH_A / BK;
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B_EFF * LOAD_VEC_BATCH_B / BK;
const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A * LOAD_VEC_BATCH_A / BK;
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B * LOAD_VEC_BATCH_B / BK;
#ifdef MUL_MAT_ID
#ifdef MUL_MAT_ID_USE_SUBGROUPS
@@ -247,15 +239,15 @@ void main() {
uint pos_a =
#ifdef MUL_MAT_ID
expert_idx * (p.batch_stride_a / LOAD_VEC_A_EFF) +
expert_idx * (p.batch_stride_a / LOAD_VEC_A) +
#else
batch_idx_a * (p.batch_stride_a / LOAD_VEC_A_EFF) +
batch_idx_a * (p.batch_stride_a / LOAD_VEC_A) +
#endif
(ir * BM * p.stride_a + start_k) / LOAD_VEC_A_EFF;
(ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
#ifdef MUL_MAT_ID
uint pos_b = 0;
#else
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B_EFF;
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
#endif
#ifdef COOPMAT
@@ -295,8 +287,8 @@ void main() {
barrier();
pos_a += BK / LOAD_VEC_A_EFF;
pos_b += BK / LOAD_VEC_B_EFF;
pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B;
#ifdef COOPMAT
[[unroll]] for (uint i = 0; i < BK; i += TK) {
@@ -36,7 +36,6 @@ layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working wit
layout (constant_id = 4) const bool enable_smaller_matrices = false;
const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN;
const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN;
layout (constant_id = 5) const uint ALIGNED = 0;
layout (push_constant) uniform parameter
{
@@ -112,7 +111,7 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
};
uint _ne1;
layout (constant_id = 6) const uint subgroup_size = 32;
layout (constant_id = 5) const uint subgroup_size = 32;
shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size];
B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
@@ -298,12 +297,12 @@ void main() {
// Hint to the compiler that values are aligned (want 16B alignment).
// Quants are always block-aligned, no alignment needed.
if (ALIGNED != 0) {
#if ALIGNED
#if QUANT_K == 1
stride_a &= ~7;
stride_a &= ~7;
#endif
stride_b &= ~7;
#endif
stride_b &= ~7;
}
// Create layouts for both clamped and unclamped accesses
tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2);
@@ -1,57 +1,50 @@
void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uint idx_m, const uint block, const uint end_k) {
#if defined(DATA_A_F32) || defined(DATA_A_F16)
#if LOAD_VEC_A == 8
if (ALIGNED != 0) {
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]);
buf_a[buf_idx ] = aa[0].xy;
buf_a[buf_idx + 1] = aa[0].zw;
buf_a[buf_idx + 2] = aa[1].xy;
buf_a[buf_idx + 3] = aa[1].zw;
return;
}
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]);
buf_a[buf_idx ] = aa[0].xy;
buf_a[buf_idx + 1] = aa[0].zw;
buf_a[buf_idx + 2] = aa[1].xy;
buf_a[buf_idx + 3] = aa[1].zw;
#elif LOAD_VEC_A == 4
if (ALIGNED != 0) {
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]);
buf_a[buf_idx ] = aa.xy;
buf_a[buf_idx + 1] = aa.zw;
return;
}
#endif
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]);
buf_a[buf_idx ] = aa.xy;
buf_a[buf_idx + 1] = aa.zw;
#else // LOAD_VEC_BATCH_A == 2
const uint idx = pos_a + col * p.stride_a + row * 2;
const uint buf_idx = col * SHMEM_STRIDE + row;
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
buf_a[buf_idx] = FLOAT_TYPEV2(data_a_scalar[idx],
data_a_scalar[idx + 1]);
buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx],
data_a[idx + 1]);
} else if (idx_m < p.M && block + row * 2 < end_k) {
buf_a[buf_idx] = FLOAT_TYPEV2(data_a_scalar[idx], 0.0f);
buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx], 0.0f);
} else {
buf_a[buf_idx] = FLOAT_TYPEV2(0.0f);
}
#endif
#elif defined(DATA_A_BF16)
#if LOAD_VEC_A == 4
if (ALIGNED != 0) {
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx]));
buf_a[buf_idx ] = aa.xy;
buf_a[buf_idx + 1] = aa.zw;
return;
}
#endif
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx]));
buf_a[buf_idx ] = aa.xy;
buf_a[buf_idx + 1] = aa.zw;
#else // LOAD_VEC_BATCH_A == 2
const uint idx = pos_a + col * p.stride_a + row * 2;
const uint buf_idx = col * SHMEM_STRIDE + row;
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a_scalar[idx]),
TO_FLOAT_TYPE(data_a_scalar[idx + 1]));
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]),
TO_FLOAT_TYPE(data_a[idx + 1]));
} else if (idx_m < p.M && block + row * 2 < end_k) {
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a_scalar[idx]), 0.0f);
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]), 0.0f);
} else {
buf_a[buf_idx] = FLOAT_TYPEV2(0.0f);
}
#endif
#elif defined(DATA_A_Q4_0)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
@@ -533,85 +526,75 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
#if !defined(MUL_MAT_ID)
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint idx_n, const uint block, const uint end_k) {
#if LOAD_VEC_B == 8
if (ALIGNED != 0) {
// Not supported for b_type bf16 because bf16mat2x4 does not exist
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
buf_b[buf_idx + 0] = bb[0].xy;
buf_b[buf_idx + 1] = bb[0].zw;
buf_b[buf_idx + 2] = bb[1].xy;
buf_b[buf_idx + 3] = bb[1].zw;
return;
}
// Not supported for b_type bf16 because bf16mat2x4 does not exist
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
buf_b[buf_idx + 0] = bb[0].xy;
buf_b[buf_idx + 1] = bb[0].zw;
buf_b[buf_idx + 2] = bb[1].xy;
buf_b[buf_idx + 3] = bb[1].zw;
#elif LOAD_VEC_B == 4
if (ALIGNED != 0) {
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
#if defined(DATA_B_BF16)
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
#else
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
#endif
buf_b[buf_idx + 0] = bb.xy;
buf_b[buf_idx + 1] = bb.zw;
return;
}
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
#endif
buf_b[buf_idx + 0] = bb.xy;
buf_b[buf_idx + 1] = bb.zw;
#else // LOAD_VEC_BATCH_B == 2
const uint idx = pos_b + col * p.stride_b + row * 2;
const uint buf_idx = col * SHMEM_STRIDE + row;
if (idx_n < p.N && block + row * 2 + 1 < end_k) {
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]),
TO_FLOAT_TYPE(data_b_scalar[idx + 1]));
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]),
TO_FLOAT_TYPE(data_b[idx + 1]));
} else if (idx_n < p.N && block + row * 2 < end_k) {
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), 0.0f);
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
} else {
buf_b[buf_idx] = FLOAT_TYPEV2(0.0f);
}
#endif
}
#else
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint ic, const uint _ne1, const uint block, const uint end_k) {
#if LOAD_VEC_B == 8
if (ALIGNED != 0) {
// Not supported for b_type bf16 because bf16mat2x4 does not exist
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
buf_b[buf_idx + 0] = bb[0].xy;
buf_b[buf_idx + 1] = bb[0].zw;
buf_b[buf_idx + 2] = bb[1].xy;
buf_b[buf_idx + 3] = bb[1].zw;
return;
}
// Not supported for b_type bf16 because bf16mat2x4 does not exist
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
buf_b[buf_idx + 0] = bb[0].xy;
buf_b[buf_idx + 1] = bb[0].zw;
buf_b[buf_idx + 2] = bb[1].xy;
buf_b[buf_idx + 3] = bb[1].zw;
#elif LOAD_VEC_B == 4
if (ALIGNED != 0) {
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
#if defined(DATA_B_BF16)
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
#else
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
#endif
buf_b[buf_idx + 0] = bb.xy;
buf_b[buf_idx + 1] = bb.zw;
return;
}
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
#endif
buf_b[buf_idx + 0] = bb.xy;
buf_b[buf_idx + 1] = bb.zw;
#else // LOAD_VEC_BATCH_B == 2
const uint row_i = ic * BN + col;
const uint buf_idx = col * SHMEM_STRIDE + row;
if (row_i < _ne1 && block + row * 2 + 1 < end_k) {
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]),
TO_FLOAT_TYPE(data_b_scalar[idx + 1]));
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]),
TO_FLOAT_TYPE(data_b[idx + 1]));
} else if (row_i < _ne1 && block + row * 2 < end_k) {
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), 0.0f);
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
} else {
buf_b[buf_idx] = FLOAT_TYPEV2(0.0f);
}
#endif
}
#endif
+10 -10
View File
@@ -1,26 +1,26 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#include "generic_unary_head.glsl"
#extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 512
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
shared vec2 sum[BLOCK_SIZE];
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
const uint a_base = get_aoffset() + src0_idx(row * p.ne00);
const uint d_base = get_doffset() + dst_idx(row * p.ne10);
sum[tid] = vec2(0.0f, 0.0f);
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
const float xi = float(data_a[a_base + i0*p.nb00]);
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
const float xi = float(data_a[row*p.KX + col]);
sum[tid].x += xi;
sum[tid].y += xi * xi;
}
@@ -34,11 +34,11 @@ void main() {
barrier();
}
const float mean = sum[0].x / p.ne00;
const float var = sum[0].y / p.ne00 - mean * mean;
const float mean = sum[0].x / p.KX;
const float var = sum[0].y / p.KX - mean * mean;
const float inv_std = inversesqrt(var + p.param1);
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
data_d[d_base + i0*p.nb10] = D_TYPE((float(data_a[a_base + i0*p.nb00]) - mean) * inv_std);
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
data_d[row*p.KX + col] = D_TYPE((float(data_a[row*p.KX + col]) - mean) * inv_std);
}
}
@@ -0,0 +1,17 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sin(val));
}
@@ -0,0 +1,17 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sqrt(val));
}
@@ -0,0 +1,17 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val * val);
}
@@ -17,30 +17,6 @@ float op_neg(float x) {
return -x;
}
float op_sqr(float x) {
return x * x;
}
float op_sqrt(float x) {
return sqrt(x);
}
float op_sin(float x) {
return sin(x);
}
float op_cos(float x) {
return cos(x);
}
float op_clamp(float x) {
return clamp(x, p.param1, p.param2);
}
float op_leaky_relu(float x) {
return max(x, 0.0f) + min(x, 0.0f) * p.param1;
}
float op_step(float x) {
return x >= 0.0f ? 1.0f : 0.0f;
}
@@ -11,7 +11,6 @@
#include <future>
#include <queue>
#include <condition_variable>
#include <atomic>
#include <cstdio>
#include <cstring>
#include <cstdlib>
@@ -35,9 +34,6 @@
std::mutex lock;
std::vector<std::pair<std::string, std::string>> shader_fnames;
// Set when any shader subprocess fails (non-zero exit / stderr / launch failure) so the
// build is stopped instead of silently producing a broken libggml-vulkan. (issue #24393)
static std::atomic<bool> compile_failed{false};
std::locale c_locale("C");
std::string GLSLC = "glslc";
@@ -82,7 +78,7 @@ enum MatMulIdType {
namespace {
int execute_command(std::vector<std::string>& command, std::string& stdout_str, std::string& stderr_str) {
void execute_command(std::vector<std::string>& command, std::string& stdout_str, std::string& stderr_str) {
#ifdef _WIN32
HANDLE stdout_read, stdout_write;
HANDLE stderr_read, stderr_write;
@@ -131,11 +127,8 @@ int execute_command(std::vector<std::string>& command, std::string& stdout_str,
CloseHandle(stdout_read);
CloseHandle(stderr_read);
WaitForSingleObject(pi.hProcess, INFINITE);
DWORD exit_code = 1;
GetExitCodeProcess(pi.hProcess, &exit_code);
CloseHandle(pi.hProcess);
CloseHandle(pi.hThread);
return (int)exit_code;
#else
int stdout_pipe[2];
int stderr_pipe[2];
@@ -182,9 +175,7 @@ int execute_command(std::vector<std::string>& command, std::string& stdout_str,
close(stdout_pipe[0]);
close(stderr_pipe[0]);
int status = 0;
waitpid(pid, &status, 0);
return WIFEXITED(status) ? WEXITSTATUS(status) : -1;
waitpid(pid, nullptr, 0);
}
#endif
}
@@ -381,14 +372,13 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
// }
// std::cout << std::endl;
int exit_code = execute_command(cmd, stdout_str, stderr_str);
if (exit_code != 0 || !stderr_str.empty()) {
std::cerr << "cannot compile " << name << " (exit code " << exit_code << ")\n\n";
execute_command(cmd, stdout_str, stderr_str);
if (!stderr_str.empty()) {
std::cerr << "cannot compile " << name << "\n\n";
for (const auto& part : cmd) {
std::cerr << part << " ";
}
std::cerr << "\n\n" << stderr_str << std::endl;
compile_failed = true;
return;
}
@@ -408,7 +398,6 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
shader_fnames.push_back(std::make_pair(name, out_path));
} catch (const std::exception& e) {
std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl;
compile_failed = true;
}
}
@@ -550,9 +539,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
};
// Shaders with f16 B_TYPE
string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
// bf16
{
@@ -574,7 +565,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
#endif
{
if (!dot2) {
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPE_SCALAR", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
}
}
}
@@ -591,6 +583,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
}
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
// For unaligned, load one at a time for f32/f16, or two at a time for quants
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant;
// For aligned matmul loads
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
@@ -603,11 +597,13 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
// don't generate f32 variants for coopmat2
if (!coopmat2) {
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPE_SCALAR", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
}
if (tname != "f16" && tname != "f32") {
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
}
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
@@ -854,12 +850,21 @@ void process_shaders() {
string_to_spv("repeat_i32", "repeat.comp", {{"A_TYPE", "int32_t"}, {"D_TYPE", "int32_t"}});
string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("get_rows_back_f32", "get_rows_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}});
string_to_spv("repeat_i16", "repeat.comp", {{"A_TYPE", "int16_t"}, {"D_TYPE", "int16_t"}});
string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("sqrt_f32", "sqrt.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("concat_i8", "concat.comp", {{"A_TYPE", "uint8_t"}, {"B_TYPE", "uint8_t"}, {"D_TYPE", "uint8_t"}});
@@ -886,18 +891,6 @@ void process_shaders() {
string_to_spv("silu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_silu"}});
string_to_spv("relu_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_relu"}});
string_to_spv("relu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_relu"}});
string_to_spv("sqr_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sqr"}});
string_to_spv("sqr_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sqr"}});
string_to_spv("sqrt_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sqrt"}});
string_to_spv("sqrt_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sqrt"}});
string_to_spv("sin_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sin"}});
string_to_spv("sin_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sin"}});
string_to_spv("cos_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_cos"}});
string_to_spv("cos_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_cos"}});
string_to_spv("clamp_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_clamp"}});
string_to_spv("clamp_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_clamp"}});
string_to_spv("leaky_relu_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_leaky_relu"}});
string_to_spv("leaky_relu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_leaky_relu"}});
string_to_spv("neg_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_neg"}});
string_to_spv("neg_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_neg"}});
string_to_spv("tanh_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_tanh"}});
@@ -955,6 +948,7 @@ void process_shaders() {
string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
@@ -1066,31 +1060,6 @@ void process_shaders() {
}
}
for (auto unroll : {false, true}) {
for (auto a_f16 : {false, true}) {
std::map<std::string, std::string> defines = {
{"A_TYPE", a_f16 ? "float16_t" : "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"},
{"UNROLL", unroll ? "[[unroll]]" : ""},
};
std::string name = std::string("conv3d") + (a_f16 ? "_f16" : "") + "_f32";
string_to_spv(name + (unroll ? "_unroll" : ""), "conv3d_mm.comp", defines);
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (unroll) {
auto cm2_defines = defines;
cm2_defines["COOPMAT2"] = "1";
string_to_spv(name, "conv3d_mm.comp", cm2_defines, true, false, true);
}
#endif
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (unroll) {
auto cm1_defines = defines;
cm1_defines["COOPMAT"] = "1";
string_to_spv(name, "conv3d_mm.comp", cm1_defines, true, true, false);
}
#endif
}
}
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
string_to_spv("conv2d_dw_whcn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
@@ -1282,11 +1251,6 @@ int main(int argc, char** argv) {
process_shaders();
if (compile_failed) {
std::cerr << "vulkan-shaders-gen: one or more shaders failed to compile" << std::endl;
return EXIT_FAILURE;
}
write_output_files();
return EXIT_SUCCESS;
@@ -905,12 +905,11 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key {
ggml_type src0_type;
ggml_type src1_type;
int vectorized;
uint32_t num_cols;
bool use_mmvq;
bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const {
return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&
num_cols == other.num_cols && use_mmvq == other.use_mmvq;
use_mmvq == other.use_mmvq;
}
};
@@ -920,7 +919,6 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {
ggml_webgpu_hash_combine(seed, key.src0_type);
ggml_webgpu_hash_combine(seed, key.src1_type);
ggml_webgpu_hash_combine(seed, key.vectorized);
ggml_webgpu_hash_combine(seed, key.num_cols);
ggml_webgpu_hash_combine(seed, key.use_mmvq);
return seed;
}
@@ -995,12 +993,11 @@ struct ggml_webgpu_mul_mat_id_pipeline_key {
ggml_type src0_type;
ggml_type src1_type;
uint32_t n_experts;
uint32_t num_cols;
int vectorized;
bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const {
return src0_type == other.src0_type && src1_type == other.src1_type && n_experts == other.n_experts &&
num_cols == other.num_cols && vectorized == other.vectorized;
vectorized == other.vectorized;
}
};
@@ -1010,7 +1007,6 @@ struct ggml_webgpu_mul_mat_id_pipeline_key_hash {
ggml_webgpu_hash_combine(seed, key.src0_type);
ggml_webgpu_hash_combine(seed, key.src1_type);
ggml_webgpu_hash_combine(seed, key.n_experts);
ggml_webgpu_hash_combine(seed, key.num_cols);
ggml_webgpu_hash_combine(seed, key.vectorized);
return seed;
}
@@ -1111,7 +1107,7 @@ inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0,
const ggml_tensor * src1,
bool supports_dot_product,
const std::string & vendor) {
if (src1->ne[1] <= 4) {
if (src1->ne[1] == 1) {
bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia";
if (supports_dp4a && supports_dot_product) {
switch (src1->type) {
@@ -1893,7 +1889,6 @@ class ggml_webgpu_shader_lib {
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0;
key.num_cols = context.dst->ne[1];
key.use_mmvq =
ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor);
@@ -2009,7 +2004,6 @@ class ggml_webgpu_shader_lib {
if (key.vectorized) {
variant += "_vectorized";
}
defines.push_back(std::string("NUM_COLS=") + std::to_string(key.num_cols));
auto processed = preprocessor.preprocess(shader_src, defines);
auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
@@ -2427,7 +2421,6 @@ class ggml_webgpu_shader_lib {
if (key.vectorized) {
variant += "_vectorized";
}
defines.push_back(std::string("NUM_COLS=1"));
defines.push_back(std::string("N_EXPERTS=") + std::to_string(key.n_experts));
+10 -12
View File
@@ -1418,17 +1418,15 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context &
const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
const size_t q8_src1_align_offset = ROUNDUP_POW2(
dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
const size_t q8_src1_binding_size = ROUNDUP_POW2(
src1->ne[3] * src1->ne[2] * src1->ne[1] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)),
WEBGPU_STORAGE_BUF_BINDING_MULT);
const size_t q8_src1_binding_size =
ROUNDUP_POW2(src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)),
WEBGPU_STORAGE_BUF_BINDING_MULT);
std::vector<uint32_t> q8_params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
(uint32_t) src1->ne[0],
(uint32_t) src1->ne[1],
(uint32_t) src1->ne[2],
(uint32_t) src1->ne[3],
};
@@ -1444,7 +1442,7 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context &
uint32_t q8_wg_x = 1;
uint32_t q8_wg_y = 1;
const uint32_t wg_per_vec = (src0->ne[0] / 4 + (q8_wg_size - 1)) / q8_wg_size;
const uint32_t q8_total_wg = src1->ne[1] * src1->ne[2] * src1->ne[3] * wg_per_vec;
const uint32_t q8_total_wg = src1->ne[2] * src1->ne[3] * wg_per_vec;
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
compute_2d_workgroups(q8_total_wg, max_wg_per_dim, q8_wg_x, q8_wg_y);
@@ -1458,7 +1456,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
ggml_tensor * src1,
ggml_tensor * dst) {
// Determine if this is a mat-vec operation
bool use_mat_vec = (dst->ne[1] <= 4);
bool is_vec = (dst->ne[1] == 1);
// use MMVQ path for mat-vec
bool use_mmvq = ggml_webgpu_can_use_mmvq(src0, src1, ctx->global_ctx->capabilities.supports_dot_product,
@@ -1484,7 +1482,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
webgpu_pipeline pipeline;
std::vector<webgpu_dispatch_desc> dispatches;
if (use_mat_vec) {
if (is_vec) {
if (use_mmvq) {
ggml_webgpu_quantize_q8_dispatch(ctx, src0, src1, dst, dispatches);
}
@@ -1531,7 +1529,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
uint32_t wg_y = 1;
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
if (use_mat_vec) {
if (is_vec) {
auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
uint32_t batches = dst->ne[2] * dst->ne[3];
@@ -3693,8 +3691,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
ggml_webgpu_can_use_mmvq(src0, src1, ctx->webgpu_global_ctx->capabilities.supports_dot_product,
ctx->webgpu_global_ctx->vendor);
if (use_mmvq) {
const size_t q8_src1_size = src1->ne[3] * src1->ne[2] * src1->ne[1] *
(36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32));
const size_t q8_src1_size =
src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32));
res = ROUNDUP_POW2(res + q8_src1_size +
ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
WEBGPU_STORAGE_BUF_BINDING_MULT);
@@ -4270,7 +4268,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_OP_RMS_NORM:
case GGML_OP_NORM:
case GGML_OP_L2_NORM:
supports_op = (op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32) && ggml_is_contiguous_rows(src0);
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
break;
case GGML_OP_ROPE:
supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
@@ -103,7 +103,7 @@ fn main(
#ifdef USE_SUBGROUP_REDUCTION
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let subgroup_total = subgroupAdd(acc[0][row]);
let subgroup_total = subgroupAdd(acc[row]);
if (subgroup_invocation_id == 0u) {
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
}
@@ -126,7 +126,7 @@ fn main(
#ifdef USE_WORKGROUP_REDUCTION
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
partial_sums[partial_index(row, thread_id)] = acc[0][row];
partial_sums[partial_index(row, thread_id)] = acc[row];
}
workgroupBarrier();
@@ -91,67 +91,61 @@ fn main(
let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base;
#ifdef MMVQ
let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * params.n * (params.k / 32u);
let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * (params.k / 32u);
let acc = accumulate_vec_q_dot(thread_id, row_base, src0_batch_offset, src1q_idx_base);
#else
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base);
#endif
for (var col = 0u;col < NUM_COLS;col += 1) {
#ifdef USE_SUBGROUP_REDUCTION
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let subgroup_total = subgroupAdd(acc[col][row]);
if (subgroup_invocation_id == 0u) {
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
}
}
workgroupBarrier();
for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) {
let output_row = row_base + row;
var row_acc = 0.0f;
for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) {
row_acc += partial_sums[partial_index(row, k)];
}
let row_total = subgroupAdd(row_acc);
if (subgroup_invocation_id == 0) {
dst[dst_idx_base + col * params.m + row] = row_total;
}
}
#endif
#ifdef USE_WORKGROUP_REDUCTION
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
partial_sums[partial_index(row, thread_id)] = acc[col][row];
}
workgroupBarrier();
var stride = WG_SIZE / 2u;
while (stride > 0) {
if (thread_id < stride) {
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)];
}
}
workgroupBarrier();
stride = stride / 2;
}
if (thread_id < OUTPUTS_PER_WG) {
let output_row = row_base + thread_id;
if (output_row < params.m) {
dst[dst_idx_base + col * params.m + thread_id] = partial_sums[partial_index(thread_id, 0)];
}
}
#endif
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let subgroup_total = subgroupAdd(acc[row]);
if (subgroup_invocation_id == 0u) {
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
}
}
workgroupBarrier();
for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) {
let output_row = row_base + row;
var row_acc = 0.0f;
for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) {
row_acc += partial_sums[partial_index(row, k)];
}
let row_total = subgroupAdd(row_acc);
if (subgroup_invocation_id == 0) {
dst[dst_idx_base + row] = row_total;
}
}
#endif
#ifdef USE_WORKGROUP_REDUCTION
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
partial_sums[partial_index(row, thread_id)] = acc[row];
}
workgroupBarrier();
var stride = WG_SIZE / 2u;
while (stride > 0) {
if (thread_id < stride) {
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)];
}
}
workgroupBarrier();
stride = stride / 2;
}
if (thread_id < OUTPUTS_PER_WG) {
let output_row = row_base + thread_id;
if (output_row < params.m) {
dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)];
}
}
#endif
}
File diff suppressed because it is too large Load Diff
@@ -51,7 +51,10 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE {
fn get_dm(block_byte_base: u32) -> f32 {
return f32(load_f16_at_src0(block_byte_base));
}
#endif // MUL_ACC_Q4_0
fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 {
return f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK;
}
#endif
#ifdef MUL_ACC_Q4_1
#define BLOCK_SIZE_BYTES 20
@@ -82,7 +85,10 @@ fn get_dm(block_byte_base: u32) -> vec2<f32> {
f32(load_f16_at_src0(block_byte_base + 2u))
);
}
#endif // MUL_ACC_Q4_1
fn mul_q8_1(row_sum: i32, dma: vec2<f32>, b_ds: B_DS_TYPE) -> f32 {
return f32(row_sum) * (dma.x * b_ds.x) + dma.y * b_ds.y / THREADS_PER_BLOCK;
}
#endif
#ifdef MUL_ACC_Q8_0
#define BLOCK_SIZE_BYTES 34
@@ -105,48 +111,46 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE {
fn get_dm(block_byte_base: u32) -> f32 {
return f32(load_f16_at_src0(block_byte_base));
}
#endif // MUL_ACC_Q8_0
fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 {
return f32(row_sum) * (da * b_ds);
}
#endif
#if defined(LEGACY_QUANTS)
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<array<f32, OUTPUTS_PER_WG>, NUM_COLS> {
var acc: array<array<f32, OUTPUTS_PER_WG>, NUM_COLS>;
#ifdef LEGACY_QUANTS
fn mmvq_dot_product(a_byte_base: u32, b_inner_id: u32, b_repacked: vec2<u32>, b_ds: B_DS_TYPE) -> f32 {
var row_sum = 0;
let a_repacked = repack_a(a_byte_base, b_inner_id);
row_sum += dot4I8Packed(a_repacked[0], b_repacked[0]);
row_sum += dot4I8Packed(a_repacked[1], b_repacked[1]);
return mul_q8_1(row_sum, get_dm(a_byte_base), b_ds);
}
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
var acc: array<f32, OUTPUTS_PER_WG>;
let num_blocks = params.k / BLOCK_SIZE;
for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
let inner_id = thread_id % THREADS_PER_BLOCK;
let b_inner_id = thread_id % THREADS_PER_BLOCK;
let b_block_idx = src1q_idx_base + block;
let b_repacked = repack_b_qs(b_block_idx, b_inner_id);
let b_ds = repack_b_dm(b_block_idx);
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
let a_repacked = repack_a(block_byte_base, inner_id);
let da = get_dm(block_byte_base);
for (var col = 0u;col < NUM_COLS;col += 1) {
let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + block;
let b_repacked = repack_b_qs(src1q_idx, inner_id);
let b_ds = repack_b_dm(src1q_idx);
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]);
#if defined(MUL_ACC_Q4_0)
acc[col][row] += f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK;
#endif // MUL_ACC_Q4_0
#if defined(MUL_ACC_Q4_1)
acc[col][row] += f32(row_sum) * (da.x * b_ds.x) + da.y * b_ds.y / THREADS_PER_BLOCK;
#endif // MUL_ACC_Q4_1
#if defined(MUL_ACC_Q8_0)
acc[col][row] += f32(row_sum) * (da * b_ds);
#endif // MUL_ACC_Q8_0
}
acc[row] += mmvq_dot_product(block_byte_base, b_inner_id, b_repacked, b_ds);
}
}
}
return acc;
}
#endif // LEGACY_QUANTS
#endif
#ifdef MUL_ACC_Q2_K
#define BLOCK_SIZE_BYTES 84
@@ -187,7 +191,22 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> {
let scale = byte_of(load_u32_at_src0_aligned(scale_byte), scale_byte & 3u);
return vec2<f32>(f32(scale & 0xFu), f32(scale >> 4u));
}
#endif // MUL_ACC_Q2_K
fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 {
let a_repacked = repack_a(a_byte_base, tid);
let dm = get_dm(a_byte_base);
let scale_min = get_scale_min(a_byte_base, tid);
let scale_q = i32(scale_min.x);
let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u;
let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1])
+ dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q;
let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4)
+ dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4);
return b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m));
}
#endif
#ifdef MUL_ACC_Q4_K
#define BLOCK_SIZE_BYTES 144
@@ -246,52 +265,39 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> {
return vec2<f32>(scale, min_val);
}
#endif // MUL_ACC_Q4_K
fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 {
let a_repacked = repack_a(a_byte_base, tid);
let dm = get_dm(a_byte_base);
let scale_min = get_scale_min(a_byte_base, tid);
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1])
+ dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]);
// Each thread covers half of the Q8_1 block, so add only b_ds.y/2.
return b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD));
}
#endif
#ifdef K_QUANTS
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<array<f32, OUTPUTS_PER_WG>, NUM_COLS> {
var acc: array<array<f32, OUTPUTS_PER_WG>, NUM_COLS>;
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
var acc: array<f32, OUTPUTS_PER_WG>;
let tid = thread_id % THREADS_PER_BLOCK;
for (var block = thread_id / THREADS_PER_BLOCK; block < params.k / BLOCK_SIZE; block += WG_SIZE / THREADS_PER_BLOCK) {
let src1q_idx = src1q_idx_base + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE;
let b_repacked = repack_b_qs(src1q_idx, tid);
let b_ds = repack_b_dm(src1q_idx);
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
let a_repacked = repack_a(block_byte_base, tid);
let dm = get_dm(block_byte_base);
let scale_min = get_scale_min(block_byte_base, tid);
for (var col = 0u;col < NUM_COLS;col += 1) {
let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE;
let b_repacked = repack_b_qs(src1q_idx, tid);
let b_ds = repack_b_dm(src1q_idx);
#if defined(MUL_ACC_Q2_K)
let scale_q = i32(scale_min.x);
let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u;
let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1])
+ dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q;
let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4)
+ dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4);
acc[col][row] += b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m));
#endif // MUL_ACC_Q2_K
#if defined(MUL_ACC_Q4_K)
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1])
+ dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]);
// Each thread covers half of the Q8_1 block, so add only b_ds.y/2.
acc[col][row] += b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD));
#endif // MUL_ACC_Q4_K
}
acc[row] += mmvq_dot_product(block_byte_base, tid, b_repacked, b_ds);
}
}
}
return acc;
}
#endif // K_QUANTS
#endif
@@ -9,11 +9,9 @@ requires packed_4x8_integer_dot_product;
struct Params {
offset_src1: u32,
stride_11: u32,
stride_12: u32,
stride_13: u32,
ne0: u32,
ne1: u32,
ne2: u32,
ne3: u32,
};
@@ -59,28 +57,25 @@ fn main(
@builtin(num_workgroups) num_wg: vec3<u32>
) {
let thread_id = local_id.x;
let ne0_vec4 = params.ne0 / 4u;
let num_vec4 = params.ne0 / 4u;
let wg_per_vec = (ne0_vec4 + (WG_SIZE - 1u)) / WG_SIZE;
let total_batches = wg_per_vec * params.ne1 * params.ne2 * params.ne3;
let wg_per_vec = (num_vec4 + (WG_SIZE - 1u)) / WG_SIZE;
let total_batches = wg_per_vec * params.ne2 * params.ne3;
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
if (wg_linear >= total_batches) {
return;
}
let vec_idx = wg_linear / wg_per_vec;
let src13_idx = vec_idx / (params.ne2 * params.ne1);
let vec_ne12_num = vec_idx % (params.ne2 * params.ne1);
let src12_idx = vec_ne12_num / params.ne1;
let src11_idx = vec_ne12_num % params.ne1;
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + src11_idx * params.stride_11;
let src13_idx = wg_linear / (params.ne2 * wg_per_vec);
let src12_idx = (wg_linear - src13_idx * (params.ne2 * wg_per_vec)) / wg_per_vec;
let src11_wg_idx = wg_linear % wg_per_vec;
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
let src1_idx_vec4_base = src1_idx_base / 4u;
let blocks_per_row = params.ne0 / 32u;
let blocks_per_wg = (WG_SIZE * 4u) / 32u;
let src1q_idx_base = ((src13_idx * params.ne2 + src12_idx) * params.ne1 + src11_idx) * blocks_per_row;
let src11_wg_idx = wg_linear % wg_per_vec;
let src1q_idx_base = (src13_idx * params.ne2 + src12_idx) * blocks_per_row;
let src1q_idx = src1q_idx_base + src11_wg_idx * blocks_per_wg + thread_id / 8u;
let qs_idx = thread_id % 8u;
@@ -90,7 +85,7 @@ fn main(
var thread_amax = 0.0;
let src11_vec4_idx = src11_wg_idx * WG_SIZE + thread_id;
let is_valid = src11_vec4_idx < ne0_vec4;
let is_valid = src11_vec4_idx < num_vec4;
#ifdef USE_SUBGROUP_REDUCTION
-1
View File
@@ -359,7 +359,6 @@ class Keys:
CHUNK_SIZE = "clip.audio.chunk_size"
CONV_KERNEL_SIZE = "clip.audio.conv_kernel_size"
MAX_POS_EMB = "clip.audio.max_pos_emb"
FEATURE_LAYERS = "clip.audio.feature_layer" # Granite Speech Plus
class Attention:
HEAD_COUNT = "clip.audio.attention.head_count"
-3
View File
@@ -1310,9 +1310,6 @@ class GGUFWriter:
def add_audio_max_pos_emb(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.MAX_POS_EMB, value)
def add_audio_feature_layers(self, layers: Sequence[int]) -> None:
self.add_array(Keys.ClipAudio.FEATURE_LAYERS, layers)
def add_audio_projector_window_size(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.Projector.WINDOW_SIZE, value)
+4 -14
View File
@@ -190,15 +190,7 @@ llama_model_lfm2::graph<iswa>::graph(const llama_model & model, const llm_graph_
auto * conv_rs = build_rs(inp_recr, conv_state, hparams.n_embd_r(), n_seqs);
auto * conv = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs);
// causal prepends the state, non-causal pads symmetrically for a centered window
if (hparams.causal_attn) {
bx = ggml_concat(ctx0, conv, bx, 0);
} else {
const int64_t pad = (hparams.n_shortconv_l_cache - 1) / 2;
auto * left = ggml_cont(ctx0,
ggml_view_3d(ctx0, conv, pad, hparams.n_embd, n_seqs, conv->nb[1], conv->nb[2], (d_conv - pad) * conv->nb[0]));
bx = ggml_pad_ext(ctx0, ggml_concat(ctx0, left, bx, 0), 0, pad, 0, 0, 0, 0, 0, 0);
}
bx = ggml_concat(ctx0, conv, bx, 0);
GGML_ASSERT(bx->ne[0] > conv->ne[0]);
// last d_conv columns is a new conv state
@@ -274,12 +266,10 @@ llama_model_lfm2::graph<iswa>::graph(const llama_model & model, const llm_graph_
cb(cur, "result_norm", -1);
res->t_embd = cur;
if (!cparams.embeddings) {
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;
}
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
+8 -56
View File
@@ -3298,29 +3298,21 @@ struct test_norm : public test_case {
const std::array<int64_t, 4> ne;
const bool v; // whether a is a non-contiguous view
const float eps;
const bool noncontig_rows;
std::string vars() override {
return VARS_TO_STR5(type, ne, v, eps, noncontig_rows);
return VARS_TO_STR4(type, ne, v, eps);
}
test_norm(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {64, 5, 4, 3},
bool v = false,
float eps = 1e-6f,
bool noncontig_rows = false)
: type(type), ne(ne), v(v), eps(eps), noncontig_rows(noncontig_rows) {}
float eps = 1e-6f)
: type(type), ne(ne), v(v), eps(eps) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
const std::array<int64_t, 4> ne_a = noncontig_rows ?
std::array<int64_t, 4>{ ne[1], ne[0], ne[2], ne[3] } : ne;
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
if (noncontig_rows) {
a = ggml_permute(ctx, a, 1, 0, 2, 3);
ggml_set_name(a, "permuted a");
}
if (v) {
a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);
ggml_set_name(a, "view of a");
@@ -6201,29 +6193,21 @@ struct test_l2_norm : public test_case {
const std::array<int64_t, 4> ne;
const float eps;
bool v;
bool noncontig_rows;
std::string vars() override {
return VARS_TO_STR5(type, ne, eps, v, noncontig_rows);
return VARS_TO_STR4(type, ne, eps, v);
}
test_l2_norm(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {64, 64, 320, 1},
float eps = 1e-12f,
bool v = false,
bool noncontig_rows = false)
: type(type), ne(ne), eps(eps), v(v), noncontig_rows(noncontig_rows) {}
bool v = false)
: type(type), ne(ne), eps(eps), v(v) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
const std::array<int64_t, 4> ne_a = noncontig_rows ?
std::array<int64_t, 4>{ ne[1], ne[0], ne[2], ne[3] } : ne;
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
if (noncontig_rows) {
a = ggml_permute(ctx, a, 1, 0, 2, 3);
ggml_set_name(a, "permuted a");
}
if (v) {
a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);
ggml_set_name(a, "view of a");
@@ -8298,11 +8282,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps));
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps));
}
test_cases.emplace_back(new test_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, false, eps, true));
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, { n, 5, 4, 3 }, eps));
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false, true));
}
}
@@ -8451,7 +8433,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 4, k, {3, 2}, {2, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {2, 1}));
@@ -8468,7 +8449,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 4, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
@@ -9290,34 +9270,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
}
}
struct conv3d_perf_case {
int N, IC, ID, IH, IW, OC, KD, KH, KW, s0, s1, s2, p0, p1, p2, d0, d1, d2;
};
const std::vector<conv3d_perf_case> conv3d_cases = {
{1, 320, 8, 38, 26, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
{1, 1280, 8, 38, 26, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
{1, 320, 8, 76, 52, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
{1, 1280, 8, 76, 52, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
{1, 320, 8, 152, 104, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
#if 0
// too slow on some devices
{1, 1280, 8, 152, 104, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
{1, 320, 4, 304, 208, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
{1, 640, 4, 304, 208, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
#endif
};
for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
for (const conv3d_perf_case & c : conv3d_cases) {
test_cases.emplace_back(new test_conv_3d(
c.N, c.IC, c.ID, c.IH, c.IW,
c.OC, c.KD, c.KH, c.KW,
c.s0, c.s1, c.s2, c.p0, c.p1, c.p2, c.d0, c.d1, c.d2,
kernel_type));
}
}
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1}));
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));
+24 -99
View File
@@ -1562,112 +1562,37 @@ static void test_msgs_oaicompat_json_conversion() {
}
}
static void test_msg_token_delimiters_split() {
static void test_split_by_role() {
LOG_DBG("%s\n", __func__);
// Delimiters that share a leading token, distinguished by the second token,
// to exercise the per-position token matching.
const common_chat_msg_delimiters delims = {
{ { COMMON_CHAT_ROLE_USER, "", { 10, 11 } },
{ COMMON_CHAT_ROLE_ASSISTANT, "", { 10, 12 } } }
};
// Empty inputs
assert_equals<size_t>(0, common_chat_msg_delimiters{}.split({}).spans.size());
assert_equals<size_t>(0, common_chat_msg_delimiters{}.split({ 10, 11 }).spans.size());
assert_equals<size_t>(0, delims.split({}).spans.size());
assert_equals<size_t>(0, common_chat_split_by_role("", {}).size());
assert_equals<size_t>(0, common_chat_split_by_role("hello", {}).size());
assert_equals<size_t>(0, common_chat_split_by_role("", { { "user", "<|user|>" } }).size());
// No delimiters match -> no spans
assert_equals<size_t>(0, delims.split({ 100, 101, 102 }).spans.size());
// Multi-role conversation: <user>Hi<assistant>Hello<user>Bye
// Multi-role conversation, no leading/trailing content
{
const llama_tokens tokens = {
10, 11, // <user>
100, 101, // Hi
10, 12, // <assistant>
200, 201, 202, // Hello
10, 11, // <user>
300, 301, // Bye
};
const std::string prompt = "<|user|>Hi<|assistant|>Hello<|user|>Bye";
const auto splits = common_chat_split_by_role(prompt, {
{ "user", "<|user|>" },
{ "assistant", "<|assistant|>" },
});
assert_equals<size_t>(3, splits.size());
const auto result = delims.split(tokens);
const auto & spans = result.spans;
assert_equals<size_t>(3, spans.size());
assert_equals<std::string>("user", splits[0].role);
assert_equals<size_t>(0, splits[0].pos);
assert_equals<size_t>(10, splits[0].len);
assert_equals<std::string>("<|user|>Hi", prompt.substr(splits[0].pos, splits[0].len));
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
assert_equals<size_t>(0, spans[0].pos);
assert_equals<size_t>(4, spans[0].len);
assert_equals<std::string>("assistant", splits[1].role);
assert_equals<size_t>(10, splits[1].pos);
assert_equals<size_t>(18, splits[1].len);
assert_equals<std::string>("<|assistant|>Hello", prompt.substr(splits[1].pos, splits[1].len));
assert_equals(COMMON_CHAT_ROLE_ASSISTANT, spans[1].role);
assert_equals<size_t>(4, spans[1].pos);
assert_equals<size_t>(5, spans[1].len);
assert_equals(COMMON_CHAT_ROLE_USER, spans[2].role);
assert_equals<size_t>(9, spans[2].pos);
assert_equals<size_t>(4, spans[2].len);
// is_user_start() is true at the token position where a user span begins
assert_equals(true, result.is_user_start(0));
assert_equals(false, result.is_user_start(4)); // assistant span
assert_equals(true, result.is_user_start(9));
}
// Content before the first delimiter is not captured as a span
{
const llama_tokens tokens = {
500, 501, // leading content (dropped)
10, 11, // <user>
100, // Hi
};
const auto spans = delims.split(tokens).spans;
assert_equals<size_t>(1, spans.size());
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
assert_equals<size_t>(2, spans[0].pos);
assert_equals<size_t>(3, spans[0].len);
}
// Skipped regions (media chunks) are jumped over but still count as span content
{
const llama_tokens tokens = {
10, 11, // <user>
LLAMA_TOKEN_NULL, // media chunk (3 tokens)
LLAMA_TOKEN_NULL,
LLAMA_TOKEN_NULL,
100, // Hi
10, 12, // <assistant>
};
const std::map<size_t, size_t> skips = { { 2, 3 } };
const auto spans = delims.split(tokens, skips).spans;
assert_equals<size_t>(2, spans.size());
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
assert_equals<size_t>(0, spans[0].pos);
assert_equals<size_t>(6, spans[0].len);
assert_equals(COMMON_CHAT_ROLE_ASSISTANT, spans[1].role);
assert_equals<size_t>(6, spans[1].pos);
assert_equals<size_t>(2, spans[1].len);
}
// A delimiter sequence inside a skipped region is not matched
{
const llama_tokens tokens = {
10, 11, // <user>
10, 12, // skipped region that happens to contain delimiter tokens
100, // Hi
};
const std::map<size_t, size_t> skips = { { 2, 2 } };
const auto spans = delims.split(tokens, skips).spans;
assert_equals<size_t>(1, spans.size());
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
assert_equals<size_t>(0, spans[0].pos);
assert_equals<size_t>(5, spans[0].len);
assert_equals<std::string>("user", splits[2].role);
assert_equals<size_t>(28, splits[2].pos);
assert_equals<size_t>(11, splits[2].len);
assert_equals<std::string>("<|user|>Bye", prompt.substr(splits[2].pos, splits[2].len));
}
}
@@ -5932,7 +5857,7 @@ int main(int argc, char ** argv) {
{
test_msg_diffs_compute();
test_msgs_oaicompat_json_conversion();
test_msg_token_delimiters_split();
test_split_by_role();
test_tools_oaicompat_json_conversion();
test_convert_responses_to_chatcmpl();
test_developer_role_to_system_workaround();
+1 -1
View File
@@ -42,7 +42,6 @@
#define KEY_N_HEAD "clip.%s.attention.head_count"
#define KEY_N_HEAD_KV "clip.%s.attention.head_count_kv"
#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon"
#define KEY_FEATURE_LAYERS "clip.%s.feature_layer"
// vision-specific
#define KEY_VISION_PROJ_TYPE "clip.vision.projector_type" // for models with mixed modalities
@@ -55,6 +54,7 @@
#define KEY_PATCH_SIZE "clip.vision.patch_size"
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
#define KEY_IMAGE_STD "clip.vision.image_std"
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
#define KEY_PROJ_SAMPLE_QUERY_SIDE "clip.vision.projector.query_side"
#define KEY_PROJ_SAMPLE_WINDOW_SIDE "clip.vision.projector.window_side"
+3 -3
View File
@@ -91,7 +91,7 @@ struct clip_hparams {
float eps = 1e-6;
float rope_theta = 0.0;
std::vector<int32_t> feature_layers;
std::vector<int32_t> vision_feature_layer;
int32_t attn_window_size = 0;
int32_t n_wa_pattern = 0;
std::unordered_set<int32_t> wa_layer_indexes; // explicit layer indexes that use full attention (for irregular patterns like YoutuVL)
@@ -165,8 +165,8 @@ struct clip_hparams {
return false;
}
bool is_feature_layer(int32_t layer) const {
return std::find(feature_layers.begin(), feature_layers.end(), layer) != feature_layers.end();
bool is_vision_feature_layer(int32_t layer) const {
return std::find(vision_feature_layer.begin(), vision_feature_layer.end(), layer) != vision_feature_layer.end();
}
};
+10 -9
View File
@@ -1264,10 +1264,12 @@ struct clip_model_loader {
}
}
// Load the vision/audio feature layer indices if they are explicitly provided
// Load the vision feature layer indices if they are explicitly provided;
// if multiple vision feature layers are present, the values will be concatenated
// to form the final visual features.
// NOTE: gguf conversions should standardize the values of the vision feature layer to
// be non-negative, since we use -1 to mark values as unset here.
get_arr_int(string_format(KEY_FEATURE_LAYERS, prefix), hparams.feature_layers, false);
get_arr_int(KEY_FEATURE_LAYER, hparams.vision_feature_layer, false);
// model-specific params
switch (model.proj_type) {
@@ -1649,7 +1651,6 @@ struct clip_model_loader {
get_u32(KEY_A_PROJ_WINDOW_SIZE, hparams.audio_proj_window_size);
get_u32(KEY_A_PROJ_DOWNSAMPLE_RATE, hparams.audio_proj_downsample_rate);
get_u32(KEY_A_PROJ_HEAD_COUNT, hparams.audio_proj_head_count);
// NOTE: feature layers loaded above in common path
} break;
case PROJECTOR_TYPE_JANUS_PRO:
{
@@ -1662,11 +1663,11 @@ struct clip_model_loader {
hparams.image_resize_algo = RESIZE_ALGO_BICUBIC_PILLOW;
hparams.image_resize_pad = PAD_CEIL;
// NOTE: feature_layers loaded in common path as optional
get_arr_int(KEY_FEATURE_LAYER, hparams.vision_feature_layer);
get_arr_int(KEY_PROJ_SPATIAL_OFFSETS, hparams.proj_spatial_offsets);
if (hparams.feature_layers.size() != hparams.proj_spatial_offsets.size()) {
throw std::runtime_error(string_format("%s: feature_layers.size() %d != proj_spatial_offsets.size() %d",
hparams.feature_layers.size(), hparams.proj_spatial_offsets.size()));
if (hparams.vision_feature_layer.size() != hparams.proj_spatial_offsets.size()) {
throw std::runtime_error(string_format("%s: vision_feature_layer.size() %d != proj_spatial_offsets.size() %d",
hparams.vision_feature_layer.size(), hparams.proj_spatial_offsets.size()));
}
get_u32(KEY_PROJ_SAMPLE_QUERY_SIDE, hparams.downsample_query_side);
@@ -2739,7 +2740,7 @@ struct clip_model_loader {
model.image_newline = get_tensor(TN_IMAGE_NEWLINE);
// Load separate layerwise and spatial projector tensors
const auto projector_count = hparams.feature_layers.size();
const auto projector_count = hparams.vision_feature_layer.size();
model.qf_proj_blocks.resize(projector_count);
for (size_t bid = 0; bid < projector_count; ++bid) {
auto & b = model.qf_proj_blocks[bid];
@@ -4387,7 +4388,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, int n_threads, const clip_image_f32
// Stage 1b only uses block 0's permutations; future stages
// will upload all blocks.
for (size_t bid = 0; bid < hparams.feature_layers.size(); ++bid) {
for (size_t bid = 0; bid < hparams.vision_feature_layer.size(); ++bid) {
const std::string prefix = "g4v_blk" + std::to_string(bid) + "_";
upload(prefix + "win_idx", make_win_idx(image_side, window_side));
upload(prefix + "qwin_idx", make_win_idx(new_side, query_side));
+1 -35
View File
@@ -1,7 +1,5 @@
#include "models.h"
#include <algorithm>
ggml_cgraph * clip_graph_granite_speech::build() {
const int n_frames = img.nx();
const int context_size = hparams.audio_chunk_size;
@@ -13,10 +11,6 @@ ggml_cgraph * clip_graph_granite_speech::build() {
const int padded_len = num_blocks * context_size;
const int remainder = n_frames % context_size;
// Calculate projector input dimension based on feature layers
const int proj_input_dim = n_embd * (hparams.feature_layers.size() + 1);
const bool use_feature_concat = !hparams.feature_layers.empty();
ggml_tensor * attn_dists = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, context_size * context_size);
ggml_set_name(attn_dists, "attn_dists");
ggml_set_input(attn_dists);
@@ -37,15 +31,6 @@ ggml_cgraph * clip_graph_granite_speech::build() {
cur = ggml_add(ctx0, cur, model.inp_proj_b);
cb(cur, "inp_linear", -1);
// Capture layer 0 if requested (after input_linear)
ggml_tensor * concat_result = nullptr;
if (use_feature_concat) {
if (std::find(hparams.feature_layers.begin(), hparams.feature_layers.end(), 0) != hparams.feature_layers.end()) {
concat_result = cur;
cb(concat_result, "feature_layer_0", -1);
}
}
for (int il = 0; il < n_layer; il++) {
const auto & layer = model.layers[il];
auto * residual = cur;
@@ -183,18 +168,6 @@ ggml_cgraph * clip_graph_granite_speech::build() {
NORM_TYPE_NORMAL, eps, il);
cb(cur, "layer_out", il);
// Capture intermediate layer (il + 1) if requested
if (use_feature_concat) {
if (hparams.is_feature_layer(il + 1)) {
if (concat_result == nullptr) {
concat_result = cur;
} else {
concat_result = ggml_concat(ctx0, concat_result, cur, 0);
}
cb(concat_result, string_format("feature_layer_%d", il + 1).c_str(), il);
}
}
// CTC branch
if (il + 1 == ctc_layer) {
auto * mid = build_mm(model.ctc_out_w, cur);
@@ -207,13 +180,6 @@ ggml_cgraph * clip_graph_granite_speech::build() {
}
}
// Append final output to concatenated features if using feature concatenation
if (use_feature_concat && concat_result != nullptr) {
concat_result = ggml_concat(ctx0, concat_result, cur, 0);
cb(concat_result, "concat_final", -1);
cur = concat_result;
}
cb(cur, "encoder_out", -1);
// QFormer projector
@@ -231,7 +197,7 @@ ggml_cgraph * clip_graph_granite_speech::build() {
cur = ggml_pad(ctx0, cur, 0, padded_proj - n_frames, 0, 0);
}
ggml_tensor * enc_windows = ggml_reshape_3d(ctx0, cur, proj_input_dim, window_size, nblocks_proj);
ggml_tensor * enc_windows = ggml_reshape_3d(ctx0, cur, n_embd, window_size, nblocks_proj);
ggml_tensor * queries = build_norm(model.qf_proj_blocks[0].qf_proj_query,
model.qf_proj_blocks[0].qf_proj_norm_w, model.qf_proj_blocks[0].qf_proj_norm_b,
+2 -2
View File
@@ -304,14 +304,14 @@ ggml_cgraph * clip_graph_granite4_vision::build() {
}
// --- Stage 1b/1c: WindowQFormer blocks ---
const int projector_count = hparams.feature_layers.size();
const int projector_count = hparams.vision_feature_layer.size();
const float qformer_eps = 1e-12f;
ggml_tensor * mmproj = nullptr;
for (int bid = 0; bid < projector_count; ++bid) {
const auto & blk = model.qf_proj_blocks[bid];
int vlayer = hparams.feature_layers[bid];
int vlayer = hparams.vision_feature_layer[bid];
GGML_ASSERT(vlayer >= 0 && vlayer < n_layer);
ggml_tensor * h = layer_outs[vlayer];
+3 -3
View File
@@ -21,7 +21,7 @@ ggml_cgraph * clip_graph_llava::build() {
// If we set explicit vision feature layers, only go up to the deepest one
// NOTE: only used by granite-vision models for now
for (const auto & feature_layer : hparams.feature_layers) {
for (const auto & feature_layer : hparams.vision_feature_layer) {
if (feature_layer > deepest_feature_layer) {
deepest_feature_layer = feature_layer;
}
@@ -59,7 +59,7 @@ ggml_cgraph * clip_graph_llava::build() {
// If this is an embedding feature layer, save the output.
// NOTE: 0 index here refers to the input to the encoder.
if (hparams.is_feature_layer(il)) {
if (hparams.is_vision_feature_layer(il)) {
embedding_stack.push_back(cur);
}
@@ -134,7 +134,7 @@ ggml_cgraph * clip_graph_llava::build() {
// process vision feature layers (used by granite)
{
// final layer is a vision feature layer
if (hparams.is_feature_layer(max_feature_layer)) {
if (hparams.is_vision_feature_layer(max_feature_layer)) {
embedding_stack.push_back(inpL);
}
+88 -9
View File
@@ -518,14 +518,6 @@ size_t server_tokens::get_common_prefix(const server_tokens & b) const {
return max_idx; // all tokens are equal
}
common_chat_msg_spans server_tokens::find_message_spans(const common_chat_msg_delimiters & delims) const {
std::map<size_t, size_t> skips;
for (const auto & it : map_idx_to_media) {
skips[it.first] = mtmd_input_chunk_get_n_tokens(it.second.get());
}
return delims.split(tokens, skips);
}
bool server_tokens::validate(const struct llama_context * ctx) const {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
@@ -1112,7 +1104,15 @@ json oaicompat_chat_params_parse(
llama_params["chat_parser"] = chat_params.parser;
}
llama_params["message_delimiters"] = chat_params.message_delimiters.to_json();
llama_params["message_spans"] = json::array();
for (const auto & span : chat_params.message_spans) {
llama_params["message_spans"].push_back({
{ "role", span.role },
{ "pos", span.pos },
{ "len", span.len },
});
}
// Reasoning budget: pass parameters through to sampling layer
{
@@ -1583,3 +1583,82 @@ server_tokens format_prompt_rerank(
return result;
}
//
// threadpool
//
server_threadpool::~server_threadpool() {
{
std::lock_guard<std::mutex> lock(mtx);
stop = true;
}
cv.notify_all();
for (auto & t : threads) t.join();
}
void server_threadpool::init(int n) {
// the caller (main thread) participates as a worker, so spawn n-1 threads
const int n_workers = std::max(1, n) - 1;
for (int i = 0; i < n_workers; i++) {
threads.emplace_back([this]() { run_worker(); });
}
}
void server_threadpool::run_worker() {
while (true) {
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(mtx);
cv.wait(lock, [this]() { return stop || !tasks.empty(); });
if (stop && tasks.empty()) return;
task = std::move(tasks.front());
tasks.pop();
}
task();
{
std::lock_guard<std::mutex> lock(mtx);
pending--;
}
cv_done.notify_all();
}
}
void server_threadpool::enqueue(std::function<void()> fn) {
{
std::lock_guard<std::mutex> lock(mtx);
GGML_ASSERT(!stop);
tasks.push(std::move(fn));
pending++;
}
cv.notify_one();
}
void server_threadpool::wait_all() {
// the calling thread helps drain the queue until no tasks remain pending
while (true) {
std::function<void()> task;
{
std::lock_guard<std::mutex> lock(mtx);
if (pending == 0) {
return;
}
if (!tasks.empty()) {
task = std::move(tasks.front());
tasks.pop();
}
}
if (task) {
task();
{
std::lock_guard<std::mutex> lock(mtx);
pending--;
}
cv_done.notify_all();
} else {
// no task available right now, but some are still pending (being run by workers)
std::unique_lock<std::mutex> lock(mtx);
cv_done.wait(lock, [this]() { return pending == 0 || !tasks.empty(); });
}
}
}
+41 -3
View File
@@ -12,6 +12,11 @@
#include <string>
#include <vector>
#include <cinttypes>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <queue>
#include <functional>
using json = nlohmann::ordered_json;
@@ -218,9 +223,6 @@ public:
size_t get_common_prefix(const server_tokens & b) const;
// split the tokens into message spans, skipping over media chunks
common_chat_msg_spans find_message_spans(const common_chat_msg_delimiters & delims) const;
// make sure all text tokens are within the vocab range
bool validate(const struct llama_context * ctx) const;
@@ -373,3 +375,39 @@ server_tokens format_prompt_rerank(
mtmd_context * mctx,
const std::string & query,
const std::string & doc);
//
// threadpool utils
// to be used for multi-threaded sampling
//
// the main thread participates as one of the pool's workers, so init(n)
// only spawns n-1 background threads (the caller is the nth)
struct server_threadpool {
std::vector<std::thread> threads;
std::queue<std::function<void()>> tasks;
std::mutex mtx;
std::condition_variable cv;
std::condition_variable cv_done;
int pending = 0;
bool stop = false;
~server_threadpool();
void init(int n);
template<typename T>
void run_all(std::vector<T> & tasks, std::function<void(T&)> handler) {
for (auto & item : tasks) {
enqueue([&handler, &item]() {
handler(item);
});
}
// the calling thread runs tasks too, until all are done
wait_all();
}
private:
void enqueue(std::function<void()> fn);
void wait_all();
void run_worker();
};
+143 -32
View File
@@ -89,9 +89,7 @@ struct server_batch {
}
~server_batch() {
if (batch.token != nullptr) {
llama_batch_free(batch);
}
llama_batch_free(batch);
}
void init(int32_t n_tokens_alloc) {
@@ -868,6 +866,9 @@ public:
// note: chat_params must not be refreshed upon existing sleeping state
server_chat_params chat_params;
// threadpool for parallel sampling
server_threadpool threadpool;
server_state_callback_t callback_state = [](server_state, json) -> void {};
server_context_impl() {
@@ -1217,10 +1218,6 @@ private:
cparams.ctx_other = ctx_tgt;
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
if (ctx_dft == nullptr) {
SRV_ERR("%s", "failed to create draft context\n");
return false;
}
params_base.speculative.draft.ctx_tgt = ctx_tgt;
params_base.speculative.draft.ctx_dft = ctx_dft.get();
@@ -1463,6 +1460,16 @@ private:
metrics.init();
// initialize threadpool
{
int threadpool_size = params_base.sampling_n_threads;
if (threadpool_size <= 0) {
threadpool_size = params_base.cpuparams.n_threads;
}
SRV_DBG("%s: initializing threadpool, size = %d\n", __func__, threadpool_size);
threadpool.init(threadpool_size);
}
if (params_base.cache_idle_slots) {
if (params_base.cache_ram_mib == 0) {
SRV_WRN("%s", "--cache-idle-slots requires --cache-ram, disabling\n");
@@ -3442,8 +3449,8 @@ private:
has_mtmd = true;
}
const auto & spans = slot.task->params.message_spans;
const auto last_user_pos = spans.last_user_message_pos();
const int32_t n_before_user = slot.task->params.n_before_user;
const bool n_before_user_known = n_before_user > 0;
// add prompt tokens for processing in the current batch
while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.size() < n_batch) {
@@ -3472,8 +3479,10 @@ private:
slot.n_prompt_tokens_processed++;
// stop the prompt batch exactly before a user message
if (spans.is_user_start(slot.prompt.n_tokens())) {
// stop the prompt batch exactly before the latest user input, so a checkpoint
// can be created after the previous messages
if (n_before_user_known &&
slot.prompt.n_tokens() == n_before_user) {
break;
}
@@ -3502,13 +3511,8 @@ private:
// the number of tokens added to the batch for the current slot
const auto n_tokens_cur = batch.size() - n_tokens_prev;
const auto n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur;
const bool near_prompt_end = slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch;
const bool is_user_start = spans.is_user_start(n_tokens_start);
const bool is_last_user_message = n_tokens_start == last_user_pos;
// entire prompt has been processed
if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
slot.state = SLOT_STATE_DONE_PROMPT;
@@ -3523,9 +3527,8 @@ private:
slot.init_sampler();
} else {
// skip ordinary mid-prompt checkpoints, unless the batch starts a user
// message or we are near the end of the prompt
if (!is_user_start && !near_prompt_end) {
// skip ordinary mid-prompt checkpoints
if (!n_before_user_known && !near_prompt_end) {
do_checkpoint = false;
}
}
@@ -3533,6 +3536,29 @@ private:
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id);
// checkpoints are created before the current batch is decoded, so
// their token position is the batch start rather than the prompt end
const int32_t n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur;
{
const bool is_on_user =
n_before_user_known &&
n_tokens_start == n_before_user;
const bool is_after_user =
n_before_user_known &&
n_tokens_start > n_before_user;
const bool is_allowed =
!n_before_user_known ||
is_on_user ||
(is_after_user && near_prompt_end);
if (do_checkpoint && !is_allowed) {
do_checkpoint = false;
}
}
// nothing to checkpoint yet
// TODO: is this check needed?
if (do_checkpoint && pos_min < 0) {
@@ -3542,8 +3568,8 @@ private:
// do not checkpoint after mtmd chunks
do_checkpoint = do_checkpoint && !has_mtmd;
// no need to create checkpoints that are too close together, unless it's the last user message
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || is_last_user_message || n_tokens_start > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step);
// no need to create checkpoints that are too close together
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || n_tokens_start > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step);
SLT_DBG(slot, "main/do_checkpoint = %s, pos_min = %d, pos_max = %d\n", do_checkpoint ? "yes" : "no", pos_min, pos_max);
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
@@ -3686,6 +3712,12 @@ private:
return true;
}
struct sampling_task {
server_slot * slot = nullptr;
int32_t tok_idx = 0;
llama_token sampled_id = LLAMA_TOKEN_NULL; // result
};
void post_decode(int32_t n_batch_tokens, int32_t off, llama_batch & batch_view) {
// for checking if a given batch index is inside batch_view
auto is_inside_view = [&](int32_t idx) {
@@ -3707,7 +3739,13 @@ private:
slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end();
};
std::vector<sampling_task> smpl_tasks;
smpl_tasks.resize(slots.size());
bool need_sampling = false;
iterate(slots, [&](server_slot & slot) {
auto & smpl_task = smpl_tasks[slot.id];
// optionally send prompt processing progress
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
if (slot.task->params.stream && slot.task->params.return_progress) {
@@ -3752,15 +3790,35 @@ private:
return; // sample using speculative decoding
}
// shifted according to the current sub-batch
const int tok_idx = slot.i_batch - off;
// otherwise, we must sample the next token
// also shift batch idx according to the current sub-batch
smpl_task.slot = &slot;
smpl_task.tok_idx = slot.i_batch - off;
need_sampling = true;
});
llama_token id;
{
scoped_timer timer(t_sampl, n_sampl);
id = common_sampler_sample(slot.smpl.get(), slot.ctx_tgt, tok_idx);
// run multiple sampling tasks in parallel
GGML_ASSERT(smpl_tasks.size() == slots.size());
if (need_sampling) {
llama_synchronize(ctx_tgt);
threadpool.run_all<sampling_task>(smpl_tasks, [](sampling_task & task) {
if (task.slot) {
task.sampled_id = common_sampler_sample(task.slot->smpl.get(),
task.slot->ctx_tgt, task.tok_idx);
}
});
}
iterate(slots, [&](server_slot & slot) {
auto & smpl_task = smpl_tasks[slot.id];
if (!smpl_task.slot) {
return;
}
auto tok_idx = smpl_task.tok_idx;
auto id = smpl_task.sampled_id;
slot.i_batch = -1;
common_sampler_accept(slot.smpl.get(), id, true);
@@ -4042,6 +4100,54 @@ void server_context::set_state_callback(server_state_callback_t callback) {
});
}
// compute the number of tokens before the last user message in the prompt
static int32_t prompt_get_n_before_user(
const json & message_spans,
const std::string & prompt,
const std::vector<raw_buffer> & files,
const llama_vocab * vocab,
mtmd_context * mctx) {
int32_t result = -1;
int32_t byte_pos = -1;
for (const auto & span : message_spans) {
const std::string role = json_value(span, "role", std::string());
if (role == "user") {
byte_pos = json_value(span, "pos", -1);
}
}
if (byte_pos >= 0) {
GGML_ASSERT((size_t) byte_pos <= prompt.size());
const std::string prefix = prompt.substr(0, (size_t) byte_pos);
const std::string marker = get_media_marker();
size_t n_prefix_media = 0;
for (size_t pos = 0; (pos = prefix.find(marker, pos)) != std::string::npos; pos += marker.size()) {
n_prefix_media++;
}
GGML_ASSERT(n_prefix_media <= files.size());
if (mctx != nullptr && n_prefix_media > 0) {
// TODO: this makes a copy - avoid it
std::vector<raw_buffer> prefix_files(files.begin(), files.begin() + n_prefix_media);
result = (int32_t) process_mtmd_prompt(mctx, prefix, prefix_files).size();
} else {
result = (int32_t) tokenize_input_prompts(vocab, nullptr, prefix, true, true)[0].size();
}
SRV_TRC("message_spans: last user message: byte_pos=%d, media=%zu, n_before_user=%d\n",
byte_pos, n_prefix_media, result);
}
return result;
}
//
// server_routes
//
@@ -4089,10 +4195,6 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
// tasks.reserve(inputs.size()); // TODO: this is inaccurate due to child tasks
// message delimiters for checkpointing
auto delimiters = common_chat_msg_delimiters_parse(json_value(data, "message_delimiters", json::array()));
delimiters.tokenize(ctx_server.vocab);
for (size_t i = 0; i < inputs.size(); i++) {
server_task task = server_task(type);
@@ -4106,7 +4208,16 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
meta->logit_bias_eog,
data);
task.params.message_spans = task.tokens.find_message_spans(delimiters);
const auto message_spans = json_value(data, "message_spans", json::array());
if (prompt.is_string() && message_spans.is_array()) {
task.params.n_before_user =
prompt_get_n_before_user(
message_spans,
prompt.get<std::string>(),
files,
ctx_server.vocab,
ctx_server.mctx);
}
task.id_slot = json_value(data, "id_slot", -1);
+2 -4
View File
@@ -224,7 +224,7 @@ void server_model_meta::update_caps() {
});
params.offline = true;
// params.skip_download = true; // TODO: ideally, we should validate the model here, but it takes too much time
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, {});
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER);
if (params.mmproj.path.empty()) {
multimodal = { false, false };
} else {
@@ -1393,9 +1393,7 @@ struct server_download_state : public common_download_callback {
bool run(common_params & params) {
try {
common_params_handle_models_params p;
p.callback = this;
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, p);
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, this);
is_ok = true;
} catch (const std::exception & e) {
auto model_name = params.model.get_name();
+3 -6
View File
@@ -591,11 +591,10 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp() {
for (const common_chat_tool_call & tool_call : oaicompat_msg.tool_calls) {
output.push_back(json {
{"id", "fc_" + tool_call.id},
{"type", "function_call"},
{"status", "completed"},
{"arguments", tool_call.arguments},
{"call_id", "call_" + tool_call.id},
{"call_id", "fc_" + tool_call.id},
{"name", tool_call.name},
});
}
@@ -691,11 +690,10 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp_stream() {
for (const common_chat_tool_call & tool_call : oaicompat_msg.tool_calls) {
const json output_item = {
{"id", "fc_" + tool_call.id},
{"type", "function_call"},
{"status", "completed"},
{"arguments", tool_call.arguments},
{"call_id", "call_" + tool_call.id},
{"call_id", "fc_" + tool_call.id},
{"name", tool_call.name}
};
server_sent_events.push_back(json {
@@ -1279,9 +1277,8 @@ json server_task_result_cmpl_partial::to_json_oaicompat_resp() {
{"data", json {
{"type", "response.output_item.added"},
{"item", json {
{"id", "fc_" + diff.tool_call_delta.id},
{"arguments", ""},
{"call_id", "call_" + diff.tool_call_delta.id},
{"call_id", "fc_" + diff.tool_call_delta.id},
{"name", diff.tool_call_delta.name},
{"type", "function_call"},
{"status", "in_progress"},
+3 -3
View File
@@ -62,6 +62,9 @@ struct task_params {
int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled)
// number of prompt tokens before the latest user message
int32_t n_before_user = -1;
int64_t t_max_prompt_ms = -1; // TODO: implement
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
@@ -89,9 +92,6 @@ struct task_params {
// per-request parameters for chat parsing
common_chat_parser_params chat_parser_params;
// message spans for checkpointing
common_chat_msg_spans message_spans;
// Embeddings
int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
+1 -12
View File
@@ -89,17 +89,6 @@ int llama_server(int argc, char ** argv) {
llama_backend_init();
llama_numa_init(params.numa);
// note: router mode also accepts -hf remote-preset, so we need to check that first
if (!params.model.hf_repo.empty()) {
try {
common_params_handle_models_params handle_params;
handle_params.preset_only = true;
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, handle_params);
} catch (const std::exception & e) {
// ignored for now
}
}
// router server never loads a model and must not touch the GPU
const bool is_router_server = params.model.path.empty()
&& params.model.hf_repo.empty();
@@ -274,7 +263,7 @@ int llama_server(int argc, char ** argv) {
return child.run_download(params);
} else if (!is_router_server) {
// single-model mode (NOT spawned by router)
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, {});
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER);
}
//
-19
View File
@@ -256,25 +256,6 @@ def test_router_reload_models():
os.remove(preset_path)
def test_router_remote_preset():
global server
server.model_hf_repo = "ggml-org/test-preset-ci"
server.model_hf_file = None
server.offline = False
server.start()
# Should see preset models in GET /models
res = server.make_request("GET", "/models")
assert res.status_code == 200
ids = {item["id"] for item in res.body.get("data", [])}
assert "tinygemma3-preset" in ids
assert "stories260K-test" in ids
# Should be able to load a preset model
model_id = "tinygemma3-preset"
_load_model_and_wait(model_id)
MODEL_DOWNLOAD_ID = "ggml-org/test-model-router-download:F16"
MODEL_DOWNLOAD_TIMEOUT = 30
+2 -1
View File
@@ -28,9 +28,10 @@ vite.config.ts.timestamp-*
# PWA Artifacts
apple-splash-*.png
apple-touch-icon-*.png
favicon.ico
favicon-dark.ico
maskable-icon-*.png
pwa-*.png
static/favicon*
# Storybook
*storybook.log
+7 -7
View File
@@ -35,7 +35,7 @@
"bits-ui": "2.18.1",
"clsx": "2.1.1",
"dexie": "4.4.3",
"dompurify": "3.4.11",
"dompurify": "3.4.5",
"eslint": "9.39.4",
"eslint-config-prettier": "10.1.8",
"eslint-plugin-storybook": "10.4.2",
@@ -8653,9 +8653,9 @@
"peer": true
},
"node_modules/dompurify": {
"version": "3.4.11",
"resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.4.11.tgz",
"integrity": "sha512-zhlUV12GsaRzMsf9q5M254YhA4+VuF0fG+QFqu6aYpoGlKtz+w8//jBcGVYBgQkR5GHjUomejY84AV+/uPbWdw==",
"version": "3.4.5",
"resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.4.5.tgz",
"integrity": "sha512-OrwIBKsdNSVEeubdJ1HBv/wNENRM9ytAVCv7YXt//A3vPdVMNuACRqK9mXCGCBW2ln7BT/A4X0jXHo2Gu89miA==",
"dev": true,
"license": "(MPL-2.0 OR Apache-2.0)",
"optionalDependencies": {
@@ -10226,9 +10226,9 @@
}
},
"node_modules/hono": {
"version": "4.12.26",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.26.tgz",
"integrity": "sha512-uyZtpnYxM9CmQ7QsQknM4zN8EftNqhON1qYeIKM0Se67CCEe2c44xyGURwB0axX2fBDu1dqHrHAc1hmNT8ITkw==",
"version": "4.12.23",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.23.tgz",
"integrity": "sha512-eIaZ9qDgu7XV0pxOCrg7/WhnQ6Ivm22UcxhXx/A3dcbqbbYgBEkc6e/J/s7j2tS96zoB0S9VBdLwQNCWwUo4LA==",
"dev": true,
"license": "MIT",
"engines": {
+1 -1
View File
@@ -54,7 +54,7 @@
"bits-ui": "2.18.1",
"clsx": "2.1.1",
"dexie": "4.4.3",
"dompurify": "3.4.11",
"dompurify": "3.4.5",
"eslint": "9.39.4",
"eslint-config-prettier": "10.1.8",
"eslint-plugin-storybook": "10.4.2",
+1 -8
View File
@@ -1,10 +1,4 @@
import { defineConfig } from '@vite-pwa/assets-generator/config';
import { FAVICON_COLORS, PWA_ASSET_GENERATOR } from './src/lib/constants/pwa';
import { writeThemeFavicons } from './scripts/favicon-colorize';
writeThemeFavicons(FAVICON_COLORS.LIGHT, FAVICON_COLORS.DARK, {
padding: PWA_ASSET_GENERATOR.FAVICON_PADDING
});
export default defineConfig({
headLinkOptions: {
@@ -13,8 +7,7 @@ export default defineConfig({
preset: {
transparent: {
sizes: [],
favicons: [[48, 'favicon-dark.ico']],
padding: PWA_ASSET_GENERATOR.FAVICON_PADDING
favicons: [[48, 'favicon-dark.ico']]
},
maskable: {
sizes: []
+2 -19
View File
@@ -5,32 +5,15 @@ import {
} from '@vite-pwa/assets-generator/config';
import { readFileSync } from 'node:fs';
import { resolve } from 'node:path';
import {
THEME_COLORS,
PWA_GENERATOR_DEVICES,
PWA_ASSET_GENERATOR,
FAVICON_COLORS
} from './src/lib/constants/pwa';
import { THEME_COLORS, PWA_GENERATOR_DEVICES, PWA_ASSET_GENERATOR } from './src/lib/constants/pwa';
import { SplashOrientation } from './src/lib/enums/splash.enums';
import { writeThemeFavicons } from './scripts/favicon-colorize';
writeThemeFavicons(FAVICON_COLORS.LIGHT, FAVICON_COLORS.DARK, {
padding: PWA_ASSET_GENERATOR.FAVICON_PADDING
});
export default defineConfig({
headLinkOptions: {
preset: PWA_ASSET_GENERATOR.LINK_PRESET
},
preset: combinePresetAndAppleSplashScreens(
{
...minimal2023Preset,
// tiny margin so favicon.ico / pwa-*.png breathe inside the canvas
transparent: {
...minimal2023Preset.transparent,
padding: PWA_ASSET_GENERATOR.FAVICON_PADDING
}
},
minimal2023Preset,
{
padding: PWA_ASSET_GENERATOR.SPLASH_PADDING,
resizeOptions: {
-107
View File
@@ -1,107 +0,0 @@
import { mkdirSync, readFileSync, writeFileSync } from 'node:fs';
import { dirname, resolve } from 'node:path';
import { fileURLToPath } from 'node:url';
const HERE = dirname(fileURLToPath(import.meta.url));
const PROJECT_ROOT = resolve(HERE, '..');
const DEFAULT_LOGO = resolve(PROJECT_ROOT, 'src/lib/assets/logo.svg');
const DEFAULT_OUT_DIR = resolve(PROJECT_ROOT, 'static');
const DEFAULT_OUT_LIGHT = resolve(DEFAULT_OUT_DIR, 'favicon.svg');
const DEFAULT_OUT_DARK = resolve(DEFAULT_OUT_DIR, 'favicon-dark.svg');
const CURRENT_COLOR = 'currentColor';
export interface ColorizedFavicon {
light: string;
dark: string;
}
export interface WriteThemeFaviconsOptions {
sourcePath?: string;
lightOutPath?: string;
darkOutPath?: string;
/**
* Fraction of the icon (0..1) to leave as an even margin on each side.
* Applied by wrapping the inner content in a `<g transform="...">` so the
* source `src/lib/assets/logo.svg` is not modified. Pass 0 to disable.
*/
padding?: number;
}
/**
* Replace every `currentColor` occurrence in the SVG with the given color.
* Pure: no filesystem access, so it is straightforward to unit-test.
*/
export function colorizeFaviconSvg(
svg: string,
lightColor: string,
darkColor: string
): ColorizedFavicon {
return {
light: svg.replaceAll(CURRENT_COLOR, lightColor),
dark: svg.replaceAll(CURRENT_COLOR, darkColor)
};
}
/**
* Shrink the inner SVG content uniformly and re-center it so `padding` (a
* 0..1 fraction) is reserved as equal margin on each side. Returns the input
* unchanged for non-positive padding, missing/invalid `viewBox`, or unexpected
* markup so the caller always gets a renderable SVG.
*/
export function padFaviconSvg(svg: string, padding: number): string {
if (!(padding > 0) || padding >= 1) return svg;
const viewBoxMatch = svg.match(/viewBox\s*=\s*["']([^"']+)["']/i);
if (!viewBoxMatch) return svg;
const parts = viewBoxMatch[1]
.trim()
.split(/[\s,]+/)
.map(Number);
if (parts.length !== 4 || parts.some((n) => !Number.isFinite(n))) return svg;
const [, , width, height] = parts;
if (width <= 0 || height <= 0) return svg;
const scale = 1 - padding;
const translateX = (padding * width) / 2;
const translateY = (padding * height) / 2;
const openTagStart = svg.search(/<svg\b/i);
if (openTagStart === -1) return svg;
const openTagEnd = svg.indexOf('>', openTagStart);
if (openTagEnd === -1) return svg;
const closeStart = svg.lastIndexOf('</svg');
if (closeStart === -1 || closeStart <= openTagEnd) return svg;
const openTag = svg.slice(0, openTagEnd + 1);
const inner = svg.slice(openTagEnd + 1, closeStart);
const closeTag = svg.slice(closeStart);
const group = `<g transform="translate(${translateX} ${translateY}) scale(${scale})">`;
return `${openTag}${group}${inner}</g>${closeTag}`;
}
/**
* Read `src/lib/assets/logo.svg`, colorize it for both themes, and write
* the results to the static directory so the PWA asset generator can consume
* them. Paths can be overridden for tests.
*/
export function writeThemeFavicons(
lightColor: string,
darkColor: string,
{
sourcePath = DEFAULT_LOGO,
lightOutPath = DEFAULT_OUT_LIGHT,
darkOutPath = DEFAULT_OUT_DARK,
padding = 0
}: WriteThemeFaviconsOptions = {}
): void {
const source = readFileSync(sourcePath, 'utf-8');
const { light, dark } = colorizeFaviconSvg(source, lightColor, darkColor);
mkdirSync(dirname(lightOutPath), { recursive: true });
writeFileSync(lightOutPath, padFaviconSvg(light, padding));
writeFileSync(darkOutPath, padFaviconSvg(dark, padding));
}
+1 -6
View File
@@ -48,7 +48,6 @@
--chat-form-area-height: 8rem;
--chat-form-area-offset: 2rem;
--chat-form-padding-top: 6rem;
--max-message-height: max(24rem, min(80dvh, calc(100dvh - var(--chat-form-area-height) - 12rem)));
}
@@ -56,7 +55,6 @@
:root {
--chat-form-area-height: 24rem;
--chat-form-area-offset: 12rem;
--chat-form-padding-top: 6rem;
}
}
@@ -143,6 +141,7 @@
@apply bg-background text-foreground;
scrollbar-width: thin;
scrollbar-gutter: stable;
overflow: hidden; /* Added due to Mermaid rendering somehow causing the double scrollbar */
}
/* Global scrollbar styling - visible only on hover */
@@ -194,7 +193,3 @@
scrollbar-width: none;
}
}
.mermaidTooltip {
display: none !important;
}
@@ -10,9 +10,9 @@ import { isElementInViewport } from '$lib/utils/viewport';
*/
export function fadeInView(
node: HTMLElement,
options: { duration?: number; y?: number; delay?: number; skipIfVisible?: boolean } = {}
options: { duration?: number; y?: number; skipIfVisible?: boolean } = {}
) {
const { duration = 300, y = 0, delay = 0, skipIfVisible = false } = options;
const { duration = 300, y = 0, skipIfVisible = false } = options;
if (skipIfVisible && isElementInViewport(node)) {
return;
@@ -27,12 +27,10 @@ export function fadeInView(
(entries) => {
for (const entry of entries) {
if (entry.isIntersecting) {
setTimeout(() => {
requestAnimationFrame(() => {
node.style.opacity = '1';
node.style.transform = 'translateY(0)';
});
}, delay);
requestAnimationFrame(() => {
node.style.opacity = '1';
node.style.transform = 'translateY(0)';
});
observer.disconnect();
}
}
-7
View File
@@ -1,7 +0,0 @@
<svg width="512" height="512" viewBox="0 0 512 512" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M244.95 8C215.233 8 187.774 23.8591 172.923 49.5999L95.6009 183.625C60.2162 244.959 104.481 321.6 175.29 321.6H208L316.977 132.708C348.959 77.2719 308.95 8 244.95 8ZM208 321.6H351.947C415.982 321.6 456.013 390.91 424.013 446.377C409.155 472.132 381.681 488 351.947 488H271.29C200.481 488 156.216 411.359 191.601 350.026L208 321.6Z" fill="currentColor"/>
<path d="M208 321.6H16L106.462 164.8L208 321.6Z" fill="currentColor"/>
<path d="M388.923 8L208 321.6L253.6 8H388.923Z" fill="currentColor"/>
<path d="M304 488H112L202.462 331.2L304 488Z" fill="currentColor"/>
<path d="M496 321.6H208L419.399 454.4L496 321.6Z" fill="currentColor"/>
</svg>

Before

Width:  |  Height:  |  Size: 771 B

@@ -8,13 +8,12 @@
ariaLabel?: string;
class?: string;
disabled?: boolean;
href?: string;
icon: Component;
iconSize?: string;
onclick?: (e?: MouseEvent) => void;
onclick: (e?: MouseEvent) => void;
size?: ButtonSize;
stopPropagationOnClick?: boolean;
tooltip?: string;
tooltip: string;
variant?: ButtonVariant;
tooltipSide?: TooltipSide;
}
@@ -23,7 +22,6 @@
icon,
tooltip,
variant = 'ghost',
href = '',
size = 'sm',
class: className = '',
disabled = false,
@@ -33,49 +31,34 @@
onclick,
ariaLabel
}: Props = $props();
let innerWidth = $state(0);
const showTooltip = $derived(!!tooltip && innerWidth > 768);
</script>
{#snippet button(props = {})}
<Button
{...props}
{href}
{variant}
{size}
{disabled}
onclick={(e: MouseEvent) => {
if (stopPropagationOnClick) e.stopPropagation();
<Tooltip.Root>
<Tooltip.Trigger>
<!-- prevent another nested button element -->
{#snippet child({ props })}
<Button
{...props}
{variant}
{size}
{disabled}
onclick={(e: MouseEvent) => {
if (stopPropagationOnClick) e.stopPropagation();
onclick?.(e);
}}
class="h-6 w-6 p-0 {className} flex hover:bg-transparent data-[state=open]:bg-transparent!"
aria-label={ariaLabel || tooltip}
>
{#if icon}
{@const IconComponent = icon}
onclick?.(e);
}}
class="h-6 w-6 p-0 {className} flex hover:bg-transparent data-[state=open]:bg-transparent!"
aria-label={ariaLabel || tooltip}
>
{#if icon}
{@const IconComponent = icon}
<IconComponent class={iconSize} />
{/if}
</Button>
{/snippet}
</Tooltip.Trigger>
<IconComponent class={iconSize} />
{/if}
</Button>
{/snippet}
{#if showTooltip}
<Tooltip.Root>
<Tooltip.Trigger>
<!-- prevent another nested button element -->
{#snippet child({ props })}
{@render button(props)}
{/snippet}
</Tooltip.Trigger>
<Tooltip.Content side={tooltipSide}>
<p>{tooltip}</p>
</Tooltip.Content>
</Tooltip.Root>
{:else}
{@render button({ href })}
{/if}
<svelte:window bind:innerWidth />
<Tooltip.Content side={tooltipSide}>
<p>{tooltip}</p>
</Tooltip.Content>
</Tooltip.Root>
@@ -494,7 +494,7 @@
/>
<div
class="{INPUT_CLASSES} overflow-hidden rounded-4xl md:rounded-3xl backdrop-blur-md {disabled
class="{INPUT_CLASSES} overflow-hidden rounded-3xl backdrop-blur-md {disabled
? 'cursor-not-allowed opacity-60'
: ''}"
data-slot="input-area"
@@ -510,7 +510,7 @@
/>
<div
class="flex-column relative min-h-12 items-center rounded-4xl md:rounded-3xl py-2 pb-2.25 shadow-sm transition-all focus-within:shadow-md md:py-3!"
class="flex-column relative min-h-[48px] items-center rounded-3xl py-2 pb-2.25 shadow-sm transition-all focus-within:shadow-md md:!py-3"
onpaste={handlePaste}
>
<ChatFormTextarea
@@ -15,7 +15,7 @@
<Tooltip.Root>
<Tooltip.Trigger class="w-full">
<Button
class="file-upload-button md:h-8 md:w-8 h-9 w-9 rounded-full p-0"
class="file-upload-button h-8 w-8 rounded-full p-0"
{disabled}
{onclick}
variant="secondary"
@@ -15,7 +15,6 @@
import { McpLogo } from '$lib/components/app';
import { PencilRuler, ChevronDown, ChevronRight } from '@lucide/svelte';
import { HealthCheckStatus } from '$lib/enums';
import { AttachmentAction } from '$lib/enums/attachment.enums';
interface Props {
class?: string;
@@ -271,22 +270,14 @@
</Collapsible.Root>
{/if}
<button
type="button"
class={sheetItemClass}
onclick={() => attachmentMenu.callbacks[AttachmentAction.SYSTEM_PROMPT_CLICK]()}
>
<button type="button" class={sheetItemClass} onclick={onSystemPromptClick}>
<MessageSquare class="h-4 w-4 shrink-0" />
<span>System Message</span>
</button>
{#if hasMcpPromptsSupport}
<button
type="button"
class={sheetItemClass}
onclick={() => attachmentMenu.callbacks[AttachmentAction.MCP_PROMPT_CLICK]()}
>
<button type="button" class={sheetItemClass} onclick={onMcpPromptClick}>
<Zap class="h-4 w-4 shrink-0" />
<span>MCP Prompt</span>
@@ -294,11 +285,7 @@
{/if}
{#if hasMcpResourcesSupport}
<button
type="button"
class={sheetItemClass}
onclick={() => attachmentMenu.callbacks[AttachmentAction.MCP_RESOURCES_CLICK]()}
>
<button type="button" class={sheetItemClass} onclick={onMcpResourcesClick}>
<FolderOpen class="h-4 w-4 shrink-0" />
<span>MCP Resources</span>
@@ -42,7 +42,6 @@
{hasMcpPromptsSupport}
{hasMcpResourcesSupport}
{onFileUpload}
{onSystemPromptClick}
{onMcpPromptClick}
{onMcpResourcesClick}
>
@@ -20,7 +20,7 @@
type="submit"
disabled={isDisabled}
class={[
'md:h-8 md:w-8 h-9 w-9 rounded-full p-0',
'h-8 w-8 rounded-full p-0',
showErrorState &&
'bg-red-400/10 text-red-400 hover:bg-red-400/20 hover:text-red-400 disabled:opacity-100'
]}
@@ -1,5 +1,4 @@
<script lang="ts">
import { isMobile } from '$lib/stores/viewport.svelte';
import { autoResizeTextarea } from '$lib/utils';
import { onMount } from 'svelte';
@@ -38,9 +37,7 @@
}
export function focus() {
if (isMobile.current) return;
textareaElement?.focus({ preventScroll: true });
textareaElement?.focus();
}
export function resetHeight() {
@@ -231,7 +231,7 @@
editedContent = message.content;
}
textareaElement?.focus({ preventScroll: true });
textareaElement?.focus();
editedExtras = message.extra ? [...message.extra] : [];
editedUploadedFiles = [];
@@ -324,7 +324,7 @@
}
</script>
<div use:fadeInView class="chat-message">
<div use:fadeInView>
{#if message.role === MessageRole.SYSTEM}
<ChatMessageSystem
bind:textareaElement
@@ -180,9 +180,6 @@
let displayedModel = $derived(message.model ?? null);
// model being switched to while it loads, so the selector bar tracks it
let pendingModel = $state<string | null>(null);
let isCurrentlyLoading = $derived(isLoading());
let isStreaming = $derived(isChatStreaming());
let hasNoContent = $derived(!message?.content?.trim());
@@ -210,42 +207,6 @@
isLastAssistantMessage
);
let assistantEl: HTMLDivElement | undefined = $state();
let lastUserMessageHeight = $state(0);
let assistantMarginTop = $state(0);
$effect(() => {
if (!assistantEl) return;
assistantMarginTop = Math.round(parseFloat(getComputedStyle(assistantEl).marginTop));
const chatMessageEl = assistantEl.closest('.chat-message');
const previousChatMessage = chatMessageEl?.previousElementSibling;
const userMessageEl = previousChatMessage?.querySelector(
'.chat-message-user'
) as HTMLElement | null;
if (!userMessageEl) {
lastUserMessageHeight = 0;
return;
}
const updateHeight = () => {
const rect = userMessageEl.getBoundingClientRect();
const marginTop = Math.round(parseFloat(getComputedStyle(userMessageEl).marginTop));
lastUserMessageHeight = Math.round(rect.height + marginTop);
};
updateHeight();
const resizeObserver = new ResizeObserver(updateHeight);
resizeObserver.observe(userMessageEl);
return () => {
resizeObserver.disconnect();
};
});
function handleCopyModel() {
void copyToClipboard(displayedModel ?? '');
}
@@ -258,17 +219,12 @@
</script>
<div
bind:this={assistantEl}
class="chat-message-assistant text-md group w-full leading-7.5 {className}"
style:--last-user-message-height={lastUserMessageHeight > 0
? `${lastUserMessageHeight}px`
: undefined}
style:--assistant-margin-top={assistantMarginTop > 0 ? `${assistantMarginTop}px` : undefined}
class="text-md group w-full leading-7.5 {className}"
role="group"
aria-label="Assistant message with actions"
>
{#if showProcessingInfoTop}
<div class="mt-6 w-full max-w-3xl" in:fade>
<div class="mt-6 w-full max-w-[48rem]" in:fade>
<div class="processing-container">
<span class="processing-text">
{modelLoadingText ??
@@ -301,7 +257,7 @@
{/if}
{#if showProcessingInfoBottom}
<div class="mt-4 w-full max-w-3xl" in:fade>
<div class="mt-4 w-full max-w-[48rem]" in:fade>
<div class="processing-container">
<span class="processing-text">
{modelLoadingText ??
@@ -321,19 +277,13 @@
>
{#if isRouter}
<ModelsSelectorDropdown
currentModel={pendingModel ?? displayedModel}
currentModel={displayedModel}
disabled={isLoading()}
onModelChange={async (modelId: string, modelName: string) => {
const status = modelsStore.getModelStatus(modelId);
if (status !== ServerModelStatus.LOADED) {
pendingModel = modelId;
try {
await modelsStore.loadModel(modelId);
} finally {
pendingModel = null;
}
await modelsStore.loadModel(modelId);
}
onRegenerate(modelName);
@@ -401,23 +351,6 @@
</div>
<style>
:global(.chat-message):last-child .chat-message-assistant {
--assistant-min-height-offset: calc(
var(--last-user-message-height, 19rem) + var(--chat-form-height, 6rem) +
var(--chat-form-bottom-position, 0.5rem) + var(--chat-form-padding-top, 6rem) +
var(--assistant-margin-top, 3rem)
);
min-height: calc(100dvh - var(--assistant-min-height-offset));
@media (width > 768px) {
--assistant-min-height-offset: calc(
var(--last-user-message-height, 18rem) + var(--chat-form-height, 6rem) +
var(--chat-form-bottom-position, 1rem) + var(--chat-form-padding-top, 6rem) +
var(--assistant-margin-top, 3rem)
);
}
}
.processing-container {
display: flex;
flex-direction: column;
@@ -48,7 +48,7 @@
<div
aria-label="User message with actions"
class="chat-message-user group flex flex-col items-end gap-3 md:gap-2 {className}"
class="group flex flex-col items-end gap-3 md:gap-2 {className}"
role="group"
>
{#if editCtx.isEditing}
@@ -19,7 +19,7 @@
renderMarkdown = false,
textColorClass = 'text-foreground',
cardBgClass = 'dark:bg-primary/15',
maxHeightStyle = ''
maxHeightStyle = 'max-height: var(--max-message-height);'
}: Props = $props();
let isMultiline = $state(false);
@@ -59,7 +59,7 @@
{#if content.trim()}
<Card
class="chat-message-user-bubble max-w-[80%] overflow-y-auto rounded-[1.125rem] border-none bg-primary/5 px-3.75 py-1.5 {textColorClass} backdrop-blur-md data-multiline:py-2.5 {cardBgClass}"
class="max-w-[80%] overflow-y-auto rounded-[1.125rem] border-none bg-primary/5 px-3.75 py-1.5 {textColorClass} backdrop-blur-md data-[multiline]:py-2.5 {cardBgClass}"
data-multiline={isMultiline ? '' : undefined}
style="{maxHeightStyle} overflow-wrap: anywhere; word-break: break-word;"
>
@@ -37,7 +37,6 @@
let allConversationMessages = $state<DatabaseMessage[]>([]);
let isVisible = $state(false);
let previousConversationId = $state<string | null>(null);
let previousRouteId = $state<string | null>(null);
const currentConfig = config();
@@ -158,9 +157,8 @@
});
});
beforeNavigate((navigation) => {
beforeNavigate(() => {
isVisible = false;
previousRouteId = navigation.from?.route.id ?? null;
});
afterNavigate(() => {
@@ -251,13 +249,12 @@
</script>
<div
class="transition-opacity duration-500 ease-out
{isVisible ? 'opacity-100' : 'opacity-0'}
{previousRouteId === '/(chat)/chat/[id]' ? '' : 'delay-300'}"
class="transition-opacity delay-300 duration-500 ease-out
{isVisible ? 'opacity-100' : 'opacity-0'}"
>
{#each displayMessages as { message, toolMessages, isLastAssistantMessage, siblingInfo } (message.id)}
<ChatMessage
class="mx-auto mt-12 w-full max-w-3xl"
class="mx-auto mt-12 w-full max-w-[48rem]"
{message}
{toolMessages}
{isLastAssistantMessage}
@@ -1,28 +1,31 @@
<script lang="ts">
import { Trash2 } from '@lucide/svelte';
import { afterNavigate } from '$app/navigation';
import { page } from '$app/state';
import {
ChatScreenForm,
ChatMessages,
ChatScreenDragOverlay,
ChatScreenProcessingInfo,
ChatScreenActionScrollDown,
DialogEmptyFileAlert,
DialogFileUploadError,
DialogChatError,
ServerLoadingSplash,
DialogConfirmation,
ChatScreenServerError
} from '$lib/components/app';
import { setProcessingInfoContext } from '$lib/contexts';
import { ErrorDialogType } from '$lib/enums';
import { createAutoScrollController } from '$lib/hooks/use-auto-scroll.svelte';
import { useChatScreenActiveModel } from '$lib/hooks/use-chat-screen-active-model.svelte';
import { useChatScreenDragAndDrop } from '$lib/hooks/use-chat-screen-drag-and-drop.svelte';
import { useChatScreenFileUpload } from '$lib/hooks/use-chat-screen-file-upload.svelte';
import { useChatScreenScroll } from '$lib/hooks/use-chat-screen-scroll.svelte';
import { useKeyboardShortcuts } from '$lib/hooks/use-keyboard-shortcuts.svelte';
import { device } from '$lib/stores/device.svelte';
import { isMobile } from '$lib/stores/viewport.svelte';
import {
chatStore,
errorDialog,
isLoading,
isChatStreaming,
isEditing,
getAddFilesHandler,
activeProcessingState
} from '$lib/stores/chat.svelte';
import {
@@ -31,81 +34,138 @@
activeConversation
} from '$lib/stores/conversations.svelte';
import { config } from '$lib/stores/settings.svelte';
import { serverLoading, serverError } from '$lib/stores/server.svelte';
import { parseFilesToMessageExtras } from '$lib/utils/browser-only';
import { onDestroy, onMount } from 'svelte';
import { serverLoading, serverError, isRouterMode } from '$lib/stores/server.svelte';
import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte';
import { isFileTypeSupported, filterFilesByModalities } from '$lib/utils';
import { parseFilesToMessageExtras, processFilesToChatUploaded } from '$lib/utils/browser-only';
import { onMount } from 'svelte';
import ChatScreenGreeting from './ChatScreenGreeting.svelte';
import ChatScreenActionScrollDown from './ChatScreenActionScrollDown.svelte';
import ChatScreenDialogsAndAlerts from './ChatScreenDialogsAndAlerts.svelte';
import { ROUTES } from '$lib/constants';
let { showCenteredEmpty = false } = $props();
const autoScroll = createAutoScrollController();
let disableAutoScroll = $derived(Boolean(config().disableAutoScroll));
let chatScrollContainer: HTMLDivElement | undefined = $state();
let dragCounter = $state(0);
let isDragOver = $state(false);
let showFileErrorDialog = $state(false);
let uploadedFiles = $state<ChatUploadedFile[]>([]);
let fileErrorData = $state<{
generallyUnsupported: File[];
modalityUnsupported: File[];
modalityReasons: Record<string, string>;
supportedTypes: string[];
}>({
generallyUnsupported: [],
modalityUnsupported: [],
modalityReasons: {},
supportedTypes: []
});
let showDeleteDialog = $state(false);
let showEmptyFileDialog = $state(false);
let processingInfoVisible = $state(false);
let emptyFileNames = $state<string[]>([]);
let initialMessage = $state('');
let isEmpty = $derived(
showCenteredEmpty && !activeConversation() && activeMessages().length === 0 && !isLoading()
);
let activeErrorDialog = $derived(errorDialog());
let isServerLoading = $derived(serverLoading());
let hasPropsError = $derived(!!serverError());
let isCurrentConversationLoading = $derived(isLoading() || isChatStreaming());
let showProcessingInfo = $derived(
isCurrentConversationLoading ||
(config().keepStatsVisible && !!page.params.id) ||
activeProcessingState() !== null
);
let isRouter = $derived(isRouterMode());
let conversationModel = $derived(
chatStore.getConversationModel(activeMessages() as DatabaseMessage[])
);
let activeModelId = $derived.by(() => {
const options = modelOptions();
if (!isRouter) {
return options.length > 0 ? options[0].model : null;
}
const selectedId = selectedModelId();
if (selectedId) {
const model = options.find((m) => m.id === selectedId);
if (model) return model.model;
}
if (conversationModel) {
const model = options.find((m) => m.model === conversationModel);
if (model) return model.model;
}
return null;
});
let modelPropsVersion = $state(0);
setProcessingInfoContext({
get showProcessingInfo() {
return showProcessingInfo;
}
});
let disableAutoScroll = $derived(Boolean(config().disableAutoScroll) || isMobile.current);
let isMobileUserScrolledUp = $state(false);
let mobileScrollDownHint = $state(false);
let mobileScrollDownHintLockedUntil = $state(0);
let emptyFileNames = $state<string[]>([]);
let initialMessage = $state('');
let showDeleteDialog = $state(false);
let showEmptyFileDialog = $state(false);
let isEmpty = $derived(
showCenteredEmpty && !activeConversation() && activeMessages().length === 0 && !isLoading()
);
let activeErrorDialog = $derived(errorDialog());
let isServerLoading = $derived(serverLoading());
let hasPropsError = $derived(!!serverError());
let isCurrentConversationLoading = $derived(isLoading() || isChatStreaming());
let showProcessingInfo = $derived(
isCurrentConversationLoading ||
(config().keepStatsVisible && !!page.params.id) ||
activeProcessingState() !== null
);
let chatFormBottomPosition = $derived.by(() => {
if (!isMobile.current) return '1rem';
if (device.isStandalone) return '1.5rem';
if (device.isIOSSafari) return '0.25rem';
return '0.5rem';
});
$effect(() => {
if (activeModelId) {
const cached = modelsStore.getModelProps(activeModelId);
const autoScroll = createAutoScrollController();
const scroll = useChatScreenScroll(autoScroll);
const activeModel = useChatScreenActiveModel();
const fileUpload = useChatScreenFileUpload({
capabilities: () => ({
hasVision: activeModel.hasVisionModality,
hasAudio: activeModel.hasAudioModality,
hasVideo: activeModel.hasVideoModality
}),
activeModelId: () => activeModel.activeModelId
});
const dragAndDrop = useChatScreenDragAndDrop({
onDrop: fileUpload.handleFileUpload
});
const { handleKeydown } = useKeyboardShortcuts({
deleteActiveConversation: () => {
if (activeConversation()) {
showDeleteDialog = true;
if (!cached) {
modelsStore.fetchModelProps(activeModelId).then(() => {
modelPropsVersion++;
});
}
}
});
function handleMobileScroll() {
if (!isMobile.current) return;
let hasAudioModality = $derived.by(() => {
if (activeModelId) {
void modelPropsVersion;
const container = scroll.chatScrollContainer;
if (!container) return;
return modelsStore.modelSupportsAudio(activeModelId);
}
const distanceFromBottom =
container.scrollHeight - container.clientHeight - container.scrollTop;
isMobileUserScrolledUp = distanceFromBottom > 300;
}
return false;
});
let hasVideoModality = $derived.by(() => {
if (activeModelId) {
void modelPropsVersion;
return modelsStore.modelSupportsVideo(activeModelId);
}
return false;
});
let hasVisionModality = $derived.by(() => {
if (activeModelId) {
void modelPropsVersion;
return modelsStore.modelSupportsVision(activeModelId);
}
return false;
});
async function handleDeleteConfirm() {
const conversation = activeConversation();
@@ -117,69 +177,27 @@
showDeleteDialog = false;
}
async function handleSendMessage(message: string, files?: ChatUploadedFile[]): Promise<boolean> {
const plainFiles = files ? $state.snapshot(files) : undefined;
const result = plainFiles
? await parseFilesToMessageExtras(plainFiles, activeModel.activeModelId ?? undefined)
: undefined;
if (result?.emptyFiles && result.emptyFiles.length > 0) {
emptyFileNames = result.emptyFiles;
showEmptyFileDialog = true;
if (files) {
const emptyFileNamesSet = new Set(result.emptyFiles);
fileUpload.uploadedFiles = fileUpload.uploadedFiles.filter(
(file) => !emptyFileNamesSet.has(file.name)
);
}
return false;
}
handleSendLikeScroll();
await chatStore.sendMessage(message, result?.extras);
return true;
function handleProcessingInfoVisibility(visible: boolean) {
processingInfoVisible = visible;
}
function handleSendLikeScroll() {
if (!isMobile.current) {
autoScroll.enable();
function handleDragEnter(event: DragEvent) {
event.preventDefault();
dragCounter++;
if (event.dataTransfer?.types.includes('Files')) {
isDragOver = true;
}
}
setTimeout(() => {
const container = scroll.chatScrollContainer;
if (!container) return;
function handleDragLeave(event: DragEvent) {
event.preventDefault();
const lastUserBubble = container.querySelector(
'.chat-message:nth-last-child(2) .chat-message-user .chat-message-user-bubble'
) as HTMLElement | null;
dragCounter--;
if (isMobile.current) {
// Keep the last user message bubble just above the input on mobile
const bubbleHeight = lastUserBubble?.scrollHeight ?? 0;
const baseHeight = container.scrollHeight - innerHeight;
container.scrollTo({
top: bubbleHeight > 0 ? baseHeight - bubbleHeight : baseHeight,
behavior: 'smooth'
});
} else if (lastUserBubble) {
// On desktop, place the last user message near the top of the viewport
const topPadding = 24;
const bubbleRect = lastUserBubble.getBoundingClientRect();
container.scrollTo({
top: Math.max(0, container.scrollTop + bubbleRect.top - topPadding),
behavior: 'smooth'
});
} else {
autoScroll.scrollToBottom();
}
}, 100);
if (isMobile.current) {
autoScroll.setDisabled(disableAutoScroll);
mobileScrollDownHint = true;
mobileScrollDownHintLockedUntil = Date.now() + 500;
if (dragCounter === 0) {
isDragOver = false;
}
}
@@ -189,138 +207,273 @@
}
}
function handleDragOver(event: DragEvent) {
event.preventDefault();
}
function handleDrop(event: DragEvent) {
event.preventDefault();
isDragOver = false;
dragCounter = 0;
if (event.dataTransfer?.files) {
const files = Array.from(event.dataTransfer.files);
if (isEditing()) {
const handler = getAddFilesHandler();
if (handler) {
handler(files);
return;
}
}
processFiles(files);
}
}
function handleFileRemove(fileId: string) {
uploadedFiles = uploadedFiles.filter((f) => f.id !== fileId);
}
function handleFileUpload(files: File[]) {
processFiles(files);
}
const { handleKeydown } = useKeyboardShortcuts({
deleteActiveConversation: () => {
if (activeConversation()) {
showDeleteDialog = true;
}
}
});
async function handleSystemPromptAdd(draft: { message: string; files: ChatUploadedFile[] }) {
if (draft.message || draft.files.length > 0) {
chatStore.savePendingDraft(draft.message, draft.files);
}
await chatStore.addSystemPrompt();
}
$effect(() => {
const shouldDisableAutoScroll =
config().disableAutoScroll || (isMobile.current && isCurrentConversationLoading);
autoScroll.setDisabled(shouldDisableAutoScroll);
if (!shouldDisableAutoScroll) {
function handleScroll() {
autoScroll.handleScroll();
}
async function handleSendMessage(message: string, files?: ChatUploadedFile[]): Promise<boolean> {
const plainFiles = files ? $state.snapshot(files) : undefined;
const result = plainFiles
? await parseFilesToMessageExtras(plainFiles, activeModelId ?? undefined)
: undefined;
if (result?.emptyFiles && result.emptyFiles.length > 0) {
emptyFileNames = result.emptyFiles;
showEmptyFileDialog = true;
if (files) {
const emptyFileNamesSet = new Set(result.emptyFiles);
uploadedFiles = uploadedFiles.filter((file) => !emptyFileNamesSet.has(file.name));
}
return false;
}
const extras = result?.extras;
// Enable autoscroll for user-initiated message sending
autoScroll.enable();
await chatStore.sendMessage(message, extras);
autoScroll.scrollToBottom();
return true;
}
async function processFiles(files: File[]) {
const generallySupported: File[] = [];
const generallyUnsupported: File[] = [];
for (const file of files) {
if (isFileTypeSupported(file.name, file.type)) {
generallySupported.push(file);
} else {
generallyUnsupported.push(file);
}
}
// Use model-specific capabilities for file validation
const capabilities = {
hasVision: hasVisionModality,
hasAudio: hasAudioModality,
hasVideo: hasVideoModality
};
const { supportedFiles, unsupportedFiles, modalityReasons } = filterFilesByModalities(
generallySupported,
capabilities
);
const allUnsupportedFiles = [...generallyUnsupported, ...unsupportedFiles];
if (allUnsupportedFiles.length > 0) {
const supportedTypes: string[] = ['text files', 'PDFs'];
if (hasVisionModality) supportedTypes.push('images');
if (hasAudioModality) supportedTypes.push('audio files');
if (hasVideoModality) supportedTypes.push('video files');
fileErrorData = {
generallyUnsupported,
modalityUnsupported: unsupportedFiles,
modalityReasons,
supportedTypes
};
showFileErrorDialog = true;
}
if (supportedFiles.length > 0) {
const processed = await processFilesToChatUploaded(
supportedFiles,
activeModelId ?? undefined
);
uploadedFiles = [...uploadedFiles, ...processed];
}
}
afterNavigate(() => {
if (!disableAutoScroll) {
autoScroll.enable();
}
});
onMount(() => {
const pendingDraft = chatStore.consumePendingDraft();
if (pendingDraft) {
initialMessage = pendingDraft.message;
fileUpload.uploadedFiles = pendingDraft.files;
}
function handleMessagesReady() {
if (disableAutoScroll) return;
if (!autoScroll.userScrolledUp) {
requestAnimationFrame(() => {
autoScroll.scrollToBottom('instant');
});
}
}
onMount(() => {
autoScroll.startObserving();
if (!disableAutoScroll) {
autoScroll.enable();
}
if (isMobile.current && isCurrentConversationLoading) {
mobileScrollDownHint = true;
mobileScrollDownHintLockedUntil = Date.now() + 500;
const pendingDraft = chatStore.consumePendingDraft();
if (pendingDraft) {
initialMessage = pendingDraft.message;
uploadedFiles = pendingDraft.files;
}
handleMobileScroll();
});
onDestroy(() => autoScroll.destroy());
$effect(() => {
autoScroll.setContainer(chatScrollContainer);
});
$effect(() => {
autoScroll.setDisabled(disableAutoScroll);
});
</script>
{#if dragAndDrop.isDragOver}
{#if isDragOver}
<ChatScreenDragOverlay />
{/if}
<svelte:window
onkeydown={handleKeydown}
onscroll={(e) => {
scroll.handleScroll(e);
handleMobileScroll();
if (e.isTrusted && Date.now() > mobileScrollDownHintLockedUntil) {
mobileScrollDownHint = false;
}
}}
/>
<svelte:window onkeydown={handleKeydown} />
{#if isServerLoading}
<ServerLoadingSplash />
{:else}
<div
class="chat-screen flex grow flex-col min-h-[calc(100dvh-1rem)] md:min-h-full px-4 md:py-0 pt-12 pb-48 md:pb-4"
style:--chat-form-bottom-position={chatFormBottomPosition}
ondragenter={dragAndDrop.dragHandlers.dragenter}
ondragleave={dragAndDrop.dragHandlers.dragleave}
ondragover={dragAndDrop.dragHandlers.dragover}
ondrop={dragAndDrop.dragHandlers.drop}
bind:this={chatScrollContainer}
aria-label="Chat interface with file drop zone"
class="flex h-full flex-col overflow-y-auto px-4 md:px-6"
ondragenter={handleDragEnter}
ondragleave={handleDragLeave}
ondragover={handleDragOver}
ondrop={handleDrop}
onscroll={handleScroll}
role="main"
>
{#if !isEmpty}
<ChatMessages
messages={activeMessages()}
onUserAction={() => {
handleSendLikeScroll();
}}
/>
{/if}
<div class="flex grow flex-col pt-14">
{#if !isEmpty}
<ChatMessages
messages={activeMessages()}
onMessagesReady={handleMessagesReady}
onUserAction={() => {
autoScroll.enable();
if (!autoScroll.userScrolledUp) {
autoScroll.scrollToBottom();
}
}}
/>
{/if}
<div
class={[
'pointer-events-none md:sticky fixed mt-auto transition-all duration-200',
device.isStandalone
? 'bottom-6 right-4 left-4'
: device.isIOSSafari
? 'bottom-1 left-2 right-2'
: 'bottom-2 right-2 left-2',
isEmpty ? 'md:bottom-[calc(50dvh-7rem)] 2xl:bottom-[calc(50dvh-4rem)]' : 'md:bottom-4'
]}
style:padding-top={!isEmpty ? 'var(--chat-form-padding-top)' : undefined}
>
<ChatScreenGreeting {isEmpty} />
<div
class={[
'pointer-events-none sticky right-4 left-4 mt-auto transition-all duration-200',
isEmpty ? 'bottom-[calc(50dvh-7rem)]' : 'bottom-4 pt-24 md:pt-32'
]}
>
<ChatScreenGreeting {isEmpty} />
<ChatScreenServerError />
<ChatScreenActionScrollDown
container={chatScrollContainer}
hasProcessingInfoVisible={processingInfoVisible}
/>
<div class="pointer-events-none flex flex-col gap-6 items-center w-full">
{#if (isMobile.current ? mobileScrollDownHint || isMobileUserScrolledUp : autoScroll.userScrolledUp) && page.url.hash.includes(ROUTES.CHAT) && page.params.id}
<ChatScreenActionScrollDown
onclick={() => {
mobileScrollDownHint = false;
scroll.chatScrollContainer?.scrollTo({
top: scroll.chatScrollContainer.scrollHeight,
behavior: 'smooth'
});
}}
<ChatScreenProcessingInfo onVisibilityChange={handleProcessingInfoVisibility} />
<ChatScreenServerError />
<div class="conversation-chat-form pointer-events-auto rounded-t-3xl">
<ChatScreenForm
disabled={hasPropsError || isEditing()}
{initialMessage}
isLoading={isCurrentConversationLoading}
onFileRemove={handleFileRemove}
onFileUpload={handleFileUpload}
onSend={handleSendMessage}
onStop={() => chatStore.stopGeneration()}
onSystemPromptAdd={handleSystemPromptAdd}
bind:uploadedFiles
/>
{/if}
{#if showProcessingInfo}
<ChatScreenProcessingInfo />
{/if}
</div>
</div>
<ChatScreenForm
class="pointer-events-auto conversation-chat-form"
disabled={hasPropsError || isEditing()}
{initialMessage}
isLoading={isCurrentConversationLoading}
onFileRemove={fileUpload.handleFileRemove}
onFileUpload={fileUpload.handleFileUpload}
onSend={handleSendMessage}
onStop={() => chatStore.stopGeneration()}
onSystemPromptAdd={handleSystemPromptAdd}
bind:uploadedFiles={fileUpload.uploadedFiles}
/>
</div>
</div>
{/if}
<ChatScreenDialogsAndAlerts
{showDeleteDialog}
{handleDeleteConfirm}
{showEmptyFileDialog}
{emptyFileNames}
{activeErrorDialog}
{handleErrorDialogOpenChange}
{fileUpload}
<DialogFileUploadError bind:open={showFileErrorDialog} {fileErrorData} />
<DialogConfirmation
bind:open={showDeleteDialog}
title="Delete Conversation"
description="Are you sure you want to delete this conversation? This action cannot be undone and will permanently remove all messages in this conversation."
confirmText="Delete"
cancelText="Cancel"
variant="destructive"
icon={Trash2}
onConfirm={handleDeleteConfirm}
onCancel={() => (showDeleteDialog = false)}
/>
<DialogEmptyFileAlert
bind:open={showEmptyFileDialog}
emptyFiles={emptyFileNames}
onOpenChange={(open) => {
if (!open) {
emptyFileNames = [];
}
}}
/>
<DialogChatError
message={activeErrorDialog?.message ?? ''}
contextInfo={activeErrorDialog?.contextInfo}
onOpenChange={handleErrorDialogOpenChange}
open={Boolean(activeErrorDialog)}
type={activeErrorDialog?.type ?? ErrorDialogType.SERVER}
/>
@@ -1,18 +1,58 @@
<script lang="ts">
import { ArrowDown } from '@lucide/svelte';
import ActionIcon from '$lib/components/app/actions/ActionIcon.svelte';
import { Button } from '$lib/components/ui/button';
let { onclick }: { onclick: (e?: MouseEvent) => void } = $props();
interface Props {
container: HTMLDivElement | undefined;
hasProcessingInfoVisible: boolean;
}
let { container, hasProcessingInfoVisible }: Props = $props();
let show = $state(false);
let buttonBottom = $derived(hasProcessingInfoVisible ? '2rem' : '0');
function checkVisibility() {
if (!container) return;
const { scrollTop, scrollHeight, clientHeight } = container;
const distanceFromBottom = scrollHeight - clientHeight - scrollTop;
show = distanceFromBottom > clientHeight * 0.5;
}
function scrollToBottom() {
if (container) {
container.scrollTo({
top: container.scrollHeight,
behavior: 'smooth'
});
}
}
$effect(() => {
const c = container;
if (c) {
c.addEventListener('scroll', checkVisibility);
checkVisibility();
return () => {
c.removeEventListener('scroll', checkVisibility);
};
}
});
</script>
<div class="pointer-events-auto flex justify-center relative h-0">
<ActionIcon
icon={ArrowDown}
{onclick}
ariaLabel="Scroll to bottom"
tooltip="Scroll to bottom"
size="lg"
iconSize="h-4 w-4"
class="h-9 w-9 rounded-full bg-accent text-accent-foreground absolute bottom-4 shadow-md"
/>
<div class="relative z-50 mx-auto mb-4 flex max-w-[48rem] justify-center">
<Button
onclick={scrollToBottom}
variant="secondary"
size="icon"
disabled={!show}
class="pointer-events-auto absolute h-10 w-10 rounded-full bg-background/80 shadow-lg backdrop-blur-sm transition-all duration-200 hover:bg-muted/80"
style="bottom: {buttonBottom}; transform: translateY({show ? '0' : '2rem'}); opacity: {show
? 1
: 0};"
aria-label="Scroll to bottom"
>
<ArrowDown class="h-4 w-4" />
</Button>
</div>
@@ -1,55 +0,0 @@
<script lang="ts">
import { Trash2 } from '@lucide/svelte';
import { ErrorDialogType } from '$lib/enums';
import {
DialogChatError,
DialogConfirmation,
DialogEmptyFileAlert,
DialogFileUploadError
} from '$lib/components/app';
let {
showDeleteDialog,
handleDeleteConfirm,
showEmptyFileDialog,
emptyFileNames,
activeErrorDialog,
handleErrorDialogOpenChange,
fileUpload
} = $props();
</script>
<DialogFileUploadError
bind:open={fileUpload.showFileErrorDialog}
fileErrorData={fileUpload.fileErrorData}
/>
<DialogConfirmation
bind:open={showDeleteDialog}
title="Delete Conversation"
description="Are you sure you want to delete this conversation? This action cannot be undone and will permanently remove all messages in this conversation."
confirmText="Delete"
cancelText="Cancel"
variant="destructive"
icon={Trash2}
onConfirm={handleDeleteConfirm}
onCancel={() => (showDeleteDialog = false)}
/>
<DialogEmptyFileAlert
bind:open={showEmptyFileDialog}
emptyFiles={emptyFileNames}
onOpenChange={(open) => {
if (!open) {
emptyFileNames = [];
}
}}
/>
<DialogChatError
message={activeErrorDialog?.message ?? ''}
contextInfo={activeErrorDialog?.contextInfo}
onOpenChange={handleErrorDialogOpenChange}
open={Boolean(activeErrorDialog)}
type={activeErrorDialog?.type ?? ErrorDialogType.SERVER}
/>
@@ -2,7 +2,6 @@
import { afterNavigate } from '$app/navigation';
import { page } from '$app/state';
import { ChatForm } from '$lib/components/app';
import { isMobile } from '$lib/stores/viewport.svelte';
import { onMount } from 'svelte';
import { useDraftMessages } from '$lib/hooks/use-draft-messages.svelte';
@@ -33,30 +32,7 @@
}: Props = $props();
let chatFormRef: ChatForm | undefined = $state(undefined);
let formWrapperEl: HTMLDivElement | undefined = $state();
let chatId = $derived(page.params.id as string | undefined);
$effect(() => {
if (!formWrapperEl) return;
const formEl = formWrapperEl.querySelector('form') as HTMLElement | null;
if (!formEl) return;
const updateHeight = () => {
const height = Math.round(formEl.getBoundingClientRect().height);
document.documentElement.style.setProperty('--chat-form-height', `${height}px`);
};
updateHeight();
const resizeObserver = new ResizeObserver(updateHeight);
resizeObserver.observe(formEl);
return () => {
resizeObserver.disconnect();
document.documentElement.style.removeProperty('--chat-form-height');
};
});
let hasLoadingAttachments = $derived(uploadedFiles.some((f) => f.isLoading));
let message = $derived(initialMessage);
let previousIsLoading = $derived(isLoading);
@@ -107,14 +83,12 @@
}
onMount(() => {
if (!isMobile.current) {
setTimeout(() => chatFormRef?.focus(), 100);
}
setTimeout(() => chatFormRef?.focus(), 10);
});
afterNavigate((navigation) => {
if (navigation?.from != null && !isMobile.current) {
setTimeout(() => chatFormRef?.focus(), 100);
if (navigation?.from != null) {
setTimeout(() => chatFormRef?.focus(), 10);
}
});
@@ -134,12 +108,12 @@
});
</script>
<div class="chat-screen-form-wrapper" bind:this={formWrapperEl}>
<div class="relative mx-auto max-w-[48rem]">
<ChatForm
class="mx-auto max-w-3xl {className}"
bind:this={chatFormRef}
bind:value={message}
bind:uploadedFiles
class={className}
{disabled}
{isLoading}
showMcpPromptButton
@@ -1,4 +1,5 @@
<script lang="ts">
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
import { serverStore } from '$lib/stores/server.svelte';
interface Props {
@@ -10,9 +11,10 @@
<div
class={[
'pointer-events-none mb-4 hidden px-4 text-center text-balance',
isEmpty && 'mb-[calc(50dvh-8rem)] md:mb-6 pointer-events-auto block!'
'pointer-events-none mb-4 hidden px-4 text-center',
isEmpty && 'pointer-events-auto block!'
]}
use:fadeInView={{ duration: 300 }}
>
<h1 class="mb-2 text-2xl font-semibold tracking-tight md:text-3xl">Hello there</h1>
@@ -5,8 +5,13 @@
import { chatStore, isLoading, isChatStreaming } from '$lib/stores/chat.svelte';
import { activeMessages, activeConversation } from '$lib/stores/conversations.svelte';
import { config } from '$lib/stores/settings.svelte';
import { getProcessingInfoContext } from '$lib/contexts';
import { page } from '$app/state';
const processingState = useProcessingState();
const processingInfoCtx = getProcessingInfoContext();
let showProcessingInfo = $derived(processingInfoCtx.showProcessingInfo);
let isCurrentConversationLoading = $derived(isLoading());
let isStreaming = $derived(isChatStreaming());
@@ -65,8 +70,8 @@
<div
class={[
'chat-processing-info-container pointer-events-none relative w-full hidden md:block',
processingVisible && 'visible'
'chat-processing-info-container pointer-events-none relative',
page.params.id && showProcessingInfo && 'visible'
]}
>
<div class="chat-processing-info-content absolute bottom-4 left-1/2 -translate-x-1/2">
@@ -677,6 +677,13 @@ export { default as ChatScreenForm } from './ChatScreen/ChatScreenForm.svelte';
*/
export { default as ChatScreenProcessingInfo } from './ChatScreen/ChatScreenProcessingInfo.svelte';
/**
* Scroll-to-bottom action button. Displays a floating button when the user
* has scrolled up more than half a viewport height from the bottom.
* Takes the chat container element as a prop to manage scroll state internally.
*/
export { default as ChatScreenActionScrollDown } from './ChatScreen/ChatScreenActionScrollDown.svelte';
/**
* Server error alert displayed when the server is unreachable.
* Shows the error message with a retry button.
@@ -3,7 +3,6 @@
import { Search, X } from '@lucide/svelte';
interface Props {
autofocus?: boolean;
value?: string;
placeholder?: string;
onInput?: (value: string) => void;
@@ -16,7 +15,6 @@
}
let {
autofocus,
value = $bindable(''),
placeholder = 'Search...',
onInput,
@@ -41,7 +39,7 @@
if (value) {
value = '';
onInput?.('');
ref?.focus({ preventScroll: true });
ref?.focus();
} else {
onClose?.();
}
@@ -54,7 +52,6 @@
/>
<Input
{autofocus}
{id}
bind:value
bind:ref
@@ -1,15 +0,0 @@
<script>
import logoMark from '$lib/assets/logo.svg?raw';
let { class: className = '', style = '' } = $props();
</script>
<div class={className} {style}>
{@html logoMark}
</div>
<style>
div :global(svg) {
width: var(--size, 1rem);
height: var(--size, 1rem);
}
</style>
@@ -51,11 +51,3 @@ export { default as KeyboardShortcutInfo } from './KeyboardShortcutInfo.svelte';
* Preview button is shown only for HTML code blocks.
*/
export { default as CodeBlockActions } from './CodeBlockActions.svelte';
/**
* **Logo** - Application brand mark
*
* Inline SVG of the application logo. Accepts styling via the standard
* `class` and `style` props and inherits color via `currentColor`.
*/
export { default as Logo } from './Logo.svelte';
@@ -1,11 +0,0 @@
<script lang="ts">
let { percent }: { percent: number } = $props();
</script>
<!-- thin determinate load bar pinned to the bottom edge, pulsing while it fills -->
<div class="pointer-events-none absolute inset-x-0 bottom-0 h-0.5 overflow-hidden rounded-b-sm">
<div
class="h-full animate-pulse bg-primary transition-[width] duration-200 ease-out"
style="width: {percent}%"
></div>
</div>
@@ -2,10 +2,8 @@
import { ChevronDown, Loader2, Package } from '@lucide/svelte';
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
import * as Tooltip from '$lib/components/ui/tooltip';
import { KeyboardKey, ServerModelStatus } from '$lib/enums';
import { KeyboardKey } from '$lib/enums';
import { useModelsSelector } from '$lib/hooks/use-models-selector.svelte';
import { modelsStore, routerModels } from '$lib/stores/models.svelte';
import { modelLoadFraction } from '$lib/utils';
import {
DialogModelInformation,
DropdownMenuSearchable,
@@ -13,7 +11,6 @@
ModelsSelectorList,
ModelsSelectorOption
} from '$lib/components/app';
import ModelLoadHighlight from './ModelLoadHighlight.svelte';
import type { ModelItem } from './utils';
interface Props {
@@ -116,17 +113,6 @@
{/if}
{:else}
{@const selectedOption = ms.getDisplayOption()}
{@const triggerModel = selectedOption?.model}
{@const triggerStatus = triggerModel
? routerModels().find((m) => m.id === triggerModel)?.status?.value
: undefined}
{@const triggerLoading =
!!triggerModel &&
(triggerStatus === ServerModelStatus.LOADING ||
modelsStore.isModelOperationInProgress(triggerModel))}
{@const triggerLoadPercent = triggerLoading
? Math.round(modelLoadFraction(modelsStore.getLoadProgress(triggerModel)) * 100)
: 0}
{#if ms.isRouter}
<DropdownMenu.Root bind:open={isOpen} onOpenChange={ms.handleOpenChange}>
@@ -137,7 +123,7 @@
<DropdownMenu.Trigger
{...props}
class={[
`relative inline-grid cursor-pointer grid-cols-[1fr_auto_1fr] items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 text-xs shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
`inline-grid cursor-pointer grid-cols-[1fr_auto_1fr] items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 text-xs shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
!ms.isCurrentModelInCache
? 'bg-red-400/10 !text-red-400 hover:bg-red-400/20 hover:text-red-400'
: forceForegroundText
@@ -168,10 +154,6 @@
{:else}
<ChevronDown class="h-3 w-3.5 shrink-0" />
{/if}
{#if triggerLoading}
<ModelLoadHighlight percent={triggerLoadPercent} />
{/if}
</DropdownMenu.Trigger>
{/snippet}
</Tooltip.Trigger>
@@ -10,7 +10,6 @@
RotateCw
} from '@lucide/svelte';
import { ActionIcon, ModelId } from '$lib/components/app';
import ModelLoadHighlight from './ModelLoadHighlight.svelte';
import type { ModelOption } from '$lib/types/models';
import { ServerModelStatus } from '$lib/enums';
import { modelsStore, routerModels } from '$lib/stores/models.svelte';
@@ -120,11 +119,11 @@
</div>
{#if isLoading}
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-5">
<div class="flex w-4 [@media(pointer:coarse)]:w-5 items-center justify-center">
<Loader2 class="h-4 w-4 animate-spin text-muted-foreground" />
</div>
{:else if isFailed}
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-auto">
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
<CircleAlert
class="h-3.5 w-3.5 text-red-500 group-hover:hidden [@media(pointer:coarse)]:hidden"
/>
@@ -141,7 +140,7 @@
</div>
</div>
{:else if isSleeping}
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-auto">
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
<span
class="h-2 w-2 rounded-full bg-orange-400 group-hover:hidden [@media(pointer:coarse)]:hidden"
></span>
@@ -160,7 +159,7 @@
</div>
</div>
{:else if isLoaded}
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-auto">
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
<span
class="h-2 w-2 rounded-full bg-green-500 group-hover:hidden [@media(pointer:coarse)]:hidden"
></span>
@@ -177,7 +176,7 @@
</div>
</div>
{:else}
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-auto">
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
<span
class="h-2 w-2 rounded-full bg-muted-foreground/50 group-hover:hidden [@media(pointer:coarse)]:hidden"
></span>
@@ -197,6 +196,13 @@
</div>
{#if isLoading}
<ModelLoadHighlight percent={loadPercent} />
<div
class="pointer-events-none absolute inset-x-0 bottom-0 h-0.5 overflow-hidden rounded-b-sm bg-muted"
>
<div
class="h-full bg-primary transition-[width] duration-200 ease-out"
style="width: {loadPercent}%"
></div>
</div>
{/if}
</div>
@@ -8,10 +8,6 @@
ModelsSelectorList,
SearchInput
} from '$lib/components/app';
import ModelLoadHighlight from './ModelLoadHighlight.svelte';
import { ServerModelStatus } from '$lib/enums';
import { modelsStore, routerModels } from '$lib/stores/models.svelte';
import { modelLoadFraction } from '$lib/utils';
interface Props {
class?: string;
@@ -65,23 +61,12 @@
<p class="text-xs text-muted-foreground">No models available.</p>
{:else}
{@const selectedOption = ms.getDisplayOption()}
{@const triggerModel = selectedOption?.model}
{@const triggerStatus = triggerModel
? routerModels().find((m) => m.id === triggerModel)?.status?.value
: undefined}
{@const triggerLoading =
!!triggerModel &&
(triggerStatus === ServerModelStatus.LOADING ||
modelsStore.isModelOperationInProgress(triggerModel))}
{@const triggerLoadPercent = triggerLoading
? Math.round(modelLoadFraction(modelsStore.getLoadProgress(triggerModel)) * 100)
: 0}
{#if ms.isRouter}
<button
type="button"
class={[
`relative inline-flex cursor-pointer items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 text-xs shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 max-sm:px-3 max-sm:py-2 max-sm:text-sm dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
`inline-flex cursor-pointer items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 max-sm:px-3 max-sm:py-2 text-xs max-sm:text-sm shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
!ms.isCurrentModelInCache
? 'bg-red-400/10 !text-red-400 hover:bg-red-400/20 hover:text-red-400'
: forceForegroundText
@@ -114,10 +99,6 @@
{:else}
<ChevronDown class="h-3 w-3.5 shrink-0" />
{/if}
{#if triggerLoading}
<ModelLoadHighlight percent={triggerLoadPercent} />
{/if}
</button>
<Sheet.Root bind:open={sheetOpen} onOpenChange={handleSheetOpenChange}>
@@ -0,0 +1,84 @@
<script lang="ts">
import { goto } from '$app/navigation';
import { page } from '$app/state';
import { ActionIcon } from '$lib/components/app';
import {
ICON_STRIP_TRANSITION_DURATION,
ICON_STRIP_TRANSITION_DELAY_MULTIPLIER,
SIDEBAR_ACTIONS_ITEMS
} from '$lib/constants';
import { TooltipSide } from '$lib/enums';
import { fade } from 'svelte/transition';
import { circIn } from 'svelte/easing';
import { onMount } from 'svelte';
import { useKeyboardShortcuts } from '$lib/hooks/use-keyboard-shortcuts.svelte';
interface Props {
sidebarOpen: boolean;
onSearchClick: () => void;
}
let { sidebarOpen = false, onSearchClick }: Props = $props();
const { handleKeydown } = useKeyboardShortcuts({ activateSearchMode: () => onSearchClick() });
let initialized = $state(false);
let showIcons = $derived(!sidebarOpen);
showIcons = false;
onMount(() => {
showIcons = !sidebarOpen;
setTimeout(() => {
initialized = true;
}, ICON_STRIP_TRANSITION_DELAY_MULTIPLIER * SIDEBAR_ACTIONS_ITEMS.length);
});
</script>
<svelte:window onkeydown={handleKeydown} />
<div
class="hidden shrink-0 transition-[width] duration-200 ease-linear md:block {sidebarOpen
? 'w-0'
: 'w-[calc(var(--sidebar-width-icon)+1.5rem)]'}"
></div>
<aside
class="fixed top-0 bottom-0 left-0 z-10 hidden w-[calc(var(--sidebar-width-icon)+1.5rem)] flex-col items-center justify-between py-3 transition-opacity duration-200 ease-linear md:flex {sidebarOpen
? 'pointer-events-none opacity-0'
: 'opacity-100'}"
>
<div class="mt-12 flex flex-col items-center gap-1">
{#each SIDEBAR_ACTIONS_ITEMS as item, i (item.tooltip)}
{@const onclick = item.route ? () => goto(item.route!) : onSearchClick}
{@const isActive = item.activeRouteId
? page.route.id === item.activeRouteId
: item.activeRoutePrefix
? !!page.route.id?.startsWith(item.activeRoutePrefix)
: false}
{#if showIcons}
<div
in:fade={{
duration: ICON_STRIP_TRANSITION_DURATION,
delay: !initialized
? ICON_STRIP_TRANSITION_DELAY_MULTIPLIER + i * ICON_STRIP_TRANSITION_DELAY_MULTIPLIER
: 0,
easing: circIn
}}
>
<ActionIcon
icon={item.icon}
tooltip={item.tooltip}
tooltipSide={TooltipSide.RIGHT}
size="lg"
iconSize="h-4 w-4"
class="h-9 w-9 rounded-full hover:bg-accent! {isActive
? 'bg-accent text-accent-foreground'
: ''}"
{onclick}
/>
</div>
{/if}
{/each}
</div>
</aside>
@@ -1,67 +1,40 @@
<script lang="ts">
import { goto } from '$app/navigation';
import { page } from '$app/state';
import { PanelLeftClose, PanelLeftOpen, X } from '@lucide/svelte';
import {
ActionIcon,
Logo,
SidebarNavigationConversationList,
SidebarNavigationActions
} from '$lib/components/app';
import { ROUTES } from '$lib/constants';
import { fade } from 'svelte/transition';
import { useKeyboardShortcuts } from '$lib/hooks/use-keyboard-shortcuts.svelte';
import { conversationsStore, conversations } from '$lib/stores/conversations.svelte';
import { chatStore } from '$lib/stores/chat.svelte';
import { Trash2, Pencil, Pin, X } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import { DialogConfirmation } from '$lib/components/app';
import SidebarNavigationActions from './SidebarNavigationActions.svelte';
import SidebarNavigationConversationItem from './SidebarNavigationConversationItem.svelte';
import { Checkbox } from '$lib/components/ui/checkbox';
import Label from '$lib/components/ui/label/label.svelte';
import ScrollArea from '$lib/components/ui/scroll-area/scroll-area.svelte';
import * as Sidebar from '$lib/components/ui/sidebar';
import Input from '$lib/components/ui/input/input.svelte';
import { ROUTES } from '$lib/constants/routes';
import { RouterService } from '$lib/services/router.service';
import { isMobile } from '$lib/stores/viewport.svelte';
import { TooltipSide } from '$lib/enums';
import { device } from '$lib/stores/device.svelte';
import { circIn } from 'svelte/easing';
import {
conversationsStore,
conversations,
buildConversationTree
} from '$lib/stores/conversations.svelte';
import { chatStore } from '$lib/stores/chat.svelte';
import { getPreviewText } from '$lib/utils';
import { APP_NAME } from '$lib/constants';
interface Props {
onSearchClick?: () => void;
}
let { onSearchClick = () => {} }: Props = $props();
const { handleKeydown } = useKeyboardShortcuts({ activateSearchMode: () => onSearchClick() });
let isExpandedMode = $state(false);
let hoveredTooltip = $state<string | null>(null);
let logoHovered = $state(false);
const isStripExpanded = $derived(isExpandedMode || hoveredTooltip !== null);
const isOnMobile = $derived(isMobile.current);
function toggleExpandedMode() {
isExpandedMode = !isExpandedMode;
if (!isExpandedMode) {
hoveredTooltip = null;
}
}
$effect(() => {
if (!isExpandedMode) {
isSearchModeActive = false;
searchQuery = '';
cancelMobileCollapse();
}
});
// On mobile the dedicated /search route hides the sidebar (see the aside
// render guard below). Collapse it as we enter /search so it doesn't
// reappear expanded when the user navigates back via the back button.
$effect(() => {
if (isMobile.current && page.url.hash.includes(ROUTES.SEARCH)) {
isExpandedMode = false;
}
});
const sidebar = Sidebar.useSidebar();
let currentChatId = $derived(page.params.id);
let isSearchModeActive = $state(false);
let searchQuery = $state('');
let showDeleteDialog = $state(false);
let deleteWithForks = $state(false);
let showEditDialog = $state(false);
let selectedConversation = $state<DatabaseConversation | null>(null);
let editedName = $state('');
let selectedConversationNamePreview = $derived.by(() =>
selectedConversation ? getPreviewText(selectedConversation.name) : ''
);
let filteredConversations = $derived.by(() => {
if (isSearchModeActive) {
@@ -77,206 +50,294 @@
return conversations();
});
async function selectConversation(id: string) {
if (isMobile.current) {
scheduleMobileCollapse();
let conversationTree = $derived(buildConversationTree(filteredConversations));
let pinnedConversations = $derived.by(() => {
return conversationTree.filter(({ conversation }) => conversation.pinned);
});
let unpinnedConversations = $derived.by(() => {
return conversationTree.filter(({ conversation }) => !conversation.pinned);
});
let selectedConversationHasDescendants = $derived.by(() => {
if (!selectedConversation) return false;
const allConvs = conversations();
const queue = [selectedConversation.id];
while (queue.length > 0) {
const parentId = queue.pop()!;
for (const c of allConvs) {
if (c.forkedFromConversationId === parentId) return true;
}
}
return false;
});
async function handleDeleteConversation(id: string) {
const conversation = conversations().find((conv) => conv.id === id);
if (conversation) {
selectedConversation = conversation;
deleteWithForks = false;
showDeleteDialog = true;
}
await goto(RouterService.chat(id));
}
async function handleEditConversation(id: string) {
const conversation = conversations().find((conv) => conv.id === id);
if (!conversation) return;
const newName = window.prompt('Rename conversation', conversation.name);
if (newName && newName.trim()) {
await conversationsStore.updateConversationName(id, newName.trim());
if (conversation) {
selectedConversation = conversation;
editedName = conversation.name;
showEditDialog = true;
}
}
async function handleDeleteConversation(id: string) {
const conversation = conversations().find((conv) => conv.id === id);
if (!conversation) return;
function handleConfirmDelete() {
if (selectedConversation) {
const convId = selectedConversation.id;
const withForks = deleteWithForks;
showDeleteDialog = false;
const confirmed = window.confirm(
`Delete "${conversation.name}"? This action cannot be undone.`
);
if (!confirmed) return;
setTimeout(() => {
conversationsStore.deleteConversation(convId, {
deleteWithForks: withForks
});
}, 100); // Wait for animation to finish
}
}
await conversationsStore.deleteConversation(id, { deleteWithForks: false });
function handleConfirmEdit() {
if (!editedName.trim() || !selectedConversation) return;
showEditDialog = false;
conversationsStore.updateConversationName(selectedConversation.id, editedName);
selectedConversation = null;
}
export function handleMobileSidebarItemClick() {
if (sidebar.isMobile) {
sidebar.toggle();
}
}
let chatSidebarActions: { activateSearch?: () => void } | undefined = $state();
let openedForSearch = $state(false);
export function activateSearchMode() {
if (!sidebar.open) {
openedForSearch = true;
}
chatSidebarActions?.activateSearch?.();
}
function handleSearchDeactivated() {
if (openedForSearch) {
openedForSearch = false;
sidebar.toggle();
}
}
$effect(() => {
if (!sidebar.open) {
isSearchModeActive = false;
searchQuery = '';
openedForSearch = false;
}
});
export function editActiveConversation() {
if (currentChatId) {
const activeConversation = filteredConversations.find((conv) => conv.id === currentChatId);
if (activeConversation) {
const event = new CustomEvent('edit-active-conversation', {
detail: { conversationId: currentChatId }
});
document.dispatchEvent(event);
}
}
}
async function selectConversation(id: string) {
if (isSearchModeActive) {
isSearchModeActive = false;
searchQuery = '';
}
handleMobileSidebarItemClick();
await goto(RouterService.chat(id));
}
function handleStopGeneration(id: string) {
chatStore.stopGenerationForChat(id);
}
let innerWidth = $state(0);
let pendingCollapse = $state<ReturnType<typeof setTimeout> | null>(null);
function scheduleMobileCollapse() {
if (pendingCollapse) {
clearTimeout(pendingCollapse);
}
pendingCollapse = setTimeout(() => {
isExpandedMode = false;
pendingCollapse = null;
}, 100);
}
function cancelMobileCollapse() {
if (pendingCollapse) {
clearTimeout(pendingCollapse);
pendingCollapse = null;
}
}
</script>
<svelte:window onkeydown={handleKeydown} bind:innerWidth />
<div class="flex h-full flex-col">
<ScrollArea class="h-full flex-1">
<Sidebar.Header class="gap-4 bg-sidebar/50 p-3 backdrop-blur-lg md:pt-4 md:pb-2">
<div class="flex items-center justify-between">
<a href={ROUTES.START} onclick={handleMobileSidebarItemClick}>
<h1 class="inline-flex items-center gap-1 px-2 text-xl font-semibold">
{APP_NAME}
</h1>
</a>
{#if innerWidth > 768 || (!page.url.hash.includes(ROUTES.SETTINGS) && !page.url.hash.includes(ROUTES.MCP_SERVERS) && !page.url.hash.includes(ROUTES.SEARCH))}
<aside
class={[
// Layout & positioning
'fixed md:sticky top-2 left-2 md:left-0 md:ml-2 md:mt-2 pt-2 z-10 w-[calc(100dvw-1rem)]',
// Dimensions & overflow
'md:h-[calc(100dvh-1.125rem)]',
isExpandedMode &&
(device.isStandalone
? 'h-[calc(100dvh-2rem)]'
: device.isIOSDevice
? 'h-[calc(100dvh-0.5rem)]'
: 'h-[calc(100dvh-1rem)]'),
// Shape & depth
'rounded-3xl md:rounded-2xl',
// Flex layout
'flex flex-col justify-between',
// Transition
'md:transition-[width,padding] duration-200 ease-out',
// Expanded state: width, surface, depth
isStripExpanded && 'md:w-72 md:bg-muted/60 md:backdrop-blur-xl border-border shadow-md',
// Collapsed state
!isStripExpanded && 'md:w-12',
// Expanded mode flag (for mobile ::before overlay)
isExpandedMode && 'is-expanded'
]}
>
<div class="px-2 flex items-center justify-between">
<div
role="button"
tabindex="0"
class="relative"
onmouseenter={() => (logoHovered = true)}
onmouseleave={() => (logoHovered = false)}
>
<ActionIcon
icon={!isExpandedMode && logoHovered && innerWidth > 768 ? PanelLeftOpen : Logo}
size="lg"
iconSize="h-4.5 w-4.5 md:h-4 md:w-4"
class="{isExpandedMode
? 'bg-muted! md:bg-foreground/5!'
: 'bg-transparent!'} md:h-9 md:w-9 h-10 w-10 rounded-full md:hover:bg-foreground/10! pointer-events-auto"
href={isExpandedMode ? ROUTES.START : undefined}
onclick={isExpandedMode ? undefined : toggleExpandedMode}
tooltip={isExpandedMode ? undefined : 'Open Sidebar'}
tooltipSide={TooltipSide.RIGHT}
ariaLabel={isExpandedMode ? 'Go to start' : 'Expand navigation'}
/>
</div>
{#if isExpandedMode || isOnMobile}
<div
class="flex items-center transition-all duration-150 ease-out {isMobile.current &&
!isExpandedMode
? 'opacity-0 h-0!'
: ''}"
in:fade={{ duration: 150, easing: circIn, delay: 50 }}
out:fade={{ duration: 100 }}
<Button
class="rounded-full md:hidden"
variant="ghost"
size="icon"
onclick={() => sidebar.toggle()}
>
<ActionIcon
icon={isMobile.current ? X : PanelLeftClose}
size="lg"
iconSize="h-4.5 w-4.5 md:h-4 md:w-4"
class="backdrop-blur-none md:h-9 md:w-9 h-10 w-10 rounded-full mr-1 hover:bg-accent!"
onclick={toggleExpandedMode}
tooltip="Close Sidebar"
tooltipSide={TooltipSide.LEFT}
ariaLabel="Collapse navigation"
/>
</div>
{/if}
</div>
<div class="mt-2 flex min-h-0 flex-1 flex-col gap-4 md:gap-1 overflow-y-auto">
<div
class="flex min-h-0 flex-1 flex-col gap-4 md:gap-1 {isMobile.current
? 'transition-[opacity,height] duration-200 ease-out'
: ''} {isMobile.current && !isExpandedMode ? 'opacity-0 !h-0' : ''}"
in:fade={{ duration: 200 }}
out:fade={{ duration: 200 }}
>
<SidebarNavigationActions
isExpandedMode={innerWidth > 768 ? isExpandedMode : true}
class="px-2"
bind:isSearchModeActive
bind:searchQuery
onSearchDeactivated={() => {
isSearchModeActive = false;
searchQuery = '';
}}
onSearchClick={() => {
isExpandedMode = true;
isSearchModeActive = true;
}}
onNewChat={() => {
if (isMobile.current) {
scheduleMobileCollapse();
}
}}
/>
{#if isExpandedMode || isOnMobile}
<SidebarNavigationConversationList
class="px-2"
{filteredConversations}
{currentChatId}
{isSearchModeActive}
{searchQuery}
onSelect={selectConversation}
onEdit={handleEditConversation}
onDelete={handleDeleteConversation}
onStop={handleStopGeneration}
/>
{/if}
<X class="h-4 w-4" />
<span class="sr-only">Close sidebar</span>
</Button>
</div>
<SidebarNavigationActions
bind:this={chatSidebarActions}
{handleMobileSidebarItemClick}
bind:isSearchModeActive
bind:searchQuery
onSearchDeactivated={handleSearchDeactivated}
/>
</Sidebar.Header>
{#if !isSearchModeActive && pinnedConversations.length > 0}
<Sidebar.Group class="p-0 px-4">
<Sidebar.GroupLabel>
<div class="flex items-center gap-1">
<Pin class="h-3.5 w-3.5" />
<span>Pinned</span>
</div>
</Sidebar.GroupLabel>
<Sidebar.GroupContent>
<Sidebar.Menu>
{#each pinnedConversations as { conversation, depth } (conversation.id)}
<Sidebar.MenuItem class="mb-1 p-0">
<SidebarNavigationConversationItem
conversation={{
id: conversation.id,
name: conversation.name,
lastModified: conversation.lastModified,
currNode: conversation.currNode,
forkedFromConversationId: conversation.forkedFromConversationId,
pinned: conversation.pinned
}}
{depth}
isActive={currentChatId === conversation.id}
onSelect={selectConversation}
onEdit={handleEditConversation}
onDelete={handleDeleteConversation}
onStop={handleStopGeneration}
/>
</Sidebar.MenuItem>
{/each}
</Sidebar.Menu>
</Sidebar.GroupContent>
</Sidebar.Group>
{/if}
<Sidebar.Group class="mt-2 h-[calc(100vh-21rem)] space-y-2 p-0 px-3">
{#if (filteredConversations.length > 0 && isSearchModeActive) || !isSearchModeActive}
<Sidebar.GroupLabel>
{isSearchModeActive ? 'Search results' : 'Recent conversations'}
</Sidebar.GroupLabel>
{/if}
<Sidebar.GroupContent>
<Sidebar.Menu>
{#each isSearchModeActive ? conversationTree : unpinnedConversations as { conversation, depth } (conversation.id)}
<Sidebar.MenuItem class="mb-1 p-0">
<SidebarNavigationConversationItem
conversation={{
id: conversation.id,
name: conversation.name,
lastModified: conversation.lastModified,
currNode: conversation.currNode,
forkedFromConversationId: conversation.forkedFromConversationId,
pinned: conversation.pinned
}}
{depth}
isActive={currentChatId === conversation.id}
onSelect={selectConversation}
onEdit={handleEditConversation}
onDelete={handleDeleteConversation}
onStop={handleStopGeneration}
/>
</Sidebar.MenuItem>
{/each}
{#if (isSearchModeActive ? conversationTree : unpinnedConversations).length === 0}
<div class="px-2 py-4 text-center">
<p class="mb-4 p-4 text-sm text-muted-foreground">
{searchQuery.length > 0
? 'No results found'
: isSearchModeActive
? 'Start typing to see results'
: 'No conversations yet'}
</p>
</div>
{/if}
</Sidebar.Menu>
</Sidebar.GroupContent>
</Sidebar.Group>
</ScrollArea>
</div>
<DialogConfirmation
bind:open={showDeleteDialog}
title="Delete Conversation"
description={selectedConversation
? `Are you sure you want to delete "${selectedConversationNamePreview}"? This action cannot be undone and will permanently remove all messages in this conversation.`
: ''}
confirmText="Delete"
cancelText="Cancel"
variant="destructive"
icon={Trash2}
onConfirm={handleConfirmDelete}
onCancel={() => {
showDeleteDialog = false;
selectedConversation = null;
}}
>
{#if selectedConversationHasDescendants}
<div class="flex items-center gap-2 py-2">
<Checkbox id="delete-with-forks" bind:checked={deleteWithForks} />
<Label for="delete-with-forks" class="text-sm">Also delete all forked conversations</Label>
</div>
</aside>
{/if}
{/if}
</DialogConfirmation>
<style>
aside {
@media (max-width: 768px) {
--size: 1.125rem;
<DialogConfirmation
bind:open={showEditDialog}
title="Edit Conversation Name"
description=""
confirmText="Save"
cancelText="Cancel"
icon={Pencil}
onConfirm={handleConfirmEdit}
onCancel={() => {
showEditDialog = false;
selectedConversation = null;
}}
onKeydown={(event) => {
if (event.key === 'Enter') {
event.preventDefault();
event.stopImmediatePropagation();
handleConfirmEdit();
}
}
@media (max-width: 768px) {
aside {
&:not(.is-expanded) {
pointer-events: none;
}
}
aside.is-expanded::before {
content: '';
position: fixed;
top: -0.5rem;
bottom: -0.25rem;
left: -0.5rem;
right: -0.5rem;
z-index: -1;
background: var(--background);
backdrop-filter: blur(1rem);
pointer-events: none;
}
}
</style>
}}
>
<Input
class="text-foreground"
placeholder="Enter a new name"
type="text"
bind:value={editedName}
/>
</DialogConfirmation>
@@ -1,86 +1,39 @@
<script lang="ts">
import { goto } from '$app/navigation';
import { page } from '$app/state';
import { Search } from '@lucide/svelte';
import { ActionIcon, KeyboardShortcutInfo, SearchInput } from '$lib/components/app';
import { KeyboardShortcutInfo } from '$lib/components/app';
import { Button } from '$lib/components/ui/button';
import {
ICON_STRIP_TRANSITION_DURATION,
ICON_STRIP_TRANSITION_DELAY_MULTIPLIER,
ROUTES,
SIDEBAR_ACTIONS_ITEMS
} from '$lib/constants';
import { isMobile } from '$lib/stores/viewport.svelte';
import { TooltipSide } from '$lib/enums';
import { fade } from 'svelte/transition';
import { circIn } from 'svelte/easing';
import { onMount } from 'svelte';
import type { Component } from 'svelte';
import { SearchInput } from '$lib/components/app';
import { page } from '$app/state';
import { SIDEBAR_ACTIONS_ITEMS } from '$lib/constants/ui';
interface Props {
class: string;
isExpandedMode: boolean;
handleMobileSidebarItemClick: () => void;
isSearchModeActive: boolean;
searchQuery: string;
isCancelAlwaysVisible?: boolean;
onSearchDeactivated?: () => void;
onSearchClick?: () => void;
onNewChat?: () => void;
}
let {
class: className,
isExpandedMode = false,
isSearchModeActive = $bindable(false),
searchQuery = $bindable(''),
onSearchDeactivated,
onSearchClick,
onNewChat
handleMobileSidebarItemClick,
isSearchModeActive = $bindable(),
searchQuery = $bindable(),
isCancelAlwaysVisible = false,
onSearchDeactivated
}: Props = $props();
let initialized = $state(false);
let showIcons = $state(false);
let searchInputRef = $state<HTMLInputElement | null>(null);
const isOnMobile = $derived(isMobile.current);
$effect(() => {
if (isSearchModeActive && searchInputRef) {
searchInputRef.focus();
}
});
onMount(() => {
showIcons = true;
setTimeout(() => {
initialized = true;
}, ICON_STRIP_TRANSITION_DELAY_MULTIPLIER * SIDEBAR_ACTIONS_ITEMS.length);
});
function handleSearchModeDeactivate() {
isSearchModeActive = false;
searchQuery = '';
onSearchDeactivated?.();
}
function isItemActive(item: {
activeRouteId?: string;
activeRoutePrefix?: string;
activeUrlIncludes?: string;
}): boolean {
if (item.activeRouteId) {
return page.route.id === item.activeRouteId;
}
if (item.activeRoutePrefix) {
return !!page.route.id?.startsWith(item.activeRoutePrefix);
}
if (item.activeUrlIncludes) {
return page.url?.hash?.includes(item.activeUrlIncludes) ?? false;
}
return false;
export function activateSearch() {
isSearchModeActive = true;
// Focus after Svelte renders the input
queueMicrotask(() => searchInputRef?.focus());
}
</script>
@@ -88,109 +41,56 @@
<IconComponent class="h-4 w-4" />
{/snippet}
{#if isSearchModeActive}
<div class="px-4 my-2">
<div class="my-1 space-y-1">
{#if isSearchModeActive}
<SearchInput
bind:value={searchQuery}
bind:ref={searchInputRef}
onClose={handleSearchModeDeactivate}
onKeyDown={(e) => e.key === 'Escape' && handleSearchModeDeactivate()}
placeholder="Search conversations..."
{isCancelAlwaysVisible}
/>
</div>
{:else if isExpandedMode || isOnMobile}
<div
class="{className} flex flex-col gap-5 md:gap-1 mt-2 md:mt-0 {!isExpandedMode && isOnMobile
? 'hidden pointer-events-none'
: ''}"
>
{#each SIDEBAR_ACTIONS_ITEMS as item, i (item.tooltip)}
{@const isActive = isItemActive(item)}
{@const isSearchOnMobile = item.icon === Search && isMobile.current}
{@const itemHref = isSearchOnMobile ? ROUTES.SEARCH : item.route}
{@const itemOnClick = item.route
? () => {
onNewChat?.();
goto(item.route!);
}
: isSearchOnMobile
? undefined
: onSearchClick}
{@const itemTransition = {
duration: ICON_STRIP_TRANSITION_DURATION,
delay: !initialized
? ICON_STRIP_TRANSITION_DELAY_MULTIPLIER + i * ICON_STRIP_TRANSITION_DELAY_MULTIPLIER
: 0,
easing: circIn
}}
{:else}
{#each SIDEBAR_ACTIONS_ITEMS as item (item.route)}
{#if !item.route}
<Button
class="w-full justify-between px-2 backdrop-blur-none! hover:[&>kbd]:opacity-100"
onclick={activateSearch}
variant="ghost"
>
<div class="flex items-center gap-2">
{@render itemIcon(item.icon)}
{#if showIcons}
<div transition:fade={itemTransition}>
<Button
class="w-full min-w-9 justify-between px-2 backdrop-blur-none! hover:[&>kbd]:opacity-100 {isActive
? 'bg-accent text-accent-foreground'
: ''}"
href={itemHref}
onclick={itemOnClick}
variant="ghost"
size="default"
>
<span class="flex min-w-0 items-center px-0.5 gap-2">
{@render itemIcon(item.icon)}
{item.tooltip}
</div>
{#if showIcons}
<span
in:fade={{ duration: 150, easing: circIn, delay: 50 }}
out:fade={{ duration: 100 }}
class="min-w-0 truncate">{item.tooltip}</span
>
{/if}
</span>
{#if item.keys}
<KeyboardShortcutInfo keys={item.keys} />
{/if}
</Button>
{:else}
<Button
class="w-full justify-between px-2 backdrop-blur-none! hover:[&>kbd]:opacity-100 {(item.activeRouteId &&
page.route.id === item.activeRouteId) ||
(item.activeRoutePrefix && page.route.id?.startsWith(item.activeRoutePrefix))
? 'bg-accent text-accent-foreground'
: ''}"
href={item.route}
onclick={handleMobileSidebarItemClick}
variant="ghost"
>
<div class="flex items-center gap-2">
{@render itemIcon(item.icon)}
{#if item.keys}
<KeyboardShortcutInfo keys={item.keys} />
{/if}
</Button>
</div>
{item.tooltip}
</div>
{#if item.keys}
<KeyboardShortcutInfo keys={item.keys} />
{/if}
</Button>
{/if}
{/each}
</div>
{:else}
<div class="{className} flex-col gap-1 hidden md:flex">
{#each SIDEBAR_ACTIONS_ITEMS as item, i (item.tooltip)}
{@const isActive = isItemActive(item)}
{@const isSearchOnMobile = item.icon === Search && isMobile.current}
{@const itemOnClick = item.route
? () => {
onNewChat?.();
goto(item.route!);
}
: isSearchOnMobile
? undefined
: onSearchClick}
{@const itemTransition = {
duration: ICON_STRIP_TRANSITION_DURATION,
delay: !initialized
? ICON_STRIP_TRANSITION_DELAY_MULTIPLIER + i * ICON_STRIP_TRANSITION_DELAY_MULTIPLIER
: 0,
easing: circIn
}}
{#if showIcons}
<div transition:fade={itemTransition}>
<ActionIcon
icon={item.icon}
tooltip={item.tooltip}
tooltipSide={TooltipSide.RIGHT}
size="lg"
iconSize="h-4 w-4"
class="h-9 w-9 rounded-full hover:bg-accent! {isActive
? 'bg-accent text-accent-foreground'
: ''}"
onclick={itemOnClick}
/>
</div>
{/if}
{/each}
</div>
{/if}
{/if}
</div>
@@ -1,135 +0,0 @@
<script lang="ts">
import { Pin } from '@lucide/svelte';
import { buildConversationTree } from '$lib/stores/conversations.svelte';
import SidebarNavigationConversationItem from './SidebarNavigationConversationItem.svelte';
import SidebarNavigationSearchResults from './SidebarNavigationSearchResults.svelte';
interface Props {
class: string;
filteredConversations: DatabaseConversation[];
currentChatId: string | undefined;
isSearchModeActive: boolean;
searchQuery: string;
onSelect: (id: string) => void;
onEdit: (id: string) => void;
onDelete: (id: string) => void;
onStop: (id: string) => void;
}
let {
class: className,
filteredConversations,
currentChatId,
isSearchModeActive,
searchQuery,
onSelect,
onEdit,
onDelete,
onStop
}: Props = $props();
let conversationTree = $derived(buildConversationTree(filteredConversations));
let pinnedConversations = $derived(
conversationTree.filter(({ conversation }) => conversation.pinned)
);
let unpinnedConversations = $derived(
conversationTree.filter(({ conversation }) => !conversation.pinned)
);
const recentEmptyMessage = $derived(
searchQuery.length > 0 ? 'No results found' : 'No conversations yet'
);
</script>
{#if isSearchModeActive}
<SidebarNavigationSearchResults
class={className}
{searchQuery}
{filteredConversations}
{currentChatId}
{onSelect}
{onEdit}
{onDelete}
{onStop}
/>
{:else}
{#if pinnedConversations.length > 0}
<div class="py-2 flex whitespace-nowrap {className}">
<div
class="text-muted-foreground inline-flex h-8 shrink-0 items-center rounded-md px-2 text-xs font-medium gap-1"
>
<Pin class="h-3.5 w-3.5" />
<span>Pinned</span>
</div>
</div>
<ul class="flex w-full min-w-0 flex-col gap-4 md:gap-1 {className}">
{#each pinnedConversations as { conversation, depth } (conversation.id)}
<li class="group/item relative mb-1 p-0">
<SidebarNavigationConversationItem
conversation={{
id: conversation.id,
name: conversation.name,
lastModified: conversation.lastModified,
currNode: conversation.currNode,
forkedFromConversationId: conversation.forkedFromConversationId,
pinned: conversation.pinned
}}
{depth}
isActive={currentChatId === conversation.id}
{onSelect}
{onEdit}
{onDelete}
{onStop}
/>
</li>
{/each}
</ul>
{/if}
<div class="mt-2 flex min-h-0 flex-1 flex-col gap-4 md:gap-2 whitespace-nowrap {className}">
{#if filteredConversations.length > 0}
<div
class="text-muted-foreground flex h-8 shrink-0 items-center rounded-md px-2 text-xs font-medium"
>
Recent conversations
</div>
{/if}
<div class="min-h-0 flex-1 md:overflow-y-auto">
<ul class="flex w-full min-w-0 flex-col gap-4 md:gap-1">
{#each unpinnedConversations as { conversation, depth } (conversation.id)}
<li class="group/item relative mb-1 p-0">
<SidebarNavigationConversationItem
conversation={{
id: conversation.id,
name: conversation.name,
lastModified: conversation.lastModified,
currNode: conversation.currNode,
forkedFromConversationId: conversation.forkedFromConversationId,
pinned: conversation.pinned
}}
{depth}
isActive={currentChatId === conversation.id}
{onSelect}
{onEdit}
{onDelete}
{onStop}
/>
</li>
{/each}
{#if unpinnedConversations.length === 0}
<li class="px-2 py-4 text-center">
<p class="mb-4 p-4 text-sm text-muted-foreground">
{recentEmptyMessage}
</p>
</li>
{/if}
</ul>
</div>
</div>
{/if}
@@ -16,6 +16,4 @@
}: Props = $props();
</script>
<div class="mb-4 px-2 {className}">
<SearchInput bind:value {placeholder} {onInput} />
</div>
<SearchInput bind:value {placeholder} {onInput} class="mb-4 {className}" />

Some files were not shown because too many files have changed in this diff Show More