Compare commits

...

2 Commits
b9529 ... b9531

Author SHA1 Message Date
Johannes Gäßler
6effcecd0b TP: round up granularity to 128 (#24180)
* TP: round up granularity to 128

* remove assert
2026-06-05 17:35:13 +02:00
therealkenc
86591c7536 cli: fix model params not propagated (#23893)
Fixes #23847
2026-06-05 17:29:41 +02:00
2 changed files with 20 additions and 10 deletions

View File

@@ -553,10 +553,12 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
};
auto get_split_granularity = [&](int64_t blck_size, uint32_t il, const std::vector<std::pair<int64_t, uint32_t>> & segments) -> std::vector<int64_t> {
// for better performance it may make sense to round up blck_size to a higher power of 2 so that more efficient kernels can be used
if (hparams.is_recr(il)) {
// linear attention
const int64_t head_dim = hparams.ssm_d_state;
const int64_t granularity_qkv = std::lcm(blck_size, head_dim);
const int64_t head_dim = hparams.ssm_d_state;
const int64_t blck_size_perf = std::lcm(blck_size, 128);
const int64_t granularity_qkv = std::lcm(blck_size_perf, head_dim);
if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_attn_gate_weight) ||
std::regex_match(tensor_name, pattern_ssm_conv1d) || std::regex_match(tensor_name, pattern_ssm_out_weight)) {
return std::vector<int64_t>(segments.size(), granularity_qkv);
@@ -578,17 +580,24 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
// regular attention
const uint32_t n_gqa = hparams.n_gqa(il);
const uint32_t n_embd_q = n_gqa * hparams.n_embd_head_k(il);
if (std::regex_match(tensor_name, pattern_attn_sinks)) {
GGML_ASSERT(segments.size() == 1);
return {std::lcm(n_embd_q, blck_size)/n_embd_q * n_gqa};
// to handle head sizes like 80, only increase granularity while it doesn't cause underutilization
int64_t blck_size_perf = blck_size;
while (blck_size_perf < 128 && blck_size_perf*ud->n_devices < n_embd_q) {
blck_size_perf *= 2;
}
const int64_t granularity_q = std::lcm(n_embd_q, blck_size);
if (std::regex_match(tensor_name, pattern_attn_sinks)) {
GGML_ASSERT(segments.size() == 1);
return {std::lcm(n_embd_q, blck_size_perf)/n_embd_q * n_gqa};
}
const int64_t granularity_q = std::lcm(n_embd_q, blck_size_perf);
if (std::regex_match(tensor_name, pattern_q_weight) || std::regex_match(tensor_name, pattern_q_bias)) {
GGML_ASSERT(segments.size() == 1);
// some models have Q gate tensors, for those cases the granularity needs to be doubled:
if (ud->model->arch == LLM_ARCH_QWEN3NEXT || ud->model->arch == LLM_ARCH_QWEN35 || ud->model->arch == LLM_ARCH_QWEN35MOE) {
return {std::lcm(2*n_embd_q, blck_size)};
return {std::lcm(2*n_embd_q, blck_size_perf)};
}
return {granularity_q};
}
@@ -613,8 +622,9 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
// FFN
if (std::regex_match(tensor_name, pattern_ffn_up_gate_weight) || std::regex_match(tensor_name, pattern_ffn_up_gate_bias) ||
std::regex_match(tensor_name, pattern_ffn_gate_up_weight) || std::regex_match(tensor_name, pattern_ffn_down_weight)) {
const int64_t blck_size_perf = std::lcm(blck_size, 128);
GGML_ASSERT(segments.size() == 1);
return {blck_size};
return {blck_size_perf};
}
// everything else
@@ -627,7 +637,6 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
tensor_config tc = get_tensor_config();
split_state.axis = tc.axis;
if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) {
const int64_t ne_full = tensor->ne[split_state.axis];
const int64_t blck_size = ggml_blck_size(tc.tensor_axis_0->type);
const float * tensor_split = ud->model->tensor_split();
std::vector<float> tensor_split_scan;
@@ -644,7 +653,6 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
const int64_t ne_s = segments[is].first;
const uint32_t nr_s = segments[is].second;
const int64_t g_s = granularity[is];
GGML_ASSERT(ne_full % g_s == 0);
int64_t low = 0;
size_t j = 0;
for (; j < ud->n_devices - 1; j++) {

View File

@@ -397,6 +397,8 @@ int llama_cli(int argc, char ** argv) {
return 1;
}
ctx_cli.defaults.sampling = params.sampling;
console::spinner::stop();
console::log("\n");