mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-21 09:10:58 +02:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ef22b3e4ac | ||
|
|
68e7ea3eab | ||
|
|
928b486b0c | ||
|
|
7dbb0e998a | ||
|
|
dd9280a664 |
50
.github/workflows/build-virtgpu.yml
vendored
Normal file
50
.github/workflows/build-virtgpu.yml
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
name: CI (virtgpu)
|
||||
|
||||
on:
|
||||
workflow_dispatch: # allows manual triggering
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths: [
|
||||
'.github/workflows/build-virtgpu.yml',
|
||||
'**/CMakeLists.txt',
|
||||
'**/.cmake',
|
||||
'**/*.h',
|
||||
'**/*.hpp',
|
||||
'**/*.c',
|
||||
'**/*.cpp'
|
||||
]
|
||||
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
paths: [
|
||||
'.github/workflows/build-virtgpu.yml',
|
||||
'ggml/src/ggml-virtgpu/**'
|
||||
]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
ubuntu-24-virtgpu:
|
||||
runs-on: ${{ 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential libdrm-dev pkg-config libssl-dev
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build \
|
||||
-DGGML_VIRTGPU=ON \
|
||||
-DGGML_VIRTGPU_BACKEND=ON
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
@@ -622,10 +622,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
for (auto & seq_breaker : params.sampling.dry_sequence_breakers) {
|
||||
string_process_escapes(seq_breaker);
|
||||
}
|
||||
for (auto & pair : params.speculative.draft.replacements) {
|
||||
string_process_escapes(pair.first);
|
||||
string_process_escapes(pair.second);
|
||||
}
|
||||
}
|
||||
|
||||
if (!params.kv_overrides.empty()) {
|
||||
@@ -3518,13 +3514,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.speculative.draft.p_min = std::stof(value);
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_P_MIN"));
|
||||
add_opt(common_arg(
|
||||
{"--spec-draft-ctx-size", "-cd", "--ctx-size-draft"}, "N",
|
||||
string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.draft.n_ctx),
|
||||
[](common_params & params, int value) {
|
||||
params.speculative.draft.n_ctx = value;
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_CTX_SIZE"));
|
||||
add_opt(common_arg(
|
||||
{"--spec-draft-device", "-devd", "--device-draft"}, "<dev1,dev2,..>",
|
||||
"comma-separated list of devices to use for offloading the draft model (none = don't offload)\n"
|
||||
@@ -3561,32 +3550,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_MODEL"));
|
||||
add_opt(common_arg(
|
||||
{"--spec-draft-replace", "--spec-replace"}, "TARGET", "DRAFT",
|
||||
"translate the string in TARGET into DRAFT if the draft model and main model are not compatible",
|
||||
[](common_params & params, const std::string & tgt, const std::string & dft) {
|
||||
params.speculative.draft.replacements.push_back({ tgt, dft });
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
|
||||
{"--spec-type"}, common_speculative_all_types_str(),
|
||||
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
|
||||
common_speculative_type_to_str(params.speculative.type).c_str()),
|
||||
common_speculative_type_name_str(params.speculative.types).c_str()),
|
||||
[](common_params & params, const std::string & value) {
|
||||
if (value == "none") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
|
||||
} else if (value == "ngram-cache") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE;
|
||||
} else if (value == "ngram-simple") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE;
|
||||
} else if (value == "ngram-map-k") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K;
|
||||
} else if (value == "ngram-map-k4v") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V;
|
||||
} else if (value == "ngram-mod") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD;
|
||||
} else {
|
||||
throw std::invalid_argument("unknown speculative decoding type without draft model");
|
||||
}
|
||||
const auto enabled_types = string_split<std::string>(value, ',');
|
||||
params.speculative.types = common_speculative_types_from_names(enabled_types);
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_TYPE"));
|
||||
add_opt(common_arg(
|
||||
@@ -4075,7 +4044,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"--spec-default"},
|
||||
string_format("enable default speculative decoding config"),
|
||||
[](common_params & params) {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD;
|
||||
params.speculative.types = { COMMON_SPECULATIVE_TYPE_NGRAM_MOD };
|
||||
params.speculative.ngram_mod.n_match = 24;
|
||||
params.speculative.ngram_mod.n_min = 48;
|
||||
params.speculative.ngram_mod.n_max = 64;
|
||||
|
||||
@@ -1422,7 +1422,7 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {
|
||||
|
||||
// try to remove the last tokens
|
||||
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
|
||||
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
|
||||
LOG_WRN("%s: the context does not support partial sequence removal\n", __func__);
|
||||
res = COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
|
||||
goto done;
|
||||
}
|
||||
@@ -1960,3 +1960,102 @@ bool common_prompt_batch_decode(
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t common_prompt_checkpoint::size() const {
|
||||
return data_tgt.size() + data_dft.size();
|
||||
}
|
||||
|
||||
bool common_prompt_checkpoint::empty() const {
|
||||
return data_tgt.empty();
|
||||
}
|
||||
|
||||
void common_prompt_checkpoint::clear() {
|
||||
n_tokens = 0;
|
||||
|
||||
pos_min = 0;
|
||||
pos_max = 0;
|
||||
|
||||
data_tgt.clear();
|
||||
data_dft.clear();
|
||||
}
|
||||
|
||||
void common_prompt_checkpoint::update_pos(
|
||||
int64_t n_tokens,
|
||||
llama_pos pos_min,
|
||||
llama_pos pos_max) {
|
||||
this->n_tokens = n_tokens;
|
||||
this->pos_min = pos_min;
|
||||
this->pos_max = pos_max;
|
||||
}
|
||||
|
||||
void common_prompt_checkpoint::update_tgt(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_state_seq_flags flags) {
|
||||
if (ctx == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t ckpt_size = llama_state_seq_get_size_ext(ctx, seq_id, flags);
|
||||
|
||||
data_tgt.resize(ckpt_size);
|
||||
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx, data_tgt.data(), ckpt_size, seq_id, flags);
|
||||
if (n != ckpt_size) {
|
||||
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n);
|
||||
}
|
||||
}
|
||||
|
||||
void common_prompt_checkpoint::update_dft(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_state_seq_flags flags) {
|
||||
if (ctx == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t ckpt_size = llama_state_seq_get_size_ext(ctx, seq_id, flags);
|
||||
|
||||
data_dft.resize(ckpt_size);
|
||||
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx, data_dft.data(), ckpt_size, seq_id, flags);
|
||||
if (n != ckpt_size) {
|
||||
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n);
|
||||
}
|
||||
}
|
||||
|
||||
void common_prompt_checkpoint::load_tgt(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_state_seq_flags flags) const {
|
||||
if (ctx == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (data_tgt.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx, data_tgt.data(), data_tgt.size(), seq_id, flags);
|
||||
if (n != data_tgt.size()) {
|
||||
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_tgt.size(), n);
|
||||
}
|
||||
}
|
||||
|
||||
void common_prompt_checkpoint::load_dft(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_state_seq_flags flags) const {
|
||||
if (ctx == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (data_dft.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx, data_dft.data(), data_dft.size(), seq_id, flags);
|
||||
if (n != data_dft.size()) {
|
||||
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_dft.size(), n);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -295,8 +295,6 @@ struct common_params_model {
|
||||
std::string name = ""; // in format <user>/<model>[:<tag>] (tag is optional) // NOLINT
|
||||
};
|
||||
|
||||
struct common_ngram_mod;
|
||||
|
||||
// draft-model-based speculative decoding parameters
|
||||
struct common_params_speculative_draft {
|
||||
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
|
||||
@@ -307,11 +305,9 @@ struct common_params_speculative_draft {
|
||||
|
||||
common_params_model mparams;
|
||||
|
||||
llama_model * model = nullptr; // a llama_model that can be shared by multiple speculative contexts
|
||||
llama_context * ctx_tgt = nullptr;
|
||||
llama_context * ctx_dft = nullptr;
|
||||
|
||||
llama_context_params cparams; // these are the parameters for the draft llama_context
|
||||
|
||||
int32_t n_ctx = 0; // draft context size
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
||||
|
||||
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
|
||||
@@ -322,7 +318,6 @@ struct common_params_speculative_draft {
|
||||
|
||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
|
||||
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
||||
};
|
||||
|
||||
@@ -331,9 +326,6 @@ struct common_params_speculative_ngram_mod {
|
||||
|
||||
int32_t n_max = 64;
|
||||
int32_t n_min = 48;
|
||||
|
||||
// shared instance of the ngram container for all speculative decoding contexts
|
||||
std::shared_ptr<common_ngram_mod> obj;
|
||||
};
|
||||
|
||||
struct common_params_speculative_ngram_map {
|
||||
@@ -348,8 +340,7 @@ struct common_params_speculative_ngram_cache {
|
||||
};
|
||||
|
||||
struct common_params_speculative {
|
||||
// TODO: become a vector in order to support "chains of speculators"
|
||||
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE;
|
||||
std::vector<enum common_speculative_type> types = { COMMON_SPECULATIVE_TYPE_NONE };
|
||||
|
||||
common_params_speculative_draft draft;
|
||||
|
||||
@@ -1026,3 +1017,47 @@ ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std
|
||||
|
||||
// "adamw" or "sgd" (case insensitive)
|
||||
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *);
|
||||
|
||||
//
|
||||
// prompt utils
|
||||
//
|
||||
|
||||
struct common_prompt_checkpoint {
|
||||
int64_t n_tokens;
|
||||
|
||||
llama_pos pos_min;
|
||||
llama_pos pos_max;
|
||||
|
||||
std::vector<uint8_t> data_tgt;
|
||||
std::vector<uint8_t> data_dft;
|
||||
|
||||
size_t size() const;
|
||||
|
||||
bool empty() const;
|
||||
void clear();
|
||||
|
||||
void update_pos(
|
||||
int64_t n_tokens,
|
||||
llama_pos pos_min,
|
||||
llama_pos pos_max);
|
||||
|
||||
void update_tgt(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_state_seq_flags flags);
|
||||
|
||||
void update_dft(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_state_seq_flags flags);
|
||||
|
||||
void load_tgt(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_state_seq_flags flags) const;
|
||||
|
||||
void load_dft(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_state_seq_flags flags) const;
|
||||
};
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,8 +5,14 @@
|
||||
|
||||
struct common_speculative;
|
||||
|
||||
// comma separated list the provided types
|
||||
std::string common_speculative_type_name_str(const std::vector<enum common_speculative_type> & types);
|
||||
|
||||
// comma separated list of all types
|
||||
std::string common_speculative_type_name_str();
|
||||
const char * common_speculative_all_types_str();
|
||||
|
||||
// parse user provided types
|
||||
std::vector<enum common_speculative_type> common_speculative_types_from_names(const std::vector<std::string> & names);
|
||||
|
||||
// convert string to type
|
||||
enum common_speculative_type common_speculative_type_from_name(const std::string & name);
|
||||
@@ -14,27 +20,44 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
|
||||
// convert type to string
|
||||
std::string common_speculative_type_to_str(enum common_speculative_type type);
|
||||
|
||||
common_speculative * common_speculative_init(
|
||||
common_params_speculative & params,
|
||||
llama_context * ctx_tgt);
|
||||
common_speculative * common_speculative_init(common_params_speculative & params, uint32_t n_seq);
|
||||
|
||||
void common_speculative_free(common_speculative * spec);
|
||||
|
||||
struct common_speculative_draft_params {
|
||||
// this flag is used to chain the drafts through all the available implementations
|
||||
// after the first successful draft from an implementation, we set it
|
||||
// to false to prevent further drafts for that sequence
|
||||
// at the end of the draft() call, all drafting flags will be reset to false
|
||||
bool drafting = false;
|
||||
|
||||
// overrides individual configurations (-1 disabled)
|
||||
// can be used to constraint the max draft based on the remaining context size
|
||||
int32_t n_max = -1;
|
||||
|
||||
llama_pos n_past;
|
||||
llama_token id_last;
|
||||
|
||||
// TODO: remove in the future by keeping track of the prompt from the _begin() call and the consecutive accept calls
|
||||
const llama_tokens * prompt;
|
||||
|
||||
// the generated draft from the last _draft() call
|
||||
llama_tokens * result;
|
||||
};
|
||||
|
||||
common_speculative_draft_params & common_speculative_get_draft_params(common_speculative * spec, llama_seq_id seq_id);
|
||||
|
||||
// optionally call once at the beginning of a new generation
|
||||
void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt);
|
||||
void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt);
|
||||
|
||||
// sample up to n_draft tokens and add them to the batch using the draft model
|
||||
llama_tokens common_speculative_draft(
|
||||
common_speculative * spec,
|
||||
const common_params_speculative & params,
|
||||
const llama_tokens & prompt,
|
||||
llama_token id_last);
|
||||
// process the batch and update the internal state of the speculative context
|
||||
bool common_speculative_process(common_speculative * spec, const llama_batch & batch);
|
||||
|
||||
// informs the speculative decoder that n_accepted tokens were accepted by the target model
|
||||
void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
|
||||
// generate drafts for the sequences specified with `common_speculative_get_draft_params`
|
||||
void common_speculative_draft(common_speculative * spec);
|
||||
|
||||
int32_t common_speculative_n_max(const common_speculative * spec, const common_params_speculative & params);
|
||||
int32_t common_speculative_n_min(const common_speculative * spec, const common_params_speculative & params);
|
||||
// informs the speculative context that n_accepted tokens were accepted by the target model
|
||||
void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted);
|
||||
|
||||
// print statistics about the speculative decoding
|
||||
void common_speculative_print_stats(const common_speculative * spec);
|
||||
|
||||
@@ -6,7 +6,7 @@ Demonstration of basic greedy speculative decoding
|
||||
./bin/llama-speculative-simple \
|
||||
-m ../models/qwen2.5-32b-coder-instruct/ggml-model-q8_0.gguf \
|
||||
-md ../models/qwen2.5-1.5b-coder-instruct/ggml-model-q4_0.gguf \
|
||||
-f test.txt -c 0 -ngl 99 --color \
|
||||
--sampling-seq k --top-k 1 -fa --temp 0.0 \
|
||||
-ngld 99 --draft-max 16 --draft-min 5 --draft-p-min 0.9
|
||||
-f test.txt -c 0 -ngl 99 --color on \
|
||||
--sampling-seq k --top-k 1 -fa on --temp 0.0 \
|
||||
-ngld 99 --spec-draft-n-max 16 --spec-draft-n-draft-min 5 --draft-p-min 0.9
|
||||
```
|
||||
|
||||
@@ -13,20 +13,6 @@
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
struct spec_checkpoint {
|
||||
int64_t n_tokens = 0;
|
||||
|
||||
std::vector<uint8_t> data;
|
||||
|
||||
size_t size() const {
|
||||
return data.size();
|
||||
}
|
||||
|
||||
bool empty() const {
|
||||
return data.empty();
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
@@ -43,11 +29,6 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (params.speculative.draft.mparams.path.empty()) {
|
||||
LOG_ERR("%s: --model-draft is required\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// init llama.cpp
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
@@ -62,18 +43,11 @@ int main(int argc, char ** argv) {
|
||||
model_tgt = llama_init_tgt->model();
|
||||
ctx_tgt = llama_init_tgt->context();
|
||||
|
||||
// check if the context supports partial sequence removal
|
||||
const auto ctx_seq_rm = common_context_can_seq_rm(ctx_tgt);
|
||||
const bool use_ckpt = (ctx_seq_rm == COMMON_CONTEXT_SEQ_RM_TYPE_FULL);
|
||||
|
||||
if (use_ckpt) {
|
||||
LOG_INF("speculative decoding will use checkpoints (context does not support partial sequence removal)\n");
|
||||
}
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model_tgt);
|
||||
|
||||
// load the draft model
|
||||
llama_model_ptr model_dft;
|
||||
llama_context_ptr ctx_dft;
|
||||
|
||||
// TODO: simplify this logic
|
||||
{
|
||||
@@ -81,9 +55,6 @@ int main(int argc, char ** argv) {
|
||||
|
||||
auto params_dft = params;
|
||||
|
||||
params_dft.n_parallel = 1;
|
||||
params_dft.n_ctx = params_spec.n_ctx;
|
||||
params_dft.n_batch = llama_n_ctx_seq(ctx_tgt);
|
||||
params_dft.devices = params_spec.devices;
|
||||
params_dft.model = params_spec.mparams;
|
||||
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
|
||||
@@ -103,8 +74,19 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
params.speculative.draft.model = model_dft.get();
|
||||
params.speculative.draft.cparams = common_context_params_to_llama(params_dft);
|
||||
auto cparams = common_context_params_to_llama(params_dft);
|
||||
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
|
||||
|
||||
params.speculative.draft.ctx_tgt = ctx_tgt;
|
||||
params.speculative.draft.ctx_dft = ctx_dft.get();
|
||||
}
|
||||
|
||||
// check if the context supports partial sequence removal
|
||||
const bool use_ckpt_tgt = (common_context_can_seq_rm(ctx_tgt) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL);
|
||||
const bool use_ckpt_dft = (common_context_can_seq_rm(ctx_dft.get()) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL);
|
||||
|
||||
if (use_ckpt_tgt) {
|
||||
LOG_INF("speculative decoding will use checkpoints (context does not support partial sequence removal)\n");
|
||||
}
|
||||
|
||||
// Tokenize the prompt
|
||||
@@ -136,6 +118,8 @@ int main(int argc, char ** argv) {
|
||||
// used to determine end of generation
|
||||
bool has_eos = false;
|
||||
|
||||
llama_seq_id seq_id = 0;
|
||||
|
||||
// ================================================
|
||||
// everything until here is standard initialization
|
||||
// the relevant stuff for speculative decoding starts here
|
||||
@@ -146,7 +130,8 @@ int main(int argc, char ** argv) {
|
||||
common_sampler_ptr smpl(common_sampler_init(model_tgt, params.sampling));
|
||||
|
||||
// eval the prompt
|
||||
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
|
||||
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
|
||||
llama_decode(ctx_dft.get(), llama_batch_get_one(inp.data(), inp.size() - 1));
|
||||
|
||||
// note: keep the last token separate!
|
||||
llama_token id_last = inp.back();
|
||||
@@ -160,16 +145,16 @@ int main(int argc, char ** argv) {
|
||||
// init the speculator
|
||||
const auto & params_spec = params.speculative;
|
||||
|
||||
struct common_speculative * spec = common_speculative_init(params.speculative, ctx_tgt);
|
||||
struct common_speculative * spec = common_speculative_init(params.speculative, 1);
|
||||
|
||||
common_speculative_begin(spec, prompt_tgt);
|
||||
common_speculative_begin(spec, seq_id, prompt_tgt);
|
||||
|
||||
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
|
||||
|
||||
size_t n_draft = 0;
|
||||
|
||||
llama_tokens draft;
|
||||
spec_checkpoint spec_ckpt;
|
||||
common_prompt_checkpoint ckpt;
|
||||
|
||||
const auto t_enc_end = ggml_time_us();
|
||||
|
||||
@@ -184,40 +169,57 @@ int main(int argc, char ** argv) {
|
||||
// from a cache or lookup tables.
|
||||
//
|
||||
if (draft.empty()) {
|
||||
ckpt.update_pos(
|
||||
prompt_tgt.size(),
|
||||
llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), seq_id),
|
||||
llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), seq_id));
|
||||
|
||||
if (use_ckpt_dft) {
|
||||
ckpt.update_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
}
|
||||
|
||||
// generate a new draft
|
||||
draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last);
|
||||
common_speculative_get_draft_params(spec, seq_id) = {
|
||||
/* .drafting = */ true,
|
||||
/* .n_max = */ -1,
|
||||
/* .n_past = */ n_past,
|
||||
/* .id_last = */ id_last,
|
||||
/* .prompt = */ &prompt_tgt,
|
||||
/* .result = */ &draft, // output
|
||||
};
|
||||
common_speculative_draft(spec);
|
||||
|
||||
// save the original draft size
|
||||
n_draft = draft.size();
|
||||
|
||||
// save a checkpoint of the target context before evaluating the draft
|
||||
// this allows us to restore the state if partial draft acceptance occurs
|
||||
if (!draft.empty() && use_ckpt) {
|
||||
const size_t ckpt_size = llama_state_seq_get_size_ext(ctx_tgt, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
spec_ckpt.data.resize(ckpt_size);
|
||||
if (!draft.empty()) {
|
||||
if (use_ckpt_tgt) {
|
||||
ckpt.update_tgt(ctx_tgt, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
}
|
||||
}
|
||||
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx_tgt, spec_ckpt.data.data(), ckpt_size, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
GGML_ASSERT(n == ckpt_size);
|
||||
{
|
||||
ckpt.load_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
|
||||
spec_ckpt.n_tokens = (int64_t) prompt_tgt.size();
|
||||
LOG_DBG("created speculative checkpoint (n_tokens = %" PRId64 ", size = %.3f MiB)\n",
|
||||
spec_ckpt.n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024);
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, ckpt.pos_max + 1, -1);
|
||||
}
|
||||
} else {
|
||||
// we have a previous (partial) draft to reuse from checkpoint restoration
|
||||
if (use_ckpt) {
|
||||
GGML_ASSERT(!spec_ckpt.empty());
|
||||
if (use_ckpt_tgt) {
|
||||
GGML_ASSERT(!ckpt.empty());
|
||||
}
|
||||
}
|
||||
|
||||
// always have a token to evaluate from before - id_last
|
||||
common_batch_clear(batch_tgt);
|
||||
common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true);
|
||||
common_batch_add (batch_tgt, id_last, n_past++, { seq_id }, true);
|
||||
|
||||
// evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
|
||||
{
|
||||
for (size_t i = 0; i < draft.size(); ++i) {
|
||||
common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
|
||||
common_batch_add(batch_tgt, draft[i], n_past + i, { seq_id }, true);
|
||||
}
|
||||
|
||||
//LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
|
||||
@@ -225,9 +227,15 @@ int main(int argc, char ** argv) {
|
||||
llama_decode(ctx_tgt, batch_tgt);
|
||||
}
|
||||
|
||||
// evaluate the same batch with the draft model
|
||||
{
|
||||
// TODO: extend to support MTP, Eagle, etc. See server code for reference
|
||||
llama_decode(ctx_dft.get(), batch_tgt);
|
||||
}
|
||||
|
||||
// only save the sampler sampler state if we use checkpoints
|
||||
common_sampler_ptr smpl_save;
|
||||
if (use_ckpt) {
|
||||
if (use_ckpt_tgt) {
|
||||
smpl_save.reset(common_sampler_clone(smpl.get()));
|
||||
}
|
||||
|
||||
@@ -247,17 +255,24 @@ int main(int argc, char ** argv) {
|
||||
// check for partial draft acceptance:
|
||||
// if the context doesn't support partial sequence removal, restore the checkpoint
|
||||
// and make the accepted tokens the new partial draft for the next iteration
|
||||
if (use_ckpt && ids.size() - 1 < draft.size()) {
|
||||
if (use_ckpt_tgt && ids.size() - 1 < draft.size()) {
|
||||
LOG_DBG("partial acceptance: %zu < %zu, restoring checkpoint\n", ids.size() - 1, draft.size());
|
||||
|
||||
draft = std::move(ids);
|
||||
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx_tgt, spec_ckpt.data.data(), spec_ckpt.size(), 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
GGML_ASSERT(n == spec_ckpt.size());
|
||||
{
|
||||
ckpt.load_tgt(ctx_tgt, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, spec_ckpt.n_tokens, -1);
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_tgt), seq_id, ckpt.pos_max + 1, -1);
|
||||
}
|
||||
|
||||
prompt_tgt.resize(spec_ckpt.n_tokens);
|
||||
{
|
||||
ckpt.load_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, ckpt.pos_max + 1, -1);
|
||||
}
|
||||
|
||||
prompt_tgt.resize(ckpt.n_tokens);
|
||||
smpl = std::move(smpl_save);
|
||||
|
||||
n_past = (int) prompt_tgt.size();
|
||||
@@ -265,7 +280,7 @@ int main(int argc, char ** argv) {
|
||||
continue;
|
||||
}
|
||||
|
||||
common_speculative_accept(spec, ids.size() - 1);
|
||||
common_speculative_accept(spec, seq_id, ids.size() - 1);
|
||||
|
||||
// full acceptance: consume the draft and commit accepted tokens
|
||||
n_past += ids.size() - 1;
|
||||
@@ -305,7 +320,8 @@ int main(int argc, char ** argv) {
|
||||
{
|
||||
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, n_past, -1);
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_tgt), seq_id, n_past, -1);
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, n_past, -1);
|
||||
}
|
||||
|
||||
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
|
||||
|
||||
@@ -855,7 +855,7 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
|
||||
vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
|
||||
|
||||
std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT];
|
||||
std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16;
|
||||
|
||||
std::map<std::pair<uint32_t, uint32_t>, vk_pipeline> pipeline_fa_mask_opt;
|
||||
|
||||
@@ -2933,10 +2933,10 @@ struct vk_fa_tuning_params {
|
||||
}
|
||||
};
|
||||
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type);
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type);
|
||||
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
|
||||
|
||||
static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
|
||||
static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
|
||||
|
||||
vk_fa_tuning_params result{};
|
||||
result.path = FA_SCALAR;
|
||||
@@ -2988,7 +2988,7 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device,
|
||||
|
||||
result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;
|
||||
|
||||
if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, kv_type)) {
|
||||
if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, k_type, v_type)) {
|
||||
result.block_rows /= 2;
|
||||
}
|
||||
|
||||
@@ -3011,10 +3011,11 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device,
|
||||
return result;
|
||||
}
|
||||
|
||||
static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
|
||||
static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
|
||||
GGML_UNUSED(n_rows);
|
||||
GGML_UNUSED(n_kv);
|
||||
GGML_UNUSED(kv_type);
|
||||
GGML_UNUSED(k_type);
|
||||
GGML_UNUSED(v_type);
|
||||
GGML_UNUSED(f32acc);
|
||||
|
||||
vk_fa_tuning_params result{};
|
||||
@@ -3070,12 +3071,6 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device
|
||||
}
|
||||
|
||||
static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
|
||||
// Mixed K/V is only implemented on the coopmat2 (flash_attn_cm2) path; never use scalar/cm1.
|
||||
if (k_type != v_type) {
|
||||
GGML_ASSERT(device->coopmat2);
|
||||
return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
|
||||
}
|
||||
|
||||
FaCodePath path = device->coopmat2 ? FA_COOPMAT2 :
|
||||
device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
|
||||
|
||||
@@ -3087,7 +3082,7 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_
|
||||
if (path == FA_COOPMAT1) {
|
||||
bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) ||
|
||||
(!f32acc && device->coopmat_support_16x16x16_f16acc);
|
||||
const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc);
|
||||
const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
|
||||
bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc);
|
||||
|
||||
if (!shape_ok || !shmem_ok) {
|
||||
@@ -3107,9 +3102,9 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_
|
||||
|
||||
switch (path) {
|
||||
case FA_SCALAR:
|
||||
return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, k_type, f32acc);
|
||||
return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
|
||||
case FA_COOPMAT1:
|
||||
return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc);
|
||||
return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
|
||||
case FA_COOPMAT2:
|
||||
return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
|
||||
default:
|
||||
@@ -3279,6 +3274,20 @@ static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_dev
|
||||
return 0; // If no matching configuration is found
|
||||
}
|
||||
|
||||
// Whether scalar flash attention will use the MMQ path for the given k_type.
|
||||
static bool ggml_vk_fa_scalar_uses_mmq(const vk_device& device, ggml_type k_type) {
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
return device->integer_dot_product && device->subgroup_clustered &&
|
||||
(k_type == GGML_TYPE_Q4_0 || k_type == GGML_TYPE_Q4_1 ||
|
||||
k_type == GGML_TYPE_Q5_0 || k_type == GGML_TYPE_Q5_1 ||
|
||||
k_type == GGML_TYPE_Q8_0);
|
||||
#else
|
||||
GGML_UNUSED(device);
|
||||
GGML_UNUSED(k_type);
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ggml_vk_load_shaders(vk_device& device) {
|
||||
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
|
||||
|
||||
@@ -3525,121 +3534,96 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
align, disable_robustness, require_full_subgroups, required_subgroup_size);
|
||||
};
|
||||
|
||||
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
|
||||
for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \
|
||||
FaCodePath path = fa.first.path; \
|
||||
uint32_t Br = fa.first.Br; \
|
||||
uint32_t Bc = fa.first.Bc; \
|
||||
bool aligned = fa.first.aligned; \
|
||||
bool f32acc = fa.first.f32acc; \
|
||||
uint32_t fa_sgs = fa.first.subgroup_size; \
|
||||
bool fa_ds = fa.first.subgroup_size == 0; \
|
||||
if (path == FAPATH) { \
|
||||
if (aligned) { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
|
||||
} \
|
||||
} else { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
// FA scalar has two SPIR-V modules (MMQ vs non-MMQ); FA cm1 has one. K/V
|
||||
// quant type is selected at runtime via the FaTypeK / FaTypeV spec constants.
|
||||
|
||||
if (device->fp16) {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
||||
for (auto &fa : device->pipeline_flash_attn_f32_f16) {
|
||||
if (fa.first.path != FA_SCALAR) continue;
|
||||
const uint32_t Br = fa.first.Br;
|
||||
const uint32_t Bc = fa.first.Bc;
|
||||
const bool aligned = fa.first.aligned;
|
||||
const bool f32acc = fa.first.f32acc;
|
||||
const uint32_t fa_sgs = fa.first.subgroup_size;
|
||||
const bool fa_ds = fa.first.subgroup_size == 0;
|
||||
|
||||
const bool use_mmq = ggml_vk_fa_scalar_uses_mmq(device, fa.first.k_type);
|
||||
const void * spv_data = nullptr;
|
||||
size_t spv_size = 0;
|
||||
if (use_mmq) {
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product && device->subgroup_clustered) {
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _int8)
|
||||
} else
|
||||
if (device->fp16) {
|
||||
if (f32acc) { spv_data = flash_attn_f32_f16_int8_data; spv_size = flash_attn_f32_f16_int8_len; }
|
||||
else { spv_data = flash_attn_f32_f16_f16acc_int8_data; spv_size = flash_attn_f32_f16_f16acc_int8_len; }
|
||||
} else {
|
||||
spv_data = flash_attn_f32_f16_fp32_int8_data;
|
||||
spv_size = flash_attn_f32_f16_fp32_int8_len;
|
||||
}
|
||||
#endif
|
||||
{
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, )
|
||||
}
|
||||
} else {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product && device->subgroup_clustered) {
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32_int8)
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32)
|
||||
} else {
|
||||
if (device->fp16) {
|
||||
if (f32acc) { spv_data = flash_attn_f32_f16_data; spv_size = flash_attn_f32_f16_len; }
|
||||
else { spv_data = flash_attn_f32_f16_f16acc_data; spv_size = flash_attn_f32_f16_f16acc_len; }
|
||||
} else {
|
||||
spv_data = flash_attn_f32_f16_fp32_data;
|
||||
spv_size = flash_attn_f32_f16_fp32_len;
|
||||
}
|
||||
}
|
||||
const char *name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16";
|
||||
ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7,
|
||||
sizeof(vk_flash_attn_push_constants), {Br, 1, 1},
|
||||
get_fa_spec_constants(fa.first), aligned ? Bc : 1, true,
|
||||
!fa_ds, !fa_ds ? fa_sgs : 0);
|
||||
}
|
||||
|
||||
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (device->coopmat1_fa_support) {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT1, _cm1)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT1, _cm1)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT1, _cm1)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT1, _cm1)
|
||||
}
|
||||
#endif
|
||||
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
#define CREATE_FA_CM2_MIXED() \
|
||||
for (int fa_k_ty = 0; fa_k_ty < (int)GGML_TYPE_COUNT; ++fa_k_ty) { \
|
||||
for (auto &fa : device->pipeline_flash_attn_f32_f16[fa_k_ty]) { \
|
||||
FaCodePath path = fa.first.path; \
|
||||
uint32_t Br = fa.first.Br; \
|
||||
uint32_t Bc = fa.first.Bc; \
|
||||
bool aligned = fa.first.aligned; \
|
||||
bool f32acc = fa.first.f32acc; \
|
||||
if (path == FA_COOPMAT2) { \
|
||||
if (aligned) { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \
|
||||
} \
|
||||
} else { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
for (auto &fa : device->pipeline_flash_attn_f32_f16) {
|
||||
if (fa.first.path != FA_COOPMAT1) continue;
|
||||
const uint32_t Br = fa.first.Br;
|
||||
const uint32_t Bc = fa.first.Bc;
|
||||
const bool aligned = fa.first.aligned;
|
||||
const bool f32acc = fa.first.f32acc;
|
||||
const uint32_t fa_sgs = fa.first.subgroup_size;
|
||||
const bool fa_ds = fa.first.subgroup_size == 0;
|
||||
|
||||
const void * spv_data;
|
||||
size_t spv_size;
|
||||
if (f32acc) { spv_data = flash_attn_f32_f16_cm1_data; spv_size = flash_attn_f32_f16_cm1_len; }
|
||||
else { spv_data = flash_attn_f32_f16_f16acc_cm1_data; spv_size = flash_attn_f32_f16_f16acc_cm1_len; }
|
||||
const char *name = aligned ? "flash_attn_f32_f16_aligned_cm1" : "flash_attn_f32_f16_cm1";
|
||||
ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7,
|
||||
sizeof(vk_flash_attn_push_constants), {Br, 1, 1},
|
||||
get_fa_spec_constants(fa.first), aligned ? Bc : 1, true,
|
||||
!fa_ds, !fa_ds ? fa_sgs : 0);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (device->coopmat2) {
|
||||
for (auto &fa : device->pipeline_flash_attn_f32_f16) {
|
||||
if (fa.first.path != FA_COOPMAT2) continue;
|
||||
const uint32_t Br = fa.first.Br;
|
||||
const uint32_t Bc = fa.first.Bc;
|
||||
const bool aligned = fa.first.aligned;
|
||||
const bool f32acc = fa.first.f32acc;
|
||||
|
||||
const void * spv_data;
|
||||
size_t spv_size;
|
||||
const char * name;
|
||||
if (aligned) {
|
||||
if (f32acc) { spv_data = flash_attn_f32_f16_cm2_data; spv_size = flash_attn_f32_f16_cm2_len; name = "flash_attn_f32_f16_aligned_f32acc_cm2"; }
|
||||
else { spv_data = flash_attn_f32_f16_f16acc_cm2_data; spv_size = flash_attn_f32_f16_f16acc_cm2_len; name = "flash_attn_f32_f16_aligned_f16acc_cm2"; }
|
||||
} else {
|
||||
if (f32acc) { spv_data = flash_attn_f32_f16_cm2_data; spv_size = flash_attn_f32_f16_cm2_len; name = "flash_attn_f32_f16_f32acc_cm2"; }
|
||||
else { spv_data = flash_attn_f32_f16_f16acc_cm2_data; spv_size = flash_attn_f32_f16_f16acc_cm2_len; name = "flash_attn_f32_f16_f16acc_cm2"; }
|
||||
}
|
||||
ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7,
|
||||
sizeof(vk_flash_attn_push_constants), {Br, 1, 1},
|
||||
get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, false, 0);
|
||||
}
|
||||
if (device->coopmat2) {
|
||||
CREATE_FA_CM2_MIXED();
|
||||
}
|
||||
#undef CREATE_FA_CM2_MIXED
|
||||
#endif
|
||||
#undef CREATE_FA
|
||||
|
||||
const int mul_mat_id_param_count = 5;
|
||||
|
||||
@@ -8940,8 +8924,9 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) {
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type) {
|
||||
GGML_UNUSED(f32acc);
|
||||
GGML_UNUSED(v_type);
|
||||
// Needs to be kept up to date on shader changes
|
||||
const uint32_t wg_size = params.workgroup_size;
|
||||
const uint32_t Br = params.block_rows;
|
||||
@@ -8949,10 +8934,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
|
||||
|
||||
const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
|
||||
|
||||
const bool mmq = device->integer_dot_product && device->subgroup_clustered &&
|
||||
(kv_type == GGML_TYPE_Q4_0 || kv_type == GGML_TYPE_Q4_1 ||
|
||||
kv_type == GGML_TYPE_Q5_0 || kv_type == GGML_TYPE_Q5_1 ||
|
||||
kv_type == GGML_TYPE_Q8_0 || kv_type == GGML_TYPE_IQ4_NL);
|
||||
const bool mmq = ggml_vk_fa_scalar_uses_mmq(device, k_type);
|
||||
|
||||
// tmpsh is overestimated slightly
|
||||
const uint32_t tmpsh = wg_size * sizeof(float);
|
||||
@@ -8969,17 +8951,10 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
|
||||
// kvsh uses D = HSV (K goes through kblocksh instead)
|
||||
kvsh = params.shmem_staging ? Bc * (hsv / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
|
||||
|
||||
// block_a_cache size depends on quant type
|
||||
uint32_t block_a_size;
|
||||
switch (kv_type) {
|
||||
case GGML_TYPE_Q4_0: block_a_size = 4 * sizeof(uint32_t) + float_type_size; break;
|
||||
case GGML_TYPE_Q4_1: block_a_size = 4 * sizeof(uint32_t) + 2 * float_type_size; break;
|
||||
case GGML_TYPE_Q5_0: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + float_type_size; break;
|
||||
case GGML_TYPE_Q5_1: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + 2 * float_type_size; break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_IQ4_NL: block_a_size = 8 * sizeof(int32_t) + float_type_size; break;
|
||||
default: block_a_size = 0; break;
|
||||
}
|
||||
// The mixed MMQ shader uses a superset block_a_cache that fits every
|
||||
// FA-supported quant: int32_t qs[8] + uint32_t qh + FLOAT_TYPEV2 dm.
|
||||
// Single-scale types leave dm.y unused; non-Q5_* leave qh unused.
|
||||
const uint32_t block_a_size = 8 * sizeof(int32_t) + sizeof(uint32_t) + 2 * float_type_size;
|
||||
kblocksh_size = params.shmem_staging ? Bc * (hsk / 32) * block_a_size : block_a_size;
|
||||
} else {
|
||||
Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
|
||||
@@ -9117,10 +9092,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
|
||||
tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, v->type, f32acc);
|
||||
|
||||
if (tuning_params.path != FA_COOPMAT2) {
|
||||
GGML_ASSERT(k->type == v->type);
|
||||
}
|
||||
|
||||
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
|
||||
uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
|
||||
uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
|
||||
@@ -9164,7 +9135,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
|
||||
{
|
||||
std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
|
||||
auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type];
|
||||
auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16;
|
||||
auto it = pipelines.find(fa_pipeline_state);
|
||||
if (it != pipelines.end()) {
|
||||
pipeline = it->second;
|
||||
@@ -15642,10 +15613,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
// mismatching K/V type is currently supported for coopmat2 only.
|
||||
if (op->src[1]->type != op->src[2]->type && !coopmat2) {
|
||||
return false;
|
||||
}
|
||||
auto fa_kv_ok = [coopmat2](ggml_type t) {
|
||||
switch (t) {
|
||||
case GGML_TYPE_F32:
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
|
||||
#include "types.glsl"
|
||||
#include "flash_attn_base.glsl"
|
||||
#include "flash_attn_dequant.glsl"
|
||||
|
||||
const uint32_t HSK_per_thread = HSK / D_split;
|
||||
const uint32_t HSV_per_thread = HSV / D_split;
|
||||
@@ -128,18 +129,20 @@ void main() {
|
||||
|
||||
Qf[buf_ib].qs[buf_iqs] = pack32(i8vec4(vals));
|
||||
|
||||
#if defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
|
||||
if (buf_iqs == 0) {
|
||||
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0);
|
||||
}
|
||||
#else // Q4_0, Q4_1, Q5_0, Q5_1
|
||||
const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w;
|
||||
const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8);
|
||||
// Q8_0 K only needs (qd, _); the asymmetric Q4_*/Q5_* family also stores
|
||||
// the row-sum scaled by qd, used in k_dot_correction.
|
||||
if (FaTypeK == FA_TYPE_Q8_0) {
|
||||
if (buf_iqs == 0) {
|
||||
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0);
|
||||
}
|
||||
} else {
|
||||
const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w;
|
||||
const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8);
|
||||
|
||||
if (buf_iqs == 0) {
|
||||
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd);
|
||||
if (buf_iqs == 0) {
|
||||
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
barrier();
|
||||
@@ -177,13 +180,9 @@ void main() {
|
||||
// mo_offset will point to the tile starting at row i*Br and col 0
|
||||
uint32_t mo_offset = mo_stride * i;
|
||||
|
||||
#if BLOCK_SIZE > 1
|
||||
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
|
||||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
|
||||
#else
|
||||
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
||||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
||||
#endif
|
||||
// FaBlockBytesK/V == 2 for f16, 16 for f32, ggml block byte size for quants.
|
||||
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / FaBlockBytesK;
|
||||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / FaBlockBytesV;
|
||||
uint32_t m_offset = gqa_iq1*KV;
|
||||
if (p.nem2 != 1 || p.nem3 != 1) {
|
||||
m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
|
||||
@@ -257,21 +256,21 @@ void main() {
|
||||
if (idx + gl_WorkGroupSize.x <= Bc * HSK / 4 || c < Bc) {
|
||||
FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0);
|
||||
if (!KV_bounds_check || j * Bc + c < KV) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
#else
|
||||
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
|
||||
#endif
|
||||
if (USE_DECODE_K) {
|
||||
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE_K + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE_K;
|
||||
uint iqs = (coord % BLOCK_SIZE_K);
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
} else {
|
||||
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
|
||||
}
|
||||
}
|
||||
|
||||
kvsh[c * kvsh_stride + d] = K_Tf;
|
||||
}
|
||||
}
|
||||
#else // MMQ
|
||||
const uint ints_per_block = 8 / QUANT_R_MMQ;
|
||||
const uint ints_per_block = 8u / fa_quant_r_mmq(FaTypeK);
|
||||
const uint quant_iters = Bc * HSK / 32 * ints_per_block;
|
||||
[[unroll]] for (uint32_t idx = 0; idx < quant_iters; idx += gl_WorkGroupSize.x) {
|
||||
const uint32_t iqs = (idx + tid) % ints_per_block;
|
||||
@@ -310,15 +309,13 @@ void main() {
|
||||
FLOAT_TYPEV4 K_Tf;
|
||||
if (SHMEM_STAGING != 0) {
|
||||
K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
|
||||
} else {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
} else if (USE_DECODE_K) {
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE_K;
|
||||
uint iqs = (coord % BLOCK_SIZE_K);
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
#else
|
||||
} else {
|
||||
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
|
||||
#endif
|
||||
}
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Sf[r][c] += dot(ACC_TYPEV4(Q_cache[r]), ACC_TYPEV4(K_Tf));
|
||||
@@ -335,15 +332,13 @@ void main() {
|
||||
FLOAT_TYPEV4 K_Tf;
|
||||
if (SHMEM_STAGING != 0) {
|
||||
K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
|
||||
} else {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
} else if (USE_DECODE_K) {
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE_K;
|
||||
uint iqs = (coord % BLOCK_SIZE_K);
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
#else
|
||||
} else {
|
||||
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
|
||||
#endif
|
||||
}
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Sf[r][c] += dot(ACC_TYPEV4(Qf[tile_row(r) * qf_stride + d * D_split + d_tid]), ACC_TYPEV4(K_Tf));
|
||||
@@ -366,72 +361,47 @@ void main() {
|
||||
int32_t k_quants[d_per_step];
|
||||
ACC_TYPEV2 k_dm;
|
||||
|
||||
// Q4_*/Q5_* take the block-8 fast path when one step covers a full
|
||||
// block; Q8_0 always goes through the per-int get_k_qs* helpers
|
||||
// (its qs is byte-packed, not nibble-packed).
|
||||
const bool block8_fast = (d_per_step == 8) && (FaTypeK != FA_TYPE_Q8_0);
|
||||
|
||||
if (SHMEM_STAGING != 0) {
|
||||
const uint k_block_idx = (d_tid * (HSK_per_thread / 4) + d_block) / 8;
|
||||
const uint buf_ib = (c * cols_per_iter + col_tid) * qf_stride + k_block_idx;
|
||||
#if QUANT_AUXF == 1
|
||||
k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm, 0.0);
|
||||
#else
|
||||
k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm);
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
if (d_per_step == 8) {
|
||||
if (block8_fast) {
|
||||
const bool has_qh = (FaTypeK == FA_TYPE_Q5_0) || (FaTypeK == FA_TYPE_Q5_1);
|
||||
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
|
||||
uint vui = kblocksh[buf_ib].qs[d];
|
||||
k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
|
||||
k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
|
||||
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF;
|
||||
uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF;
|
||||
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
|
||||
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
|
||||
#endif
|
||||
if (has_qh) {
|
||||
uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF;
|
||||
uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF;
|
||||
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
|
||||
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
|
||||
}
|
||||
}
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
} else {
|
||||
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
|
||||
k_quants[d] = get_k_qs_shmem(buf_ib, (d_tid * (HSK_per_thread / 4) + d_block) % 8 + d);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d_tid * (HSK_per_thread / 4) + d_block);
|
||||
const uint ib = coord / BLOCK_SIZE;
|
||||
const uint iqs = (coord % BLOCK_SIZE);
|
||||
const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d_tid * (HSK_per_thread / 4) + d_block);
|
||||
const uint ib = coord / BLOCK_SIZE_K;
|
||||
const uint iqs = (coord % BLOCK_SIZE_K);
|
||||
|
||||
#if QUANT_AUXF == 1
|
||||
k_dm = ACC_TYPEV2(get_k_d(ib, k_offset), 0.0);
|
||||
#else
|
||||
k_dm = ACC_TYPEV2(get_k_dm(ib, k_offset));
|
||||
#endif
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
if (d_per_step == 8) {
|
||||
#if defined(DATA_A_Q5_0)
|
||||
uint qh = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qh[0],
|
||||
k_packed.k_data_packed16[k_offset + ib].qh[1]));
|
||||
#elif defined(DATA_A_Q5_1)
|
||||
uint qh = k_packed.k_data_packed16[k_offset + ib].qh;
|
||||
#endif
|
||||
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
|
||||
#if defined(A_TYPE_PACKED32)
|
||||
uint vui = k_packed32.k_data_packed32[k_offset + ib].qs[d];
|
||||
#else
|
||||
uint vui = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 0],
|
||||
k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 1]));
|
||||
#endif
|
||||
k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
|
||||
k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
|
||||
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
uint qh_lo = (qh >> (d * 4)) & 0xF;
|
||||
uint qh_hi = (qh >> (d * 4 + 16)) & 0xF;
|
||||
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
|
||||
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
|
||||
#endif
|
||||
k_dm = ACC_TYPEV2(get_k_scale(ib, k_offset));
|
||||
|
||||
if (block8_fast) {
|
||||
fa_k_qs_block8 blk = get_k_qs_block8(ib, k_offset);
|
||||
[[unroll]] for (uint32_t d = 0; d < 8; d++) {
|
||||
k_quants[d] = blk.qs[d];
|
||||
}
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
} else {
|
||||
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
|
||||
k_quants[d] = get_k_qs(ib, iqs + d * 4, k_offset);
|
||||
}
|
||||
@@ -516,14 +486,14 @@ void main() {
|
||||
if (idx + gl_WorkGroupSize.x <= Bc * HSV / 4 || c < Bc) {
|
||||
FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0);
|
||||
if (!KV_bounds_check || j * Bc + c < KV) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
#else
|
||||
V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
|
||||
#endif
|
||||
if (USE_DECODE_V) {
|
||||
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE_V + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE_V;
|
||||
uint iqs = (coord % BLOCK_SIZE_V);
|
||||
V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
} else {
|
||||
V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
|
||||
}
|
||||
}
|
||||
|
||||
kvsh[c * kvsh_stride + d] = V_Tf;
|
||||
@@ -547,15 +517,13 @@ void main() {
|
||||
FLOAT_TYPEV4 Vf;
|
||||
if (SHMEM_STAGING != 0) {
|
||||
Vf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
|
||||
} else {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
} else if (USE_DECODE_V) {
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE_V + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE_V;
|
||||
uint iqs = (coord % BLOCK_SIZE_V);
|
||||
Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
#else
|
||||
} else {
|
||||
Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
|
||||
#endif
|
||||
}
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] += FLOAT_TYPEV4(Pf[r] * Vf);
|
||||
|
||||
@@ -87,176 +87,58 @@ layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];};
|
||||
|
||||
#define BINDING_IDX_K 0
|
||||
#define BINDING_IDX_V 1
|
||||
#if defined(DATA_A_F32)
|
||||
layout (binding = 1) readonly buffer K_PACKED {vec4 k_data_packed[];} k_packed;
|
||||
layout (binding = 2) readonly buffer V_PACKED {vec4 v_data_packed[];} v_packed;
|
||||
#elif defined(A_TYPE_PACKED16)
|
||||
layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
|
||||
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
|
||||
#endif
|
||||
|
||||
#if defined(A_TYPE_PACKED32)
|
||||
layout (binding = 1) readonly buffer K_PACKED32 {A_TYPE_PACKED32 k_data_packed32[];} k_packed32;
|
||||
layout (binding = 2) readonly buffer V_PACKED32 {A_TYPE_PACKED32 v_data_packed32[];} v_packed32;
|
||||
#endif
|
||||
// FaTypeK / FaTypeV spec constant values. These mirror enum ggml_type so the
|
||||
// host can pass the type directly. Keep in sync with ggml.h.
|
||||
#define FA_TYPE_F32 0u
|
||||
#define FA_TYPE_F16 1u
|
||||
#define FA_TYPE_Q4_0 2u
|
||||
#define FA_TYPE_Q4_1 3u
|
||||
#define FA_TYPE_Q5_0 6u
|
||||
#define FA_TYPE_Q5_1 7u
|
||||
#define FA_TYPE_Q8_0 8u
|
||||
#define FA_TYPE_Q1_0 41u
|
||||
|
||||
#ifndef BLOCK_SIZE
|
||||
#define BLOCK_SIZE 1
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_F32)
|
||||
#undef BLOCK_SIZE
|
||||
#define BLOCK_SIZE 4
|
||||
#define BLOCK_BYTE_SIZE 16
|
||||
|
||||
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
// iqs is currently always zero in the flash attention shaders
|
||||
if (binding_idx == BINDING_IDX_K) {
|
||||
return FLOAT_TYPEV4(k_packed.k_data_packed[a_offset + ib]);
|
||||
} else {
|
||||
return FLOAT_TYPEV4(v_packed.v_data_packed[a_offset + ib]);
|
||||
// Number of matrix elements per buffer block, derived from the K/V type spec
|
||||
// constant. F32 is treated as a vec4 "block" of 4 floats. F16 uses block size 1
|
||||
// and bypasses the dequant path entirely. Quants follow their ggml block sizes.
|
||||
uint fa_block_elems(uint ty) {
|
||||
switch (ty) {
|
||||
case FA_TYPE_F32: return 4u;
|
||||
case FA_TYPE_F16: return 1u;
|
||||
case FA_TYPE_Q4_0: return uint(QUANT_K_Q4_0);
|
||||
case FA_TYPE_Q4_1: return uint(QUANT_K_Q4_1);
|
||||
case FA_TYPE_Q5_0: return uint(QUANT_K_Q5_0);
|
||||
case FA_TYPE_Q5_1: return uint(QUANT_K_Q5_1);
|
||||
case FA_TYPE_Q8_0: return uint(QUANT_K_Q8_0);
|
||||
case FA_TYPE_Q1_0: return uint(QUANT_K_Q1_0); // cm2-only, harmless elsewhere
|
||||
default: return 1u;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0)
|
||||
#define BLOCK_BYTE_SIZE 18
|
||||
#elif defined(DATA_A_Q4_1)
|
||||
#define BLOCK_BYTE_SIZE 20
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
|
||||
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
if (binding_idx == BINDING_IDX_K) {
|
||||
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
|
||||
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF);
|
||||
#ifdef DATA_A_Q4_1
|
||||
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * nibbles + FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].m);
|
||||
#else
|
||||
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f));
|
||||
#endif
|
||||
} else {
|
||||
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
|
||||
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF);
|
||||
#ifdef DATA_A_Q4_1
|
||||
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * nibbles + FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].m);
|
||||
#else
|
||||
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f));
|
||||
#endif
|
||||
// QUANT_R_MMQ for FA-eligible K types. Q4_*/Q5_* store two nibbles per byte
|
||||
// (R==2); Q8_0 stores one byte per element (R==1). Used to derive the number
|
||||
// of int32s per 32-element block on the MMQ K path: ints_per_block == 8 / R.
|
||||
uint fa_quant_r_mmq(uint ty) {
|
||||
switch (ty) {
|
||||
case FA_TYPE_Q4_0: return uint(QUANT_R_Q4_0);
|
||||
case FA_TYPE_Q4_1: return uint(QUANT_R_Q4_1);
|
||||
case FA_TYPE_Q5_0: return uint(QUANT_R_Q5_0);
|
||||
case FA_TYPE_Q5_1: return uint(QUANT_R_Q5_1);
|
||||
case FA_TYPE_Q8_0: return uint(QUANT_R_Q8_0);
|
||||
default: return 1u;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q5_0)
|
||||
#define BLOCK_BYTE_SIZE 22
|
||||
#elif defined(DATA_A_Q5_1)
|
||||
#define BLOCK_BYTE_SIZE 24
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
if (binding_idx == BINDING_IDX_K) {
|
||||
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
|
||||
#ifdef DATA_A_Q5_1
|
||||
uint qh = k_packed.k_data_packed16[a_offset + ib].qh;
|
||||
#else
|
||||
uint qh = uint(k_packed.k_data_packed16[a_offset + ib].qh[0]) | (uint(k_packed.k_data_packed16[a_offset + ib].qh[1]) << 16);
|
||||
#endif
|
||||
FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) * FLOAT_TYPE(16.0f);
|
||||
|
||||
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF);
|
||||
#ifdef DATA_A_Q5_1
|
||||
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles + hb) + FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].m);
|
||||
#else
|
||||
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f));
|
||||
#endif
|
||||
} else {
|
||||
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
|
||||
#ifdef DATA_A_Q5_1
|
||||
uint qh = v_packed.v_data_packed16[a_offset + ib].qh;
|
||||
#else
|
||||
uint qh = uint(v_packed.v_data_packed16[a_offset + ib].qh[0]) | (uint(v_packed.v_data_packed16[a_offset + ib].qh[1]) << 16);
|
||||
#endif
|
||||
FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) * FLOAT_TYPE(16.0f);
|
||||
|
||||
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF);
|
||||
#ifdef DATA_A_Q5_1
|
||||
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles + hb) + FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].m);
|
||||
#else
|
||||
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
#define BLOCK_BYTE_SIZE 18
|
||||
|
||||
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
if (binding_idx == BINDING_IDX_K) {
|
||||
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
|
||||
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(
|
||||
kvalues_iq4nl[vui_lo & 0xF],
|
||||
kvalues_iq4nl[(vui_lo >> 8) & 0xF],
|
||||
kvalues_iq4nl[vui_hi & 0xF],
|
||||
kvalues_iq4nl[(vui_hi >> 8) & 0xF]);
|
||||
} else {
|
||||
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
|
||||
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(
|
||||
kvalues_iq4nl[vui_lo & 0xF],
|
||||
kvalues_iq4nl[(vui_lo >> 8) & 0xF],
|
||||
kvalues_iq4nl[vui_hi & 0xF],
|
||||
kvalues_iq4nl[(vui_hi >> 8) & 0xF]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#if defined(DATA_A_Q8_0)
|
||||
#define BLOCK_BYTE_SIZE 34
|
||||
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
if (binding_idx == BINDING_IDX_K) {
|
||||
const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
||||
const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
||||
|
||||
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y);
|
||||
} else {
|
||||
const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
||||
const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
||||
|
||||
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
// These can't be `const` globals because GLSL forbids function calls in global
|
||||
// const initializers, even when the spec constants would let the driver fold
|
||||
// them. Macros expand at the use site and fold after specialization.
|
||||
#define BLOCK_SIZE_K fa_block_elems(FaTypeK)
|
||||
#define BLOCK_SIZE_V fa_block_elems(FaTypeV)
|
||||
// F16 reads f16 elements directly from the binding; everything else routes
|
||||
// through dequantize4 / the MMQ helpers to unpack from the packed block layout.
|
||||
#define USE_DECODE_K (FaTypeK != FA_TYPE_F16)
|
||||
#define USE_DECODE_V (FaTypeV != FA_TYPE_F16)
|
||||
|
||||
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
#include "types.glsl"
|
||||
#include "flash_attn_base.glsl"
|
||||
#include "flash_attn_dequant.glsl"
|
||||
|
||||
// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
|
||||
const uint32_t MatBr = 16;
|
||||
@@ -127,13 +128,9 @@ void main() {
|
||||
// mo_offset will point to the tile starting at row i*Br and col 0
|
||||
uint32_t mo_offset = mo_stride * i;
|
||||
|
||||
#if BLOCK_SIZE > 1
|
||||
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
|
||||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
|
||||
#else
|
||||
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
||||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
||||
#endif
|
||||
// FaBlockBytesK/V == 2 for f16 (sizeof f16) and == 16 for f32 (vec4) and == ggml block size for quants.
|
||||
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / FaBlockBytesK;
|
||||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / FaBlockBytesV;
|
||||
uint32_t m_offset = gqa_iq1*KV;
|
||||
if (p.nem2 != 1 || p.nem3 != 1) {
|
||||
m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
|
||||
@@ -227,14 +224,14 @@ void main() {
|
||||
if (idx + gl_WorkGroupSize.x <= Bc * HSK_pad / 4 || c < Bc) {
|
||||
f16vec4 K_Tf = f16vec4(0);
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
#else
|
||||
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
|
||||
#endif
|
||||
if (USE_DECODE_K) {
|
||||
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE_K + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE_K;
|
||||
uint iqs = (coord % BLOCK_SIZE_K);
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
} else {
|
||||
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
|
||||
}
|
||||
}
|
||||
|
||||
kvsh[c * kvsh_stride + d] = K_Tf;
|
||||
@@ -256,47 +253,40 @@ void main() {
|
||||
// staged through a Bc * MatBr size staging buffer.
|
||||
// If K is not type f16, then it is always staged for dequantization.
|
||||
if (SHMEM_STAGING == 0) {
|
||||
#if BLOCK_SIZE == 1
|
||||
if (KV_bounds_check || d * 16 + 16 > HSK) {
|
||||
#endif
|
||||
barrier();
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t col_vec = (idx + tid) % (MatBr / 4);
|
||||
uint32_t row = (idx + tid) / (MatBr / 4);
|
||||
if (idx + tid < Bc * MatBr / 4) {
|
||||
f16vec4 K_Tf = f16vec4(0);
|
||||
if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
#else
|
||||
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]);
|
||||
#endif
|
||||
// For quants we always need to dequant into kvsh; for f16 we can load
|
||||
// directly from global memory when alignment / bounds allow it.
|
||||
const bool stage_k = USE_DECODE_K || KV_bounds_check || d * 16 + 16 > HSK;
|
||||
if (stage_k) {
|
||||
barrier();
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t col_vec = (idx + tid) % (MatBr / 4);
|
||||
uint32_t row = (idx + tid) / (MatBr / 4);
|
||||
if (idx + tid < Bc * MatBr / 4) {
|
||||
f16vec4 K_Tf = f16vec4(0);
|
||||
if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) {
|
||||
if (USE_DECODE_K) {
|
||||
uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE_K + d * 16 + col_vec * 4;
|
||||
uint ib = coord / BLOCK_SIZE_K;
|
||||
uint iqs = (coord % BLOCK_SIZE_K);
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
} else {
|
||||
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]);
|
||||
}
|
||||
}
|
||||
|
||||
kvsh[row * kvsh_stride + col_vec] = K_Tf;
|
||||
}
|
||||
|
||||
kvsh[row * kvsh_stride + col_vec] = K_Tf;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
barrier();
|
||||
#if BLOCK_SIZE == 1
|
||||
}
|
||||
#endif
|
||||
|
||||
#if BLOCK_SIZE == 1
|
||||
if (KV_bounds_check || d * 16 + 16 > HSK)
|
||||
#endif
|
||||
{
|
||||
if (stage_k) {
|
||||
uint coord = (gl_SubgroupID * MatBc) * kvsh_stride;
|
||||
coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
#if BLOCK_SIZE == 1
|
||||
else {
|
||||
} else {
|
||||
const uint coord = k_offset / 4 + (j * Bc + gl_SubgroupID * MatBc) * k_stride / 4 + d * 16 / 4;
|
||||
coopMatLoad(KMat, data_kv4, coord, k_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
uint coord = (gl_SubgroupID * MatBc) * kvsh_stride + d * 16 / 4;
|
||||
coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
@@ -397,14 +387,14 @@ void main() {
|
||||
if (idx + gl_WorkGroupSize.x <= Bc * HSV_pad / 4 || c < Bc) {
|
||||
f16vec4 V_Tf = f16vec4(0);
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
#else
|
||||
V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
|
||||
#endif
|
||||
if (USE_DECODE_V) {
|
||||
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE_V + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE_V;
|
||||
uint iqs = (coord % BLOCK_SIZE_V);
|
||||
V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
} else {
|
||||
V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
|
||||
}
|
||||
}
|
||||
|
||||
kvsh[c * kvsh_stride + d] = V_Tf;
|
||||
@@ -431,36 +421,33 @@ void main() {
|
||||
// staged through a Bc * MatBr size staging buffer.
|
||||
// If V is not type f16, then it is always staged for dequantization.
|
||||
if (SHMEM_STAGING == 0) {
|
||||
#if BLOCK_SIZE == 1
|
||||
// For f16, only preload if not aligned
|
||||
if (KV_bounds_check) {
|
||||
#endif
|
||||
[[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) {
|
||||
const uint idx = i * gl_WorkGroupSize.x + tid;
|
||||
const uint row = idx / v_cols;
|
||||
const uint col = idx % v_cols;
|
||||
// For quants we always preload via kvsh. For f16 we only preload when
|
||||
// alignment / bounds force it (otherwise we coopMatLoad direct from data_vv4).
|
||||
const bool stage_v = USE_DECODE_V || KV_bounds_check;
|
||||
if (stage_v) {
|
||||
[[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) {
|
||||
const uint idx = i * gl_WorkGroupSize.x + tid;
|
||||
const uint row = idx / v_cols;
|
||||
const uint col = idx % v_cols;
|
||||
|
||||
const uint v_row = j * Bc + row;
|
||||
const uint v_col = hsv_tile * MatBc * row_split + col * 4;
|
||||
const uint v_row = j * Bc + row;
|
||||
const uint v_col = hsv_tile * MatBc * row_split + col * 4;
|
||||
|
||||
const uint coord = v_row * v_stride * BLOCK_SIZE + v_col;
|
||||
const uint ib = coord / BLOCK_SIZE;
|
||||
const uint iqs = coord % BLOCK_SIZE;
|
||||
const uint coord = v_row * v_stride * BLOCK_SIZE_V + v_col;
|
||||
const uint ib = coord / BLOCK_SIZE_V;
|
||||
const uint iqs = coord % BLOCK_SIZE_V;
|
||||
|
||||
if (!KV_bounds_check || (v_row < KV && v_col < HSV)) {
|
||||
#if BLOCK_SIZE > 1
|
||||
kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
#else
|
||||
kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
|
||||
#endif
|
||||
} else {
|
||||
kvsh[row * vsh_stride + col] = f16vec4(0.0f);
|
||||
if (!KV_bounds_check || (v_row < KV && v_col < HSV)) {
|
||||
if (USE_DECODE_V) {
|
||||
kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
} else {
|
||||
kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
|
||||
}
|
||||
} else {
|
||||
kvsh[row * vsh_stride + col] = f16vec4(0.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if BLOCK_SIZE == 1
|
||||
}
|
||||
#endif
|
||||
}
|
||||
barrier();
|
||||
|
||||
@@ -471,15 +458,12 @@ void main() {
|
||||
coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
if (SHMEM_STAGING == 0) {
|
||||
#if BLOCK_SIZE == 1
|
||||
if (!KV_bounds_check) {
|
||||
if (!USE_DECODE_V && !KV_bounds_check) {
|
||||
// F16 values can be loaded directly from global memory
|
||||
const uint v_tile_row = j * Bc + bc_chunk * MatBc;
|
||||
const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4;
|
||||
coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
} else {
|
||||
const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4);
|
||||
coopMatLoad(QMat, kvsh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
|
||||
@@ -28,43 +28,28 @@ layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_
|
||||
uint8_t raw[FaBlockBytesV];
|
||||
};
|
||||
|
||||
uint fa_block_elems(uint ty) {
|
||||
switch (ty) {
|
||||
case 0u: return 4u; // GGML_TYPE_F32: vec4 block (matches decodeBufF32 / dequantFuncF32)
|
||||
case 1u: return 1u; // GGML_TYPE_F16
|
||||
case 2u: return uint(QUANT_K_Q4_0);
|
||||
case 3u: return uint(QUANT_K_Q4_1);
|
||||
case 6u: return uint(QUANT_K_Q5_0);
|
||||
case 7u: return uint(QUANT_K_Q5_1);
|
||||
case 8u: return uint(QUANT_K_Q8_0);
|
||||
case 41u: return uint(QUANT_K_Q1_0);
|
||||
default:
|
||||
return 1u;
|
||||
}
|
||||
}
|
||||
|
||||
float16_t faDecodeK(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) {
|
||||
switch (FaTypeK) {
|
||||
case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock);
|
||||
case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
|
||||
case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
|
||||
case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
|
||||
case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
|
||||
case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
|
||||
case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
|
||||
case FA_TYPE_F32: return dequantFuncF32 (decodeBufF32 (bl_in), blockCoords, coordInBlock);
|
||||
case FA_TYPE_Q4_0: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
|
||||
case FA_TYPE_Q4_1: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
|
||||
case FA_TYPE_Q5_0: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
|
||||
case FA_TYPE_Q5_1: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
|
||||
case FA_TYPE_Q8_0: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
|
||||
case FA_TYPE_Q1_0: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
|
||||
default: return float16_t(0);
|
||||
}
|
||||
}
|
||||
|
||||
float16_t faDecodeV(const decodeBufFA_V bl_in, const uint blockCoords[2], const uint coordInBlock[2]) {
|
||||
switch (FaTypeV) {
|
||||
case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock);
|
||||
case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
|
||||
case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
|
||||
case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
|
||||
case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
|
||||
case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
|
||||
case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
|
||||
case FA_TYPE_F32: return dequantFuncF32 (decodeBufF32 (bl_in), blockCoords, coordInBlock);
|
||||
case FA_TYPE_Q4_0: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
|
||||
case FA_TYPE_Q4_1: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
|
||||
case FA_TYPE_Q5_0: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
|
||||
case FA_TYPE_Q5_1: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
|
||||
case FA_TYPE_Q8_0: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
|
||||
case FA_TYPE_Q1_0: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
|
||||
default: return float16_t(0);
|
||||
}
|
||||
}
|
||||
|
||||
123
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl
Normal file
123
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl
Normal file
@@ -0,0 +1,123 @@
|
||||
// Asymmetric K/V flash attention: aliased SSBO views of bindings 1 (K) and 2 (V)
|
||||
// covering every supported FA element type, plus an uber dequantize4() that
|
||||
// switches on FaTypeK / FaTypeV. After spec-constant specialization the driver
|
||||
// folds away every path except the one matching the K/V type for this pipeline.
|
||||
//
|
||||
// Included by flash_attn.comp and flash_attn_cm1.comp. Not included by
|
||||
// flash_attn_cm2.comp, which has its own buffer_reference-based decode path.
|
||||
//
|
||||
// We use macros (rather than per-quant decode functions taking a struct) on
|
||||
// purpose: the FA shaders don't enable GL_EXT_shader_explicit_arithmetic_types_float16
|
||||
// when FLOAT16 isn't defined, which makes float16-containing struct values
|
||||
// illegal to return from / pass to functions. Macros expand inline where the
|
||||
// float16 stays in storage and is converted to FLOAT_TYPE at use.
|
||||
|
||||
// F32 is fed as a vec4 "block" (4 floats), matching what dequant_funcs_cm2.glsl
|
||||
// does for F32 in the cm2 shader. FaBlockBytesK/V == 16 for F32.
|
||||
layout (binding = 1) readonly buffer K_PACKED_F32 { vec4 data[]; } k_packed_f32;
|
||||
layout (binding = 2) readonly buffer V_PACKED_F32 { vec4 data[]; } v_packed_f32;
|
||||
|
||||
layout (binding = 1) readonly buffer K_PACKED_Q4_0 { block_q4_0_packed16 data[]; } k_packed_q4_0;
|
||||
layout (binding = 2) readonly buffer V_PACKED_Q4_0 { block_q4_0_packed16 data[]; } v_packed_q4_0;
|
||||
layout (binding = 1) readonly buffer K_PACKED_Q4_1 { block_q4_1_packed16 data[]; } k_packed_q4_1;
|
||||
layout (binding = 2) readonly buffer V_PACKED_Q4_1 { block_q4_1_packed16 data[]; } v_packed_q4_1;
|
||||
layout (binding = 1) readonly buffer K_PACKED_Q5_0 { block_q5_0_packed16 data[]; } k_packed_q5_0;
|
||||
layout (binding = 2) readonly buffer V_PACKED_Q5_0 { block_q5_0_packed16 data[]; } v_packed_q5_0;
|
||||
layout (binding = 1) readonly buffer K_PACKED_Q5_1 { block_q5_1_packed16 data[]; } k_packed_q5_1;
|
||||
layout (binding = 2) readonly buffer V_PACKED_Q5_1 { block_q5_1_packed16 data[]; } v_packed_q5_1;
|
||||
layout (binding = 1) readonly buffer K_PACKED_Q8_0 { block_q8_0_packed16 data[]; } k_packed_q8_0;
|
||||
layout (binding = 2) readonly buffer V_PACKED_Q8_0 { block_q8_0_packed16 data[]; } v_packed_q8_0;
|
||||
|
||||
// Q4_1 and Q5_1 packed32 views: aliased to the same memory as the packed16
|
||||
// views, used by the MMQ K-side hot path for fast 4-uint loads.
|
||||
layout (binding = 1) readonly buffer K_PACKED_Q4_1_P32 { block_q4_1_packed32 data[]; } k_packed_q4_1_p32;
|
||||
layout (binding = 1) readonly buffer K_PACKED_Q5_1_P32 { block_q5_1_packed32 data[]; } k_packed_q5_1_p32;
|
||||
|
||||
// Per-quant decode bodies are expanded once for the K view set and once for
|
||||
// the V view set. The macros take the buffer name as a parameter.
|
||||
#define FA_DEQUANT4_F32(BUF) \
|
||||
return FLOAT_TYPEV4(BUF.data[a_offset + ib]);
|
||||
|
||||
#define FA_DEQUANT4_Q4_0(BUF) { \
|
||||
uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \
|
||||
uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \
|
||||
uint shift = (iqs & 0x10) >> 2; \
|
||||
vui_lo >>= shift; \
|
||||
vui_hi >>= shift; \
|
||||
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \
|
||||
vui_hi & 0xF, (vui_hi >> 8) & 0xF); \
|
||||
return FLOAT_TYPE(BUF.data[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f)); \
|
||||
}
|
||||
|
||||
#define FA_DEQUANT4_Q4_1(BUF) { \
|
||||
uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \
|
||||
uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \
|
||||
uint shift = (iqs & 0x10) >> 2; \
|
||||
vui_lo >>= shift; \
|
||||
vui_hi >>= shift; \
|
||||
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \
|
||||
vui_hi & 0xF, (vui_hi >> 8) & 0xF); \
|
||||
return FLOAT_TYPE(BUF.data[a_offset + ib].d) * nibbles \
|
||||
+ FLOAT_TYPE(BUF.data[a_offset + ib].m); \
|
||||
}
|
||||
|
||||
#define FA_DEQUANT4_Q5_0(BUF) { \
|
||||
uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \
|
||||
uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \
|
||||
uint shift = (iqs & 0x10) >> 2; \
|
||||
vui_lo >>= shift; \
|
||||
vui_hi >>= shift; \
|
||||
uint qh = uint(BUF.data[a_offset + ib].qh[0]) \
|
||||
| (uint(BUF.data[a_offset + ib].qh[1]) << 16); \
|
||||
FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, \
|
||||
(qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) \
|
||||
* FLOAT_TYPE(16.0f); \
|
||||
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \
|
||||
vui_hi & 0xF, (vui_hi >> 8) & 0xF); \
|
||||
return FLOAT_TYPE(BUF.data[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f)); \
|
||||
}
|
||||
|
||||
#define FA_DEQUANT4_Q5_1(BUF) { \
|
||||
uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \
|
||||
uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \
|
||||
uint shift = (iqs & 0x10) >> 2; \
|
||||
vui_lo >>= shift; \
|
||||
vui_hi >>= shift; \
|
||||
uint qh = BUF.data[a_offset + ib].qh; \
|
||||
FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, \
|
||||
(qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) \
|
||||
* FLOAT_TYPE(16.0f); \
|
||||
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \
|
||||
vui_hi & 0xF, (vui_hi >> 8) & 0xF); \
|
||||
return FLOAT_TYPE(BUF.data[a_offset + ib].d) * (nibbles + hb) \
|
||||
+ FLOAT_TYPE(BUF.data[a_offset + ib].m); \
|
||||
}
|
||||
|
||||
#define FA_DEQUANT4_Q8_0(BUF) { \
|
||||
const i8vec2 v0 = unpack8(int32_t(BUF.data[a_offset + ib].qs[iqs / 2 ])).xy; \
|
||||
const i8vec2 v1 = unpack8(int32_t(BUF.data[a_offset + ib].qs[iqs / 2 + 1])).xy; \
|
||||
return FLOAT_TYPE(BUF.data[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); \
|
||||
}
|
||||
|
||||
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
if (binding_idx == BINDING_IDX_K) {
|
||||
switch (FaTypeK) {
|
||||
case FA_TYPE_F32: FA_DEQUANT4_F32 (k_packed_f32)
|
||||
case FA_TYPE_Q4_0: FA_DEQUANT4_Q4_0(k_packed_q4_0)
|
||||
case FA_TYPE_Q4_1: FA_DEQUANT4_Q4_1(k_packed_q4_1)
|
||||
case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(k_packed_q5_0)
|
||||
case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(k_packed_q5_1)
|
||||
case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(k_packed_q8_0)
|
||||
}
|
||||
} else {
|
||||
switch (FaTypeV) {
|
||||
case FA_TYPE_F32: FA_DEQUANT4_F32 (v_packed_f32)
|
||||
case FA_TYPE_Q4_0: FA_DEQUANT4_Q4_0(v_packed_q4_0)
|
||||
case FA_TYPE_Q4_1: FA_DEQUANT4_Q4_1(v_packed_q4_1)
|
||||
case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(v_packed_q5_0)
|
||||
case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(v_packed_q5_1)
|
||||
case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(v_packed_q8_0)
|
||||
}
|
||||
}
|
||||
return FLOAT_TYPEV4(0);
|
||||
}
|
||||
@@ -1,149 +1,203 @@
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
|
||||
// MMQ K-side helpers, asymmetric form. Each function dispatches on FaTypeK and
|
||||
// reads from the matching aliased K binding declared in flash_attn_dequant.glsl.
|
||||
// Spec-constant specialization folds the unused paths.
|
||||
|
||||
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
|
||||
#ifdef DATA_A_Q4_0
|
||||
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
|
||||
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
|
||||
#else
|
||||
uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4];
|
||||
#endif
|
||||
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui >>= shift;
|
||||
|
||||
return int32_t(vui & 0x0F0F0F0F);
|
||||
switch (FaTypeK) {
|
||||
case FA_TYPE_Q4_0: {
|
||||
uint vui = pack32(u16vec2(k_packed_q4_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
|
||||
k_packed_q4_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui >>= shift;
|
||||
return int32_t(vui & 0x0F0F0F0F);
|
||||
}
|
||||
case FA_TYPE_Q4_1: { // uses packed32 alias
|
||||
uint vui = k_packed_q4_1_p32.data[a_offset + ib].qs[(iqs & 0xF) / 4];
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui >>= shift;
|
||||
return int32_t(vui & 0x0F0F0F0F);
|
||||
}
|
||||
case FA_TYPE_Q5_0: {
|
||||
uint vui = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
|
||||
k_packed_q5_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
|
||||
uint qh = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qh[0],
|
||||
k_packed_q5_0.data[a_offset + ib].qh[1]));
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui >>= shift;
|
||||
uint qh_bits = (qh >> iqs) & 0xF;
|
||||
return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
|
||||
}
|
||||
case FA_TYPE_Q5_1: { // qs via packed32, qh via packed16
|
||||
uint vui = k_packed_q5_1_p32.data[a_offset + ib].qs[(iqs & 0xF) / 4];
|
||||
uint qh = k_packed_q5_1.data[a_offset + ib].qh;
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui >>= shift;
|
||||
uint qh_bits = (qh >> iqs) & 0xF;
|
||||
return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
|
||||
}
|
||||
case FA_TYPE_Q8_0: {
|
||||
return pack32(i16vec2(k_packed_q8_0.data[a_offset + ib].qs[iqs / 2],
|
||||
k_packed_q8_0.data[a_offset + ib].qs[iqs / 2 + 1]));
|
||||
}
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
|
||||
#ifdef DATA_A_Q5_0
|
||||
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
|
||||
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
|
||||
uint qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qh[0],
|
||||
k_packed.k_data_packed16[a_offset + ib].qh[1]));
|
||||
#else
|
||||
uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4];
|
||||
uint qh = k_packed.k_data_packed16[a_offset + ib].qh;
|
||||
#endif
|
||||
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui >>= shift;
|
||||
|
||||
uint qh_bits = (qh >> iqs) & 0xF;
|
||||
return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
|
||||
// Per-block scale/min, packed as (d, m). Single-scale types (Q4_0, Q5_0, Q8_0)
|
||||
// return (d, 0) so call sites always see the same shape.
|
||||
FLOAT_TYPEV2 get_k_scale(uint ib, uint a_offset) {
|
||||
switch (FaTypeK) {
|
||||
case FA_TYPE_Q4_0: return FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q4_0.data[a_offset + ib].d), 0.0);
|
||||
case FA_TYPE_Q4_1: return FLOAT_TYPEV2(k_packed_q4_1_p32.data[a_offset + ib].dm);
|
||||
case FA_TYPE_Q5_0: return FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q5_0.data[a_offset + ib].d), 0.0);
|
||||
case FA_TYPE_Q5_1: return FLOAT_TYPEV2(k_packed_q5_1_p32.data[a_offset + ib].dm);
|
||||
case FA_TYPE_Q8_0: return FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q8_0.data[a_offset + ib].d), 0.0);
|
||||
default: return FLOAT_TYPEV2(0);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q8_0)
|
||||
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
|
||||
return pack32(i16vec2(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2], k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1]));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
|
||||
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
|
||||
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui >>= shift;
|
||||
|
||||
u8vec4 idx = unpack8(vui & 0x0F0F0F0F);
|
||||
return pack32(i8vec4(kvalues_iq4nl_const[idx.x],
|
||||
kvalues_iq4nl_const[idx.y],
|
||||
kvalues_iq4nl_const[idx.z],
|
||||
kvalues_iq4nl_const[idx.w]));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if QUANT_AUXF == 1
|
||||
FLOAT_TYPE get_k_d(uint ib, uint a_offset) {
|
||||
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d);
|
||||
}
|
||||
#else
|
||||
FLOAT_TYPEV2 get_k_dm(uint ib, uint a_offset) {
|
||||
return FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + ib].dm);
|
||||
}
|
||||
#endif
|
||||
|
||||
void k_block_to_shmem(const uint buf_ib, const uint global_ib, const uint iqs, const uint a_offset) {
|
||||
#if defined(DATA_A_Q4_0)
|
||||
kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
|
||||
#elif defined(DATA_A_Q4_1)
|
||||
kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs];
|
||||
#elif defined(DATA_A_Q5_0)
|
||||
kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
|
||||
if (iqs == 0) {
|
||||
kblocksh[buf_ib].qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qh[0],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qh[1]));
|
||||
// kblocksh[].qs is int32_t for the unified MMQ struct; uint sources need
|
||||
// explicit casts. The bit pattern is what we care about here -- the actual
|
||||
// signed/unsigned interpretation happens downstream in the dot product.
|
||||
switch (FaTypeK) {
|
||||
case FA_TYPE_Q4_0: {
|
||||
kblocksh[buf_ib].qs[iqs] = int32_t(pack32(u16vec2(k_packed_q4_0.data[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed_q4_0.data[a_offset + global_ib].qs[iqs * 2 + 1])));
|
||||
break;
|
||||
}
|
||||
case FA_TYPE_Q4_1: {
|
||||
kblocksh[buf_ib].qs[iqs] = int32_t(k_packed_q4_1_p32.data[a_offset + global_ib].qs[iqs]);
|
||||
break;
|
||||
}
|
||||
case FA_TYPE_Q5_0: {
|
||||
kblocksh[buf_ib].qs[iqs] = int32_t(pack32(u16vec2(k_packed_q5_0.data[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed_q5_0.data[a_offset + global_ib].qs[iqs * 2 + 1])));
|
||||
if (iqs == 0) {
|
||||
kblocksh[buf_ib].qh = pack32(u16vec2(k_packed_q5_0.data[a_offset + global_ib].qh[0],
|
||||
k_packed_q5_0.data[a_offset + global_ib].qh[1]));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case FA_TYPE_Q5_1: {
|
||||
kblocksh[buf_ib].qs[iqs] = int32_t(k_packed_q5_1_p32.data[a_offset + global_ib].qs[iqs]);
|
||||
if (iqs == 0) {
|
||||
kblocksh[buf_ib].qh = k_packed_q5_1.data[a_offset + global_ib].qh;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case FA_TYPE_Q8_0: {
|
||||
kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed_q8_0.data[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed_q8_0.data[a_offset + global_ib].qs[iqs * 2 + 1]));
|
||||
break;
|
||||
}
|
||||
}
|
||||
#elif defined(DATA_A_Q5_1)
|
||||
kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs];
|
||||
if (iqs == 0) {
|
||||
kblocksh[buf_ib].qh = k_packed.k_data_packed16[a_offset + global_ib].qh;
|
||||
}
|
||||
#elif defined(DATA_A_Q8_0)
|
||||
kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
|
||||
#elif defined(DATA_A_IQ4_NL)
|
||||
const uint qs = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
|
||||
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
|
||||
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
|
||||
kblocksh[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_iq4nl_const[i_a0.x], kvalues_iq4nl_const[i_a0.y],
|
||||
kvalues_iq4nl_const[i_a0.z], kvalues_iq4nl_const[i_a0.w]));
|
||||
kblocksh[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_iq4nl_const[i_a1.x], kvalues_iq4nl_const[i_a1.y],
|
||||
kvalues_iq4nl_const[i_a1.z], kvalues_iq4nl_const[i_a1.w]));
|
||||
#endif
|
||||
|
||||
if (iqs == 0) {
|
||||
#if QUANT_AUXF == 1
|
||||
kblocksh[buf_ib].dm = FLOAT_TYPE(k_packed.k_data_packed16[a_offset + global_ib].d);
|
||||
#else
|
||||
kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + global_ib].dm);
|
||||
#endif
|
||||
// Q4_0/Q5_0/Q8_0 store dm.x = d; Q4_1/Q5_1 store dm = (d, m) pair.
|
||||
switch (FaTypeK) {
|
||||
case FA_TYPE_Q4_0: kblocksh[buf_ib].dm = FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q4_0.data[a_offset + global_ib].d), 0.0); break;
|
||||
case FA_TYPE_Q4_1: kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed_q4_1_p32.data[a_offset + global_ib].dm); break;
|
||||
case FA_TYPE_Q5_0: kblocksh[buf_ib].dm = FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q5_0.data[a_offset + global_ib].d), 0.0); break;
|
||||
case FA_TYPE_Q5_1: kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed_q5_1_p32.data[a_offset + global_ib].dm); break;
|
||||
case FA_TYPE_Q8_0: kblocksh[buf_ib].dm = FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q8_0.data[a_offset + global_ib].d), 0.0); break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// d_per_step==8 hot path: read one full 32-element block worth of nibble-packed
|
||||
// int32 quants. Equivalent to 8 calls to get_k_qs(ib, d*4, a_offset) but reads
|
||||
// qh (Q5_*) and runs pack32 (Q4_0/Q5_0) once per block instead of per nibble
|
||||
// quad. iqs is always 0 in this path (hsk4 % 8 == 0 implies block-aligned).
|
||||
// Q8_0 takes the generic get_k_qs path because its qs layout (i8 pairs) doesn't
|
||||
// share this nibble shape.
|
||||
//
|
||||
// Returned via a struct so the caller's k_quants array (sized from spec
|
||||
// constants) doesn't need to match a fixed[8] out-parameter type.
|
||||
struct fa_k_qs_block8 {
|
||||
int32_t qs[8];
|
||||
};
|
||||
|
||||
fa_k_qs_block8 get_k_qs_block8(uint ib, uint a_offset) {
|
||||
fa_k_qs_block8 r;
|
||||
uint qh = 0;
|
||||
if (FaTypeK == FA_TYPE_Q5_0) {
|
||||
qh = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qh[0],
|
||||
k_packed_q5_0.data[a_offset + ib].qh[1]));
|
||||
} else if (FaTypeK == FA_TYPE_Q5_1) {
|
||||
qh = k_packed_q5_1.data[a_offset + ib].qh;
|
||||
}
|
||||
const bool has_qh = (FaTypeK == FA_TYPE_Q5_0) || (FaTypeK == FA_TYPE_Q5_1);
|
||||
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
|
||||
uint vui = 0;
|
||||
switch (FaTypeK) {
|
||||
case FA_TYPE_Q4_0: { // packed16
|
||||
vui = pack32(u16vec2(k_packed_q4_0.data[a_offset + ib].qs[d * 2 + 0],
|
||||
k_packed_q4_0.data[a_offset + ib].qs[d * 2 + 1]));
|
||||
break;
|
||||
}
|
||||
case FA_TYPE_Q4_1: { // packed32 alias
|
||||
vui = k_packed_q4_1_p32.data[a_offset + ib].qs[d];
|
||||
break;
|
||||
}
|
||||
case FA_TYPE_Q5_0: { // packed16
|
||||
vui = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qs[d * 2 + 0],
|
||||
k_packed_q5_0.data[a_offset + ib].qs[d * 2 + 1]));
|
||||
break;
|
||||
}
|
||||
case FA_TYPE_Q5_1: { // packed32 alias
|
||||
vui = k_packed_q5_1_p32.data[a_offset + ib].qs[d];
|
||||
break;
|
||||
}
|
||||
}
|
||||
r.qs[d ] = int32_t( vui & 0x0F0F0F0F);
|
||||
r.qs[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
|
||||
if (has_qh) {
|
||||
uint qh_lo = (qh >> (d * 4)) & 0xFu;
|
||||
uint qh_hi = (qh >> (d * 4 + 16)) & 0xFu;
|
||||
r.qs[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
|
||||
r.qs[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
|
||||
}
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
int32_t get_k_qs_shmem(const uint buf_ib, const uint pos) {
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
|
||||
uint sub = pos % 4;
|
||||
uint shift = ((pos % 8) >= 4) ? 4 : 0;
|
||||
return int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F);
|
||||
#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
uint sub = pos % 4;
|
||||
uint shift = ((pos % 8) >= 4) ? 4 : 0;
|
||||
int32_t result = int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F);
|
||||
uint qh_bits = (kblocksh[buf_ib].qh >> (pos * 4)) & 0xF;
|
||||
return result | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
|
||||
#elif defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
|
||||
return kblocksh[buf_ib].qs[pos];
|
||||
#endif
|
||||
switch (FaTypeK) {
|
||||
case FA_TYPE_Q4_0:
|
||||
case FA_TYPE_Q4_1: {
|
||||
uint sub = pos % 4;
|
||||
uint shift = ((pos % 8) >= 4) ? 4u : 0u;
|
||||
return int32_t((uint(kblocksh[buf_ib].qs[sub]) >> shift) & 0x0F0F0F0Fu);
|
||||
}
|
||||
case FA_TYPE_Q5_0:
|
||||
case FA_TYPE_Q5_1: {
|
||||
uint sub = pos % 4;
|
||||
uint shift = ((pos % 8) >= 4) ? 4u : 0u;
|
||||
int32_t result = int32_t((uint(kblocksh[buf_ib].qs[sub]) >> shift) & 0x0F0F0F0Fu);
|
||||
uint qh_bits = (kblocksh[buf_ib].qh >> (pos * 4u)) & 0xFu;
|
||||
return result | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
|
||||
}
|
||||
case FA_TYPE_Q8_0: {
|
||||
return kblocksh[buf_ib].qs[pos];
|
||||
}
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE k_dot_correction(const uint qib, const ACC_TYPEV2 k_dm) {
|
||||
#if defined(DATA_A_Q4_0)
|
||||
return -ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
|
||||
#elif defined(DATA_A_Q5_0)
|
||||
return -ACC_TYPE(16.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
|
||||
#elif defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
|
||||
return ACC_TYPE(Qf[qib].ds.y) * k_dm.y;
|
||||
#else
|
||||
return ACC_TYPE(0.0);
|
||||
#endif
|
||||
switch (FaTypeK) {
|
||||
case FA_TYPE_Q4_0: return -ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
|
||||
case FA_TYPE_Q5_0: return -ACC_TYPE(16.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
|
||||
case FA_TYPE_Q4_1:
|
||||
case FA_TYPE_Q5_1: return ACC_TYPE(Qf[qib].ds.y) * k_dm.y;
|
||||
default: return ACC_TYPE(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
void k_block_to_shmem_zero(const uint buf_ib, const uint iqs) {
|
||||
kblocksh[buf_ib].qs[iqs] = 0;
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
kblocksh[buf_ib].qs[iqs + 4] = 0;
|
||||
#endif
|
||||
if (iqs == 0) {
|
||||
#if QUANT_AUXF == 1
|
||||
kblocksh[buf_ib].dm = FLOAT_TYPE(0.0f);
|
||||
#else
|
||||
kblocksh[buf_ib].dm = FLOAT_TYPEV2(0.0f);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,13 @@
|
||||
#if defined(DATA_A_Q4_0)
|
||||
#if defined(FA_MMQ_MIXED)
|
||||
// Mixed-K flash attention MMQ: superset cache that fits Q4_0/Q4_1/Q5_0/Q5_1/Q8_0.
|
||||
// Q4_*/Q5_* only use qs[0..3] and (for Q5_*) qh. Q8_0 uses qs[0..7]. Single-scale
|
||||
// types (Q4_0/Q5_0/Q8_0) leave dm.y unused.
|
||||
struct block_a_cache {
|
||||
int32_t qs[8];
|
||||
uint32_t qh;
|
||||
FLOAT_TYPEV2 dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q4_0)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
uint32_t qs[16/4];
|
||||
|
||||
@@ -643,42 +643,22 @@ void process_shaders() {
|
||||
|
||||
if (fp16) {
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
string_to_spv("flash_attn_f32_f16_mixed", "flash_attn_cm2.comp",
|
||||
string_to_spv("flash_attn_f32_f16", "flash_attn_cm2.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc);
|
||||
#endif
|
||||
}
|
||||
|
||||
for (const auto& tname : type_names) {
|
||||
if (tname == "bf16") continue;
|
||||
|
||||
if (fp16) {
|
||||
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
|
||||
} else if (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "iq4_nl" || tname == "q8_0" || tname == "f32") {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
|
||||
}
|
||||
string_to_spv("flash_attn_f32_f16", "flash_attn_cm1.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
|
||||
#endif
|
||||
}
|
||||
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc);
|
||||
} else if (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "iq4_nl" || tname == "q8_0" || tname == "f32") {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc);
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (tname != "f32") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }, {"MMQ", "1"}}), fp16, false, false, f16acc, "_int8");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
string_to_spv("flash_attn_f32_f16", "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc);
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
string_to_spv("flash_attn_f32_f16", "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"MMQ", "1"}, {"FA_MMQ_MIXED", "1"}}), fp16, false, false, f16acc, "_int8");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -858,6 +858,8 @@ extern "C" {
|
||||
size_t n_token_capacity,
|
||||
size_t * n_token_count_out);
|
||||
|
||||
#define LLAMA_STATE_SEQ_FLAGS_NONE 0
|
||||
|
||||
// for backwards-compat
|
||||
#define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1
|
||||
|
||||
|
||||
@@ -2475,11 +2475,29 @@ public:
|
||||
}
|
||||
|
||||
if (need_alloc) {
|
||||
mbuf_cur = std::move(mbuf);
|
||||
if (!mbuf_cur.buf || mbuf_cur.total_size != mbuf.total_size) {
|
||||
mbuf_cur = std::move(mbuf);
|
||||
|
||||
mbuf_cur.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(mbuf_cur.ctx.get(), buft));
|
||||
mbuf_cur.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(mbuf_cur.ctx.get(), buft));
|
||||
|
||||
LLAMA_LOG_INFO("%s: allocated '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0);
|
||||
LLAMA_LOG_INFO("%s: allocated '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0);
|
||||
} else {
|
||||
//LLAMA_LOG_INFO("%s: reallocating tensors in '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0);
|
||||
|
||||
// save the old buffer and allocate the new tensors in it
|
||||
auto buf = std::move(mbuf_cur.buf);
|
||||
|
||||
mbuf_cur = std::move(mbuf);
|
||||
|
||||
ggml_tallocr talloc = ggml_tallocr_new(buf.get());
|
||||
|
||||
for (size_t i = 0; i < mbuf_cur.org.size(); ++i) {
|
||||
ggml_backend_view_init(mbuf_cur.org[i]);
|
||||
ggml_tallocr_alloc(&talloc, mbuf_cur.cpy[i]);
|
||||
}
|
||||
|
||||
mbuf_cur.buf = std::move(buf);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < mbuf_cur.org.size(); ++i) {
|
||||
@@ -2559,8 +2577,7 @@ public:
|
||||
|
||||
mbuf.org.push_back(ggml_view_1d(mbuf.ctx.get(), rinfo.tensor, n, rinfo.offset));
|
||||
|
||||
auto & view = mbuf.org.back();
|
||||
view->buffer = rinfo.tensor->buffer;
|
||||
ggml_backend_view_init(mbuf.org.back());
|
||||
}
|
||||
|
||||
for (auto & [buft, mbuf] : mbufs_new) {
|
||||
|
||||
@@ -195,11 +195,9 @@
|
||||
| `--spec-draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 0)<br/>(env: LLAMA_ARG_SPEC_DRAFT_N_MIN) |
|
||||
| `--spec-draft-p-split, --draft-p-split P` | speculative decoding split probability (default: 0.10)<br/>(env: LLAMA_ARG_SPEC_DRAFT_P_SPLIT) |
|
||||
| `--spec-draft-p-min, --draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.75)<br/>(env: LLAMA_ARG_SPEC_DRAFT_P_MIN) |
|
||||
| `--spec-draft-ctx-size, -cd, --ctx-size-draft N` | size of the prompt context for the draft model (default: 0, 0 = loaded from model)<br/>(env: LLAMA_ARG_SPEC_DRAFT_CTX_SIZE) |
|
||||
| `--spec-draft-device, -devd, --device-draft <dev1,dev2,..>` | comma-separated list of devices to use for offloading the draft model (none = don't offload)<br/>use --list-devices to see a list of available devices |
|
||||
| `--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)<br/>(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) |
|
||||
| `--spec-draft-model, -md, --model-draft FNAME` | draft model for speculative decoding (default: unused)<br/>(env: LLAMA_ARG_SPEC_DRAFT_MODEL) |
|
||||
| `--spec-draft-replace, --spec-replace TARGET DRAFT` | translate the string in TARGET into DRAFT if the draft model and main model are not compatible |
|
||||
| `--spec-type [none\|ngram-cache\|ngram-simple\|ngram-map-k\|ngram-map-k4v\|ngram-mod]` | type of speculative decoding to use when no draft model is provided (default: none)<br/><br/>(env: LLAMA_ARG_SPEC_TYPE) |
|
||||
| `--spec-ngram-mod-n-min N` | minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48) |
|
||||
| `--spec-ngram-mod-n-max N` | maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64) |
|
||||
|
||||
@@ -244,11 +244,9 @@ For the full list of features, please refer to [server's changelog](https://gith
|
||||
| `--spec-draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 0)<br/>(env: LLAMA_ARG_SPEC_DRAFT_N_MIN) |
|
||||
| `--spec-draft-p-split, --draft-p-split P` | speculative decoding split probability (default: 0.10)<br/>(env: LLAMA_ARG_SPEC_DRAFT_P_SPLIT) |
|
||||
| `--spec-draft-p-min, --draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.75)<br/>(env: LLAMA_ARG_SPEC_DRAFT_P_MIN) |
|
||||
| `--spec-draft-ctx-size, -cd, --ctx-size-draft N` | size of the prompt context for the draft model (default: 0, 0 = loaded from model)<br/>(env: LLAMA_ARG_SPEC_DRAFT_CTX_SIZE) |
|
||||
| `--spec-draft-device, -devd, --device-draft <dev1,dev2,..>` | comma-separated list of devices to use for offloading the draft model (none = don't offload)<br/>use --list-devices to see a list of available devices |
|
||||
| `--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)<br/>(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) |
|
||||
| `--spec-draft-model, -md, --model-draft FNAME` | draft model for speculative decoding (default: unused)<br/>(env: LLAMA_ARG_SPEC_DRAFT_MODEL) |
|
||||
| `--spec-draft-replace, --spec-replace TARGET DRAFT` | translate the string in TARGET into DRAFT if the draft model and main model are not compatible |
|
||||
| `--spec-type [none\|ngram-cache\|ngram-simple\|ngram-map-k\|ngram-map-k4v\|ngram-mod]` | type of speculative decoding to use when no draft model is provided (default: none)<br/><br/>(env: LLAMA_ARG_SPEC_TYPE) |
|
||||
| `--spec-ngram-mod-n-min N` | minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48) |
|
||||
| `--spec-ngram-mod-n-max N` | maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64) |
|
||||
@@ -1045,16 +1043,23 @@ If query param `?fail_on_no_slot=1` is set, this endpoint will respond with stat
|
||||
|
||||
This endpoint is only accessible if `--metrics` is set.
|
||||
|
||||
Available metrics:
|
||||
- `llamacpp:prompt_tokens_total`: Number of prompt tokens processed.
|
||||
- `llamacpp:tokens_predicted_total`: Number of generation tokens processed.
|
||||
- `llamacpp:prompt_tokens_seconds`: Average prompt throughput in tokens/s.
|
||||
- `llamacpp:predicted_tokens_seconds`: Average generation throughput in tokens/s.
|
||||
- `llamacpp:kv_cache_usage_ratio`: KV-cache usage. `1` means 100 percent usage.
|
||||
- `llamacpp:kv_cache_tokens`: KV-cache tokens.
|
||||
- `llamacpp:requests_processing`: Number of requests processing.
|
||||
- `llamacpp:requests_deferred`: Number of requests deferred.
|
||||
- `llamacpp:n_tokens_max`: High watermark of the context size observed.
|
||||
In *router mode* the query param `?model={model_id}` has to be set. This endpoint will respond with status code 400 `model name is missing from the request` if not set.
|
||||
|
||||
#### Available metrics
|
||||
|
||||
| Metric | Type | Description |
|
||||
| ------ | ---------------------- | ----------- |
|
||||
| `llamacpp:prompt_tokens_total` | Counter | Number of prompt tokens processed. |
|
||||
| `llamacpp:prompt_seconds_total` | Counter | Prompt process time in seconds. |
|
||||
| `llamacpp:prompt_tokens_seconds` | Gauge | Average prompt throughput in tokens/s. |
|
||||
| `llamacpp:tokens_predicted_total` | Counter | Number of generation tokens processed. |
|
||||
| `llamacpp:tokens_predicted_seconds_total` | Counter | Predict process time in seconds. |
|
||||
| `llamacpp:predicted_tokens_seconds` | Gauge | Average generation throughput in tokens/s. |
|
||||
| `llamacpp:requests_processing` | Gauge | Number of requests processing. |
|
||||
| `llamacpp:requests_deferred` | Gauge | Number of requests deferred. |
|
||||
| `llamacpp:n_tokens_max` | Counter | High watermark of the context size observed. |
|
||||
| `llamacpp:n_decode_total` | Counter | Total Number of llama_decode() calls. |
|
||||
| `llamacpp:n_busy_slots_per_decode` | Gauge | Average number of busy slots per llama_decode() call. |
|
||||
|
||||
### POST `/slots/{id_slot}?action=save`: Save the prompt cache of the specified slot to a file.
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -76,7 +76,7 @@ json task_params::to_json(bool only_metrics) const {
|
||||
{"reasoning_in_content", chat_parser_params.reasoning_in_content},
|
||||
{"generation_prompt", chat_parser_params.generation_prompt},
|
||||
{"samplers", samplers},
|
||||
{"speculative.type", common_speculative_type_to_str(speculative.type)},
|
||||
{"speculative.types", common_speculative_type_name_str(speculative.types)},
|
||||
{"timings_per_token", timings_per_token},
|
||||
{"post_sampling_probs", post_sampling_probs},
|
||||
{"backend_sampling", sampling.backend_sampling},
|
||||
@@ -133,7 +133,7 @@ json task_params::to_json(bool only_metrics) const {
|
||||
{"reasoning_in_content", chat_parser_params.reasoning_in_content},
|
||||
{"generation_prompt", chat_parser_params.generation_prompt},
|
||||
{"samplers", samplers},
|
||||
{"speculative.type", common_speculative_type_to_str(speculative.type)},
|
||||
{"speculative.types", common_speculative_type_name_str(speculative.types)},
|
||||
{"timings_per_token", timings_per_token},
|
||||
{"post_sampling_probs", post_sampling_probs},
|
||||
{"backend_sampling", sampling.backend_sampling},
|
||||
@@ -296,6 +296,8 @@ task_params server_task::params_from_json_cmpl(
|
||||
|
||||
params.speculative = defaults.speculative;
|
||||
|
||||
// TODO: to keep things simple, we disable speculative parameter adjustments for now
|
||||
#if 0
|
||||
// TODO: for now, be able to adjust only the draft-model based speculative parameters
|
||||
params.speculative.draft.n_min = json_value(data, "speculative.n_min", defaults.speculative.draft.n_min);
|
||||
params.speculative.draft.n_max = json_value(data, "speculative.n_max", defaults.speculative.draft.n_max);
|
||||
@@ -305,7 +307,6 @@ task_params server_task::params_from_json_cmpl(
|
||||
params.speculative.draft.n_min = std::max(params.speculative.draft.n_min, 0);
|
||||
params.speculative.draft.n_max = std::max(params.speculative.draft.n_max, 0);
|
||||
|
||||
#if 0
|
||||
// for debugging and research purposes
|
||||
params.speculative.type = common_speculative_type_from_name(json_value(data, "speculative.type", common_speculative_type_to_str(defaults.speculative.type)));
|
||||
|
||||
@@ -1981,7 +1982,7 @@ size_t server_prompt_cache::n_tokens() const {
|
||||
return res;
|
||||
}
|
||||
|
||||
server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size) {
|
||||
server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size_tgt, size_t state_size_dft) {
|
||||
// first check if the current state is contained fully in the cache
|
||||
for (auto it = states.begin(); it != states.end(); ++it) {
|
||||
const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens);
|
||||
@@ -2005,11 +2006,13 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<uint8_t> state_data;
|
||||
std::vector<uint8_t> state_data_tgt;
|
||||
std::vector<uint8_t> state_data_dft;
|
||||
|
||||
// check if we can allocate enough memory for the new state
|
||||
try {
|
||||
state_data.resize(state_size);
|
||||
state_data_tgt.resize(state_size_tgt);
|
||||
state_data_dft.resize(state_size_dft);
|
||||
} catch (const std::bad_alloc & e) {
|
||||
SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what());
|
||||
|
||||
@@ -2022,17 +2025,19 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto & cur = states.emplace_back();
|
||||
cur = {
|
||||
states.push_back({
|
||||
/*.tokens =*/ prompt.tokens.clone(),
|
||||
/*.data =*/ std::move(state_data),
|
||||
/*.data =*/ {
|
||||
/*.main =*/ std::move(state_data_tgt),
|
||||
/*.drft =*/ std::move(state_data_dft),
|
||||
},
|
||||
/*.checkpoints =*/ prompt.checkpoints,
|
||||
};
|
||||
});
|
||||
|
||||
return &cur;
|
||||
return &states.back();
|
||||
}
|
||||
|
||||
bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) {
|
||||
bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx_tgt, llama_context * ctx_dft, int32_t id_slot) {
|
||||
const int lcp_best = prompt.tokens.get_common_prefix(tokens_new);
|
||||
|
||||
float f_keep_best = prompt.tokens.size() > 0 ? float(lcp_best) / prompt.tokens.size() : -1.0f; // empty slot: any cache entry wins
|
||||
@@ -2065,16 +2070,39 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok
|
||||
if (it_best != states.end()) {
|
||||
SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
|
||||
|
||||
const size_t size = it_best->data.size();
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0);
|
||||
if (n != size) {
|
||||
SRV_WRN("failed to restore state with size %zu\n", size);
|
||||
{
|
||||
auto & data = it_best->data.main;
|
||||
|
||||
return false;
|
||||
const size_t size = data.size();
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx_tgt, data.data(), size, id_slot, 0);
|
||||
if (n != size) {
|
||||
SRV_WRN("failed to restore state with size %zu\n", size);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
data.clear();
|
||||
data.shrink_to_fit();
|
||||
}
|
||||
|
||||
it_best->data.clear();
|
||||
it_best->data.shrink_to_fit();
|
||||
{
|
||||
auto & data = it_best->data.drft;
|
||||
|
||||
if (!data.empty()) {
|
||||
GGML_ASSERT(ctx_dft);
|
||||
|
||||
const size_t size = data.size();
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx_dft, data.data(), size, id_slot, 0);
|
||||
if (n != size) {
|
||||
SRV_WRN("failed to restore state with size %zu\n", size);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
data.clear();
|
||||
data.shrink_to_fit();
|
||||
}
|
||||
}
|
||||
|
||||
prompt = std::move(*it_best);
|
||||
|
||||
|
||||
@@ -565,42 +565,29 @@ struct server_task_result_apply_lora : server_task_result {
|
||||
virtual json to_json() override;
|
||||
};
|
||||
|
||||
struct server_prompt_checkpoint {
|
||||
llama_pos pos_min;
|
||||
llama_pos pos_max;
|
||||
|
||||
int64_t n_tokens;
|
||||
|
||||
std::vector<uint8_t> data;
|
||||
struct server_prompt_data {
|
||||
std::vector<uint8_t> main;
|
||||
std::vector<uint8_t> drft;
|
||||
|
||||
size_t size() const {
|
||||
return data.size();
|
||||
}
|
||||
|
||||
bool empty() const {
|
||||
return data.empty();
|
||||
}
|
||||
|
||||
void clear() {
|
||||
pos_min = 0;
|
||||
pos_max = 0;
|
||||
n_tokens = 0;
|
||||
data.clear();
|
||||
return main.size() + drft.size();
|
||||
}
|
||||
};
|
||||
|
||||
struct server_prompt {
|
||||
server_tokens tokens;
|
||||
|
||||
std::vector<uint8_t> data;
|
||||
server_prompt_data data;
|
||||
|
||||
std::list<server_prompt_checkpoint> checkpoints;
|
||||
std::list<common_prompt_checkpoint> checkpoints;
|
||||
|
||||
size_t size() const {
|
||||
size_t res = data.size();
|
||||
size_t res = 0;
|
||||
|
||||
for (const auto & checkpoint : checkpoints) {
|
||||
res += checkpoint.size();
|
||||
res += data.size();
|
||||
|
||||
for (const auto & ckpt : checkpoints) {
|
||||
res += ckpt.size();
|
||||
}
|
||||
|
||||
return res;
|
||||
@@ -614,7 +601,7 @@ struct server_prompt {
|
||||
return server_prompt {
|
||||
tokens.clone(),
|
||||
data,
|
||||
checkpoints
|
||||
checkpoints,
|
||||
};
|
||||
}
|
||||
};
|
||||
@@ -637,9 +624,9 @@ struct server_prompt_cache {
|
||||
|
||||
size_t n_tokens() const;
|
||||
|
||||
server_prompt * alloc(const server_prompt & prompt, size_t state_size);
|
||||
server_prompt * alloc(const server_prompt & prompt, size_t state_size_main, size_t state_size_drft);
|
||||
|
||||
bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot);
|
||||
bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx_main, llama_context * ctx_drft, int32_t id_slot);
|
||||
|
||||
void update();
|
||||
};
|
||||
|
||||
@@ -5,7 +5,7 @@ from utils import *
|
||||
|
||||
server = ServerPreset.stories15m_moe()
|
||||
|
||||
MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf"
|
||||
MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/tiny-llamas/resolve/main/stories15M-q4_0.gguf"
|
||||
|
||||
def create_server():
|
||||
global server
|
||||
|
||||
Reference in New Issue
Block a user