mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-30 11:13:02 +02:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c1a1c8ee94 | |||
| 27c8bb4f63 | |||
| ebd048fc5e | |||
| 0ed235ea2c |
+2
-2
@@ -467,7 +467,7 @@ void common_models_handler_apply(common_models_handler & handler, common_params
|
||||
// the first part is what gets loaded, so point params.model.path at it
|
||||
if (!url_tasks.empty()) {
|
||||
std::string first_path = url_tasks.front().local_path;
|
||||
url_tasks.front().on_done = [&]() { params.model.path = first_path; };
|
||||
url_tasks.front().on_done = [&, first_path]() { params.model.path = first_path; };
|
||||
}
|
||||
for (auto & task : url_tasks) {
|
||||
tasks.push_back(std::move(task));
|
||||
@@ -3471,7 +3471,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params) {
|
||||
params.offline = true;
|
||||
}
|
||||
).set_env("LLAMA_ARG_OFFLINE"));
|
||||
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_OFFLINE"));
|
||||
add_opt(common_arg(
|
||||
{"-lv", "--verbosity", "--log-verbosity"}, "N",
|
||||
string_format("Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:\n"
|
||||
|
||||
+47
-47
@@ -225,7 +225,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
|
||||
}
|
||||
|
||||
if (!SetPriorityClass(GetCurrentProcess(), p)) {
|
||||
LOG_WRN("failed to set process priority class %d : (%d)\n", prio, (int) GetLastError());
|
||||
COM_WRN("failed to set process priority class %d : (%d)\n", prio, (int) GetLastError());
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -251,7 +251,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
|
||||
}
|
||||
|
||||
if (setpriority(PRIO_PROCESS, 0, p) != 0) {
|
||||
LOG_WRN("failed to set process priority %d : %s (%d)\n", prio, strerror(errno), errno);
|
||||
COM_WRN("failed to set process priority %d : %s (%d)\n", prio, strerror(errno), errno);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
@@ -284,14 +284,14 @@ void postprocess_cpu_params(common_cpu_params & cpuparams, const common_cpu_para
|
||||
|
||||
if (n_set && n_set < cpuparams.n_threads) {
|
||||
// Not enough set bits, may experience performance issues.
|
||||
LOG_WRN("Not enough set bits in CPU mask (%d) to satisfy requested thread count: %d\n", n_set, cpuparams.n_threads);
|
||||
COM_WRN("Not enough set bits in CPU mask (%d) to satisfy requested thread count: %d\n", n_set, cpuparams.n_threads);
|
||||
}
|
||||
}
|
||||
|
||||
bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THREADS]) {
|
||||
size_t dash_loc = range.find('-');
|
||||
if (dash_loc == std::string::npos) {
|
||||
LOG_ERR("Format of CPU range is invalid! Expected [<start>]-[<end>].\n");
|
||||
COM_ERR("%s", "Format of CPU range is invalid! Expected [<start>]-[<end>].\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -303,7 +303,7 @@ bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THRE
|
||||
} else {
|
||||
start_i = std::stoull(range.substr(0, dash_loc));
|
||||
if (start_i >= GGML_MAX_N_THREADS) {
|
||||
LOG_ERR("Start index out of bounds!\n");
|
||||
COM_ERR("%s", "Start index out of bounds!\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -313,7 +313,7 @@ bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THRE
|
||||
} else {
|
||||
end_i = std::stoull(range.substr(dash_loc + 1));
|
||||
if (end_i >= GGML_MAX_N_THREADS) {
|
||||
LOG_ERR("End index out of bounds!\n");
|
||||
COM_ERR("%s", "End index out of bounds!\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -333,7 +333,7 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
|
||||
}
|
||||
|
||||
size_t num_digits = mask.length() - start_i;
|
||||
if (num_digits > 128) num_digits = 128;
|
||||
num_digits = std::min<size_t>(num_digits, 128);
|
||||
|
||||
size_t end_i = num_digits + start_i;
|
||||
|
||||
@@ -348,7 +348,7 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
|
||||
} else if (c >= 'A' && c <= 'F') {
|
||||
id -= 'A' - 10;
|
||||
} else {
|
||||
LOG_ERR("Invalid hex character '%c' at position %d\n", c, int32_t(i));
|
||||
COM_ERR("Invalid hex character '%c' at position %d\n", c, int32_t(i));
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -379,21 +379,21 @@ void common_params_print_info(const common_params & params, bool print_devices)
|
||||
#else
|
||||
const char * build_type = " (debug)";
|
||||
#endif
|
||||
LOG_TRC("%s: build %d (%s) with %s for %s%s\n", __func__, llama_build_number(), llama_commit(), llama_compiler(), llama_build_target(), build_type);
|
||||
COM_TRC("%s: build %d (%s) with %s for %s%s\n", __func__, llama_build_number(), llama_commit(), llama_compiler(), llama_build_target(), build_type);
|
||||
|
||||
LOG_INF("log_info: verbosity = %d (adjust with the `-lv N` CLI arg)\n", common_log_get_verbosity_thold());
|
||||
COM_INF("%s: verbosity = %d (adjust with the `-lv N` CLI arg)\n", __func__, common_log_get_verbosity_thold());
|
||||
|
||||
// device enumeration creates a primary context on CUDA backends, skip it when the caller does not own any device
|
||||
if (print_devices) {
|
||||
LOG_INF("device_info:\n");
|
||||
COM_TRC("%s", "device_info:\n");
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||
auto * dev = ggml_backend_dev_get(i);
|
||||
size_t free, total;
|
||||
ggml_backend_dev_memory(dev, &free, &total);
|
||||
LOG_INF(" - %-8s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
|
||||
COM_TRC(" - %-8s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
|
||||
}
|
||||
}
|
||||
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
|
||||
COM_TRC("%s\n", common_params_get_system_info(params).c_str());
|
||||
}
|
||||
|
||||
std::string common_params_get_system_info(const common_params & params) {
|
||||
@@ -660,7 +660,7 @@ void string_process_escapes(std::string & input) {
|
||||
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides) {
|
||||
const char * sep = strchr(data, '=');
|
||||
if (sep == nullptr || sep - data >= 128) {
|
||||
LOG_ERR("%s: malformed KV override '%s'\n", __func__, data);
|
||||
COM_ERR("%s: malformed KV override '%s'\n", __func__, data);
|
||||
return false;
|
||||
}
|
||||
llama_model_kv_override kvo;
|
||||
@@ -683,20 +683,20 @@ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_over
|
||||
} else if (std::strcmp(sep, "false") == 0) {
|
||||
kvo.val_bool = false;
|
||||
} else {
|
||||
LOG_ERR("%s: invalid boolean value for KV override '%s'\n", __func__, data);
|
||||
COM_ERR("%s: invalid boolean value for KV override '%s'\n", __func__, data);
|
||||
return false;
|
||||
}
|
||||
} else if (strncmp(sep, "str:", 4) == 0) {
|
||||
sep += 4;
|
||||
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR;
|
||||
if (strlen(sep) > 127) {
|
||||
LOG_ERR("%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data);
|
||||
COM_ERR("%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data);
|
||||
return false;
|
||||
}
|
||||
strncpy(kvo.val_str, sep, 127);
|
||||
kvo.val_str[127] = '\0';
|
||||
} else {
|
||||
LOG_ERR("%s: invalid type for KV override '%s'\n", __func__, data);
|
||||
COM_ERR("%s: invalid type for KV override '%s'\n", __func__, data);
|
||||
return false;
|
||||
}
|
||||
overrides.emplace_back(std::move(kvo));
|
||||
@@ -1199,8 +1199,8 @@ common_init_result::common_init_result(common_params & params, bool model_only)
|
||||
auto cparams = common_context_params_to_llama(params);
|
||||
|
||||
if (params.fit_params) {
|
||||
LOG_INF("%s: fitting params to device memory ...\n", __func__);
|
||||
LOG_INF("%s: (for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on)\n", __func__);
|
||||
COM_TRC("%s", "fitting params to device memory ...\n");
|
||||
COM_TRC("%s", "(for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on)\n");
|
||||
common_fit_params(params.model.path.c_str(), &mparams, &cparams,
|
||||
params.tensor_split,
|
||||
params.tensor_buft_overrides.data(),
|
||||
@@ -1227,7 +1227,7 @@ common_init_result::common_init_result(common_params & params, bool model_only)
|
||||
llama_adapter_lora_ptr lora;
|
||||
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
|
||||
if (lora == nullptr) {
|
||||
LOG_ERR("%s: failed to load lora adapter '%s'\n", __func__, la.path.c_str());
|
||||
COM_ERR("failed to load lora adapter '%s'\n", la.path.c_str());
|
||||
pimpl->model.reset(model);
|
||||
return;
|
||||
}
|
||||
@@ -1246,14 +1246,14 @@ common_init_result::common_init_result(common_params & params, bool model_only)
|
||||
common_init_sampler_from_model(model, params.sampling);
|
||||
|
||||
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
|
||||
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
|
||||
COM_WRN("%s", "vocab does not have an EOS token, ignoring --ignore-eos\n");
|
||||
params.sampling.ignore_eos = false;
|
||||
}
|
||||
|
||||
// initialize once
|
||||
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
|
||||
if (llama_vocab_is_eog(vocab, i)) {
|
||||
LOG_TRC("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(vocab, i).c_str(), -INFINITY);
|
||||
COM_TRC("added %s logit bias = %f\n", common_token_to_piece(vocab, i).c_str(), -INFINITY);
|
||||
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
|
||||
}
|
||||
}
|
||||
@@ -1291,7 +1291,7 @@ common_init_result::common_init_result(common_params & params, bool model_only)
|
||||
|
||||
llama_context * lctx = llama_init_from_model(model, cparams);
|
||||
if (lctx == NULL) {
|
||||
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
|
||||
COM_ERR("failed to create context with model '%s'\n", params.model.path.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1328,7 +1328,7 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
|
||||
|
||||
llama_model * model = res->model();
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
|
||||
COM_ERR("failed to load model '%s'\n", params.model.path.c_str());
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -1338,14 +1338,14 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
|
||||
|
||||
llama_context * lctx = res->context();
|
||||
if (lctx == NULL) {
|
||||
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
|
||||
COM_ERR("failed to create context with model '%s'\n", params.model.path.c_str());
|
||||
return res;
|
||||
}
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
|
||||
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
|
||||
COM_WRN("%s", "KV cache shifting is not supported for this context, disabling KV cache shifting\n");
|
||||
params.ctx_shift = false;
|
||||
}
|
||||
|
||||
@@ -1374,7 +1374,7 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
|
||||
bool ok = true;
|
||||
|
||||
if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) {
|
||||
LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__);
|
||||
COM_WRN("%s", "vocab does not have a BOS token, reranking will not work\n");
|
||||
ok = false;
|
||||
}
|
||||
|
||||
@@ -1383,10 +1383,10 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
|
||||
bool has_rerank_prompt = llama_model_chat_template(model, "rerank") != NULL;
|
||||
|
||||
if (!has_eos && !has_sep && !has_rerank_prompt) {
|
||||
LOG_WRN("%s: warning: vocab does not have an EOS token, SEP token, or rerank prompt. Reranking will not work\n", __func__);
|
||||
COM_WRN("%s", "vocab does not have an EOS token, SEP token, or rerank prompt. Reranking will not work\n");
|
||||
ok = false;
|
||||
} else if (!has_eos) {
|
||||
LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
|
||||
COM_WRN("%s", "vocab does not have an EOS token, using SEP token as fallback\n");
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
@@ -1399,7 +1399,7 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
|
||||
}
|
||||
|
||||
if (params.warmup) {
|
||||
LOG_INF("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
|
||||
COM_TRC("%s", "warming up the model with an empty run - please wait ... (--no-warmup to disable)\n");
|
||||
|
||||
std::vector<llama_token> tmp;
|
||||
llama_token bos = llama_vocab_bos(vocab);
|
||||
@@ -1473,20 +1473,20 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {
|
||||
|
||||
int ret = llama_decode(ctx, llama_batch_get_one(tmp.data(), tmp.size()));
|
||||
if (ret != 0) {
|
||||
LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
|
||||
COM_ERR("llama_decode() failed: %d\n", ret);
|
||||
res = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
|
||||
goto done;
|
||||
}
|
||||
|
||||
if (llama_n_rs_seq(ctx) > 0) {
|
||||
LOG_INF("%s: the context supports bounded partial sequence removal\n", __func__);
|
||||
COM_TRC("%s", "the context supports bounded partial sequence removal\n");
|
||||
res = COMMON_CONTEXT_SEQ_RM_TYPE_RS;
|
||||
goto done;
|
||||
}
|
||||
|
||||
// try to remove the last tokens
|
||||
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
|
||||
LOG_TRC("%s: the context does not support partial sequence removal\n", __func__);
|
||||
COM_TRC("%s", "the context does not support partial sequence removal\n");
|
||||
res = COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
|
||||
goto done;
|
||||
}
|
||||
@@ -1803,13 +1803,13 @@ static common_control_vector_data common_control_vector_load_one(const common_co
|
||||
};
|
||||
struct gguf_context * ctx_gguf = gguf_init_from_file(load_info.fname.c_str(), meta_gguf_params);
|
||||
if (!ctx_gguf) {
|
||||
LOG_ERR("%s: failed to load control vector file from %s\n", __func__, load_info.fname.c_str());
|
||||
COM_ERR("failed to load control vector file from %s\n", load_info.fname.c_str());
|
||||
return result;
|
||||
}
|
||||
|
||||
int32_t n_tensors = gguf_get_n_tensors(ctx_gguf);
|
||||
if (n_tensors == 0) {
|
||||
LOG_WRN("%s: no direction tensors found in %s\n", __func__, load_info.fname.c_str());
|
||||
COM_WRN("no direction tensors found in %s\n", load_info.fname.c_str());
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_tensors; i++) {
|
||||
@@ -1827,23 +1827,23 @@ static common_control_vector_data common_control_vector_load_one(const common_co
|
||||
}
|
||||
}
|
||||
if (layer_idx < 0) {
|
||||
LOG_ERR("%s: invalid/unparsable direction tensor layer index in %s\n", __func__, load_info.fname.c_str());
|
||||
COM_ERR("invalid/unparsable direction tensor layer index in %s\n", load_info.fname.c_str());
|
||||
result.n_embd = -1;
|
||||
break;
|
||||
} else if (layer_idx == 0) {
|
||||
LOG_ERR("%s: invalid (zero) direction tensor layer index in %s\n", __func__, load_info.fname.c_str());
|
||||
COM_ERR("invalid (zero) direction tensor layer index in %s\n", load_info.fname.c_str());
|
||||
result.n_embd = -1;
|
||||
break;
|
||||
}
|
||||
|
||||
struct ggml_tensor * tensor = ggml_get_tensor(ctx, name.c_str());
|
||||
if (tensor->type != GGML_TYPE_F32) {
|
||||
LOG_ERR("%s: invalid (non-F32) direction tensor type in %s\n", __func__, load_info.fname.c_str());
|
||||
COM_ERR("invalid (non-F32) direction tensor type in %s\n", load_info.fname.c_str());
|
||||
result.n_embd = -1;
|
||||
break;
|
||||
}
|
||||
if (ggml_n_dims(tensor) != 1) {
|
||||
LOG_ERR("%s: invalid (non-1D) direction tensor shape in %s\n", __func__, load_info.fname.c_str());
|
||||
COM_ERR("invalid (non-1D) direction tensor shape in %s\n", load_info.fname.c_str());
|
||||
result.n_embd = -1;
|
||||
break;
|
||||
}
|
||||
@@ -1851,7 +1851,7 @@ static common_control_vector_data common_control_vector_load_one(const common_co
|
||||
if (result.n_embd == -1) {
|
||||
result.n_embd = ggml_nelements(tensor);
|
||||
} else if (ggml_nelements(tensor) != result.n_embd) {
|
||||
LOG_ERR("%s: direction tensor in %s does not match previous dimensions\n", __func__, load_info.fname.c_str());
|
||||
COM_ERR("direction tensor in %s does not match previous dimensions\n", load_info.fname.c_str());
|
||||
result.n_embd = -1;
|
||||
break;
|
||||
}
|
||||
@@ -1868,7 +1868,7 @@ static common_control_vector_data common_control_vector_load_one(const common_co
|
||||
}
|
||||
|
||||
if (result.n_embd == -1) {
|
||||
LOG_WRN("%s: skipping %s due to invalid direction tensors\n", __func__, load_info.fname.c_str());
|
||||
COM_WRN("skipping %s due to invalid direction tensors\n", load_info.fname.c_str());
|
||||
result.data.clear();
|
||||
}
|
||||
|
||||
@@ -1889,7 +1889,7 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
|
||||
break;
|
||||
}
|
||||
if (result.n_embd != -1 && result.n_embd != cur.n_embd) {
|
||||
LOG_ERR("%s: control vectors in %s does not match previous dimensions\n", __func__, info.fname.c_str());
|
||||
COM_ERR("control vectors in %s does not match previous dimensions\n", info.fname.c_str());
|
||||
result.n_embd = -1;
|
||||
break;
|
||||
}
|
||||
@@ -1905,7 +1905,7 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
|
||||
}
|
||||
|
||||
if (result.n_embd == -1) {
|
||||
LOG_ERR("%s: no valid control vector files passed\n", __func__);
|
||||
COM_ERR("%s", "no valid control vector files passed\n");
|
||||
result.data.clear();
|
||||
}
|
||||
|
||||
@@ -2016,13 +2016,13 @@ bool common_prompt_batch_decode(
|
||||
// memory, so we can't just remove the last token from the memory and replay the last token which
|
||||
// is the reason for this logic.
|
||||
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(all_tokens.data() + offset), n_tokens_before_last))) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
COM_ERR("%s", "failed to eval\n");
|
||||
return false;
|
||||
}
|
||||
n_past += n_tokens_before_last;
|
||||
|
||||
llama_state_save_file(ctx, state_path.data(), all_tokens.data(), all_tokens.size());
|
||||
LOG_INF("saved session before last token to %s, n_new = %zu\n", state_path.data(), all_tokens.size());
|
||||
COM_INF("saved session before last token to %s, n_new = %zu\n", state_path.data(), all_tokens.size());
|
||||
|
||||
llama_token last_token = all_tokens.back();
|
||||
llama_batch batch = llama_batch_get_one(&last_token, 1);
|
||||
@@ -2030,13 +2030,13 @@ bool common_prompt_batch_decode(
|
||||
batch.pos = &pos;
|
||||
|
||||
if (llama_decode(ctx, batch)) {
|
||||
LOG_ERR("%s : failed to eval last token\n", __func__);
|
||||
COM_ERR("%s", "failed to eval last token\n");
|
||||
return false;
|
||||
}
|
||||
n_past++;
|
||||
} else {
|
||||
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(all_tokens.data() + offset), n_new))) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
COM_ERR("%s", "failed to eval\n");
|
||||
return false;
|
||||
}
|
||||
n_past += n_new;
|
||||
|
||||
@@ -25,6 +25,13 @@
|
||||
#define DIRECTORY_SEPARATOR '/'
|
||||
#endif // _WIN32
|
||||
|
||||
#define COM_DBG(fmt, ...) LOG_DBG("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define COM_TRC(fmt, ...) LOG_TRC("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define COM_INF(fmt, ...) LOG_INF("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define COM_WRN(fmt, ...) LOG_WRN("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define COM_ERR(fmt, ...) LOG_ERR("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define COM_CNT(fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
|
||||
|
||||
#define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0)
|
||||
#define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0)
|
||||
|
||||
|
||||
+1
-1
@@ -233,7 +233,7 @@ static void common_params_fit_impl(
|
||||
sum_projected_used = dmds_full.back().mb.total();
|
||||
sum_free = dmds_full.back().total;
|
||||
sum_projected_free = sum_free - sum_projected_used;
|
||||
LOG_INF("%s: projected to use %" PRId64 " MiB of host memory vs. %" PRId64 " MiB of total host memory\n",
|
||||
LOG_TRC("%s: projected to use %" PRId64 " MiB of host memory vs. %" PRId64 " MiB of total host memory\n",
|
||||
__func__, sum_projected_used/MiB, sum_free/MiB);
|
||||
if (sum_projected_free >= margins[0]) {
|
||||
LOG_TRC("%s: will leave %" PRId64 " >= %" PRId64 " MiB of system memory, no changes needed\n",
|
||||
|
||||
+10
-10
@@ -65,12 +65,12 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
|
||||
if (ctx->start_matcher.advance(token)) {
|
||||
ctx->state = REASONING_BUDGET_COUNTING;
|
||||
ctx->remaining = ctx->budget;
|
||||
LOG_INF("reasoning-budget: activated, budget=%d tokens\n", ctx->budget);
|
||||
COM_TRC("activated, budget=%d tokens\n", ctx->budget);
|
||||
|
||||
if (ctx->remaining <= 0) {
|
||||
ctx->state = REASONING_BUDGET_FORCING;
|
||||
ctx->force_pos = 0;
|
||||
LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
|
||||
COM_TRC("%s", "budget=0, forcing immediately\n");
|
||||
}
|
||||
}
|
||||
break;
|
||||
@@ -80,7 +80,7 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
|
||||
{
|
||||
if (ctx->end_matcher.advance(token)) {
|
||||
ctx->state = REASONING_BUDGET_DONE;
|
||||
LOG_INF("reasoning-budget: deactivated (natural end)\n");
|
||||
COM_TRC("%s", "deactivated (natural end)\n");
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -95,7 +95,7 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
|
||||
ctx->state = REASONING_BUDGET_FORCING;
|
||||
ctx->force_pos = 0;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n");
|
||||
COM_TRC("%s", "UTF-8 complete, now forcing end sequence\n");
|
||||
}
|
||||
} else if (ctx->state == REASONING_BUDGET_COUNTING) {
|
||||
ctx->remaining--;
|
||||
@@ -104,11 +104,11 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
|
||||
ctx->state = REASONING_BUDGET_FORCING;
|
||||
ctx->force_pos = 0;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: budget exhausted, forcing end sequence\n");
|
||||
COM_TRC("%s", "budget exhausted, forcing end sequence\n");
|
||||
} else {
|
||||
ctx->state = REASONING_BUDGET_WAITING_UTF8;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: budget exhausted, waiting for UTF-8 completion\n");
|
||||
COM_TRC("%s", "budget exhausted, waiting for UTF-8 completion\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -118,7 +118,7 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
|
||||
ctx->force_pos++;
|
||||
if (ctx->force_pos >= ctx->forced_tokens.size()) {
|
||||
ctx->state = REASONING_BUDGET_DONE;
|
||||
LOG_INF("reasoning-budget: forced sequence complete, done\n");
|
||||
COM_TRC("%s", "forced sequence complete, done\n");
|
||||
}
|
||||
break;
|
||||
case REASONING_BUDGET_DONE:
|
||||
@@ -128,12 +128,12 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
|
||||
ctx->state = REASONING_BUDGET_COUNTING;
|
||||
ctx->remaining = ctx->budget;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: re-activated on new start tag, budget=%d tokens\n", ctx->budget);
|
||||
COM_TRC("re-activated on new start tag, budget=%d tokens\n", ctx->budget);
|
||||
|
||||
if (ctx->remaining <= 0) {
|
||||
ctx->state = REASONING_BUDGET_FORCING;
|
||||
ctx->force_pos = 0;
|
||||
LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
|
||||
COM_TRC("%s", "budget=0, forcing immediately\n");
|
||||
}
|
||||
}
|
||||
break;
|
||||
@@ -264,7 +264,7 @@ bool common_reasoning_budget_force(struct llama_sampler * smpl) {
|
||||
ctx->state = REASONING_BUDGET_FORCING;
|
||||
ctx->force_pos = 0;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: forced into forcing state (manual transition)\n");
|
||||
COM_TRC("%s", "forced into forcing state (manual transition)\n");
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
+69
-64
@@ -18,6 +18,13 @@
|
||||
#include <map>
|
||||
#include <cinttypes>
|
||||
|
||||
#define SPC_DBG(fmt, ...) LOG_DBG("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SPC_TRC(fmt, ...) LOG_TRC("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SPC_INF(fmt, ...) LOG_INF("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SPC_WRN(fmt, ...) LOG_WRN("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SPC_ERR(fmt, ...) LOG_ERR("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SPC_CNT(fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
|
||||
|
||||
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
|
||||
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
||||
|
||||
@@ -60,21 +67,20 @@ static bool common_speculative_are_compatible(
|
||||
const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
|
||||
|
||||
const auto vocab_type_tgt = llama_vocab_type(vocab_tgt);
|
||||
LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
|
||||
SPC_DBG("vocab_type tgt: %d\n", vocab_type_tgt);
|
||||
|
||||
const auto vocab_type_dft = llama_vocab_type(vocab_dft);
|
||||
LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
|
||||
SPC_DBG("vocab_type dft: %d\n", vocab_type_dft);
|
||||
|
||||
if (vocab_type_tgt != vocab_type_dft) {
|
||||
LOG_WRN("%s: draft model vocab type must match target model to use speculation but "
|
||||
"vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt);
|
||||
SPC_WRN("draft model vocab type must match target model to use speculation but "
|
||||
"vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
|
||||
(llama_vocab_get_add_bos(vocab_tgt) && llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft))) {
|
||||
LOG_WRN("%s: draft model bos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
|
||||
__func__,
|
||||
SPC_WRN("draft model bos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
|
||||
llama_vocab_get_add_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_dft),
|
||||
llama_vocab_bos(vocab_tgt), llama_vocab_bos(vocab_dft));
|
||||
return false;
|
||||
@@ -82,8 +88,7 @@ static bool common_speculative_are_compatible(
|
||||
|
||||
if (llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
|
||||
(llama_vocab_get_add_eos(vocab_tgt) && llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft))) {
|
||||
LOG_WRN("%s: draft model eos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
|
||||
__func__,
|
||||
SPC_WRN("draft model eos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
|
||||
llama_vocab_get_add_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_dft),
|
||||
llama_vocab_eos(vocab_tgt), llama_vocab_eos(vocab_dft));
|
||||
return false;
|
||||
@@ -97,8 +102,8 @@ static bool common_speculative_are_compatible(
|
||||
: n_vocab_dft - n_vocab_tgt;
|
||||
|
||||
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
|
||||
LOG_DBG("%s: draft model vocab must closely match target model to use speculation but ", __func__);
|
||||
LOG_DBG("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
|
||||
SPC_DBG("draft model vocab must closely match target model to use speculation but "
|
||||
"target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
|
||||
n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
|
||||
return false;
|
||||
}
|
||||
@@ -108,8 +113,8 @@ static bool common_speculative_are_compatible(
|
||||
const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
|
||||
|
||||
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
||||
LOG_DBG("%s: draft model vocab must match target model to use speculation but ", __func__);
|
||||
LOG_DBG("token %d content differs - target '%s', draft '%s'\n", i,
|
||||
SPC_DBG("draft model vocab must match target model to use speculation but "
|
||||
"token %d content differs - target '%s', draft '%s'\n", i,
|
||||
common_token_to_piece(vocab_tgt, i).c_str(),
|
||||
common_token_to_piece(vocab_dft, i).c_str());
|
||||
return false;
|
||||
@@ -186,9 +191,9 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
||||
auto * ctx_dft = this->params.ctx_dft;
|
||||
auto * ctx_tgt = this->params.ctx_tgt;
|
||||
|
||||
LOG_INF("%s: adding speculative implementation 'draft-simple'\n", __func__);
|
||||
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%f\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min);
|
||||
LOG_INF("%s: - gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n", __func__,
|
||||
SPC_TRC("%s", "adding speculative implementation 'draft-simple'\n");
|
||||
SPC_TRC("- n_max=%d, n_min=%d, p_min=%f\n", this->params.n_max, this->params.n_min, this->params.p_min);
|
||||
SPC_TRC("- gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n",
|
||||
this->params.n_gpu_layers,
|
||||
ggml_type_name(this->params.cache_type_k),
|
||||
ggml_type_name(this->params.cache_type_v),
|
||||
@@ -228,16 +233,16 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
||||
}
|
||||
|
||||
const bool vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft));
|
||||
LOG_DBG("%s: vocab_cmpt = %d\n", __func__, vocab_cmpt);
|
||||
SPC_DBG("vocab_cmpt = %d\n", vocab_cmpt);
|
||||
|
||||
if (!vocab_cmpt) {
|
||||
LOG_ERR("%s: the target and draft vocabs are not compatible\n", __func__);
|
||||
SPC_ERR("%s", "the target and draft vocabs are not compatible\n");
|
||||
|
||||
throw std::runtime_error("draft model vocab type must match target model to use speculation");
|
||||
}
|
||||
|
||||
if (n_seq != llama_n_seq_max(ctx_dft)) {
|
||||
LOG_ERR("%s: n_seq mismatch: %d != %d\n", __func__, n_seq, llama_n_seq_max(ctx_dft));
|
||||
SPC_ERR("n_seq mismatch: %d != %d\n", n_seq, llama_n_seq_max(ctx_dft));
|
||||
|
||||
throw std::runtime_error("the draft model number of sequences is incompatible with the speculative n_seq");
|
||||
}
|
||||
@@ -257,7 +262,7 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
||||
const int ret = llama_decode(ctx_dft, batch);
|
||||
|
||||
if (ret != 0) {
|
||||
LOG_ERR("%s: failed to decode draft batch, ret = %d\n", __func__, ret);
|
||||
SPC_ERR("failed to decode draft batch, ret = %d\n", ret);
|
||||
|
||||
return false;
|
||||
}
|
||||
@@ -290,7 +295,7 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
||||
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
|
||||
SPC_ERR("llama_decode returned %d\n", ret);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -314,7 +319,7 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
||||
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
||||
|
||||
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
|
||||
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
SPC_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p,
|
||||
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
|
||||
}
|
||||
@@ -354,7 +359,7 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
||||
// evaluate the drafted tokens on the draft model
|
||||
ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
|
||||
SPC_ERR("llama_decode[%d] returned %d\n", i, ret);
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -449,8 +454,8 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, n_seq)
|
||||
, params(params.draft)
|
||||
{
|
||||
LOG_INF("%s: adding speculative implementation 'draft-eagle3'\n", __func__);
|
||||
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%f, backend_sampling=%d\n", __func__, params.draft.n_max, params.draft.n_min, params.draft.p_min, (int) params.draft.backend_sampling);
|
||||
SPC_TRC("%s", "adding speculative implementation 'draft-eagle3'\n");
|
||||
SPC_TRC("- n_max=%d, n_min=%d, p_min=%f, backend_sampling=%d\n", params.draft.n_max, params.draft.n_min, params.draft.p_min, (int) params.draft.backend_sampling);
|
||||
|
||||
auto * ctx_tgt = this->params.ctx_tgt;
|
||||
auto * ctx_dft = this->params.ctx_dft;
|
||||
@@ -493,7 +498,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_top_k(10));
|
||||
|
||||
if (!llama_set_sampler(ctx_dft, seq_id, chain)) {
|
||||
LOG_WRN("%s: backend offload failed for seq_id=%d; using CPU sampler\n", __func__, (int) seq_id);
|
||||
SPC_WRN("backend offload failed for seq_id=%d; using CPU sampler\n", (int) seq_id);
|
||||
llama_sampler_free(chain);
|
||||
chain = nullptr;
|
||||
}
|
||||
@@ -548,9 +553,9 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
auto * ctx_dft = this->params.ctx_dft;
|
||||
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
|
||||
if (pos_max < N - 2) {
|
||||
LOG_WRN("%s: ctx_dft pos_max=%d < N-2=%d — process() did not run on every prefill ubatch. "
|
||||
SPC_WRN("ctx_dft pos_max=%d < N-2=%d — process() did not run on every prefill ubatch. "
|
||||
"Drafts may degrade.\n",
|
||||
__func__, (int) pos_max, N - 2);
|
||||
(int) pos_max, N - 2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -621,8 +626,8 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
};
|
||||
const int32_t rc = llama_encode(ctx_dft, enc_batch);
|
||||
if (rc != 0) {
|
||||
LOG_ERR("%s: llama_encode(ctx_dft) failed rc=%d (n_tokens=%d, offset=%d)\n",
|
||||
__func__, rc, (int) n_chunk, (int) i);
|
||||
SPC_ERR("llama_encode(ctx_dft) failed rc=%d (n_tokens=%d, offset=%d)\n",
|
||||
rc, (int) n_chunk, (int) i);
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -692,8 +697,8 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
if (batch.n_tokens > 0) {
|
||||
const int32_t rc = llama_decode(ctx_dft, batch);
|
||||
if (rc != 0) {
|
||||
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (n_tokens=%d, ubatch_pos[0]=%d)\n",
|
||||
__func__, rc, (int) batch.n_tokens, (int) batch_in.pos[0]);
|
||||
SPC_ERR("llama_decode(ctx_dft) failed rc=%d (n_tokens=%d, ubatch_pos[0]=%d)\n",
|
||||
rc, (int) batch.n_tokens, (int) batch_in.pos[0]);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -744,7 +749,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
|
||||
SPC_ERR("llama_decode returned %d\n", ret);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -770,7 +775,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
||||
|
||||
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
|
||||
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
SPC_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p,
|
||||
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
|
||||
}
|
||||
@@ -809,7 +814,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
|
||||
ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
|
||||
SPC_ERR("llama_decode[%d] returned %d\n", i, ret);
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -942,9 +947,9 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
"MTP input row width must match the target h_nextn width");
|
||||
n_mtp_layers = std::max(1, (int) llama_model_n_layer_nextn(llama_get_model(ctx_dft)));
|
||||
|
||||
LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
|
||||
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
|
||||
LOG_INF("%s: - gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n", __func__,
|
||||
SPC_TRC("%s", "adding speculative implementation 'draft-mtp'\n");
|
||||
SPC_TRC("- n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
|
||||
SPC_TRC("- gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n",
|
||||
this->params.n_gpu_layers,
|
||||
ggml_type_name(this->params.cache_type_k),
|
||||
ggml_type_name(this->params.cache_type_v),
|
||||
@@ -975,7 +980,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_top_k(10));
|
||||
|
||||
if (!llama_set_sampler(ctx_dft, seq_id, chain)) {
|
||||
LOG_WRN("%s: backend offload failed for seq_id=%d; using CPU sampler\n", __func__, (int) seq_id);
|
||||
SPC_WRN("backend offload failed for seq_id=%d; using CPU sampler\n", (int) seq_id);
|
||||
llama_sampler_free(chain);
|
||||
chain = nullptr;
|
||||
}
|
||||
@@ -1038,11 +1043,11 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
|
||||
|
||||
if (pos_max < N - 1 && !is_mem_shared) {
|
||||
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - "
|
||||
SPC_WRN("ctx_dft pos_max=%d < N-1=%d - "
|
||||
"process() hook may not have run on every prefill ubatch "
|
||||
"(need_embd / logits=1 on every prompt position?). "
|
||||
"Drafts may degrade.\n",
|
||||
__func__, (int) pos_max, N - 1);
|
||||
(int) pos_max, N - 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1128,8 +1133,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
|
||||
const int32_t rc = llama_decode(ctx_dft, batch);
|
||||
if (rc != 0) {
|
||||
LOG_ERR("%s: llama_decode(ctx_dft) head=%d failed rc=%d (pos=%d)\n",
|
||||
__func__, head, (int) rc, (int) batch_in.pos[0]);
|
||||
SPC_ERR("llama_decode(ctx_dft) head=%d failed rc=%d (pos=%d)\n",
|
||||
head, (int) rc, (int) batch_in.pos[0]);
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
@@ -1217,7 +1222,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
|
||||
SPC_ERR("llama_decode[%d] returned %d\n", i, ret);
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -1239,7 +1244,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
||||
|
||||
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
|
||||
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
SPC_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p,
|
||||
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
|
||||
}
|
||||
@@ -1353,8 +1358,8 @@ struct common_speculative_impl_ngram_simple : public common_speculative_impl {
|
||||
, params(params.ngram_simple)
|
||||
, config(config)
|
||||
{
|
||||
LOG_INF("%s: adding speculative implementation 'ngram-simple'\n", __func__);
|
||||
LOG_INF("%s: - size_n=%d, size_m=%d, min_hits=%d\n", __func__,
|
||||
SPC_TRC("%s", "adding speculative implementation 'ngram-simple'\n");
|
||||
SPC_TRC("- size_n=%d, size_m=%d, min_hits=%d\n",
|
||||
this->params.size_n, this->params.size_m, this->params.min_hits);
|
||||
}
|
||||
|
||||
@@ -1403,8 +1408,8 @@ struct common_speculative_impl_ngram_map_k : public common_speculative_impl {
|
||||
this->config.push_back(config);
|
||||
}
|
||||
|
||||
LOG_INF("%s: adding speculative implementation '%s'\n", __func__, common_speculative_type_to_str(this->type).c_str());
|
||||
LOG_INF("%s: - size_key=%d, size_value=%d, key_only=%d, min_hits=%d\n", __func__,
|
||||
SPC_TRC("adding speculative implementation '%s'\n", common_speculative_type_to_str(this->type).c_str());
|
||||
SPC_TRC("- size_key=%d, size_value=%d, key_only=%d, min_hits=%d\n",
|
||||
config.size_key, config.size_value, config.key_only, config.min_hits);
|
||||
}
|
||||
|
||||
@@ -1478,15 +1483,15 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
|
||||
, verbose(std::getenv("LLAMA_TRACE") != nullptr) {
|
||||
static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t));
|
||||
|
||||
LOG_INF("%s: adding speculative implementation 'ngram-mod'\n", __func__);
|
||||
LOG_INF("%s: - n_match=%d, n_max=%d, n_min=%d\n", __func__,
|
||||
SPC_TRC("%s", "adding speculative implementation 'ngram-mod'\n");
|
||||
SPC_TRC("- n_match=%d, n_max=%d, n_min=%d\n",
|
||||
this->params.n_match, this->params.n_max, this->params.n_min);
|
||||
LOG_INF("%s: - mod size=%zu (%.3f MB)\n", __func__,
|
||||
SPC_TRC("- mod size=%zu (%.3f MB)\n",
|
||||
mod.size(), (float)(mod.size_bytes())/1024/1024);
|
||||
|
||||
if (this->params.n_match < 16) {
|
||||
LOG_WRN("%s: ngram_mod n_match=%d is too small - poor quality is possible, "
|
||||
"see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, this->params.n_match);
|
||||
SPC_WRN("ngram_mod n_match=%d is too small - poor quality is possible, "
|
||||
"see: https://github.com/ggml-org/llama.cpp/pull/19164\n", this->params.n_match);
|
||||
}
|
||||
|
||||
sinfos.resize(n_seq);
|
||||
@@ -1510,11 +1515,11 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
|
||||
sinfo.i_last = prompt.size() - n;
|
||||
|
||||
const double f = (double)mod.get_used() / (double)mod.size();
|
||||
LOG_INF("%s: ngram_mod occupancy = %zu/%zu (%.2f)\n", __func__, mod.get_used(), mod.size(), f);
|
||||
SPC_TRC("ngram_mod occupancy = %zu/%zu (%.2f)\n", mod.get_used(), mod.size(), f);
|
||||
|
||||
constexpr double f_thold = 0.25;
|
||||
if (f > f_thold) {
|
||||
LOG_WRN("%s: ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", __func__, f, f_thold);
|
||||
SPC_WRN("ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", f, f_thold);
|
||||
|
||||
mod.reset();
|
||||
}
|
||||
@@ -1608,7 +1613,7 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
|
||||
sinfo.n_low++;
|
||||
if (sinfo.n_low >= 5) {
|
||||
if (verbose) {
|
||||
LOG_WRN("%s: low acceptance streak (%d) - resetting ngram_mod\n", __func__, sinfo.n_low);
|
||||
SPC_TRC("low acceptance streak (%d) - resetting ngram_mod\n", sinfo.n_low);
|
||||
}
|
||||
|
||||
mod.reset();
|
||||
@@ -1658,8 +1663,8 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl {
|
||||
, save_dynamic(save_dynamic)
|
||||
, save_static(save_static)
|
||||
{
|
||||
LOG_INF("%s: adding speculative implementation 'ngram-cache'\n", __func__);
|
||||
LOG_INF("%s: - n_draft=%d, cache_static=%s, cache_dynamic=%s\n", __func__,
|
||||
SPC_TRC("%s", "adding speculative implementation 'ngram-cache'\n");
|
||||
SPC_TRC("- n_draft=%d, cache_static=%s, cache_dynamic=%s\n",
|
||||
n_draft,
|
||||
path_static.empty() ? "none" : path_static.c_str(),
|
||||
path_dynamic.empty() ? "none" : path_dynamic.c_str());
|
||||
@@ -1674,7 +1679,7 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl {
|
||||
sinfo.ngram_cache_static = ngram_cache_static;
|
||||
}
|
||||
} catch (...) {
|
||||
LOG_ERR("failed to open static lookup cache: %s", path_static.c_str());
|
||||
SPC_ERR("failed to open static lookup cache: %s", path_static.c_str());
|
||||
GGML_ABORT("Couldn't read static lookup cache");
|
||||
}
|
||||
}
|
||||
@@ -1687,7 +1692,7 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl {
|
||||
sinfo.ngram_cache_dynamic = ngram_cache_dynamic;
|
||||
}
|
||||
} catch (...) {
|
||||
LOG_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str());
|
||||
SPC_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str());
|
||||
GGML_ABORT("Couldn't read dynamic lookup cache");
|
||||
}
|
||||
}
|
||||
@@ -2034,7 +2039,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
||||
}
|
||||
|
||||
if (impls.empty()) {
|
||||
LOG_WRN("%s: no implementations specified for speculative decoding\n", __func__);
|
||||
SPC_TRC("%s", "no implementations specified for speculative decoding\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -2161,13 +2166,13 @@ void common_speculative_draft(common_speculative * spec) {
|
||||
|
||||
if (dp.n_max > 0) {
|
||||
if (!result.empty() && (int) result.size() > dp.n_max) {
|
||||
LOG_DBG("%s: truncating draft to %d tokens\n", __func__, dp.n_max);
|
||||
SPC_DBG("truncating draft to %d tokens\n", dp.n_max);
|
||||
result.resize(dp.n_max);
|
||||
}
|
||||
}
|
||||
|
||||
if (!result.empty()) {
|
||||
LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__,
|
||||
SPC_DBG("called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n",
|
||||
common_speculative_type_to_str(impl.get()->type).c_str(), dp.prompt->size(),
|
||||
impl.get()->n_call_draft, result.size());
|
||||
|
||||
@@ -2291,7 +2296,7 @@ void common_speculative_print_stats(const common_speculative * spec) {
|
||||
str_stats = ", #mean acc len = " + oss.str() + ", #acc rate/pos = (" + tmp.str() + ")";
|
||||
}
|
||||
|
||||
LOG_INF("statistics %16s: #calls(b,g,a) = %4zu %6zu %6zu, #gen drafts = %6zu, #acc drafts = %5zu, #gen tokens = %6zu, #acc tokens = %5zu%s%s\n",
|
||||
SPC_TRC("statistics %16s: #calls(b,g,a) = %4zu %6zu %6zu, #gen drafts = %6zu, #acc drafts = %5zu, #gen tokens = %6zu, #acc tokens = %5zu%s%s\n",
|
||||
common_speculative_type_to_str(impl->type).c_str(),
|
||||
impl->n_call_begin, impl->n_call_draft, impl->n_call_accept,
|
||||
impl->n_gen_drafts,
|
||||
|
||||
@@ -386,6 +386,46 @@ static void ggml_cpy_f32_iq4_nl_cuda(
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
// check if a same-type copy reduces to a 2D strided copy (height rows of width
|
||||
// contiguous bytes), so it can use cudaMemcpy2DAsync instead of the scalar kernel
|
||||
static bool ggml_cuda_cpy_as_memcpy_2d(const ggml_tensor * src0, const ggml_tensor * src1,
|
||||
size_t & width, size_t & height, size_t & spitch, size_t & dpitch) {
|
||||
// require matching shape: a reshaped copy maps elements by flat order, which the
|
||||
// prefix walk below does not handle
|
||||
if (src0->type != src1->type || !ggml_are_same_shape(src0, src1)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// grow the contiguous prefix block shared by both tensors
|
||||
size_t block_nb = ggml_element_size(src0);
|
||||
int d = 0;
|
||||
for (; d < GGML_MAX_DIMS; ++d) {
|
||||
if (src0->nb[d] != block_nb || src1->nb[d] != block_nb) {
|
||||
break;
|
||||
}
|
||||
block_nb *= src0->ne[d];
|
||||
}
|
||||
|
||||
// d == 0: nothing contiguous; d == GGML_MAX_DIMS: fully contiguous (handled by memcpy)
|
||||
if (d == 0 || d == GGML_MAX_DIMS) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// dim d carries the rows; everything above it must be a single element
|
||||
for (int i = d + 1; i < GGML_MAX_DIMS; ++i) {
|
||||
if (src0->ne[i] != 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
width = block_nb;
|
||||
height = src0->ne[d];
|
||||
spitch = src0->nb[d];
|
||||
dpitch = src1->nb[d];
|
||||
|
||||
return spitch >= width && dpitch >= width;
|
||||
}
|
||||
|
||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
const int64_t ne = ggml_nelements(src0);
|
||||
GGML_ASSERT(ne == ggml_nelements(src1));
|
||||
@@ -421,6 +461,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) &&
|
||||
src0->ne[3] == 1 && nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0);
|
||||
|
||||
size_t mc_width = 0, mc_height = 0, mc_spitch = 0, mc_dpitch = 0;
|
||||
|
||||
if (src0->type == src1->type && contiguous_srcs) {
|
||||
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
|
||||
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
|
||||
@@ -431,6 +473,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
{
|
||||
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
||||
}
|
||||
} else if (ggml_cuda_cpy_as_memcpy_2d(src0, src1, mc_width, mc_height, mc_spitch, mc_dpitch)) {
|
||||
CUDA_CHECK(cudaMemcpy2DAsync(src1_ddc, mc_dpitch, src0_ddc, mc_spitch,
|
||||
mc_width, mc_height, cudaMemcpyDeviceToDevice, main_stream));
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
if (can_be_transposed) {
|
||||
ggml_cpy_scalar_cuda<float, float, true>
|
||||
|
||||
@@ -192,7 +192,10 @@ set(GGML_OPENCL_KERNELS
|
||||
mul_mm_f16_f32_kq_kqv
|
||||
conv2d
|
||||
conv2d_f16_f32
|
||||
flash_attn_pre_f16
|
||||
flash_attn_f32_f16
|
||||
flash_attn_f32_q8_0
|
||||
flash_attn_f32_q4_0
|
||||
flash_attn_f16
|
||||
flash_attn_f32
|
||||
)
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
#pragma once
|
||||
|
||||
// Flash-attention per-(dk,dv) tile tuning for the Adreno OpenCL backend.
|
||||
// Isolated from ggml-opencl.cpp so the tuning numbers are easy to find and
|
||||
// edit; the FA dispatch and kernel-compile logic stay in the main file.
|
||||
// This header is a file section — it is #included exactly once, at the point
|
||||
// in ggml-opencl.cpp where the ggml logging macros are already in scope.
|
||||
|
||||
// Per-(dk, dv) FA config; shared by dispatch and supports_op.
|
||||
struct ggml_opencl_fa_dim {
|
||||
int dk; int dv; int bm; int bn; int n_split; int nkv_split_threshold;
|
||||
};
|
||||
|
||||
// Split variant fires when n_kv >= threshold (threshold=0 -> always split).
|
||||
// Default tuning covers Adreno 7xx/8xx mobile and X1-series laptop GPUs.
|
||||
static const ggml_opencl_fa_dim g_fa_dims_adreno_default[] = {
|
||||
{ 40, 40, 64, 32, 1, 0}, { 64, 64, 64, 32, 2, 64},
|
||||
{ 80, 80, 64, 32, 2, 64}, { 96, 96, 64, 32, 2, 64},
|
||||
{112, 112, 64, 32, 2, 64}, {128, 128, 64, 32, 2, 64},
|
||||
{192, 128, 16, 16, 1, 0},
|
||||
{192, 192, 16, 16, 1, 0},
|
||||
{256, 256, 16, 16, 16, 0},
|
||||
};
|
||||
|
||||
struct ggml_opencl_fa_dim_table {
|
||||
const ggml_opencl_fa_dim * data;
|
||||
size_t count;
|
||||
|
||||
const ggml_opencl_fa_dim * begin() const { return data; }
|
||||
const ggml_opencl_fa_dim * end() const { return data + count; }
|
||||
};
|
||||
|
||||
// Mutable copy of the active table; GGML_OPENCL_FA_TUNE patches entries here
|
||||
// at backend init without touching the const source table.
|
||||
static ggml_opencl_fa_dim g_fa_dims_runtime[
|
||||
sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0])];
|
||||
|
||||
static ggml_opencl_fa_dim_table g_opencl_fa_dims = {
|
||||
g_fa_dims_adreno_default,
|
||||
sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0]),
|
||||
};
|
||||
|
||||
// GGML_OPENCL_FA_TUNE=dk:dv:bm:bn:nsplit:thr[,…] — patches matching entries
|
||||
// in the active table at backend init, before the first FA kernel compiles.
|
||||
// Unmatched (dk,dv) pairs are warned and ignored.
|
||||
static void ggml_opencl_fa_apply_env_overrides() {
|
||||
const char * e = std::getenv("GGML_OPENCL_FA_TUNE");
|
||||
if (!e || !e[0]) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::string s = e;
|
||||
size_t pos = 0;
|
||||
while (pos < s.size()) {
|
||||
size_t comma = s.find(',', pos);
|
||||
std::string entry = s.substr(pos, comma == std::string::npos ? std::string::npos : comma - pos);
|
||||
int dk, dv, bm, bn, nsplit, thr;
|
||||
if (std::sscanf(entry.c_str(), "%d:%d:%d:%d:%d:%d", &dk, &dv, &bm, &bn, &nsplit, &thr) == 6) {
|
||||
bool patched = false;
|
||||
for (size_t i = 0; i < g_opencl_fa_dims.count; ++i) {
|
||||
ggml_opencl_fa_dim & d = g_fa_dims_runtime[i];
|
||||
if (d.dk == dk && d.dv == dv) {
|
||||
d.bm = bm; d.bn = bn; d.n_split = nsplit; d.nkv_split_threshold = thr;
|
||||
GGML_LOG_INFO("ggml_opencl: FA tune override DK=%d DV=%d -> bm=%d bn=%d n_split=%d thr=%d\n",
|
||||
dk, dv, bm, bn, nsplit, thr);
|
||||
patched = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!patched) {
|
||||
GGML_LOG_WARN("ggml_opencl: FA tune override DK=%d DV=%d ignored (no matching dim)\n", dk, dv);
|
||||
}
|
||||
} else {
|
||||
GGML_LOG_WARN("ggml_opencl: FA tune override entry malformed: '%s'\n", entry.c_str());
|
||||
}
|
||||
if (comma == std::string::npos) break;
|
||||
pos = comma + 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Copy the default table into the mutable runtime buffer and apply any
|
||||
// GGML_OPENCL_FA_TUNE overrides. A per-generation table can be added here
|
||||
// once it has been tuned on hardware.
|
||||
static void ggml_cl_init_fa_dims_table() {
|
||||
const size_t count = sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0]);
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
g_fa_dims_runtime[i] = g_fa_dims_adreno_default[i];
|
||||
}
|
||||
g_opencl_fa_dims = { g_fa_dims_runtime, count };
|
||||
ggml_opencl_fa_apply_env_overrides();
|
||||
}
|
||||
+1787
-252
File diff suppressed because it is too large
Load Diff
@@ -1582,6 +1582,158 @@ kernel void kernel_restore_block_q8_0(
|
||||
}
|
||||
}
|
||||
|
||||
// View-aware AoS q8_0 -> f32 dequant (f32/f32 FA path).
|
||||
kernel void kernel_dequant_q8_0_f32_view_aos(
|
||||
global char * src,
|
||||
ulong src_offset,
|
||||
ulong src_nb1,
|
||||
ulong src_nb2,
|
||||
ulong src_nb3,
|
||||
int nblk0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
global float * dst
|
||||
) {
|
||||
int blk_i0 = get_global_id(0);
|
||||
int i1 = get_global_id(1);
|
||||
int batch = get_global_id(2);
|
||||
|
||||
if (blk_i0 >= nblk0) return;
|
||||
if (i1 >= ne1) return;
|
||||
|
||||
int i2 = batch % ne2;
|
||||
int i3 = batch / ne2;
|
||||
if (i3 >= ne3) return;
|
||||
|
||||
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK8_0);
|
||||
float d = vload_half(0, (global half *)block);
|
||||
global char * qs = block + 2;
|
||||
|
||||
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
|
||||
global float * out = dst + (dst_row_base + blk_i0) * QK8_0;
|
||||
|
||||
for (int i = 0; i < QK8_0; ++i) {
|
||||
out[i] = d * (float)qs[i];
|
||||
}
|
||||
}
|
||||
|
||||
// View-aware AoS q8_0 -> f16 dequant. Rows tight, batch strides may be gapped.
|
||||
kernel void kernel_dequant_q8_0_f16_view_aos(
|
||||
global char * src,
|
||||
ulong src_offset,
|
||||
ulong src_nb1,
|
||||
ulong src_nb2,
|
||||
ulong src_nb3,
|
||||
int nblk0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
global half * dst
|
||||
) {
|
||||
int blk_i0 = get_global_id(0);
|
||||
int i1 = get_global_id(1);
|
||||
int batch = get_global_id(2);
|
||||
|
||||
if (blk_i0 >= nblk0) return;
|
||||
if (i1 >= ne1) return;
|
||||
|
||||
int i2 = batch % ne2;
|
||||
int i3 = batch / ne2;
|
||||
if (i3 >= ne3) return;
|
||||
|
||||
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK8_0);
|
||||
float d = vload_half(0, (global half *)block);
|
||||
global char * qs = block + 2;
|
||||
|
||||
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
|
||||
global half * out = dst + (dst_row_base + blk_i0) * QK8_0;
|
||||
|
||||
for (int i = 0; i < QK8_0; ++i) {
|
||||
out[i] = (half)(d * (float)qs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// View-aware AoS q4_0 -> f32 dequant (mirrors the q8_0 view variant).
|
||||
kernel void kernel_dequant_q4_0_f32_view_aos(
|
||||
global char * src,
|
||||
ulong src_offset,
|
||||
ulong src_nb1,
|
||||
ulong src_nb2,
|
||||
ulong src_nb3,
|
||||
int nblk0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
global float * dst
|
||||
) {
|
||||
int blk_i0 = get_global_id(0);
|
||||
int i1 = get_global_id(1);
|
||||
int batch = get_global_id(2);
|
||||
|
||||
if (blk_i0 >= nblk0) return;
|
||||
if (i1 >= ne1) return;
|
||||
|
||||
int i2 = batch % ne2;
|
||||
int i3 = batch / ne2;
|
||||
if (i3 >= ne3) return;
|
||||
|
||||
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK4_0/2);
|
||||
float d = vload_half(0, (global half *)block);
|
||||
global uchar * qs = (global uchar *)(block + 2);
|
||||
|
||||
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
|
||||
global float * out = dst + (dst_row_base + blk_i0) * QK4_0;
|
||||
|
||||
for (int i = 0; i < QK4_0/2; ++i) {
|
||||
uchar byte = qs[i];
|
||||
int q0 = (int)(byte & 0x0F) - 8;
|
||||
int q1 = (int)(byte >> 4) - 8;
|
||||
out[i] = d * (float)q0;
|
||||
out[i + QK4_0/2] = d * (float)q1;
|
||||
}
|
||||
}
|
||||
|
||||
// View-aware AoS q4_0 -> f16 dequant (mirrors the q8_0 view variant).
|
||||
kernel void kernel_dequant_q4_0_f16_view_aos(
|
||||
global char * src,
|
||||
ulong src_offset,
|
||||
ulong src_nb1,
|
||||
ulong src_nb2,
|
||||
ulong src_nb3,
|
||||
int nblk0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
global half * dst
|
||||
) {
|
||||
int blk_i0 = get_global_id(0);
|
||||
int i1 = get_global_id(1);
|
||||
int batch = get_global_id(2);
|
||||
|
||||
if (blk_i0 >= nblk0) return;
|
||||
if (i1 >= ne1) return;
|
||||
|
||||
int i2 = batch % ne2;
|
||||
int i3 = batch / ne2;
|
||||
if (i3 >= ne3) return;
|
||||
|
||||
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK4_0/2);
|
||||
float d = vload_half(0, (global half *)block);
|
||||
global uchar * qs = (global uchar *)(block + 2);
|
||||
|
||||
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
|
||||
global half * out = dst + (dst_row_base + blk_i0) * QK4_0;
|
||||
|
||||
for (int i = 0; i < QK4_0/2; ++i) {
|
||||
uchar byte = qs[i];
|
||||
int q0 = (int)(byte & 0x0F) - 8;
|
||||
int q1 = (int)(byte >> 4) - 8;
|
||||
out[i] = (half)(d * (float)q0);
|
||||
out[i + QK4_0/2] = (half)(d * (float)q1);
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_restore_block_q8_0_trans(
|
||||
global uchar * src_q,
|
||||
global half * src_d,
|
||||
|
||||
@@ -4,14 +4,26 @@
|
||||
#define ACC_TYPE4 float4
|
||||
#define DATA_TYPE half
|
||||
#define DATA_TYPE4 half4
|
||||
#define CONVERT_ACC4(x) convert_float4(x)
|
||||
#define CONVERT_DATA4(x) convert_half4(x)
|
||||
#define CONVERT_ACC4(x) ((float4)((float)(x).s0, (float)(x).s1, (float)(x).s2, (float)(x).s3))
|
||||
#define CONVERT_DATA4(x) ((half4)((half)(x).s0, (half)(x).s1, (half)(x).s2, (half)(x).s3))
|
||||
|
||||
#define DK_VEC (DK/4)
|
||||
#define DV_VEC (DV/4)
|
||||
#define WG_SIZE (BLOCK_M)
|
||||
#define Q1_WG_SIZE 64
|
||||
|
||||
// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs,
|
||||
// infinite operand can cause undefined behavior and miscompilation for exp.
|
||||
// Therefore, a large negative value is used instead.
|
||||
#define FA_M_INIT (-3.0e38f)
|
||||
|
||||
// Drop full unroll at DK>=192 — Adreno compiler host-memory budget.
|
||||
#if DK >= 192
|
||||
#define FA_UNROLL
|
||||
#else
|
||||
#define FA_UNROLL _Pragma("unroll")
|
||||
#endif
|
||||
|
||||
inline float get_alibi_slope(
|
||||
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
|
||||
) {
|
||||
@@ -81,18 +93,18 @@ __kernel void flash_attn_f16(
|
||||
if (my_query_row < n_q) {
|
||||
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
|
||||
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DK_VEC; ++i) {
|
||||
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE4 o_acc[DV_VEC];
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] = (ACC_TYPE4)(0.0f);
|
||||
}
|
||||
ACC_TYPE m_i = -INFINITY;
|
||||
ACC_TYPE m_i = FA_M_INIT;
|
||||
ACC_TYPE l_i = 0.0f;
|
||||
|
||||
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
|
||||
@@ -125,49 +137,72 @@ __kernel void flash_attn_f16(
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int j = 0; j < BLOCK_N; j += 2) {
|
||||
for (int j = 0; j < BLOCK_N; j += 4) {
|
||||
const int k_row0 = k_start + j;
|
||||
const int k_row1 = k_start + j + 1;
|
||||
const int k_row2 = k_start + j + 2;
|
||||
const int k_row3 = k_start + j + 3;
|
||||
|
||||
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f);
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
|
||||
dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
|
||||
const ACC_TYPE4 qk = q_priv[k];
|
||||
dot_acc0 = mad(qk, CONVERT_ACC4(l_k[j][k]), dot_acc0);
|
||||
dot_acc1 = mad(qk, CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
|
||||
dot_acc2 = mad(qk, CONVERT_ACC4(l_k[j+2][k]), dot_acc2);
|
||||
dot_acc3 = mad(qk, CONVERT_ACC4(l_k[j+3][k]), dot_acc3);
|
||||
}
|
||||
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
|
||||
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
|
||||
ACC_TYPE s0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
|
||||
ACC_TYPE s1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
|
||||
ACC_TYPE s2 = (dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3) * scale;
|
||||
ACC_TYPE s3 = (dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3) * scale;
|
||||
|
||||
if (is_causal) {
|
||||
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
|
||||
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
|
||||
const int causal_limit = n_kv - n_q + my_query_row;
|
||||
if (k_row0 > causal_limit) s0 = FA_M_INIT;
|
||||
if (k_row1 > causal_limit) s1 = FA_M_INIT;
|
||||
if (k_row2 > causal_limit) s2 = FA_M_INIT;
|
||||
if (k_row3 > causal_limit) s3 = FA_M_INIT;
|
||||
}
|
||||
|
||||
if (k_row0 >= n_kv) score0 = -INFINITY;
|
||||
if (k_row1 >= n_kv) score1 = -INFINITY;
|
||||
if (k_row0 >= n_kv) s0 = FA_M_INIT;
|
||||
if (k_row1 >= n_kv) s1 = FA_M_INIT;
|
||||
if (k_row2 >= n_kv) s2 = FA_M_INIT;
|
||||
if (k_row3 >= n_kv) s3 = FA_M_INIT;
|
||||
|
||||
if (mask_base != NULL) {
|
||||
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
|
||||
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
||||
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
||||
if (k_row0 < n_kv) s0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
||||
if (k_row1 < n_kv) s1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
||||
if (k_row2 < n_kv) s2 += slope * (ACC_TYPE)mask_ptr[k_row2];
|
||||
if (k_row3 < n_kv) s3 += slope * (ACC_TYPE)mask_ptr[k_row3];
|
||||
}
|
||||
|
||||
if (logit_softcap > 0.0f) {
|
||||
score0 = logit_softcap * tanh(score0 / logit_softcap);
|
||||
score1 = logit_softcap * tanh(score1 / logit_softcap);
|
||||
s0 = logit_softcap * tanh(s0 / logit_softcap);
|
||||
s1 = logit_softcap * tanh(s1 / logit_softcap);
|
||||
s2 = logit_softcap * tanh(s2 / logit_softcap);
|
||||
s3 = logit_softcap * tanh(s3 / logit_softcap);
|
||||
}
|
||||
|
||||
const ACC_TYPE m_new = max(m_i, max(score0, score1));
|
||||
const ACC_TYPE p0 = exp(score0 - m_new);
|
||||
const ACC_TYPE p1 = exp(score1 - m_new);
|
||||
const ACC_TYPE scale_prev = exp(m_i - m_new);
|
||||
const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3)));
|
||||
const ACC_TYPE scale_prev = native_exp(m_i - m_new);
|
||||
const ACC_TYPE p0 = native_exp(s0 - m_new);
|
||||
const ACC_TYPE p1 = native_exp(s1 - m_new);
|
||||
const ACC_TYPE p2 = native_exp(s2 - m_new);
|
||||
const ACC_TYPE p3 = native_exp(s3 - m_new);
|
||||
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]);
|
||||
o_acc[i] = mad(p3, CONVERT_ACC4(l_v[j+3][i]),
|
||||
mad(p2, CONVERT_ACC4(l_v[j+2][i]),
|
||||
mad(p1, CONVERT_ACC4(l_v[j+1][i]),
|
||||
mad(p0, CONVERT_ACC4(l_v[j][i]),
|
||||
o_acc[i] * scale_prev))));
|
||||
}
|
||||
l_i = l_i * scale_prev + p0 + p1;
|
||||
l_i = l_i * scale_prev + p0 + p1 + p2 + p3;
|
||||
m_i = m_new;
|
||||
}
|
||||
}
|
||||
@@ -179,7 +214,7 @@ __kernel void flash_attn_f16(
|
||||
const ACC_TYPE m_final = max(m_i, m_sink);
|
||||
|
||||
const ACC_TYPE scale_o = exp(m_i - m_final);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] *= scale_o;
|
||||
}
|
||||
@@ -191,12 +226,12 @@ __kernel void flash_attn_f16(
|
||||
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
|
||||
if (l_i > 0.0f) {
|
||||
const ACC_TYPE l_inv = 1.0f / l_i;
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_row[i] = (DATA_TYPE4)(0.0f);
|
||||
}
|
||||
@@ -258,7 +293,7 @@ __kernel void flash_attn_f16_q1(
|
||||
ACC_TYPE4 q_priv[DK_VEC];
|
||||
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
|
||||
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DK_VEC; ++i) {
|
||||
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
|
||||
}
|
||||
@@ -270,12 +305,12 @@ __kernel void flash_attn_f16_q1(
|
||||
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
|
||||
}
|
||||
|
||||
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
|
||||
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT;
|
||||
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
|
||||
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
|
||||
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
|
||||
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
|
||||
}
|
||||
@@ -293,7 +328,7 @@ __kernel void flash_attn_f16_q1(
|
||||
__local ACC_TYPE local_m[Q1_WG_SIZE];
|
||||
local_m[tid] = m_i;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -301,7 +336,7 @@ __kernel void flash_attn_f16_q1(
|
||||
const ACC_TYPE m_final = local_m[0];
|
||||
|
||||
ACC_TYPE4 o_acc[DV_VEC];
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE l_i = 0.0f;
|
||||
|
||||
@@ -311,7 +346,7 @@ __kernel void flash_attn_f16_q1(
|
||||
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
|
||||
const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);
|
||||
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
|
||||
}
|
||||
@@ -325,7 +360,7 @@ __kernel void flash_attn_f16_q1(
|
||||
}
|
||||
const ACC_TYPE p = exp(score - m_final);
|
||||
l_i += p;
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; i++) {
|
||||
o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
|
||||
}
|
||||
@@ -335,7 +370,7 @@ __kernel void flash_attn_f16_q1(
|
||||
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
|
||||
local_l[tid] = l_i;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_l[tid] += local_l[tid + s];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -354,7 +389,7 @@ __kernel void flash_attn_f16_q1(
|
||||
for (int i = 0; i < DV_VEC; i++) {
|
||||
local_o_comp[tid] = o_acc[i];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -364,7 +399,7 @@ __kernel void flash_attn_f16_q1(
|
||||
}
|
||||
}
|
||||
} else if (tid == 0) {
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,18 @@
|
||||
#define WG_SIZE (BLOCK_M)
|
||||
#define Q1_WG_SIZE 64
|
||||
|
||||
// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs,
|
||||
// infinite operand can cause undefined behavior and miscompilation for exp.
|
||||
// Therefore, a large negative value is used instead.
|
||||
#define FA_M_INIT (-3.0e38f)
|
||||
|
||||
// Drop full unroll at DK>=192 — Adreno compiler host-memory budget.
|
||||
#if DK >= 192
|
||||
#define FA_UNROLL
|
||||
#else
|
||||
#define FA_UNROLL _Pragma("unroll")
|
||||
#endif
|
||||
|
||||
inline float get_alibi_slope(
|
||||
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
|
||||
) {
|
||||
@@ -82,18 +94,18 @@ __kernel void flash_attn_f32(
|
||||
if (my_query_row < n_q) {
|
||||
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
|
||||
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DK_VEC; ++i) {
|
||||
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE4 o_acc[DV_VEC];
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] = (ACC_TYPE4)(0.0f);
|
||||
}
|
||||
ACC_TYPE m_i = -INFINITY;
|
||||
ACC_TYPE m_i = FA_M_INIT;
|
||||
ACC_TYPE l_i = 0.0f;
|
||||
|
||||
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
|
||||
@@ -126,49 +138,72 @@ __kernel void flash_attn_f32(
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int j = 0; j < BLOCK_N; j += 2) {
|
||||
for (int j = 0; j < BLOCK_N; j += 4) {
|
||||
const int k_row0 = k_start + j;
|
||||
const int k_row1 = k_start + j + 1;
|
||||
const int k_row2 = k_start + j + 2;
|
||||
const int k_row3 = k_start + j + 3;
|
||||
|
||||
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f);
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
|
||||
dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
|
||||
const ACC_TYPE4 qk = q_priv[k];
|
||||
dot_acc0 = mad(qk, CONVERT_ACC4(l_k[j][k]), dot_acc0);
|
||||
dot_acc1 = mad(qk, CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
|
||||
dot_acc2 = mad(qk, CONVERT_ACC4(l_k[j+2][k]), dot_acc2);
|
||||
dot_acc3 = mad(qk, CONVERT_ACC4(l_k[j+3][k]), dot_acc3);
|
||||
}
|
||||
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
|
||||
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
|
||||
ACC_TYPE s0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
|
||||
ACC_TYPE s1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
|
||||
ACC_TYPE s2 = (dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3) * scale;
|
||||
ACC_TYPE s3 = (dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3) * scale;
|
||||
|
||||
if (is_causal) {
|
||||
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
|
||||
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
|
||||
const int causal_limit = n_kv - n_q + my_query_row;
|
||||
if (k_row0 > causal_limit) s0 = FA_M_INIT;
|
||||
if (k_row1 > causal_limit) s1 = FA_M_INIT;
|
||||
if (k_row2 > causal_limit) s2 = FA_M_INIT;
|
||||
if (k_row3 > causal_limit) s3 = FA_M_INIT;
|
||||
}
|
||||
|
||||
if (k_row0 >= n_kv) score0 = -INFINITY;
|
||||
if (k_row1 >= n_kv) score1 = -INFINITY;
|
||||
if (k_row0 >= n_kv) s0 = FA_M_INIT;
|
||||
if (k_row1 >= n_kv) s1 = FA_M_INIT;
|
||||
if (k_row2 >= n_kv) s2 = FA_M_INIT;
|
||||
if (k_row3 >= n_kv) s3 = FA_M_INIT;
|
||||
|
||||
if (mask_base != NULL) {
|
||||
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
|
||||
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
||||
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
||||
if (k_row0 < n_kv) s0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
||||
if (k_row1 < n_kv) s1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
||||
if (k_row2 < n_kv) s2 += slope * (ACC_TYPE)mask_ptr[k_row2];
|
||||
if (k_row3 < n_kv) s3 += slope * (ACC_TYPE)mask_ptr[k_row3];
|
||||
}
|
||||
|
||||
if (logit_softcap > 0.0f) {
|
||||
score0 = logit_softcap * tanh(score0 / logit_softcap);
|
||||
score1 = logit_softcap * tanh(score1 / logit_softcap);
|
||||
s0 = logit_softcap * tanh(s0 / logit_softcap);
|
||||
s1 = logit_softcap * tanh(s1 / logit_softcap);
|
||||
s2 = logit_softcap * tanh(s2 / logit_softcap);
|
||||
s3 = logit_softcap * tanh(s3 / logit_softcap);
|
||||
}
|
||||
|
||||
const ACC_TYPE m_new = max(m_i, max(score0, score1));
|
||||
const ACC_TYPE p0 = exp(score0 - m_new);
|
||||
const ACC_TYPE p1 = exp(score1 - m_new);
|
||||
const ACC_TYPE scale_prev = exp(m_i - m_new);
|
||||
const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3)));
|
||||
const ACC_TYPE scale_prev = native_exp(m_i - m_new);
|
||||
const ACC_TYPE p0 = native_exp(s0 - m_new);
|
||||
const ACC_TYPE p1 = native_exp(s1 - m_new);
|
||||
const ACC_TYPE p2 = native_exp(s2 - m_new);
|
||||
const ACC_TYPE p3 = native_exp(s3 - m_new);
|
||||
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]);
|
||||
o_acc[i] = mad(p3, CONVERT_ACC4(l_v[j+3][i]),
|
||||
mad(p2, CONVERT_ACC4(l_v[j+2][i]),
|
||||
mad(p1, CONVERT_ACC4(l_v[j+1][i]),
|
||||
mad(p0, CONVERT_ACC4(l_v[j][i]),
|
||||
o_acc[i] * scale_prev))));
|
||||
}
|
||||
l_i = l_i * scale_prev + p0 + p1;
|
||||
l_i = l_i * scale_prev + p0 + p1 + p2 + p3;
|
||||
m_i = m_new;
|
||||
}
|
||||
}
|
||||
@@ -180,7 +215,7 @@ __kernel void flash_attn_f32(
|
||||
const ACC_TYPE m_final = max(m_i, m_sink);
|
||||
|
||||
const ACC_TYPE scale_o = exp(m_i - m_final);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] *= scale_o;
|
||||
}
|
||||
@@ -192,12 +227,12 @@ __kernel void flash_attn_f32(
|
||||
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
|
||||
if (l_i > 0.0f) {
|
||||
const ACC_TYPE l_inv = 1.0f / l_i;
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_row[i] = (DATA_TYPE4)(0.0f);
|
||||
}
|
||||
@@ -259,7 +294,7 @@ __kernel void flash_attn_f32_q1(
|
||||
ACC_TYPE4 q_priv[DK_VEC];
|
||||
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
|
||||
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DK_VEC; ++i) {
|
||||
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
|
||||
}
|
||||
@@ -271,12 +306,12 @@ __kernel void flash_attn_f32_q1(
|
||||
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
|
||||
}
|
||||
|
||||
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
|
||||
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT;
|
||||
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
|
||||
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
|
||||
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
|
||||
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
|
||||
}
|
||||
@@ -294,7 +329,7 @@ __kernel void flash_attn_f32_q1(
|
||||
__local ACC_TYPE local_m[Q1_WG_SIZE];
|
||||
local_m[tid] = m_i;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -302,7 +337,7 @@ __kernel void flash_attn_f32_q1(
|
||||
const ACC_TYPE m_final = local_m[0];
|
||||
|
||||
ACC_TYPE4 o_acc[DV_VEC];
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE l_i = 0.0f;
|
||||
|
||||
@@ -312,7 +347,7 @@ __kernel void flash_attn_f32_q1(
|
||||
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
|
||||
const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);
|
||||
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
|
||||
}
|
||||
@@ -326,7 +361,7 @@ __kernel void flash_attn_f32_q1(
|
||||
}
|
||||
const ACC_TYPE p = exp(score - m_final);
|
||||
l_i += p;
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; i++) {
|
||||
o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
|
||||
}
|
||||
@@ -336,7 +371,7 @@ __kernel void flash_attn_f32_q1(
|
||||
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
|
||||
local_l[tid] = l_i;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_l[tid] += local_l[tid + s];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -355,7 +390,7 @@ __kernel void flash_attn_f32_q1(
|
||||
for (int i = 0; i < DV_VEC; i++) {
|
||||
local_o_comp[tid] = o_acc[i];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -365,7 +400,7 @@ __kernel void flash_attn_f32_q1(
|
||||
}
|
||||
}
|
||||
} else if (tid == 0) {
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,13 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_khr_subgroup_shuffle
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroup_shuffle : enable
|
||||
#define HAS_SUBGROUP_SHUFFLE 1
|
||||
#elif defined(cl_qcom_subgroup_shuffle)
|
||||
#pragma OPENCL EXTENSION cl_qcom_subgroup_shuffle : enable
|
||||
#define HAS_SUBGROUP_SHUFFLE 1
|
||||
#endif
|
||||
|
||||
#define ACC_TYPE float
|
||||
#define ACC_TYPE4 float4
|
||||
#define Q_DATA_TYPE4 float4
|
||||
@@ -12,9 +20,34 @@
|
||||
|
||||
#define DK_VEC (DK/4)
|
||||
#define DV_VEC (DV/4)
|
||||
#define WG_SIZE (BLOCK_M)
|
||||
#define Q1_WG_SIZE 64
|
||||
|
||||
// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs,
|
||||
// infinite operand can cause undefined behavior and miscompilation for exp.
|
||||
// Therefore, a large negative value is used instead.
|
||||
#define FA_M_INIT (-3.0e38f)
|
||||
|
||||
// Drop full unroll at DK>=192 — Adreno compiler host-memory budget.
|
||||
#if DK >= 192
|
||||
#define FA_UNROLL
|
||||
#else
|
||||
#define FA_UNROLL _Pragma("unroll")
|
||||
#endif
|
||||
|
||||
// N_SPLIT>1 splits DK/DV across threads to cut per-thread register use.
|
||||
#ifndef N_SPLIT
|
||||
#define N_SPLIT 1
|
||||
#endif
|
||||
|
||||
#define SPLIT_DK_VEC (DK_VEC / N_SPLIT)
|
||||
#define SPLIT_DV_VEC (DV_VEC / N_SPLIT)
|
||||
|
||||
#if N_SPLIT > 1
|
||||
#define WG_SIZE (BLOCK_M * N_SPLIT)
|
||||
#else
|
||||
#define WG_SIZE (BLOCK_M)
|
||||
#endif
|
||||
|
||||
inline float get_alibi_slope(
|
||||
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
|
||||
) {
|
||||
@@ -54,19 +87,38 @@ __kernel void flash_attn_f32_f16(
|
||||
const int mask_ne2,
|
||||
const int mask_ne3,
|
||||
const global void* sinks_void,
|
||||
const ulong sinks_offset
|
||||
const ulong sinks_offset,
|
||||
const global void * k_pad_void,
|
||||
const global void * v_pad_void,
|
||||
const global void * mask_pad_void,
|
||||
const global char * blk,
|
||||
const int n_kv_blocks,
|
||||
const ulong mask_pad_nb1,
|
||||
const ulong mask_pad_nb2,
|
||||
const ulong mask_pad_nb3
|
||||
) {
|
||||
const int tid = get_local_id(0);
|
||||
const int block_q_idx = get_group_id(0);
|
||||
const int head_batch_idx = get_global_id(1);
|
||||
|
||||
const int my_query_row = block_q_idx * BLOCK_M + tid;
|
||||
#if N_SPLIT > 1
|
||||
const int q_lane = tid / N_SPLIT;
|
||||
const int split_idx = tid % N_SPLIT;
|
||||
#else
|
||||
const int q_lane = tid;
|
||||
const int split_idx = 0;
|
||||
#endif
|
||||
|
||||
const int my_query_row = block_q_idx * BLOCK_M + q_lane;
|
||||
const int query_valid = my_query_row < n_q;
|
||||
|
||||
const int batch_idx = head_batch_idx / n_head;
|
||||
const int head_idx = head_batch_idx % n_head;
|
||||
|
||||
const int gqa_ratio = n_head / n_head_kv;
|
||||
const int head_kv_idx = head_idx / gqa_ratio;
|
||||
const int mask_head_idx = mask_void != NULL ? head_idx % mask_ne2 : 0;
|
||||
const int mask_batch_idx = mask_void != NULL ? batch_idx % mask_ne3 : 0;
|
||||
|
||||
const global char* q_base = (const global char*)q_void + q_offset;
|
||||
const global char* k_base = (const global char*)k_void + k_offset;
|
||||
@@ -75,27 +127,41 @@ __kernel void flash_attn_f32_f16(
|
||||
|
||||
const global char* mask_base = NULL;
|
||||
if (mask_void != NULL) {
|
||||
const int mask_head_idx = head_idx % mask_ne2;
|
||||
const int mask_batch_idx = batch_idx % mask_ne3;
|
||||
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
|
||||
}
|
||||
const global char* mask_pad_base = NULL;
|
||||
if (mask_pad_void != NULL) {
|
||||
mask_pad_base = (const global char*)mask_pad_void + mask_batch_idx * mask_pad_nb3 + mask_head_idx * mask_pad_nb2;
|
||||
}
|
||||
const global char* blk_base = NULL;
|
||||
if (blk != NULL) {
|
||||
const int n_q_blocks = (n_q + BLOCK_M - 1) / BLOCK_M;
|
||||
blk_base = blk + (((mask_batch_idx * mask_ne2) + mask_head_idx) * n_q_blocks + block_q_idx) * n_kv_blocks;
|
||||
}
|
||||
|
||||
ACC_TYPE4 q_priv[DK_VEC];
|
||||
if (my_query_row < n_q) {
|
||||
ACC_TYPE4 q_priv[SPLIT_DK_VEC];
|
||||
const int dk_off = split_idx * SPLIT_DK_VEC;
|
||||
if (query_valid) {
|
||||
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
|
||||
const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DK_VEC; ++i) {
|
||||
q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DK_VEC; ++i) {
|
||||
q_priv[i] = CONVERT_Q_ACC4(q_ptr[dk_off + i]);
|
||||
}
|
||||
} else {
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DK_VEC; ++i) {
|
||||
q_priv[i] = (ACC_TYPE4)(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE4 o_acc[DV_VEC];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
ACC_TYPE4 o_acc[SPLIT_DV_VEC];
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
|
||||
o_acc[i] = (ACC_TYPE4)(0.0f);
|
||||
}
|
||||
ACC_TYPE m_i = -INFINITY;
|
||||
|
||||
ACC_TYPE m_i = FA_M_INIT;
|
||||
ACC_TYPE l_i = 0.0f;
|
||||
|
||||
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
|
||||
@@ -103,86 +169,369 @@ __kernel void flash_attn_f32_f16(
|
||||
__local KV_DATA_TYPE4 l_k[BLOCK_N][DK_VEC];
|
||||
__local KV_DATA_TYPE4 l_v[BLOCK_N][DV_VEC];
|
||||
|
||||
#if N_SPLIT > 1 && !defined(HAS_SUBGROUP_SHUFFLE)
|
||||
__local ACC_TYPE local_partial[BLOCK_N][WG_SIZE];
|
||||
__local ACC_TYPE local_p[BLOCK_M][BLOCK_N];
|
||||
__local ACC_TYPE local_softmax_scale[BLOCK_M];
|
||||
__local ACC_TYPE local_l_inv[BLOCK_M];
|
||||
#endif
|
||||
|
||||
for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) {
|
||||
char blk_cur = 1;
|
||||
if (blk_base != NULL) {
|
||||
blk_cur = blk_base[k_start / BLOCK_N];
|
||||
if (blk_cur == 0) continue;
|
||||
}
|
||||
|
||||
const int use_kv_pad = k_pad_void != NULL && k_start + BLOCK_N > n_kv;
|
||||
const int k_tile_start = use_kv_pad ? 0 : k_start;
|
||||
const ulong k_tile_nb2 = use_kv_pad ? (ulong) BLOCK_N * k_nb1 : k_nb2;
|
||||
const ulong k_tile_nb3 = use_kv_pad ? (ulong) n_head_kv * k_tile_nb2 : k_nb3;
|
||||
const ulong v_tile_nb2 = use_kv_pad ? (ulong) BLOCK_N * v_nb1 : v_nb2;
|
||||
const ulong v_tile_nb3 = use_kv_pad ? (ulong) n_head_kv * v_tile_nb2 : v_nb3;
|
||||
const global char* k_tile_base = use_kv_pad ? (const global char*) k_pad_void : k_base;
|
||||
const global char* v_tile_base = use_kv_pad ? (const global char*) v_pad_void : v_base;
|
||||
|
||||
for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) {
|
||||
const int row = i / DK_VEC;
|
||||
const int col = i % DK_VEC;
|
||||
const int k_row_idx = k_start + row;
|
||||
if (k_row_idx < n_kv) {
|
||||
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1;
|
||||
l_k[row][col] = ((__global KV_DATA_TYPE4*)(k_base + k_row_offset))[col];
|
||||
const int k_row_idx = k_tile_start + row;
|
||||
if (use_kv_pad || k_row_idx < n_kv) {
|
||||
const ulong k_row_offset = batch_idx * k_tile_nb3 + head_kv_idx * k_tile_nb2 + k_row_idx * k_nb1;
|
||||
l_k[row][col] = ((__global KV_DATA_TYPE4*)(k_tile_base + k_row_offset))[col];
|
||||
} else {
|
||||
l_k[row][col] = (KV_DATA_TYPE4)(0.0h);
|
||||
}
|
||||
}
|
||||
for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) {
|
||||
const int row = i / DV_VEC;
|
||||
const int col = i % DV_VEC;
|
||||
const int v_row_idx = k_start + row;
|
||||
if (v_row_idx < n_kv) {
|
||||
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1;
|
||||
l_v[row][col] = ((__global KV_DATA_TYPE4*)(v_base + v_row_offset))[col];
|
||||
const int v_row_idx = k_tile_start + row;
|
||||
if (use_kv_pad || v_row_idx < n_kv) {
|
||||
const ulong v_row_offset = batch_idx * v_tile_nb3 + head_kv_idx * v_tile_nb2 + v_row_idx * v_nb1;
|
||||
l_v[row][col] = ((__global KV_DATA_TYPE4*)(v_tile_base + v_row_offset))[col];
|
||||
} else {
|
||||
l_v[row][col] = (KV_DATA_TYPE4)(0.0h);
|
||||
}
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
if (my_query_row >= n_q) {
|
||||
continue;
|
||||
#if N_SPLIT > 1 && defined(HAS_SUBGROUP_SHUFFLE)
|
||||
{
|
||||
const int dv_off = split_idx * SPLIT_DV_VEC;
|
||||
for (int j = 0; j < BLOCK_N; j += 2) {
|
||||
const int k_row0 = k_start + j;
|
||||
const int k_row1 = k_start + j + 1;
|
||||
|
||||
ACC_TYPE partial0 = 0.0f;
|
||||
ACC_TYPE partial1 = 0.0f;
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < SPLIT_DK_VEC; k++) {
|
||||
const ACC_TYPE4 qk = q_priv[k];
|
||||
ACC_TYPE4 dot0 = qk * CONVERT_KV_ACC4(l_k[j ][dk_off + k]);
|
||||
ACC_TYPE4 dot1 = qk * CONVERT_KV_ACC4(l_k[j+1][dk_off + k]);
|
||||
partial0 += dot0.s0 + dot0.s1 + dot0.s2 + dot0.s3;
|
||||
partial1 += dot1.s0 + dot1.s1 + dot1.s2 + dot1.s3;
|
||||
}
|
||||
|
||||
FA_UNROLL
|
||||
for (int step = 1; step < N_SPLIT; step <<= 1) {
|
||||
partial0 += sub_group_shuffle_xor(partial0, step);
|
||||
partial1 += sub_group_shuffle_xor(partial1, step);
|
||||
}
|
||||
|
||||
ACC_TYPE score0 = partial0 * scale;
|
||||
ACC_TYPE score1 = partial1 * scale;
|
||||
|
||||
if (!query_valid) { score0 = FA_M_INIT; score1 = FA_M_INIT; }
|
||||
if (is_causal) {
|
||||
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = FA_M_INIT;
|
||||
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = FA_M_INIT;
|
||||
}
|
||||
if (k_row0 >= n_kv) score0 = FA_M_INIT;
|
||||
if (k_row1 >= n_kv) score1 = FA_M_INIT;
|
||||
|
||||
if (query_valid && mask_base != NULL && blk_cur != 2) {
|
||||
if (use_kv_pad && mask_pad_base != NULL) {
|
||||
const global MASK_DATA_TYPE* mask_ptr =
|
||||
(const global MASK_DATA_TYPE*)(mask_pad_base + my_query_row * mask_pad_nb1);
|
||||
score0 += slope * (ACC_TYPE)mask_ptr[j];
|
||||
score1 += slope * (ACC_TYPE)mask_ptr[j + 1];
|
||||
} else {
|
||||
const global MASK_DATA_TYPE* mask_ptr =
|
||||
(const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
|
||||
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
||||
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
||||
}
|
||||
}
|
||||
|
||||
if (logit_softcap > 0.0f) {
|
||||
score0 = logit_softcap * tanh(score0 / logit_softcap);
|
||||
score1 = logit_softcap * tanh(score1 / logit_softcap);
|
||||
}
|
||||
|
||||
const ACC_TYPE m_new = max(m_i, max(score0, score1));
|
||||
// Whole tile masked (m_new == FA_M_INIT): force the exp() args
|
||||
// far negative so the tile contributes 0, not exp(0)=1.
|
||||
const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new;
|
||||
const ACC_TYPE sp = native_exp(m_i - m_exp);
|
||||
const ACC_TYPE p0 = native_exp(score0 - m_exp);
|
||||
const ACC_TYPE p1 = native_exp(score1 - m_exp);
|
||||
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
|
||||
o_acc[i] = o_acc[i] * sp
|
||||
+ p0 * CONVERT_KV_ACC4(l_v[j ][dv_off + i])
|
||||
+ p1 * CONVERT_KV_ACC4(l_v[j+1][dv_off + i]);
|
||||
}
|
||||
l_i = l_i * sp + p0 + p1;
|
||||
m_i = m_new;
|
||||
}
|
||||
}
|
||||
|
||||
for (int j = 0; j < BLOCK_N; j += 2) {
|
||||
const int k_row0 = k_start + j;
|
||||
const int k_row1 = k_start + j + 1;
|
||||
|
||||
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
dot_acc0 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j][k]), dot_acc0);
|
||||
dot_acc1 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1);
|
||||
#elif N_SPLIT > 1
|
||||
// N_SPLIT>1 fallback (no shuffle): 3-phase local-memory reduction.
|
||||
// Phase 1 — partial dots for all BLOCK_N tokens.
|
||||
for (int j = 0; j < BLOCK_N; ++j) {
|
||||
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < SPLIT_DK_VEC; k++) {
|
||||
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j][dk_off + k]), dot_acc);
|
||||
}
|
||||
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
|
||||
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
|
||||
|
||||
if (is_causal) {
|
||||
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
|
||||
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
|
||||
}
|
||||
|
||||
if (k_row0 >= n_kv) score0 = -INFINITY;
|
||||
if (k_row1 >= n_kv) score1 = -INFINITY;
|
||||
|
||||
if (mask_base != NULL) {
|
||||
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
|
||||
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
||||
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
||||
}
|
||||
|
||||
if (logit_softcap > 0.0f) {
|
||||
score0 = logit_softcap * tanh(score0 / logit_softcap);
|
||||
score1 = logit_softcap * tanh(score1 / logit_softcap);
|
||||
}
|
||||
|
||||
const ACC_TYPE m_new = max(m_i, max(score0, score1));
|
||||
const ACC_TYPE p0 = exp(score0 - m_new);
|
||||
const ACC_TYPE p1 = exp(score1 - m_new);
|
||||
const ACC_TYPE scale_prev = exp(m_i - m_new);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_KV_ACC4(l_v[j][i]) + p1 * CONVERT_KV_ACC4(l_v[j+1][i]);
|
||||
}
|
||||
l_i = l_i * scale_prev + p0 + p1;
|
||||
m_i = m_new;
|
||||
local_partial[j][tid] =
|
||||
dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3;
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE); // 1 barrier: partial dots visible
|
||||
|
||||
// Phase 2 — split_idx==0 reduces partial sums and computes block softmax.
|
||||
if (split_idx == 0) {
|
||||
if (query_valid) {
|
||||
ACC_TYPE m_new = m_i;
|
||||
for (int j = 0; j < BLOCK_N; ++j) {
|
||||
const int k_row = k_start + j;
|
||||
ACC_TYPE score = 0.0f;
|
||||
FA_UNROLL
|
||||
for (int s = 0; s < N_SPLIT; s++) {
|
||||
score += local_partial[j][q_lane * N_SPLIT + s];
|
||||
}
|
||||
score *= scale;
|
||||
|
||||
if (is_causal && k_row > (n_kv - n_q + my_query_row)) score = FA_M_INIT;
|
||||
if (k_row >= n_kv) score = FA_M_INIT;
|
||||
|
||||
if (mask_base != NULL && blk_cur != 2) {
|
||||
if (use_kv_pad && mask_pad_base != NULL) {
|
||||
const global MASK_DATA_TYPE* mask_ptr =
|
||||
(const global MASK_DATA_TYPE*)(mask_pad_base + my_query_row * mask_pad_nb1);
|
||||
score += slope * (ACC_TYPE)mask_ptr[j];
|
||||
} else {
|
||||
const global MASK_DATA_TYPE* mask_ptr =
|
||||
(const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
|
||||
if (k_row < n_kv) score += slope * (ACC_TYPE)mask_ptr[k_row];
|
||||
}
|
||||
}
|
||||
|
||||
if (logit_softcap > 0.0f) {
|
||||
score = logit_softcap * tanh(score / logit_softcap);
|
||||
}
|
||||
|
||||
m_new = max(m_new, score);
|
||||
local_p[q_lane][j] = score;
|
||||
}
|
||||
|
||||
const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new;
|
||||
const ACC_TYPE sp = native_exp(m_i - m_exp);
|
||||
ACC_TYPE l_new = l_i * sp;
|
||||
for (int j = 0; j < BLOCK_N; ++j) {
|
||||
const ACC_TYPE p = native_exp(local_p[q_lane][j] - m_exp);
|
||||
local_p[q_lane][j] = p;
|
||||
l_new += p;
|
||||
}
|
||||
local_softmax_scale[q_lane] = sp;
|
||||
l_i = l_new;
|
||||
m_i = m_new;
|
||||
} else {
|
||||
local_softmax_scale[q_lane] = 1.0f;
|
||||
for (int j = 0; j < BLOCK_N; ++j) local_p[q_lane][j] = 0.0f;
|
||||
}
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// Phase 3 — V accumulate using broadcast probabilities.
|
||||
{
|
||||
const ACC_TYPE sp_block = local_softmax_scale[q_lane];
|
||||
const int dv_off = split_idx * SPLIT_DV_VEC;
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
|
||||
o_acc[i] *= sp_block;
|
||||
}
|
||||
for (int j = 0; j < BLOCK_N; ++j) {
|
||||
const ACC_TYPE p = local_p[q_lane][j];
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
|
||||
o_acc[i] = mad(p, CONVERT_KV_ACC4(l_v[j][dv_off + i]), o_acc[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
// N_SPLIT==1: j+=4 unroll. Requires BLOCK_N % 4 == 0.
|
||||
if (query_valid) {
|
||||
for (int j = 0; j < BLOCK_N; j += 4) {
|
||||
const int k_row0 = k_start + j;
|
||||
const int k_row1 = k_start + j + 1;
|
||||
const int k_row2 = k_start + j + 2;
|
||||
const int k_row3 = k_start + j + 3;
|
||||
|
||||
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f);
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
const ACC_TYPE4 qk = q_priv[k];
|
||||
dot_acc0 = mad(qk, CONVERT_KV_ACC4(l_k[j][k]), dot_acc0);
|
||||
dot_acc1 = mad(qk, CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1);
|
||||
dot_acc2 = mad(qk, CONVERT_KV_ACC4(l_k[j+2][k]), dot_acc2);
|
||||
dot_acc3 = mad(qk, CONVERT_KV_ACC4(l_k[j+3][k]), dot_acc3);
|
||||
}
|
||||
ACC_TYPE s0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
|
||||
ACC_TYPE s1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
|
||||
ACC_TYPE s2 = (dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3) * scale;
|
||||
ACC_TYPE s3 = (dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3) * scale;
|
||||
|
||||
if (is_causal) {
|
||||
const int causal_limit = n_kv - n_q + my_query_row;
|
||||
if (k_row0 > causal_limit) s0 = FA_M_INIT;
|
||||
if (k_row1 > causal_limit) s1 = FA_M_INIT;
|
||||
if (k_row2 > causal_limit) s2 = FA_M_INIT;
|
||||
if (k_row3 > causal_limit) s3 = FA_M_INIT;
|
||||
}
|
||||
if (k_row0 >= n_kv) s0 = FA_M_INIT;
|
||||
if (k_row1 >= n_kv) s1 = FA_M_INIT;
|
||||
if (k_row2 >= n_kv) s2 = FA_M_INIT;
|
||||
if (k_row3 >= n_kv) s3 = FA_M_INIT;
|
||||
|
||||
if (mask_base != NULL && blk_cur != 2) {
|
||||
if (use_kv_pad && mask_pad_base != NULL) {
|
||||
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_pad_base + my_query_row * mask_pad_nb1);
|
||||
s0 += slope * (ACC_TYPE)mask_ptr[j];
|
||||
s1 += slope * (ACC_TYPE)mask_ptr[j + 1];
|
||||
s2 += slope * (ACC_TYPE)mask_ptr[j + 2];
|
||||
s3 += slope * (ACC_TYPE)mask_ptr[j + 3];
|
||||
} else {
|
||||
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
|
||||
if (k_row0 < n_kv) s0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
||||
if (k_row1 < n_kv) s1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
||||
if (k_row2 < n_kv) s2 += slope * (ACC_TYPE)mask_ptr[k_row2];
|
||||
if (k_row3 < n_kv) s3 += slope * (ACC_TYPE)mask_ptr[k_row3];
|
||||
}
|
||||
}
|
||||
|
||||
if (logit_softcap > 0.0f) {
|
||||
s0 = logit_softcap * tanh(s0 / logit_softcap);
|
||||
s1 = logit_softcap * tanh(s1 / logit_softcap);
|
||||
s2 = logit_softcap * tanh(s2 / logit_softcap);
|
||||
s3 = logit_softcap * tanh(s3 / logit_softcap);
|
||||
}
|
||||
|
||||
const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3)));
|
||||
// Whole tile masked (m_new == FA_M_INIT): force the exp() args
|
||||
// far negative so the tile contributes 0, not exp(0)=1.
|
||||
const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new;
|
||||
const ACC_TYPE scale_prev = native_exp(m_i - m_exp);
|
||||
const ACC_TYPE p0 = native_exp(s0 - m_exp);
|
||||
const ACC_TYPE p1 = native_exp(s1 - m_exp);
|
||||
const ACC_TYPE p2 = native_exp(s2 - m_exp);
|
||||
const ACC_TYPE p3 = native_exp(s3 - m_exp);
|
||||
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] = mad(p3, CONVERT_KV_ACC4(l_v[j+3][i]),
|
||||
mad(p2, CONVERT_KV_ACC4(l_v[j+2][i]),
|
||||
mad(p1, CONVERT_KV_ACC4(l_v[j+1][i]),
|
||||
mad(p0, CONVERT_KV_ACC4(l_v[j][i]),
|
||||
o_acc[i] * scale_prev))));
|
||||
}
|
||||
l_i = l_i * scale_prev + p0 + p1 + p2 + p3;
|
||||
m_i = m_new;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
// End of tile: every thread must finish reading l_k/l_v before the
|
||||
// next iteration's load overwrites them (WAR hazard on local memory).
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
|
||||
if (my_query_row < n_q) {
|
||||
// Write output.
|
||||
#if N_SPLIT > 1 && defined(HAS_SUBGROUP_SHUFFLE)
|
||||
if (query_valid) {
|
||||
ACC_TYPE sinks_sp = 1.0f;
|
||||
if (sinks_void != NULL) {
|
||||
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
|
||||
const ACC_TYPE m_sink = sinks_ptr[head_idx];
|
||||
const ACC_TYPE m_final = max(m_i, m_sink);
|
||||
sinks_sp = exp(m_i - m_final);
|
||||
l_i = l_i * sinks_sp + exp(m_sink - m_final);
|
||||
m_i = m_final;
|
||||
}
|
||||
const ACC_TYPE l_inv = (l_i > 0.0f) ? (1.0f / l_i) : 0.0f;
|
||||
const int dv_off = split_idx * SPLIT_DV_VEC;
|
||||
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
|
||||
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
|
||||
if (l_inv > 0.0f) {
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
|
||||
o_row[dv_off + i] = CONVERT_O_DATA4(o_acc[i] * sinks_sp * l_inv);
|
||||
}
|
||||
} else {
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
|
||||
o_row[dv_off + i] = (O_DATA_TYPE4)(0.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif N_SPLIT > 1
|
||||
if (split_idx == 0) {
|
||||
ACC_TYPE sinks_sp = 1.0f;
|
||||
if (query_valid && sinks_void != NULL) {
|
||||
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
|
||||
const ACC_TYPE m_sink = sinks_ptr[head_idx];
|
||||
const ACC_TYPE m_final = max(m_i, m_sink);
|
||||
sinks_sp = exp(m_i - m_final);
|
||||
l_i = l_i * sinks_sp + exp(m_sink - m_final);
|
||||
m_i = m_final;
|
||||
}
|
||||
local_softmax_scale[q_lane] = sinks_sp;
|
||||
local_l_inv[q_lane] = (query_valid && l_i > 0.0f) ? (1.0f / l_i) : 0.0f;
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
if (query_valid) {
|
||||
const ACC_TYPE sinks_sp = local_softmax_scale[q_lane];
|
||||
const ACC_TYPE l_inv = local_l_inv[q_lane];
|
||||
const int dv_off = split_idx * SPLIT_DV_VEC;
|
||||
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
|
||||
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
|
||||
if (l_inv > 0.0f) {
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
|
||||
o_row[dv_off + i] = CONVERT_O_DATA4(o_acc[i] * sinks_sp * l_inv);
|
||||
}
|
||||
} else {
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
|
||||
o_row[dv_off + i] = (O_DATA_TYPE4)(0.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
if (query_valid) {
|
||||
if (sinks_void != NULL) {
|
||||
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
|
||||
const ACC_TYPE m_sink = sinks_ptr[head_idx];
|
||||
const ACC_TYPE m_final = max(m_i, m_sink);
|
||||
|
||||
const ACC_TYPE scale_o = exp(m_i - m_final);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] *= scale_o;
|
||||
}
|
||||
@@ -194,17 +543,18 @@ __kernel void flash_attn_f32_f16(
|
||||
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
|
||||
if (l_i > 0.0f) {
|
||||
const ACC_TYPE l_inv = 1.0f / l_i;
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_row[i] = CONVERT_O_DATA4(o_acc[i] * l_inv);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_row[i] = (O_DATA_TYPE4)(0.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__kernel void flash_attn_f32_f16_q1(
|
||||
@@ -258,13 +608,16 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
|
||||
}
|
||||
|
||||
ACC_TYPE4 q_priv[DK_VEC];
|
||||
// Q is uniform across WG threads (n_q=1). Share via local memory to
|
||||
// avoid per-thread q_priv[DK_VEC] dynamic-indexed private array that
|
||||
// spills to DDR on Adreno.
|
||||
__local ACC_TYPE4 q_shared[DK_VEC];
|
||||
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
|
||||
const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DK_VEC; ++i) {
|
||||
q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);
|
||||
for (int i = tid; i < DK_VEC; i += Q1_WG_SIZE) {
|
||||
q_shared[i] = CONVERT_Q_ACC4(q_ptr[i]);
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
|
||||
|
||||
@@ -273,14 +626,14 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
|
||||
}
|
||||
|
||||
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
|
||||
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT;
|
||||
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
|
||||
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
|
||||
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
|
||||
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
|
||||
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
|
||||
}
|
||||
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
|
||||
if (mask_base != NULL) {
|
||||
@@ -296,7 +649,7 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
__local ACC_TYPE local_m[Q1_WG_SIZE];
|
||||
local_m[tid] = m_i;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -304,7 +657,7 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
const ACC_TYPE m_final = local_m[0];
|
||||
|
||||
ACC_TYPE4 o_acc[DV_VEC];
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE l_i = 0.0f;
|
||||
|
||||
@@ -314,9 +667,9 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
|
||||
const global KV_DATA_TYPE4* v_ptr = (const global KV_DATA_TYPE4*)(v_base + v_row_offset);
|
||||
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
|
||||
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
|
||||
}
|
||||
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
|
||||
if (mask_base != NULL) {
|
||||
@@ -328,7 +681,7 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
}
|
||||
const ACC_TYPE p = exp(score - m_final);
|
||||
l_i += p;
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; i++) {
|
||||
o_acc[i] = mad(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]);
|
||||
}
|
||||
@@ -338,7 +691,7 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
|
||||
local_l[tid] = l_i;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_l[tid] += local_l[tid + s];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -357,7 +710,7 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
for (int i = 0; i < DV_VEC; i++) {
|
||||
local_o_comp[tid] = o_acc[i];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -367,7 +720,257 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
}
|
||||
}
|
||||
} else if (tid == 0) {
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (O_DATA_TYPE4)(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
// Flash-decoding split pass. gid(2) = q_idx * n_splits + split_idx.
|
||||
// Partial record per split: [m, l, O[DV]]. Merge kernel applies sink + norm.
|
||||
#define FA_PARTIAL_FLOATS (2 + DV)
|
||||
|
||||
__kernel void flash_attn_f32_f16_q1_split(
|
||||
const global void * q_void, ulong q_offset,
|
||||
const global void * k_void, ulong k_offset,
|
||||
const global void * v_void, ulong v_offset,
|
||||
const float scale,
|
||||
const int n_q,
|
||||
const int n_kv,
|
||||
const int n_head,
|
||||
const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
|
||||
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
|
||||
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
|
||||
const float max_bias,
|
||||
const float m0,
|
||||
const float m1,
|
||||
const int n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int n_head_kv,
|
||||
const global void * mask_void,
|
||||
const ulong mask_offset,
|
||||
const ulong mask_nb1,
|
||||
const ulong mask_nb2,
|
||||
const ulong mask_nb3,
|
||||
const int mask_ne2,
|
||||
const int mask_ne3,
|
||||
global float * partial_void,
|
||||
const int n_splits,
|
||||
const int kv_per_split
|
||||
) {
|
||||
const int tid = get_local_id(0);
|
||||
const int head_batch_idx = get_global_id(1);
|
||||
const int split_q_idx = get_global_id(2);
|
||||
const int split_idx = split_q_idx % n_splits;
|
||||
const int q_idx = split_q_idx / n_splits;
|
||||
const int batch_idx = head_batch_idx / n_head;
|
||||
const int head_idx = head_batch_idx % n_head;
|
||||
const int gqa_ratio = n_head / n_head_kv;
|
||||
const int head_kv_idx = head_idx / gqa_ratio;
|
||||
|
||||
const int kv_start = split_idx * kv_per_split;
|
||||
const int kv_end = min(kv_start + kv_per_split, n_kv);
|
||||
|
||||
const ulong record_stride = (ulong) FA_PARTIAL_FLOATS;
|
||||
const ulong record_idx = ((((ulong) batch_idx * n_head + head_idx) * n_q + q_idx)
|
||||
* n_splits + split_idx);
|
||||
global float * rec = partial_void + record_idx * record_stride;
|
||||
global float4 * rec_o = (global float4 *) (rec + 2);
|
||||
|
||||
if (kv_start >= kv_end) {
|
||||
// Empty split: leave sentinel partial for merge.
|
||||
if (tid == 0) {
|
||||
rec[0] = FA_M_INIT;
|
||||
rec[1] = 0.0f;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const global char * q_base = (const global char *) q_void + q_offset;
|
||||
const global char * k_base = (const global char *) k_void + k_offset;
|
||||
const global char * v_base = (const global char *) v_void + v_offset;
|
||||
|
||||
const global char * mask_base = NULL;
|
||||
if (mask_void != NULL) {
|
||||
const int mask_head_idx = head_idx % mask_ne2;
|
||||
const int mask_batch_idx = batch_idx % mask_ne3;
|
||||
mask_base = (const global char *) mask_void + mask_offset +
|
||||
mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2 +
|
||||
(ulong) q_idx * mask_nb1;
|
||||
}
|
||||
|
||||
// Share Q via local memory (n_q=1 per split -> uniform across WG).
|
||||
__local ACC_TYPE4 q_shared[DK_VEC];
|
||||
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + (ulong) q_idx * q_nb1;
|
||||
const global Q_DATA_TYPE4 * q_ptr = (const global Q_DATA_TYPE4 *) (q_base + q_row_offset);
|
||||
for (int i = tid; i < DK_VEC; i += Q1_WG_SIZE) {
|
||||
q_shared[i] = CONVERT_Q_ACC4(q_ptr[i]);
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
|
||||
|
||||
// Pass 1a — split-local max.
|
||||
ACC_TYPE m_i = FA_M_INIT;
|
||||
for (int k_idx = kv_start + tid; k_idx < kv_end; k_idx += Q1_WG_SIZE) {
|
||||
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
|
||||
const global KV_DATA_TYPE4 * k_ptr = (const global KV_DATA_TYPE4 *) (k_base + k_row_offset);
|
||||
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
for (int k = 0; k < DK_VEC; ++k) {
|
||||
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
|
||||
}
|
||||
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
|
||||
if (mask_base != NULL) {
|
||||
const global MASK_DATA_TYPE * mask_ptr = (const global MASK_DATA_TYPE *) (mask_base);
|
||||
score += slope * (ACC_TYPE) mask_ptr[k_idx];
|
||||
}
|
||||
if (logit_softcap > 0.0f) {
|
||||
score = logit_softcap * tanh(score / logit_softcap);
|
||||
}
|
||||
m_i = max(m_i, score);
|
||||
}
|
||||
|
||||
__local ACC_TYPE local_m[Q1_WG_SIZE];
|
||||
local_m[tid] = m_i;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
const ACC_TYPE m_c = local_m[0];
|
||||
|
||||
// Pass 1b — softmax-weighted V accumulate.
|
||||
ACC_TYPE4 o_acc[DV_VEC];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE l_i = 0.0f;
|
||||
|
||||
for (int k_idx = kv_start + tid; k_idx < kv_end; k_idx += Q1_WG_SIZE) {
|
||||
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
|
||||
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1;
|
||||
const global KV_DATA_TYPE4 * k_ptr = (const global KV_DATA_TYPE4 *) (k_base + k_row_offset);
|
||||
const global KV_DATA_TYPE4 * v_ptr = (const global KV_DATA_TYPE4 *) (v_base + v_row_offset);
|
||||
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
for (int k = 0; k < DK_VEC; ++k) {
|
||||
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
|
||||
}
|
||||
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
|
||||
if (mask_base != NULL) {
|
||||
const global MASK_DATA_TYPE * mask_ptr = (const global MASK_DATA_TYPE *) (mask_base);
|
||||
score += slope * (ACC_TYPE) mask_ptr[k_idx];
|
||||
}
|
||||
if (logit_softcap > 0.0f) {
|
||||
score = logit_softcap * tanh(score / logit_softcap);
|
||||
}
|
||||
const ACC_TYPE p = exp(score - m_c);
|
||||
l_i += p;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] = mad(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]);
|
||||
}
|
||||
}
|
||||
|
||||
__local ACC_TYPE local_l[Q1_WG_SIZE];
|
||||
__local ACC_TYPE4 local_o[Q1_WG_SIZE];
|
||||
local_l[tid] = l_i;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_l[tid] += local_l[tid + s];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
const ACC_TYPE l_c = local_l[0];
|
||||
|
||||
if (tid == 0) {
|
||||
rec[0] = (float) m_c;
|
||||
rec[1] = (float) l_c;
|
||||
}
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
local_o[tid] = o_acc[i];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_o[tid] += local_o[tid + s];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
if (tid == 0) {
|
||||
rec_o[i] = local_o[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FD Pass 2: merge per-split partials into final O. Empty splits drop via exp(-INF)=0.
|
||||
__kernel void flash_attn_f32_merge(
|
||||
const global float * partial_void,
|
||||
global void * o_void,
|
||||
const ulong o_offset,
|
||||
const int n_head,
|
||||
const int n_splits,
|
||||
const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
|
||||
const global void * sinks_void,
|
||||
const ulong sinks_offset,
|
||||
const int n_q
|
||||
) {
|
||||
const int lane = get_local_id(0); // 0..DV_VEC-1
|
||||
const int head_batch_idx = get_global_id(1);
|
||||
const int q_idx = get_global_id(2);
|
||||
const int batch_idx = head_batch_idx / n_head;
|
||||
const int head_idx = head_batch_idx % n_head;
|
||||
|
||||
const ulong record_stride = (ulong) FA_PARTIAL_FLOATS;
|
||||
const ulong record_idx_0 = (((ulong) batch_idx * n_head + head_idx) * n_q + q_idx) * n_splits;
|
||||
const global float * rec0 = partial_void + record_idx_0 * record_stride;
|
||||
|
||||
__local ACC_TYPE m_final_shared;
|
||||
__local ACC_TYPE l_final_shared;
|
||||
if (lane == 0) {
|
||||
ACC_TYPE m = FA_M_INIT;
|
||||
for (int c = 0; c < n_splits; ++c) {
|
||||
const ACC_TYPE m_c = rec0[c * record_stride + 0];
|
||||
m = max(m, m_c);
|
||||
}
|
||||
ACC_TYPE m_sink = 0.0f;
|
||||
bool has_sink = false;
|
||||
if (sinks_void != NULL) {
|
||||
const global ACC_TYPE * sinks_ptr =
|
||||
(const global ACC_TYPE *) ((const global char *) sinks_void + sinks_offset);
|
||||
m_sink = sinks_ptr[head_idx];
|
||||
has_sink = true;
|
||||
m = max(m, m_sink);
|
||||
}
|
||||
ACC_TYPE l = 0.0f;
|
||||
for (int c = 0; c < n_splits; ++c) {
|
||||
const ACC_TYPE m_c = rec0[c * record_stride + 0];
|
||||
const ACC_TYPE l_c = rec0[c * record_stride + 1];
|
||||
if (m_c > FA_M_INIT) {
|
||||
l += l_c * exp(m_c - m);
|
||||
}
|
||||
}
|
||||
if (has_sink) {
|
||||
l += exp(m_sink - m);
|
||||
}
|
||||
m_final_shared = m;
|
||||
l_final_shared = l;
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
const ACC_TYPE m_final = m_final_shared;
|
||||
const ACC_TYPE l_final = l_final_shared;
|
||||
const ACC_TYPE l_inv = (l_final > 0.0f) ? (1.0f / l_final) : 0.0f;
|
||||
|
||||
ACC_TYPE4 o = (ACC_TYPE4)(0.0f);
|
||||
for (int c = 0; c < n_splits; ++c) {
|
||||
const global float * rec_c = rec0 + c * record_stride;
|
||||
const ACC_TYPE m_c = rec_c[0];
|
||||
if (m_c <= FA_M_INIT) continue;
|
||||
const global float4 * rec_oc = (const global float4 *) (rec_c + 2);
|
||||
const ACC_TYPE scale_c = exp(m_c - m_final);
|
||||
o = mad((ACC_TYPE4)(scale_c), rec_oc[lane], o);
|
||||
}
|
||||
o = o * l_inv;
|
||||
|
||||
const ulong o_row_offset = (ulong) batch_idx * o_nb3 + (ulong) q_idx * o_nb2 + (ulong) head_idx * o_nb1;
|
||||
global O_DATA_TYPE4 * o_row = (global O_DATA_TYPE4 *) ((global char *) o_void + o_offset + o_row_offset);
|
||||
o_row[lane] = CONVERT_O_DATA4(o);
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,156 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
__kernel void flash_attn_kv_pad_f16(
|
||||
const global void * k_void, ulong k_offset,
|
||||
const global void * v_void, ulong v_offset,
|
||||
global void * k_pad_void,
|
||||
global void * v_pad_void,
|
||||
const int n_kv,
|
||||
const int n_head_kv,
|
||||
const int n_batch,
|
||||
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
|
||||
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3
|
||||
) {
|
||||
const int row_idx = get_global_id(0);
|
||||
const int head_kv_idx = get_global_id(1);
|
||||
const int batch_idx = get_global_id(2);
|
||||
|
||||
if (row_idx >= BLOCK_N || head_kv_idx >= n_head_kv || batch_idx >= n_batch) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int tail_start = n_kv - (n_kv % BLOCK_N);
|
||||
const int src_row_idx = tail_start + row_idx;
|
||||
|
||||
const global char * k_src = (const global char *) k_void + k_offset;
|
||||
const global char * v_src = (const global char *) v_void + v_offset;
|
||||
global char * k_pad = (global char *) k_pad_void;
|
||||
global char * v_pad = (global char *) v_pad_void;
|
||||
|
||||
const ulong k_dst_offset = ((ulong) batch_idx * (ulong) n_head_kv + (ulong) head_kv_idx) * ((ulong) BLOCK_N * k_nb1) + (ulong) row_idx * k_nb1;
|
||||
const ulong v_dst_offset = ((ulong) batch_idx * (ulong) n_head_kv + (ulong) head_kv_idx) * ((ulong) BLOCK_N * v_nb1) + (ulong) row_idx * v_nb1;
|
||||
|
||||
if (src_row_idx < n_kv) {
|
||||
const ulong k_src_offset = (ulong) batch_idx * k_nb3 + (ulong) head_kv_idx * k_nb2 + (ulong) src_row_idx * k_nb1;
|
||||
const ulong v_src_offset = (ulong) batch_idx * v_nb3 + (ulong) head_kv_idx * v_nb2 + (ulong) src_row_idx * v_nb1;
|
||||
|
||||
for (ulong i = 0; i < k_nb1; ++i) {
|
||||
k_pad[k_dst_offset + i] = k_src[k_src_offset + i];
|
||||
}
|
||||
for (ulong i = 0; i < v_nb1; ++i) {
|
||||
v_pad[v_dst_offset + i] = v_src[v_src_offset + i];
|
||||
}
|
||||
} else {
|
||||
for (ulong i = 0; i < k_nb1; ++i) {
|
||||
k_pad[k_dst_offset + i] = 0;
|
||||
}
|
||||
for (ulong i = 0; i < v_nb1; ++i) {
|
||||
v_pad[v_dst_offset + i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__kernel void flash_attn_mask_pad_f16(
|
||||
const global void * mask_void, ulong mask_offset,
|
||||
global void * mask_pad_void,
|
||||
const int n_q,
|
||||
const int n_kv,
|
||||
const ulong mask_nb1,
|
||||
const ulong mask_nb2,
|
||||
const ulong mask_nb3,
|
||||
const int mask_ne2,
|
||||
const int mask_ne3
|
||||
) {
|
||||
const int col_idx = get_global_id(0);
|
||||
const int q_row = get_global_id(1);
|
||||
const int mask_slice = get_global_id(2);
|
||||
|
||||
if (col_idx >= BLOCK_N || q_row >= n_q || mask_slice >= mask_ne2 * mask_ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int tail_start = n_kv - (n_kv % BLOCK_N);
|
||||
const int src_col_idx = tail_start + col_idx;
|
||||
const int mask_head_idx = mask_slice % mask_ne2;
|
||||
const int mask_batch_idx = mask_slice / mask_ne2;
|
||||
|
||||
const global char * mask_src_base = (const global char *) mask_void + mask_offset +
|
||||
(ulong) mask_batch_idx * mask_nb3 +
|
||||
(ulong) mask_head_idx * mask_nb2 +
|
||||
(ulong) q_row * mask_nb1;
|
||||
const global half * mask_src = (const global half *) mask_src_base;
|
||||
|
||||
global half * mask_pad = (global half *) mask_pad_void;
|
||||
const ulong dst_idx =
|
||||
(((ulong) mask_batch_idx * (ulong) mask_ne2 + (ulong) mask_head_idx) * (ulong) n_q + (ulong) q_row) * (ulong) BLOCK_N +
|
||||
(ulong) col_idx;
|
||||
|
||||
mask_pad[dst_idx] = src_col_idx < n_kv ? mask_src[src_col_idx] : (half) (-INFINITY);
|
||||
}
|
||||
|
||||
// Per-KV-tile mask class. 0=all -inf (skip tile), 1=mixed (apply mask),
|
||||
// 2=all zero, no -inf (skip mask lookup). Causal diagonal tiles are class 1.
|
||||
__kernel void flash_attn_blk_f16(
|
||||
const global void * mask_void, ulong mask_offset,
|
||||
global char * blk,
|
||||
const int n_q,
|
||||
const int n_kv,
|
||||
const ulong mask_nb1,
|
||||
const ulong mask_nb2,
|
||||
const ulong mask_nb3,
|
||||
const int mask_ne2,
|
||||
const int mask_ne3
|
||||
) {
|
||||
const int kv_block_idx = get_global_id(0);
|
||||
const int q_block_idx = get_global_id(1);
|
||||
const int mask_slice = get_global_id(2);
|
||||
|
||||
const int n_q_blocks = (n_q + BLOCK_M - 1) / BLOCK_M;
|
||||
const int n_kv_blocks = (n_kv + BLOCK_N - 1) / BLOCK_N;
|
||||
if (kv_block_idx >= n_kv_blocks || q_block_idx >= n_q_blocks || mask_slice >= mask_ne2 * mask_ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int mask_head_idx = mask_slice % mask_ne2;
|
||||
const int mask_batch_idx = mask_slice / mask_ne2;
|
||||
const int q_start = q_block_idx * BLOCK_M;
|
||||
const int k_start = kv_block_idx * BLOCK_N;
|
||||
const int q_count = min(BLOCK_M, n_q - q_start);
|
||||
const int k_count = min(BLOCK_N, n_kv - k_start);
|
||||
|
||||
const half neg_max_half = (half) (-65504.0f);
|
||||
char has_unmasked = 0;
|
||||
char has_masked = 0;
|
||||
char has_nonzero = 0;
|
||||
|
||||
const global char * mask_base = (const global char *) mask_void + mask_offset +
|
||||
(ulong) mask_batch_idx * mask_nb3 +
|
||||
(ulong) mask_head_idx * mask_nb2;
|
||||
|
||||
for (int qi = 0; qi < q_count; ++qi) {
|
||||
const global half * mask_row = (const global half *) (mask_base + (ulong) (q_start + qi) * mask_nb1) + k_start;
|
||||
for (int ki = 0; ki < k_count; ++ki) {
|
||||
const half v = mask_row[ki];
|
||||
if (v <= neg_max_half) {
|
||||
has_masked = 1;
|
||||
} else {
|
||||
has_unmasked = 1;
|
||||
if (v != (half) 0.0f) {
|
||||
has_nonzero = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (has_masked && has_unmasked) break; // mixed tile — short-circuit.
|
||||
}
|
||||
|
||||
char res;
|
||||
if (has_unmasked == 0) {
|
||||
res = 0;
|
||||
} else if (has_masked || has_nonzero) {
|
||||
res = 1;
|
||||
} else {
|
||||
res = 2;
|
||||
}
|
||||
|
||||
blk[((ulong) mask_slice * (ulong) n_q_blocks + (ulong) q_block_idx) * (ulong) n_kv_blocks + (ulong) kv_block_idx] = res;
|
||||
}
|
||||
@@ -158,6 +158,239 @@ kernel void kernel_set_rows_f32_i32(
|
||||
}
|
||||
}
|
||||
|
||||
// f32 -> q8_0 quantize set_rows. Block = half d + char qs[32].
|
||||
#define QK8_0 32
|
||||
|
||||
inline void quantize_q8_0_block(global float * x, global char * qs, global half * d_out) {
|
||||
float amax = 0.0f;
|
||||
for (int j = 0; j < QK8_0; j++) {
|
||||
amax = fmax(amax, fabs(x[j]));
|
||||
}
|
||||
|
||||
float d = amax / 127.0f;
|
||||
float id = (d != 0.0f) ? 127.0f / amax : 0.0f;
|
||||
|
||||
vstore_half(d, 0, d_out);
|
||||
|
||||
for (int j = 0; j < QK8_0; j++) {
|
||||
qs[j] = (char)((int)round(x[j] * id));
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_set_rows_q8_0_i64(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne01,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
uint4 ne11,
|
||||
uint4 ne12,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
int nblk0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
dst = dst + offsetd;
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
|
||||
|
||||
if (i01 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
int i12 = fastmod(i03, ne12);
|
||||
int i11 = fastmod(i02, ne11);
|
||||
|
||||
int i10 = i01;
|
||||
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
|
||||
|
||||
global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
|
||||
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
|
||||
global float * x = src_row + blk * QK8_0;
|
||||
global char * y = dst_row + blk * (2 + QK8_0);
|
||||
|
||||
quantize_q8_0_block(x, y + 2, (global half *)y);
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_set_rows_q8_0_i32(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne01,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
uint4 ne11,
|
||||
uint4 ne12,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
int nblk0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
dst = dst + offsetd;
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
|
||||
|
||||
if (i01 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
int i12 = fastmod(i03, ne12);
|
||||
int i11 = fastmod(i02, ne11);
|
||||
|
||||
int i10 = i01;
|
||||
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
|
||||
|
||||
global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
|
||||
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
|
||||
global float * x = src_row + blk * QK8_0;
|
||||
global char * y = dst_row + blk * (2 + QK8_0);
|
||||
|
||||
quantize_q8_0_block(x, y + 2, (global half *)y);
|
||||
}
|
||||
}
|
||||
|
||||
// SoA q8_0 variants. dst_q: int8[QK8_0] per block; dst_d: fp16 scale per block.
|
||||
// Layout matches kernel_convert_block_q8_0; block index follows dst element order.
|
||||
kernel void kernel_set_rows_q8_0_soa_i64(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst_q,
|
||||
ulong offset_q,
|
||||
global char * dst_d,
|
||||
ulong offset_d,
|
||||
int ne01,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
uint4 ne11,
|
||||
uint4 ne12,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
int nblk0,
|
||||
int ne1_dst,
|
||||
int ne2_dst,
|
||||
int ne3_dst
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
dst_q = dst_q + offset_q;
|
||||
dst_d = dst_d + offset_d;
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
|
||||
|
||||
if (i01 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
int i12 = fastmod(i03, ne12);
|
||||
int i11 = fastmod(i02, ne11);
|
||||
|
||||
int i10 = i01;
|
||||
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
|
||||
|
||||
long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0;
|
||||
|
||||
global half * d_row = (global half *)(dst_d) + row_blk_base;
|
||||
global char * q_row = (global char *)(dst_q) + row_blk_base * QK8_0;
|
||||
global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
|
||||
global float * x = src_row + blk * QK8_0;
|
||||
global char * q = q_row + blk * QK8_0;
|
||||
|
||||
quantize_q8_0_block(x, q, d_row + blk);
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_set_rows_q8_0_soa_i32(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst_q,
|
||||
ulong offset_q,
|
||||
global char * dst_d,
|
||||
ulong offset_d,
|
||||
int ne01,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
uint4 ne11,
|
||||
uint4 ne12,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
int nblk0,
|
||||
int ne1_dst,
|
||||
int ne2_dst,
|
||||
int ne3_dst
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
dst_q = dst_q + offset_q;
|
||||
dst_d = dst_d + offset_d;
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
|
||||
|
||||
if (i01 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
int i12 = fastmod(i03, ne12);
|
||||
int i11 = fastmod(i02, ne11);
|
||||
|
||||
int i10 = i01;
|
||||
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
|
||||
|
||||
long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0;
|
||||
|
||||
global half * d_row = (global half *)(dst_d) + row_blk_base;
|
||||
global char * q_row = (global char *)(dst_q) + row_blk_base * QK8_0;
|
||||
global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
|
||||
global float * x = src_row + blk * QK8_0;
|
||||
global char * q = q_row + blk * QK8_0;
|
||||
|
||||
quantize_q8_0_block(x, q, d_row + blk);
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_set_rows_f16_i32(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
@@ -206,3 +439,270 @@ kernel void kernel_set_rows_f16_i32(
|
||||
dst_row[ind] = src_row[ind];
|
||||
}
|
||||
}
|
||||
|
||||
// f32 -> q4_0 quantize set_rows. Block = half d + uchar qs[16] (shuffled
|
||||
// nibbles: qs[j] low/high = elem j / j+16).
|
||||
// Dequant: val[i] = d * (nibble_i - 8)
|
||||
// nblk0 = number of q4_0 blocks per row = ne00 / 32.
|
||||
#define QK4_0 32
|
||||
#define Q4_0_BLOCK_SIZE 18
|
||||
|
||||
inline void quantize_q4_0_block(global float * x, global uchar * qs, global half * d_out) {
|
||||
// Find the signed value with the largest absolute magnitude (matches ggml ref).
|
||||
float max = 0.0f;
|
||||
float amax = 0.0f;
|
||||
for (int j = 0; j < QK4_0; j++) {
|
||||
float v = x[j];
|
||||
float a = fabs(v);
|
||||
if (a > amax) {
|
||||
amax = a;
|
||||
max = v;
|
||||
}
|
||||
}
|
||||
|
||||
float d = max / -8.0f;
|
||||
float id = (d != 0.0f) ? 1.0f / d : 0.0f;
|
||||
|
||||
vstore_half(d, 0, d_out);
|
||||
|
||||
for (int j = 0; j < QK4_0/2; j++) {
|
||||
float x0 = x[j] * id;
|
||||
float x1 = x[j + QK4_0/2] * id;
|
||||
|
||||
int i0 = (int)(x0 + 8.5f);
|
||||
int i1 = (int)(x1 + 8.5f);
|
||||
if (i0 < 0) i0 = 0;
|
||||
if (i0 > 15) i0 = 15;
|
||||
if (i1 < 0) i1 = 0;
|
||||
if (i1 > 15) i1 = 15;
|
||||
|
||||
qs[j] = (uchar)i0 | ((uchar)i1 << 4);
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_set_rows_q4_0_i64(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne01,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
uint4 ne11,
|
||||
uint4 ne12,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
int nblk0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
dst = dst + offsetd;
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
|
||||
|
||||
if (i01 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
int i12 = fastmod(i03, ne12);
|
||||
int i11 = fastmod(i02, ne11);
|
||||
|
||||
int i10 = i01;
|
||||
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
|
||||
|
||||
global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
|
||||
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
|
||||
global float * x = src_row + blk * QK4_0;
|
||||
global char * y = dst_row + blk * Q4_0_BLOCK_SIZE;
|
||||
global half * yd = (global half *)(y);
|
||||
global uchar * yqs = (global uchar *)(y + 2);
|
||||
|
||||
quantize_q4_0_block(x, yqs, yd);
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_set_rows_q4_0_i32(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne01,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
uint4 ne11,
|
||||
uint4 ne12,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
int nblk0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
dst = dst + offsetd;
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
|
||||
|
||||
if (i01 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
int i12 = fastmod(i03, ne12);
|
||||
int i11 = fastmod(i02, ne11);
|
||||
|
||||
int i10 = i01;
|
||||
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
|
||||
|
||||
global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
|
||||
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
|
||||
global float * x = src_row + blk * QK4_0;
|
||||
global char * y = dst_row + blk * Q4_0_BLOCK_SIZE;
|
||||
global half * yd = (global half *)(y);
|
||||
global uchar * yqs = (global uchar *)(y + 2);
|
||||
|
||||
quantize_q4_0_block(x, yqs, yd);
|
||||
}
|
||||
}
|
||||
|
||||
// SoA variants for q4_0 dst. Used when the backend has split block_q4_0 records
|
||||
// into separate quant (dst_q) and scale (dst_d) sub-buffers — same pattern as
|
||||
// the q8_0 SoA variants above.
|
||||
//
|
||||
// Layout (matches kernel_convert_block_q4_0, the "shuffled" variant):
|
||||
// dst_q: contiguous 16 packed nibbles per block, block i at offset i * 16 bytes.
|
||||
// dst_d: contiguous fp16 scales, block i at offset i * 2 bytes.
|
||||
// Nibble layout inside each byte is unchanged from AoS: qs[j] low nibble = element j,
|
||||
// qs[j] high nibble = element j+16. kernel_restore_block_q4_0 copies bytes as-is.
|
||||
kernel void kernel_set_rows_q4_0_soa_i64(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst_q,
|
||||
ulong offset_q,
|
||||
global char * dst_d,
|
||||
ulong offset_d,
|
||||
int ne01,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
uint4 ne11,
|
||||
uint4 ne12,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
int nblk0,
|
||||
int ne1_dst,
|
||||
int ne2_dst,
|
||||
int ne3_dst
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
dst_q = dst_q + offset_q;
|
||||
dst_d = dst_d + offset_d;
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
|
||||
|
||||
if (i01 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
int i12 = fastmod(i03, ne12);
|
||||
int i11 = fastmod(i02, ne11);
|
||||
|
||||
int i10 = i01;
|
||||
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
|
||||
|
||||
long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0;
|
||||
|
||||
global half * d_row = (global half *)(dst_d) + row_blk_base;
|
||||
global uchar * q_row = (global uchar *)(dst_q) + row_blk_base * (QK4_0/2);
|
||||
global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
|
||||
global float * x = src_row + blk * QK4_0;
|
||||
global uchar * qs = q_row + blk * (QK4_0/2);
|
||||
global half * d_bk = d_row + blk;
|
||||
|
||||
quantize_q4_0_block(x, qs, d_bk);
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_set_rows_q4_0_soa_i32(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst_q,
|
||||
ulong offset_q,
|
||||
global char * dst_d,
|
||||
ulong offset_d,
|
||||
int ne01,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
uint4 ne11,
|
||||
uint4 ne12,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
int nblk0,
|
||||
int ne1_dst,
|
||||
int ne2_dst,
|
||||
int ne3_dst
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
dst_q = dst_q + offset_q;
|
||||
dst_d = dst_d + offset_d;
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
|
||||
|
||||
if (i01 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
int i12 = fastmod(i03, ne12);
|
||||
int i11 = fastmod(i02, ne11);
|
||||
|
||||
int i10 = i01;
|
||||
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
|
||||
|
||||
long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0;
|
||||
|
||||
global half * d_row = (global half *)(dst_d) + row_blk_base;
|
||||
global uchar * q_row = (global uchar *)(dst_q) + row_blk_base * (QK4_0/2);
|
||||
global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
|
||||
global float * x = src_row + blk * QK4_0;
|
||||
global uchar * qs = q_row + blk * (QK4_0/2);
|
||||
global half * d_bk = d_row + blk;
|
||||
|
||||
quantize_q4_0_block(x, qs, d_bk);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1053,6 +1053,10 @@ static bool is_op_unsupported_case(const ggml_tensor * op) {
|
||||
(op->ne[0] == 2 && op->ne[1] == 4 && op->ne[2] == 3 && op->ne[3] == 2)) {
|
||||
return true;
|
||||
}
|
||||
// CPY into a strided view of a larger buffer (recurrent-state snapshots) not supported
|
||||
if (op->view_src && ggml_nbytes(op) != ggml_nbytes(op->view_src)) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case GGML_OP_MUL_MAT: {
|
||||
|
||||
@@ -256,7 +256,7 @@ llama_context::llama_context(
|
||||
LLAMA_LOG_INFO("%s: n_outputs_max = %u\n", __func__, cparams.n_outputs_max);
|
||||
|
||||
if (cparams.n_ctx_seq < hparams.n_ctx_train) {
|
||||
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
|
||||
LLAMA_LOG_INFO("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
|
||||
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
|
||||
}
|
||||
|
||||
|
||||
@@ -2890,12 +2890,17 @@ struct test_cpy : public test_case {
|
||||
const std::array<int64_t, 4> ne_dst;
|
||||
const std::array<int64_t, 4> permute_src;
|
||||
const std::array<int64_t, 4> permute_dst;
|
||||
const std::array<int64_t, 4> dst_alloc; // if set, dst is a view into a larger buffer (strided)
|
||||
bool _src_use_permute;
|
||||
bool _dst_use_permute;
|
||||
bool _src_transpose;
|
||||
bool _use_dst_shape;
|
||||
bool _use_dst_alloc;
|
||||
|
||||
std::string vars() override {
|
||||
if (_use_dst_alloc) {
|
||||
return VARS_TO_STR8(type_src, type_dst, ne_src, ne_dst, permute_src, permute_dst, _src_transpose, dst_alloc);
|
||||
}
|
||||
if (_use_dst_shape) {
|
||||
return VARS_TO_STR7(type_src, type_dst, ne_src, ne_dst, permute_src, permute_dst, _src_transpose);
|
||||
}
|
||||
@@ -2943,12 +2948,15 @@ struct test_cpy : public test_case {
|
||||
std::array<int64_t, 4> ne_dst = {-1, -1, -1, -1},
|
||||
std::array<int64_t, 4> permute_src = {0, 0, 0, 0},
|
||||
std::array<int64_t, 4> permute_dst = {0, 0, 0, 0},
|
||||
bool transpose_src = false)
|
||||
bool transpose_src = false,
|
||||
std::array<int64_t, 4> dst_alloc = {0, 0, 0, 0})
|
||||
: type_src(type_src), type_dst(type_dst), ne_src(ne_src), ne_dst(ne_dst), permute_src(permute_src), permute_dst(permute_dst),
|
||||
dst_alloc(dst_alloc),
|
||||
_src_use_permute(permute_src[0] + permute_src[1] + permute_src[2] + permute_src[3] > 0),
|
||||
_dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0),
|
||||
_src_transpose(transpose_src),
|
||||
_use_dst_shape(ne_dst[0] >= 0 && ne_dst[1] >= 0 && ne_dst[2] >= 0 && ne_dst[3] >= 0){}
|
||||
_use_dst_shape(ne_dst[0] >= 0 && ne_dst[1] >= 0 && ne_dst[2] >= 0 && ne_dst[3] >= 0),
|
||||
_use_dst_alloc(dst_alloc[0] > 0){}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne_src.data());
|
||||
@@ -2966,12 +2974,23 @@ struct test_cpy : public test_case {
|
||||
}
|
||||
|
||||
std::array<int64_t, 4> dst_ne = _use_dst_shape ? ne_dst : std::array<int64_t, 4>{src->ne[0], src->ne[1], src->ne[2], src->ne[3]};
|
||||
ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, dst_ne.data());
|
||||
ggml_set_name(dst, "dst");
|
||||
ggml_tensor * dst;
|
||||
|
||||
if (_dst_use_permute) {
|
||||
dst = ggml_permute(ctx, dst, permute_dst[0], permute_dst[1], permute_dst[2], permute_dst[3]);
|
||||
ggml_set_name(dst, "dst_permuted");
|
||||
if (_use_dst_alloc) {
|
||||
// view a sub-block of a larger buffer -> strided dst
|
||||
ggml_tensor * dst_buf = ggml_new_tensor(ctx, type_dst, 4, dst_alloc.data());
|
||||
ggml_set_name(dst_buf, "dst_buf");
|
||||
dst = ggml_view_4d(ctx, dst_buf, dst_ne[0], dst_ne[1], dst_ne[2], dst_ne[3],
|
||||
dst_buf->nb[1], dst_buf->nb[2], dst_buf->nb[3], 0);
|
||||
ggml_set_name(dst, "dst_view");
|
||||
} else {
|
||||
dst = ggml_new_tensor(ctx, type_dst, 4, dst_ne.data());
|
||||
ggml_set_name(dst, "dst");
|
||||
|
||||
if (_dst_use_permute) {
|
||||
dst = ggml_permute(ctx, dst, permute_dst[0], permute_dst[1], permute_dst[2], permute_dst[3]);
|
||||
ggml_set_name(dst, "dst_permuted");
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor * out = ggml_cpy(ctx, src, dst);
|
||||
@@ -8181,6 +8200,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 1, 4, 1}, {-1,-1,-1,-1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {2, 2097121, 1, 1}, {-1,-1,-1,-1}, {1, 0, 2, 3}));
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {2, 2, 524281, 1}, {-1,-1,-1,-1}, {1, 0, 2, 3}));
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {128, 2, 3, 1}, {128, 2, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, false, {128, 4, 3, 1})); // strided dst
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {128, 2, 3, 1}, {128, 2, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, false, {128, 4, 3, 1})); // strided dst
|
||||
|
||||
// CPY - different src/dst shapes (reshaping via CPY)
|
||||
// Use permutations of {3, 5, 7, 32}. Total elements: 3*5*7*32 = 3360.
|
||||
|
||||
@@ -106,7 +106,6 @@ struct server_batch {
|
||||
if ((int32_t)tokens.size() >= n_tokens_alloc) {
|
||||
return false;
|
||||
}
|
||||
// LOG_INF("adding token to batch: slot=%d, token=%d, pos=%d, output=%d\n", id_slot, token, pos, output);
|
||||
tokens.push_back({ id_slot, token, pos, output });
|
||||
return true;
|
||||
}
|
||||
@@ -228,7 +227,7 @@ struct server_slot {
|
||||
|
||||
const size_t cur_size = cur_size_tgt + cur_size_dft;
|
||||
|
||||
SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB (draft: %.3f MiB)\n",
|
||||
SRV_TRC(" - saving prompt with length %d, total state size = %.3f MiB (draft: %.3f MiB)\n",
|
||||
(int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0), cur_size_dft / (1024.0 * 1024.0));
|
||||
|
||||
auto * cur = prompt_cache.alloc(prompt, cur_size_tgt, cur_size_dft);
|
||||
@@ -258,7 +257,7 @@ struct server_slot {
|
||||
GGML_ASSERT(!is_processing());
|
||||
}
|
||||
|
||||
SLT_INF(*this, "clearing prompt with %zu tokens\n", prompt.tokens.size());
|
||||
SLT_TRC(*this, "clearing prompt with %zu tokens\n", prompt.tokens.size());
|
||||
|
||||
common_context_seq_rm(ctx_tgt, id, -1, -1);
|
||||
if (ctx_dft) {
|
||||
@@ -627,8 +626,10 @@ struct server_slot {
|
||||
}
|
||||
|
||||
SLT_INF(*this,
|
||||
"draft acceptance = %0.5f (%5d accepted / %5d generated), mean acceptance length = %5.2f, acceptance rate per position = (%s)\n",
|
||||
draft_ratio, n_draft_accepted, n_draft_total, mean_acc_len, acceptance_rates_per_pos.c_str());
|
||||
"draft acceptance = %0.5f (%5d accepted / %5d generated), mean len = %5.2f\n",
|
||||
draft_ratio, n_draft_accepted, n_draft_total, mean_acc_len);
|
||||
SLT_TRC(*this,
|
||||
" acc per pos = (%s)\n", acceptance_rates_per_pos.c_str());
|
||||
}
|
||||
|
||||
common_speculative_print_stats(spec);
|
||||
@@ -771,7 +772,7 @@ struct server_slot {
|
||||
}
|
||||
|
||||
// TODO @ngxson : move this log line to debug when it become more stable
|
||||
SLT_INF(*this, "encoding mtmd batch from idx = %zu, n_chunks = %d\n", idx, n_added);
|
||||
SLT_TRC(*this, "encoding mtmd batch from idx = %zu, n_chunks = %d\n", idx, n_added);
|
||||
|
||||
res = mtmd_batch_encode(mbatch.get());
|
||||
if (res != 0) {
|
||||
@@ -1032,7 +1033,8 @@ private:
|
||||
}
|
||||
|
||||
|
||||
SRV_INF("loading model '%s'\n", params.model.path.c_str());
|
||||
SRV_INF("loading model '%s'\n", params.model.get_name().c_str());
|
||||
SRV_TRC("local path '%s'\n", params.model.path.c_str());
|
||||
|
||||
std::string & mmproj_path = params_base.mmproj.path;
|
||||
mtmd_context_params mparams = mtmd_context_params_default();
|
||||
@@ -1061,7 +1063,7 @@ private:
|
||||
for (auto & [dev, size] : mmproj_mem) {
|
||||
total += size;
|
||||
}
|
||||
SRV_INF("[mtmd] estimated worst-case memory usage of mmproj is %.2f MiB (took %.2f ms)\n", total / (1024.0 * 1024.0), t_elapsed / 1000.0);
|
||||
SRV_TRC("[mtmd] estimated worst-case memory usage of mmproj is %.2f MiB (took %.2f ms)\n", total / (1024.0 * 1024.0), t_elapsed / 1000.0);
|
||||
GGML_ASSERT(!params_base.fit_params_target.empty());
|
||||
for (auto & [dev, size] : mmproj_mem) {
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
|
||||
@@ -1141,7 +1143,7 @@ private:
|
||||
}
|
||||
}
|
||||
}
|
||||
SRV_INF("[spec] estimated memory usage of %s is %.2f MiB\n",
|
||||
SRV_TRC("[spec] estimated memory usage of %s is %.2f MiB\n",
|
||||
has_draft ? "draft model" : "MTP context",
|
||||
total / (1024.0 * 1024.0));
|
||||
} catch (const std::exception & e) {
|
||||
@@ -1177,7 +1179,7 @@ private:
|
||||
// TODO speculative: move to common/speculative.cpp?
|
||||
const auto & params_spec = params_base.speculative.draft;
|
||||
|
||||
SRV_INF("loading draft model '%s'\n", params_spec.mparams.path.c_str());
|
||||
SRV_TRC("loading draft model '%s'\n", params_spec.mparams.path.c_str());
|
||||
|
||||
auto params_dft = params_base;
|
||||
|
||||
@@ -1229,7 +1231,7 @@ private:
|
||||
// no new model load, so we simply report 0.0 and 1.0 progress
|
||||
load_progress_callback(0.0f, &load_progress_spec);
|
||||
|
||||
SRV_INF("creating MTP draft context against the target model '%s'\n",
|
||||
SRV_TRC("creating MTP draft context against the target model '%s'\n",
|
||||
params_base.model.path.c_str());
|
||||
|
||||
auto cparams_mtp = common_context_params_to_llama(params_base);
|
||||
@@ -1303,9 +1305,6 @@ private:
|
||||
// Necessary similarity of prompt for slot selection
|
||||
slot_prompt_similarity = params_base.slot_prompt_similarity;
|
||||
|
||||
// setup slots
|
||||
SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
|
||||
|
||||
const int n_ctx_train = llama_model_n_ctx_train(model_tgt);
|
||||
|
||||
int n_ctx_slot = llama_n_ctx_seq(ctx_tgt);
|
||||
@@ -1322,9 +1321,13 @@ private:
|
||||
}
|
||||
|
||||
if (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
|
||||
SRV_WRN("%s", "speculative decoding will use checkpoints\n");
|
||||
SRV_TRC("%s", "speculative decoding will use checkpoints\n");
|
||||
}
|
||||
|
||||
// setup slots
|
||||
SRV_INF("initializing, n_slots = %d, n_ctx_slot = %d, kv_unified = '%s'\n",
|
||||
params_base.n_parallel, n_ctx_slot, params_base.kv_unified ? "true" : "false");
|
||||
|
||||
// initialize slots
|
||||
for (int i = 0; i < params_base.n_parallel; i++) {
|
||||
slots.emplace_back();
|
||||
@@ -1344,7 +1347,7 @@ private:
|
||||
}
|
||||
|
||||
if (spec) {
|
||||
SRV_INF("%s", "speculative decoding context initialized\n");
|
||||
SRV_TRC("%s", "speculative decoding context initialized\n");
|
||||
} else {
|
||||
ctx_dft.reset();
|
||||
}
|
||||
@@ -1361,7 +1364,7 @@ private:
|
||||
slot.mctx = mctx;
|
||||
slot.prompt.tokens.has_mtmd = mctx != nullptr;
|
||||
|
||||
SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
|
||||
SLT_TRC(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
|
||||
|
||||
slot.callback_on_release = [this](int id_slot) {
|
||||
queue_tasks.pop_deferred_task(id_slot);
|
||||
@@ -1397,23 +1400,23 @@ private:
|
||||
|
||||
if (params_base.cache_ram_mib != 0) {
|
||||
if (params_base.cache_ram_mib < 0) {
|
||||
SRV_INF("prompt cache is enabled, size limit: %s\n", "no limit");
|
||||
SRV_TRC("prompt cache is enabled, size limit: %s\n", "no limit");
|
||||
} else {
|
||||
SRV_INF("prompt cache is enabled, size limit: %d MiB\n", params_base.cache_ram_mib);
|
||||
SRV_TRC("prompt cache is enabled, size limit: %d MiB\n", params_base.cache_ram_mib);
|
||||
}
|
||||
SRV_INF("%s", "use `--cache-ram 0` to disable the prompt cache\n");
|
||||
SRV_TRC("%s", "use `--cache-ram 0` to disable the prompt cache\n");
|
||||
|
||||
prompt_cache = std::make_unique<server_prompt_cache>(params_base.cache_ram_mib, n_ctx);
|
||||
} else {
|
||||
SRV_INF("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n");
|
||||
SRV_TRC("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n");
|
||||
}
|
||||
SRV_INF("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n");
|
||||
SRV_TRC("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n");
|
||||
|
||||
if (params_base.n_ctx_checkpoints > 0) {
|
||||
SRV_INF("context checkpoints enabled, max = %d, min spacing = %d\n",
|
||||
SRV_TRC("context checkpoints enabled, max = %d, min spacing = %d\n",
|
||||
params_base.n_ctx_checkpoints, params_base.checkpoint_min_step);
|
||||
} else {
|
||||
SRV_INF("%s", "context checkpoints disabled\n");
|
||||
SRV_TRC("%s", "context checkpoints disabled\n");
|
||||
}
|
||||
|
||||
if (!params_base.model_alias.empty()) {
|
||||
@@ -1470,11 +1473,11 @@ private:
|
||||
params_base.cache_idle_slots = false;
|
||||
} else {
|
||||
if (params_base.kv_unified) {
|
||||
SRV_INF("%s", "idle slots will be saved to prompt cache and cleared upon starting a new task\n");
|
||||
SRV_TRC("%s", "idle slots will be saved to prompt cache and cleared upon starting a new task\n");
|
||||
} else {
|
||||
// without a unified KV cache, clearing a slot frees no reusable room, so we only
|
||||
// publish a RAM-cache copy of idle slots (their KV stays in VRAM) [TAG_IDLE_SLOT_CLEAR]
|
||||
SRV_INF("%s", "idle slots will be saved to prompt cache upon starting a new task\n");
|
||||
SRV_TRC("%s", "idle slots will be saved to prompt cache upon starting a new task\n");
|
||||
}
|
||||
SRV_DBG("%s", "__TEST_TAG_CACHE_IDLE_SLOTS_ENABLED__\n");
|
||||
}
|
||||
@@ -1500,7 +1503,7 @@ private:
|
||||
try {
|
||||
chat_templates = common_chat_templates_init(model_tgt, params_base.chat_template);
|
||||
|
||||
LOG_INF("%s: chat template, example_format: '%s'\n", __func__,
|
||||
SRV_TRC("%s: chat template, example_format: '%s'\n", __func__,
|
||||
common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str());
|
||||
|
||||
} catch (const std::exception & e) {
|
||||
@@ -1515,7 +1518,7 @@ private:
|
||||
// 2. The chat template supports it
|
||||
const bool template_supports_thinking = params_base.use_jinja && common_chat_templates_support_enable_thinking(chat_templates.get());
|
||||
const bool enable_thinking = params_base.enable_reasoning != 0 && template_supports_thinking;
|
||||
SRV_INF("%s: chat template, thinking = %d\n", __func__, enable_thinking);
|
||||
SRV_TRC("%s: chat template, thinking = %d\n", __func__, enable_thinking);
|
||||
|
||||
// IMPORTANT: chat_params is reused across sleeping / resuming states,
|
||||
// never store llama_context/llama_model pointers in chat_params,
|
||||
@@ -1658,7 +1661,7 @@ private:
|
||||
update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION;
|
||||
|
||||
if (update_cache) {
|
||||
SRV_INF("%s", "updating prompt cache\n");
|
||||
SRV_TRC("%s", "updating prompt cache\n");
|
||||
|
||||
const int64_t t_start = ggml_time_us();
|
||||
|
||||
@@ -1670,7 +1673,7 @@ private:
|
||||
|
||||
prompt_cache->update();
|
||||
|
||||
SRV_INF("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
|
||||
SRV_TRC("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2290,7 +2293,7 @@ private:
|
||||
|
||||
int id_parent = parent_task.id;
|
||||
|
||||
SRV_INF("launching slots for parent task id_task = %d with %zu child tasks\n", id_parent, parent_task.child_tasks.size());
|
||||
SRV_TRC("launching slots for parent task id_task = %d with %zu child tasks\n", id_parent, parent_task.child_tasks.size());
|
||||
|
||||
// to be called in case of failure to release all launched slots
|
||||
auto release_slots = [this, id_parent]() {
|
||||
@@ -2351,7 +2354,7 @@ private:
|
||||
// stash the draft's speculative state with the checkpoint
|
||||
common_speculative_get_state(spec.get(), slot.id, cur.data_spec);
|
||||
|
||||
SLT_INF(slot,
|
||||
SLT_TRC(slot,
|
||||
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
|
||||
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min,
|
||||
cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024);
|
||||
@@ -2415,7 +2418,7 @@ private:
|
||||
if (params_base.cache_idle_slots) {
|
||||
for (auto & slot : slots) {
|
||||
if (!slot.is_processing()) {
|
||||
SLT_INF(slot, "%s", "saving idle slot to prompt cache\n");
|
||||
SLT_TRC(slot, "%s", "saving idle slot to prompt cache\n");
|
||||
|
||||
if (slot.prompt_save(*prompt_cache)) {
|
||||
SLT_DBG(slot, "%s", "__TEST_TAG_CACHE_IDLE_SLOT__\n");
|
||||
@@ -2671,7 +2674,7 @@ private:
|
||||
auto new_loras = construct_lora_list(task.set_lora);
|
||||
// logging
|
||||
for (size_t i = 0; i < new_loras.size(); ++i) {
|
||||
SRV_INF("set lora adapter idx=%zu scale=%f\n", i, new_loras[i].scale);
|
||||
SRV_TRC("set lora adapter idx=%zu scale=%f\n", i, new_loras[i].scale);
|
||||
}
|
||||
// TODO @ngxson : make lora_adapters a dedicated member of server_context
|
||||
params_base.lora_adapters = new_loras;
|
||||
@@ -2771,7 +2774,7 @@ private:
|
||||
}
|
||||
|
||||
if (all_idle) {
|
||||
SRV_INF("%s", "all slots are idle\n");
|
||||
SRV_TRC("%s", "all slots are idle\n");
|
||||
return; // skip further processing
|
||||
|
||||
} else {
|
||||
@@ -3287,10 +3290,9 @@ private:
|
||||
const auto it = std::find_if(
|
||||
slot.prompt.checkpoints.rbegin(),
|
||||
slot.prompt.checkpoints.rend(),
|
||||
[&, func_name = __func__](const auto & cur) {
|
||||
[&](const auto & cur) {
|
||||
// guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
|
||||
LOG_INF("slot %12.*s: id %2d | task %d | Checking checkpoint with [%d, %d] against %d...\n", 12,
|
||||
func_name, (slot).id, ((slot).task ? (slot).task->id : -1), cur.pos_min, cur.pos_max, pos_min_thold);
|
||||
SLT_TRC(slot, "checking checkpoint with [%d, %d] against %d...\n", cur.pos_min, cur.pos_max, pos_min_thold);
|
||||
// workaround for [TAG_CHECKPOINTS_FIX_POS_MIN]
|
||||
if (cur.pos_max > pos_next) {
|
||||
return false;
|
||||
@@ -3310,11 +3312,11 @@ private:
|
||||
|
||||
pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max));
|
||||
n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens);
|
||||
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) it->size() / 1024 / 1024);
|
||||
SLT_TRC(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) it->size() / 1024 / 1024);
|
||||
}
|
||||
|
||||
if (do_reset) {
|
||||
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
|
||||
SLT_TRC(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
|
||||
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
||||
pos_next = 0;
|
||||
n_past = 0;
|
||||
@@ -3327,7 +3329,7 @@ private:
|
||||
for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) {
|
||||
const auto & cur = *it;
|
||||
if (cur.pos_max > pos_next) {
|
||||
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.size() / 1024 / 1024);
|
||||
SLT_TRC(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.size() / 1024 / 1024);
|
||||
it = slot.prompt.checkpoints.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
@@ -3674,7 +3676,7 @@ private:
|
||||
// all children slots should already launched by launch_slots_with_parent_task()
|
||||
// copy state to the child slots
|
||||
for (auto & child : children) {
|
||||
SLT_INF(slot, " - copying state to child %d\n", child->id);
|
||||
SLT_TRC(slot, " - copying state to child %d\n", child->id);
|
||||
|
||||
GGML_ASSERT(child->state == SLOT_STATE_WAIT_OTHER);
|
||||
|
||||
|
||||
@@ -83,7 +83,7 @@ bool server_http_context::init(const common_params & params) {
|
||||
hostname = params.hostname;
|
||||
|
||||
if (gcp.enabled) {
|
||||
SRV_INF("Google Cloud Platform compat: health route = %s, predict route = %s, port = %d\n", gcp.path_health.c_str(), gcp.path_predict.c_str(), gcp.port);
|
||||
SRV_TRC("Google Cloud Platform compat: health route = %s, predict route = %s, port = %d\n", gcp.path_health.c_str(), gcp.path_predict.c_str(), gcp.port);
|
||||
|
||||
if (port != gcp.port) {
|
||||
SRV_WRN("Google Cloud Platform compat: overriding server port %d with AIP_HTTP_PORT %d\n", port, gcp.port);
|
||||
@@ -96,13 +96,13 @@ bool server_http_context::init(const common_params & params) {
|
||||
|
||||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
if (!params.ssl_file_key.empty() && !params.ssl_file_cert.empty()) {
|
||||
SRV_INF("running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str());
|
||||
SRV_TRC("running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str());
|
||||
srv = std::make_unique<httplib::SSLServer>(
|
||||
params.ssl_file_cert.c_str(), params.ssl_file_key.c_str()
|
||||
);
|
||||
is_ssl = true;
|
||||
} else {
|
||||
SRV_INF("%s", "running without SSL\n");
|
||||
SRV_TRC("%s", "running without SSL\n");
|
||||
srv = std::make_unique<httplib::Server>();
|
||||
}
|
||||
#else
|
||||
@@ -165,9 +165,9 @@ bool server_http_context::init(const common_params & params) {
|
||||
if (params.api_keys.size() == 1) {
|
||||
const auto key = params.api_keys[0];
|
||||
const std::string substr = key.substr(std::max(static_cast<int>(key.length() - 4), 0));
|
||||
SRV_INF("api_keys: ****%s\n", substr.c_str());
|
||||
SRV_TRC("api_keys: ****%s\n", substr.c_str());
|
||||
} else if (params.api_keys.size() > 1) {
|
||||
SRV_INF("api_keys: %zu keys loaded\n", params.api_keys.size());
|
||||
SRV_TRC("api_keys: %zu keys loaded\n", params.api_keys.size());
|
||||
}
|
||||
|
||||
//
|
||||
@@ -293,7 +293,7 @@ bool server_http_context::init(const common_params & params) {
|
||||
// +4 threads for monitoring, health and some threads reserved for MCP and other tasks in the future
|
||||
n_threads_http = std::max(params.n_parallel + 4, static_cast<int32_t>(std::thread::hardware_concurrency() - 1));
|
||||
}
|
||||
SRV_INF("using %d threads for HTTP server\n", n_threads_http);
|
||||
SRV_TRC("using %d threads for HTTP server\n", n_threads_http);
|
||||
srv->new_task_queue = [n_threads_http] {
|
||||
// spawn n_threads_http fixed thread (always alive), while allow up to 1024 max possible additional threads
|
||||
// when n_threads_http is used, server will create new "dynamic" threads that will be destroyed after processing each request
|
||||
@@ -412,13 +412,13 @@ bool server_http_context::start() {
|
||||
auto is_sock = false;
|
||||
if (string_ends_with(std::string(hostname), ".sock")) {
|
||||
is_sock = true;
|
||||
SRV_INF("%s", "setting address family to AF_UNIX\n");
|
||||
SRV_TRC("%s", "setting address family to AF_UNIX\n");
|
||||
srv->set_address_family(AF_UNIX);
|
||||
// bind_to_port requires a second arg, any value other than 0 should
|
||||
// simply get ignored
|
||||
was_bound = srv->bind_to_port(hostname, 8080);
|
||||
} else {
|
||||
SRV_INF("%s", "binding port with default address family\n");
|
||||
SRV_TRC("%s", "binding port with default address family\n");
|
||||
// bind HTTP listen port
|
||||
if (port == 0) {
|
||||
const auto bound_port = srv->bind_to_any_port(hostname);
|
||||
|
||||
@@ -287,7 +287,7 @@ std::vector<std::unique_ptr<field>> make_llama_cmpl_schema(const common_params &
|
||||
->set_desc("Chat format used internally by the server")
|
||||
->set_handler([&](field_eval_context & ctx, const json & data) {
|
||||
ctx.params.chat_parser_params.format = static_cast<common_chat_format>(data.at("chat_format").get<int>());
|
||||
SRV_INF("Chat format: %s\n", common_chat_format_name(ctx.params.chat_parser_params.format));
|
||||
SRV_TRC("chat format: %s\n", common_chat_format_name(ctx.params.chat_parser_params.format));
|
||||
}));
|
||||
|
||||
add((new field_str("reasoning_format"))
|
||||
|
||||
@@ -339,11 +339,11 @@ void stream_pipe_producer::close() {
|
||||
// httplib bails its content provider the moment is_peer_alive() goes false, so pump the rest
|
||||
// of the generation into the ring buffer here. a DELETE flips is_cancelled and cuts it short
|
||||
if (done_ || session_->is_cancelled()) {
|
||||
SRV_INF("stream_pipe close: skip drain (done=%d cancelled=%d) conv=%s\n",
|
||||
SRV_TRC("stream_pipe close: skip drain (done=%d cancelled=%d) conv=%s\n",
|
||||
done_ ? 1 : 0, session_->is_cancelled() ? 1 : 0, session_->conversation_id.c_str());
|
||||
return;
|
||||
}
|
||||
SRV_INF("stream_pipe close: draining conv=%s\n", session_->conversation_id.c_str());
|
||||
SRV_TRC("stream_pipe close: draining conv=%s\n", session_->conversation_id.c_str());
|
||||
size_t drained = 0;
|
||||
std::string chunk;
|
||||
while (true) {
|
||||
@@ -357,7 +357,7 @@ void stream_pipe_producer::close() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
SRV_INF("stream_pipe close: drain ended conv=%s bytes=%zu\n", session_->conversation_id.c_str(), drained);
|
||||
SRV_TRC("stream_pipe close: drain ended conv=%s bytes=%zu\n", session_->conversation_id.c_str(), drained);
|
||||
}
|
||||
|
||||
std::shared_ptr<stream_pipe_producer> stream_pipe_producer::create(stream_session_ptr session,
|
||||
@@ -520,7 +520,7 @@ server_http_context::handler_t make_stream_delete_handler() {
|
||||
if (conv_id.empty()) {
|
||||
return make_error_response(400, "Missing conversation id in path", ERROR_TYPE_INVALID_REQUEST);
|
||||
}
|
||||
SRV_INF("DELETE /v1/stream/%s -> evict_and_cancel\n", conv_id.c_str());
|
||||
SRV_TRC("DELETE /v1/stream/%s -> evict_and_cancel\n", conv_id.c_str());
|
||||
g_stream_sessions.evict_and_cancel(conv_id);
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
res->status = 204;
|
||||
@@ -550,8 +550,7 @@ std::string stream_conv_id_from_headers(const std::map<std::string, std::string>
|
||||
|
||||
void stream_session_attach_pipe(server_http_res & res, const std::map<std::string, std::string> & headers) {
|
||||
std::string conversation_id = stream_conv_id_from_headers(headers);
|
||||
SRV_INF("stream_session_attach_pipe: conv_id=%s (empty=%d)\n",
|
||||
conversation_id.c_str(), conversation_id.empty() ? 1 : 0);
|
||||
SRV_TRC("conv_id=%s (empty=%d)\n", conversation_id.c_str(), conversation_id.empty() ? 1 : 0);
|
||||
if (conversation_id.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1626,7 +1626,7 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
|
||||
const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens);
|
||||
|
||||
if (cur_lcp_len == (int) prompt.tokens.size()) {
|
||||
SRV_INF("%s", " - prompt is already in the cache, skipping\n");
|
||||
SRV_TRC("%s", " - prompt is already in the cache, skipping\n");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
@@ -1636,7 +1636,7 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
|
||||
const int len = it->tokens.get_common_prefix(prompt.tokens);
|
||||
|
||||
if (len == (int) it->tokens.size()) {
|
||||
SRV_WRN(" - removing obsolete cached prompt with length %d\n", len);
|
||||
SRV_TRC(" - removing obsolete cached prompt with length %d\n", len);
|
||||
|
||||
it = states.erase(it);
|
||||
} else {
|
||||
@@ -1681,7 +1681,7 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok
|
||||
float f_keep_best = prompt.tokens.size() > 0 ? float(lcp_best) / prompt.tokens.size() : -1.0f; // empty slot: any cache entry wins
|
||||
float sim_best = float(lcp_best) / tokens_new.size();
|
||||
|
||||
SRV_INF(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
|
||||
SRV_TRC(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
|
||||
|
||||
auto it_best = states.end();
|
||||
|
||||
@@ -1706,7 +1706,7 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok
|
||||
}
|
||||
|
||||
if (it_best != states.end()) {
|
||||
SRV_INF(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
|
||||
SRV_TRC(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
|
||||
|
||||
{
|
||||
auto & data = it_best->data.main;
|
||||
@@ -1783,11 +1783,11 @@ void server_prompt_cache::update() {
|
||||
}
|
||||
}
|
||||
|
||||
SRV_INF(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n",
|
||||
SRV_TRC(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n",
|
||||
states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur);
|
||||
|
||||
for (const auto & state : states) {
|
||||
SRV_INF(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n",
|
||||
SRV_TRC(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n",
|
||||
(const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,7 +124,7 @@ int llama_server(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
if (params.n_parallel < 0) {
|
||||
SRV_INF("%s", "n_parallel is set to auto, using n_parallel = 4 and kv_unified = true\n");
|
||||
SRV_TRC("%s", "n_parallel is set to auto, using n_parallel = 4 and kv_unified = true\n");
|
||||
|
||||
params.n_parallel = 4;
|
||||
params.kv_unified = true;
|
||||
@@ -338,7 +338,7 @@ int llama_server(int argc, char ** argv) {
|
||||
std::function<void()> clean_up;
|
||||
|
||||
if (is_router_server) {
|
||||
SRV_INF("%s", "starting router server, no model will be loaded in this process\n");
|
||||
SRV_INF("%s", "starting server in router mode. models will be automatically loaded on-demand\n");
|
||||
|
||||
clean_up = [&models_routes]() {
|
||||
SRV_INF("%s: cleaning up before exit...\n", __func__);
|
||||
@@ -391,9 +391,6 @@ int llama_server(int argc, char ** argv) {
|
||||
});
|
||||
}
|
||||
|
||||
// load the model
|
||||
SRV_INF("%s", "loading model\n");
|
||||
|
||||
if (!ctx_server.load_model(params)) {
|
||||
clean_up();
|
||||
if (ctx_http.thread.joinable()) {
|
||||
@@ -429,8 +426,9 @@ int llama_server(int argc, char ** argv) {
|
||||
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||
#endif
|
||||
|
||||
SRV_INF("listening on %s\n", ctx_http.listening_address.c_str());
|
||||
|
||||
if (is_router_server) {
|
||||
SRV_INF("router server is listening on %s\n", ctx_http.listening_address.c_str());
|
||||
SRV_WRN("%s", "NOTE: router mode is experimental\n");
|
||||
SRV_WRN("%s", " it is not recommended to use this mode in untrusted environments\n");
|
||||
|
||||
@@ -446,8 +444,6 @@ int llama_server(int argc, char ** argv) {
|
||||
// when the HTTP server stops, clean up and exit
|
||||
clean_up();
|
||||
} else {
|
||||
SRV_INF("server is listening on %s\n", ctx_http.listening_address.c_str());
|
||||
|
||||
// optionally, notify router server that this instance is ready
|
||||
std::thread monitor_thread;
|
||||
if (child.is_child()) {
|
||||
|
||||
Reference in New Issue
Block a user