Compare commits

...

5 Commits
b9105 ... b9110

Author SHA1 Message Date
willjoha
ef22b3e4ac docs: fix metrics endpoint description in server README (#22879)
* docs: fix metrics endpoint description in server README

Required model query parameter for router mode described.

Removed metrics:
- llamacpp:kv_cache_usage_ratio
- llamacpp:kv_cache_tokens

Added metrics:
- llamacpp:prompt_seconds_total
- llamacpp:tokens_predicted_seconds_total
- llamacpp:n_decode_total
- llamacpp:n_busy_slots_per_decode

* server: fix metrics type for n_busy_slots_per_decode metric
2026-05-11 18:32:26 +02:00
Georgi Gerganov
68e7ea3eab spec : parallel drafting support (#22838)
* spec : refactor

* spec : drop support for incompatible vocabs

* spec : update common_speculative_init()

* cont : pass seq_id

* cont : dedup ctx_seq_rm_type

* server : sketch the ctx_dft decode loop

* server : draft prompt cache and checkpoints

* server : improve ctx names

* server, spec : transition to unified spec context

* cont : sync main and drft contexts

* cont : async drft eval when possible

* cont : handle non-ckpt models

* cont : pass correct n_past for drafting

* cont : process images throught the draft context

* spec : handle draft running out of context

* server : fix mtmd draft processing

* server : fix URL for draft model

* server : add comment

* server : clean-up + dry

* speculative-simple : update

* spec : fix n_past type

* server : fix slot ctx_drft ptr

* tools : update readme

* naming : improve consistency

* spec : refactor for multi-sequence speculative context

* cont : prepare params

* cont : prepare params

* spec : support parallel drafts

* server : support parallel drafting

* llama : reuse device buffers when possible

* server, spec : clean-up

* cont : clean-up

* cont : minor

* spec : reset `drafting` flag at the end

* spec : introduce `common_speculative_process()`

* spec : allow for multiple spec types (chain of speculators)

* replace old type field of type common_speculative_type in the
  common_params_speculative struct with a vector to allow multiple
  types to be specified

* introduce common_get_enabled_speculative_impls(const std::vector<enum common_speculative_type>)
  to figure out which implementations the user has enabled

* introduce common_speculative_type_from_names(const std::vector<std::string> & names)
  to parse the already user provided spec types

* all speculators run sequentially, best one wins (we verify its drafted tokens)

* maximize expected accepted tokens for current round by calculating the
  product between the probability of accepting current token (n_acc_tokens / n_gen_drafts)
  and the draft's length

---------

Co-authored-by: Petros Sideris <petros.sideris@nokia.com>
2026-05-11 19:09:43 +03:00
Kevin Pouget
928b486b0c ggml-virtgpu: Add a GHA build check (#22943)
* [ggml-virtgpu] Add a GHA build check

* Apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-05-11 21:38:22 +08:00
Daniel Bevenius
7dbb0e998a examples : update args speculative-simple README.md [no ci] (#22938)
This commit updates the command line arguments to use the correct names
and values which are now required.

The motivation for this change is that currently running the example
command as is will generate the following errors:
```console
error while handling argument "--color": error: unknown value for --color: '--sampling-seq'

usage:
-co,   --color [on|off|auto]            Colorize output to distinguish prompt and user input from generations
                                        ('on', 'off', or 'auto', default: 'auto')
                                        'auto' enables colors when output is to a terminal

error while handling argument "-fa": error: unknown value for --flash-attn: '--temp'

usage:
-fa,   --flash-attn [on|off|auto]       set Flash Attention use ('on', 'off', or 'auto', default: 'auto')
                                        (env: LLAMA_ARG_FLASH_ATTN)

error while handling argument "--draft-max": the argument has been removed. use --spec-draft-n-max or --spec-ngram-mod-n-max

usage:
--draft, --draft-n, --draft-max N       the argument has been removed. use --spec-draft-n-max or
                                        --spec-ngram-mod-n-max
                                        (env: LLAMA_ARG_DRAFT_MAX)

error while handling argument "--draft-min": the argument has been removed. use --spec-draft-n-min or --spec-ngram-mod-n-min

usage:
--draft-min, --draft-n-min N            the argument has been removed. use --spec-draft-n-min or
                                        --spec-ngram-mod-n-min
                                        (env: LLAMA_ARG_DRAFT_MIN)
```
2026-05-11 14:00:57 +03:00
Jeff Bolz
dd9280a664 vulkan: Support asymmetric FA in scalar/mmq/coopmat1 paths (#22589) 2026-05-11 12:49:03 +02:00
25 changed files with 1992 additions and 1768 deletions

50
.github/workflows/build-virtgpu.yml vendored Normal file
View File

@@ -0,0 +1,50 @@
name: CI (virtgpu)
on:
workflow_dispatch: # allows manual triggering
push:
branches:
- master
paths: [
'.github/workflows/build-virtgpu.yml',
'**/CMakeLists.txt',
'**/.cmake',
'**/*.h',
'**/*.hpp',
'**/*.c',
'**/*.cpp'
]
pull_request:
types: [opened, synchronize, reopened]
paths: [
'.github/workflows/build-virtgpu.yml',
'ggml/src/ggml-virtgpu/**'
]
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
cancel-in-progress: true
jobs:
ubuntu-24-virtgpu:
runs-on: ${{ 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
- name: Dependencies
id: depends
run: |
sudo apt-get update
sudo apt-get install -y build-essential libdrm-dev pkg-config libssl-dev
- name: Build
id: cmake_build
run: |
cmake -B build \
-DGGML_VIRTGPU=ON \
-DGGML_VIRTGPU_BACKEND=ON
cmake --build build --config Release -j $(nproc)

View File

@@ -622,10 +622,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
for (auto & seq_breaker : params.sampling.dry_sequence_breakers) {
string_process_escapes(seq_breaker);
}
for (auto & pair : params.speculative.draft.replacements) {
string_process_escapes(pair.first);
string_process_escapes(pair.second);
}
}
if (!params.kv_overrides.empty()) {
@@ -3518,13 +3514,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.draft.p_min = std::stof(value);
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_P_MIN"));
add_opt(common_arg(
{"--spec-draft-ctx-size", "-cd", "--ctx-size-draft"}, "N",
string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.draft.n_ctx),
[](common_params & params, int value) {
params.speculative.draft.n_ctx = value;
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_CTX_SIZE"));
add_opt(common_arg(
{"--spec-draft-device", "-devd", "--device-draft"}, "<dev1,dev2,..>",
"comma-separated list of devices to use for offloading the draft model (none = don't offload)\n"
@@ -3561,32 +3550,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_MODEL"));
add_opt(common_arg(
{"--spec-draft-replace", "--spec-replace"}, "TARGET", "DRAFT",
"translate the string in TARGET into DRAFT if the draft model and main model are not compatible",
[](common_params & params, const std::string & tgt, const std::string & dft) {
params.speculative.draft.replacements.push_back({ tgt, dft });
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
{"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
{"--spec-type"}, common_speculative_all_types_str(),
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
common_speculative_type_to_str(params.speculative.type).c_str()),
common_speculative_type_name_str(params.speculative.types).c_str()),
[](common_params & params, const std::string & value) {
if (value == "none") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
} else if (value == "ngram-cache") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE;
} else if (value == "ngram-simple") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE;
} else if (value == "ngram-map-k") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K;
} else if (value == "ngram-map-k4v") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V;
} else if (value == "ngram-mod") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD;
} else {
throw std::invalid_argument("unknown speculative decoding type without draft model");
}
const auto enabled_types = string_split<std::string>(value, ',');
params.speculative.types = common_speculative_types_from_names(enabled_types);
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_TYPE"));
add_opt(common_arg(
@@ -4075,7 +4044,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--spec-default"},
string_format("enable default speculative decoding config"),
[](common_params & params) {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD;
params.speculative.types = { COMMON_SPECULATIVE_TYPE_NGRAM_MOD };
params.speculative.ngram_mod.n_match = 24;
params.speculative.ngram_mod.n_min = 48;
params.speculative.ngram_mod.n_max = 64;

View File

@@ -1422,7 +1422,7 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {
// try to remove the last tokens
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
LOG_WRN("%s: the context does not support partial sequence removal\n", __func__);
res = COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
goto done;
}
@@ -1960,3 +1960,102 @@ bool common_prompt_batch_decode(
return true;
}
size_t common_prompt_checkpoint::size() const {
return data_tgt.size() + data_dft.size();
}
bool common_prompt_checkpoint::empty() const {
return data_tgt.empty();
}
void common_prompt_checkpoint::clear() {
n_tokens = 0;
pos_min = 0;
pos_max = 0;
data_tgt.clear();
data_dft.clear();
}
void common_prompt_checkpoint::update_pos(
int64_t n_tokens,
llama_pos pos_min,
llama_pos pos_max) {
this->n_tokens = n_tokens;
this->pos_min = pos_min;
this->pos_max = pos_max;
}
void common_prompt_checkpoint::update_tgt(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) {
if (ctx == nullptr) {
return;
}
const size_t ckpt_size = llama_state_seq_get_size_ext(ctx, seq_id, flags);
data_tgt.resize(ckpt_size);
const size_t n = llama_state_seq_get_data_ext(ctx, data_tgt.data(), ckpt_size, seq_id, flags);
if (n != ckpt_size) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n);
}
}
void common_prompt_checkpoint::update_dft(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) {
if (ctx == nullptr) {
return;
}
const size_t ckpt_size = llama_state_seq_get_size_ext(ctx, seq_id, flags);
data_dft.resize(ckpt_size);
const size_t n = llama_state_seq_get_data_ext(ctx, data_dft.data(), ckpt_size, seq_id, flags);
if (n != ckpt_size) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n);
}
}
void common_prompt_checkpoint::load_tgt(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) const {
if (ctx == nullptr) {
return;
}
if (data_tgt.empty()) {
return;
}
const size_t n = llama_state_seq_set_data_ext(ctx, data_tgt.data(), data_tgt.size(), seq_id, flags);
if (n != data_tgt.size()) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_tgt.size(), n);
}
}
void common_prompt_checkpoint::load_dft(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) const {
if (ctx == nullptr) {
return;
}
if (data_dft.empty()) {
return;
}
const size_t n = llama_state_seq_set_data_ext(ctx, data_dft.data(), data_dft.size(), seq_id, flags);
if (n != data_dft.size()) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_dft.size(), n);
}
}

View File

@@ -295,8 +295,6 @@ struct common_params_model {
std::string name = ""; // in format <user>/<model>[:<tag>] (tag is optional) // NOLINT
};
struct common_ngram_mod;
// draft-model-based speculative decoding parameters
struct common_params_speculative_draft {
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
@@ -307,11 +305,9 @@ struct common_params_speculative_draft {
common_params_model mparams;
llama_model * model = nullptr; // a llama_model that can be shared by multiple speculative contexts
llama_context * ctx_tgt = nullptr;
llama_context * ctx_dft = nullptr;
llama_context_params cparams; // these are the parameters for the draft llama_context
int32_t n_ctx = 0; // draft context size
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
@@ -322,7 +318,6 @@ struct common_params_speculative_draft {
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
};
@@ -331,9 +326,6 @@ struct common_params_speculative_ngram_mod {
int32_t n_max = 64;
int32_t n_min = 48;
// shared instance of the ngram container for all speculative decoding contexts
std::shared_ptr<common_ngram_mod> obj;
};
struct common_params_speculative_ngram_map {
@@ -348,8 +340,7 @@ struct common_params_speculative_ngram_cache {
};
struct common_params_speculative {
// TODO: become a vector in order to support "chains of speculators"
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE;
std::vector<enum common_speculative_type> types = { COMMON_SPECULATIVE_TYPE_NONE };
common_params_speculative_draft draft;
@@ -1026,3 +1017,47 @@ ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std
// "adamw" or "sgd" (case insensitive)
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *);
//
// prompt utils
//
struct common_prompt_checkpoint {
int64_t n_tokens;
llama_pos pos_min;
llama_pos pos_max;
std::vector<uint8_t> data_tgt;
std::vector<uint8_t> data_dft;
size_t size() const;
bool empty() const;
void clear();
void update_pos(
int64_t n_tokens,
llama_pos pos_min,
llama_pos pos_max);
void update_tgt(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags);
void update_dft(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags);
void load_tgt(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) const;
void load_dft(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) const;
};

File diff suppressed because it is too large Load Diff

View File

@@ -5,8 +5,14 @@
struct common_speculative;
// comma separated list the provided types
std::string common_speculative_type_name_str(const std::vector<enum common_speculative_type> & types);
// comma separated list of all types
std::string common_speculative_type_name_str();
const char * common_speculative_all_types_str();
// parse user provided types
std::vector<enum common_speculative_type> common_speculative_types_from_names(const std::vector<std::string> & names);
// convert string to type
enum common_speculative_type common_speculative_type_from_name(const std::string & name);
@@ -14,27 +20,44 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
// convert type to string
std::string common_speculative_type_to_str(enum common_speculative_type type);
common_speculative * common_speculative_init(
common_params_speculative & params,
llama_context * ctx_tgt);
common_speculative * common_speculative_init(common_params_speculative & params, uint32_t n_seq);
void common_speculative_free(common_speculative * spec);
struct common_speculative_draft_params {
// this flag is used to chain the drafts through all the available implementations
// after the first successful draft from an implementation, we set it
// to false to prevent further drafts for that sequence
// at the end of the draft() call, all drafting flags will be reset to false
bool drafting = false;
// overrides individual configurations (-1 disabled)
// can be used to constraint the max draft based on the remaining context size
int32_t n_max = -1;
llama_pos n_past;
llama_token id_last;
// TODO: remove in the future by keeping track of the prompt from the _begin() call and the consecutive accept calls
const llama_tokens * prompt;
// the generated draft from the last _draft() call
llama_tokens * result;
};
common_speculative_draft_params & common_speculative_get_draft_params(common_speculative * spec, llama_seq_id seq_id);
// optionally call once at the beginning of a new generation
void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt);
void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt);
// sample up to n_draft tokens and add them to the batch using the draft model
llama_tokens common_speculative_draft(
common_speculative * spec,
const common_params_speculative & params,
const llama_tokens & prompt,
llama_token id_last);
// process the batch and update the internal state of the speculative context
bool common_speculative_process(common_speculative * spec, const llama_batch & batch);
// informs the speculative decoder that n_accepted tokens were accepted by the target model
void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
// generate drafts for the sequences specified with `common_speculative_get_draft_params`
void common_speculative_draft(common_speculative * spec);
int32_t common_speculative_n_max(const common_speculative * spec, const common_params_speculative & params);
int32_t common_speculative_n_min(const common_speculative * spec, const common_params_speculative & params);
// informs the speculative context that n_accepted tokens were accepted by the target model
void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted);
// print statistics about the speculative decoding
void common_speculative_print_stats(const common_speculative * spec);

View File

@@ -6,7 +6,7 @@ Demonstration of basic greedy speculative decoding
./bin/llama-speculative-simple \
-m ../models/qwen2.5-32b-coder-instruct/ggml-model-q8_0.gguf \
-md ../models/qwen2.5-1.5b-coder-instruct/ggml-model-q4_0.gguf \
-f test.txt -c 0 -ngl 99 --color \
--sampling-seq k --top-k 1 -fa --temp 0.0 \
-ngld 99 --draft-max 16 --draft-min 5 --draft-p-min 0.9
-f test.txt -c 0 -ngl 99 --color on \
--sampling-seq k --top-k 1 -fa on --temp 0.0 \
-ngld 99 --spec-draft-n-max 16 --spec-draft-n-draft-min 5 --draft-p-min 0.9
```

View File

@@ -13,20 +13,6 @@
#include <vector>
#include <utility>
struct spec_checkpoint {
int64_t n_tokens = 0;
std::vector<uint8_t> data;
size_t size() const {
return data.size();
}
bool empty() const {
return data.empty();
}
};
int main(int argc, char ** argv) {
std::setlocale(LC_NUMERIC, "C");
@@ -43,11 +29,6 @@ int main(int argc, char ** argv) {
return 1;
}
if (params.speculative.draft.mparams.path.empty()) {
LOG_ERR("%s: --model-draft is required\n", __func__);
return 1;
}
// init llama.cpp
llama_backend_init();
llama_numa_init(params.numa);
@@ -62,18 +43,11 @@ int main(int argc, char ** argv) {
model_tgt = llama_init_tgt->model();
ctx_tgt = llama_init_tgt->context();
// check if the context supports partial sequence removal
const auto ctx_seq_rm = common_context_can_seq_rm(ctx_tgt);
const bool use_ckpt = (ctx_seq_rm == COMMON_CONTEXT_SEQ_RM_TYPE_FULL);
if (use_ckpt) {
LOG_INF("speculative decoding will use checkpoints (context does not support partial sequence removal)\n");
}
const llama_vocab * vocab = llama_model_get_vocab(model_tgt);
// load the draft model
llama_model_ptr model_dft;
llama_context_ptr ctx_dft;
// TODO: simplify this logic
{
@@ -81,9 +55,6 @@ int main(int argc, char ** argv) {
auto params_dft = params;
params_dft.n_parallel = 1;
params_dft.n_ctx = params_spec.n_ctx;
params_dft.n_batch = llama_n_ctx_seq(ctx_tgt);
params_dft.devices = params_spec.devices;
params_dft.model = params_spec.mparams;
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
@@ -103,8 +74,19 @@ int main(int argc, char ** argv) {
return 1;
}
params.speculative.draft.model = model_dft.get();
params.speculative.draft.cparams = common_context_params_to_llama(params_dft);
auto cparams = common_context_params_to_llama(params_dft);
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
params.speculative.draft.ctx_tgt = ctx_tgt;
params.speculative.draft.ctx_dft = ctx_dft.get();
}
// check if the context supports partial sequence removal
const bool use_ckpt_tgt = (common_context_can_seq_rm(ctx_tgt) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL);
const bool use_ckpt_dft = (common_context_can_seq_rm(ctx_dft.get()) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL);
if (use_ckpt_tgt) {
LOG_INF("speculative decoding will use checkpoints (context does not support partial sequence removal)\n");
}
// Tokenize the prompt
@@ -136,6 +118,8 @@ int main(int argc, char ** argv) {
// used to determine end of generation
bool has_eos = false;
llama_seq_id seq_id = 0;
// ================================================
// everything until here is standard initialization
// the relevant stuff for speculative decoding starts here
@@ -146,7 +130,8 @@ int main(int argc, char ** argv) {
common_sampler_ptr smpl(common_sampler_init(model_tgt, params.sampling));
// eval the prompt
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
llama_decode(ctx_dft.get(), llama_batch_get_one(inp.data(), inp.size() - 1));
// note: keep the last token separate!
llama_token id_last = inp.back();
@@ -160,16 +145,16 @@ int main(int argc, char ** argv) {
// init the speculator
const auto & params_spec = params.speculative;
struct common_speculative * spec = common_speculative_init(params.speculative, ctx_tgt);
struct common_speculative * spec = common_speculative_init(params.speculative, 1);
common_speculative_begin(spec, prompt_tgt);
common_speculative_begin(spec, seq_id, prompt_tgt);
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
size_t n_draft = 0;
llama_tokens draft;
spec_checkpoint spec_ckpt;
common_prompt_checkpoint ckpt;
const auto t_enc_end = ggml_time_us();
@@ -184,40 +169,57 @@ int main(int argc, char ** argv) {
// from a cache or lookup tables.
//
if (draft.empty()) {
ckpt.update_pos(
prompt_tgt.size(),
llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), seq_id),
llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), seq_id));
if (use_ckpt_dft) {
ckpt.update_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
}
// generate a new draft
draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last);
common_speculative_get_draft_params(spec, seq_id) = {
/* .drafting = */ true,
/* .n_max = */ -1,
/* .n_past = */ n_past,
/* .id_last = */ id_last,
/* .prompt = */ &prompt_tgt,
/* .result = */ &draft, // output
};
common_speculative_draft(spec);
// save the original draft size
n_draft = draft.size();
// save a checkpoint of the target context before evaluating the draft
// this allows us to restore the state if partial draft acceptance occurs
if (!draft.empty() && use_ckpt) {
const size_t ckpt_size = llama_state_seq_get_size_ext(ctx_tgt, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
spec_ckpt.data.resize(ckpt_size);
if (!draft.empty()) {
if (use_ckpt_tgt) {
ckpt.update_tgt(ctx_tgt, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
}
}
const size_t n = llama_state_seq_get_data_ext(ctx_tgt, spec_ckpt.data.data(), ckpt_size, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
GGML_ASSERT(n == ckpt_size);
{
ckpt.load_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
spec_ckpt.n_tokens = (int64_t) prompt_tgt.size();
LOG_DBG("created speculative checkpoint (n_tokens = %" PRId64 ", size = %.3f MiB)\n",
spec_ckpt.n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024);
llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, ckpt.pos_max + 1, -1);
}
} else {
// we have a previous (partial) draft to reuse from checkpoint restoration
if (use_ckpt) {
GGML_ASSERT(!spec_ckpt.empty());
if (use_ckpt_tgt) {
GGML_ASSERT(!ckpt.empty());
}
}
// always have a token to evaluate from before - id_last
common_batch_clear(batch_tgt);
common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true);
common_batch_add (batch_tgt, id_last, n_past++, { seq_id }, true);
// evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
{
for (size_t i = 0; i < draft.size(); ++i) {
common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
common_batch_add(batch_tgt, draft[i], n_past + i, { seq_id }, true);
}
//LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
@@ -225,9 +227,15 @@ int main(int argc, char ** argv) {
llama_decode(ctx_tgt, batch_tgt);
}
// evaluate the same batch with the draft model
{
// TODO: extend to support MTP, Eagle, etc. See server code for reference
llama_decode(ctx_dft.get(), batch_tgt);
}
// only save the sampler sampler state if we use checkpoints
common_sampler_ptr smpl_save;
if (use_ckpt) {
if (use_ckpt_tgt) {
smpl_save.reset(common_sampler_clone(smpl.get()));
}
@@ -247,17 +255,24 @@ int main(int argc, char ** argv) {
// check for partial draft acceptance:
// if the context doesn't support partial sequence removal, restore the checkpoint
// and make the accepted tokens the new partial draft for the next iteration
if (use_ckpt && ids.size() - 1 < draft.size()) {
if (use_ckpt_tgt && ids.size() - 1 < draft.size()) {
LOG_DBG("partial acceptance: %zu < %zu, restoring checkpoint\n", ids.size() - 1, draft.size());
draft = std::move(ids);
const size_t n = llama_state_seq_set_data_ext(ctx_tgt, spec_ckpt.data.data(), spec_ckpt.size(), 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
GGML_ASSERT(n == spec_ckpt.size());
{
ckpt.load_tgt(ctx_tgt, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, spec_ckpt.n_tokens, -1);
llama_memory_seq_rm(llama_get_memory(ctx_tgt), seq_id, ckpt.pos_max + 1, -1);
}
prompt_tgt.resize(spec_ckpt.n_tokens);
{
ckpt.load_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, ckpt.pos_max + 1, -1);
}
prompt_tgt.resize(ckpt.n_tokens);
smpl = std::move(smpl_save);
n_past = (int) prompt_tgt.size();
@@ -265,7 +280,7 @@ int main(int argc, char ** argv) {
continue;
}
common_speculative_accept(spec, ids.size() - 1);
common_speculative_accept(spec, seq_id, ids.size() - 1);
// full acceptance: consume the draft and commit accepted tokens
n_past += ids.size() - 1;
@@ -305,7 +320,8 @@ int main(int argc, char ** argv) {
{
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, n_past, -1);
llama_memory_seq_rm(llama_get_memory(ctx_tgt), seq_id, n_past, -1);
llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, n_past, -1);
}
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {

View File

@@ -855,7 +855,7 @@ struct vk_device_struct {
vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT];
std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16;
std::map<std::pair<uint32_t, uint32_t>, vk_pipeline> pipeline_fa_mask_opt;
@@ -2933,10 +2933,10 @@ struct vk_fa_tuning_params {
}
};
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type);
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type);
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
vk_fa_tuning_params result{};
result.path = FA_SCALAR;
@@ -2988,7 +2988,7 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device,
result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;
if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, kv_type)) {
if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, k_type, v_type)) {
result.block_rows /= 2;
}
@@ -3011,10 +3011,11 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device,
return result;
}
static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
GGML_UNUSED(n_rows);
GGML_UNUSED(n_kv);
GGML_UNUSED(kv_type);
GGML_UNUSED(k_type);
GGML_UNUSED(v_type);
GGML_UNUSED(f32acc);
vk_fa_tuning_params result{};
@@ -3070,12 +3071,6 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device
}
static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
// Mixed K/V is only implemented on the coopmat2 (flash_attn_cm2) path; never use scalar/cm1.
if (k_type != v_type) {
GGML_ASSERT(device->coopmat2);
return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
}
FaCodePath path = device->coopmat2 ? FA_COOPMAT2 :
device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
@@ -3087,7 +3082,7 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_
if (path == FA_COOPMAT1) {
bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) ||
(!f32acc && device->coopmat_support_16x16x16_f16acc);
const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc);
const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc);
if (!shape_ok || !shmem_ok) {
@@ -3107,9 +3102,9 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_
switch (path) {
case FA_SCALAR:
return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, k_type, f32acc);
return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
case FA_COOPMAT1:
return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc);
return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
case FA_COOPMAT2:
return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
default:
@@ -3279,6 +3274,20 @@ static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_dev
return 0; // If no matching configuration is found
}
// Whether scalar flash attention will use the MMQ path for the given k_type.
static bool ggml_vk_fa_scalar_uses_mmq(const vk_device& device, ggml_type k_type) {
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
return device->integer_dot_product && device->subgroup_clustered &&
(k_type == GGML_TYPE_Q4_0 || k_type == GGML_TYPE_Q4_1 ||
k_type == GGML_TYPE_Q5_0 || k_type == GGML_TYPE_Q5_1 ||
k_type == GGML_TYPE_Q8_0);
#else
GGML_UNUSED(device);
GGML_UNUSED(k_type);
return false;
#endif
}
static void ggml_vk_load_shaders(vk_device& device) {
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
@@ -3525,121 +3534,96 @@ static void ggml_vk_load_shaders(vk_device& device) {
align, disable_robustness, require_full_subgroups, required_subgroup_size);
};
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \
FaCodePath path = fa.first.path; \
uint32_t Br = fa.first.Br; \
uint32_t Bc = fa.first.Bc; \
bool aligned = fa.first.aligned; \
bool f32acc = fa.first.f32acc; \
uint32_t fa_sgs = fa.first.subgroup_size; \
bool fa_ds = fa.first.subgroup_size == 0; \
if (path == FAPATH) { \
if (aligned) { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
} \
} else { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
} \
} \
} \
}
// FA scalar has two SPIR-V modules (MMQ vs non-MMQ); FA cm1 has one. K/V
// quant type is selected at runtime via the FaTypeK / FaTypeV spec constants.
if (device->fp16) {
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
for (auto &fa : device->pipeline_flash_attn_f32_f16) {
if (fa.first.path != FA_SCALAR) continue;
const uint32_t Br = fa.first.Br;
const uint32_t Bc = fa.first.Bc;
const bool aligned = fa.first.aligned;
const bool f32acc = fa.first.f32acc;
const uint32_t fa_sgs = fa.first.subgroup_size;
const bool fa_ds = fa.first.subgroup_size == 0;
const bool use_mmq = ggml_vk_fa_scalar_uses_mmq(device, fa.first.k_type);
const void * spv_data = nullptr;
size_t spv_size = 0;
if (use_mmq) {
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product && device->subgroup_clustered) {
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _int8)
} else
if (device->fp16) {
if (f32acc) { spv_data = flash_attn_f32_f16_int8_data; spv_size = flash_attn_f32_f16_int8_len; }
else { spv_data = flash_attn_f32_f16_f16acc_int8_data; spv_size = flash_attn_f32_f16_f16acc_int8_len; }
} else {
spv_data = flash_attn_f32_f16_fp32_int8_data;
spv_size = flash_attn_f32_f16_fp32_int8_len;
}
#endif
{
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, )
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, )
}
} else {
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product && device->subgroup_clustered) {
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32_int8)
} else
#endif
{
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32)
} else {
if (device->fp16) {
if (f32acc) { spv_data = flash_attn_f32_f16_data; spv_size = flash_attn_f32_f16_len; }
else { spv_data = flash_attn_f32_f16_f16acc_data; spv_size = flash_attn_f32_f16_f16acc_len; }
} else {
spv_data = flash_attn_f32_f16_fp32_data;
spv_size = flash_attn_f32_f16_fp32_len;
}
}
const char *name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16";
ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7,
sizeof(vk_flash_attn_push_constants), {Br, 1, 1},
get_fa_spec_constants(fa.first), aligned ? Bc : 1, true,
!fa_ds, !fa_ds ? fa_sgs : 0);
}
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (device->coopmat1_fa_support) {
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT1, _cm1)
}
#endif
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
#define CREATE_FA_CM2_MIXED() \
for (int fa_k_ty = 0; fa_k_ty < (int)GGML_TYPE_COUNT; ++fa_k_ty) { \
for (auto &fa : device->pipeline_flash_attn_f32_f16[fa_k_ty]) { \
FaCodePath path = fa.first.path; \
uint32_t Br = fa.first.Br; \
uint32_t Bc = fa.first.Bc; \
bool aligned = fa.first.aligned; \
bool f32acc = fa.first.f32acc; \
if (path == FA_COOPMAT2) { \
if (aligned) { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \
} \
} else { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \
} \
} \
} \
} \
for (auto &fa : device->pipeline_flash_attn_f32_f16) {
if (fa.first.path != FA_COOPMAT1) continue;
const uint32_t Br = fa.first.Br;
const uint32_t Bc = fa.first.Bc;
const bool aligned = fa.first.aligned;
const bool f32acc = fa.first.f32acc;
const uint32_t fa_sgs = fa.first.subgroup_size;
const bool fa_ds = fa.first.subgroup_size == 0;
const void * spv_data;
size_t spv_size;
if (f32acc) { spv_data = flash_attn_f32_f16_cm1_data; spv_size = flash_attn_f32_f16_cm1_len; }
else { spv_data = flash_attn_f32_f16_f16acc_cm1_data; spv_size = flash_attn_f32_f16_f16acc_cm1_len; }
const char *name = aligned ? "flash_attn_f32_f16_aligned_cm1" : "flash_attn_f32_f16_cm1";
ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7,
sizeof(vk_flash_attn_push_constants), {Br, 1, 1},
get_fa_spec_constants(fa.first), aligned ? Bc : 1, true,
!fa_ds, !fa_ds ? fa_sgs : 0);
}
}
#endif
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (device->coopmat2) {
for (auto &fa : device->pipeline_flash_attn_f32_f16) {
if (fa.first.path != FA_COOPMAT2) continue;
const uint32_t Br = fa.first.Br;
const uint32_t Bc = fa.first.Bc;
const bool aligned = fa.first.aligned;
const bool f32acc = fa.first.f32acc;
const void * spv_data;
size_t spv_size;
const char * name;
if (aligned) {
if (f32acc) { spv_data = flash_attn_f32_f16_cm2_data; spv_size = flash_attn_f32_f16_cm2_len; name = "flash_attn_f32_f16_aligned_f32acc_cm2"; }
else { spv_data = flash_attn_f32_f16_f16acc_cm2_data; spv_size = flash_attn_f32_f16_f16acc_cm2_len; name = "flash_attn_f32_f16_aligned_f16acc_cm2"; }
} else {
if (f32acc) { spv_data = flash_attn_f32_f16_cm2_data; spv_size = flash_attn_f32_f16_cm2_len; name = "flash_attn_f32_f16_f32acc_cm2"; }
else { spv_data = flash_attn_f32_f16_f16acc_cm2_data; spv_size = flash_attn_f32_f16_f16acc_cm2_len; name = "flash_attn_f32_f16_f16acc_cm2"; }
}
ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7,
sizeof(vk_flash_attn_push_constants), {Br, 1, 1},
get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, false, 0);
}
if (device->coopmat2) {
CREATE_FA_CM2_MIXED();
}
#undef CREATE_FA_CM2_MIXED
#endif
#undef CREATE_FA
const int mul_mat_id_param_count = 5;
@@ -8940,8 +8924,9 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
}
}
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) {
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type) {
GGML_UNUSED(f32acc);
GGML_UNUSED(v_type);
// Needs to be kept up to date on shader changes
const uint32_t wg_size = params.workgroup_size;
const uint32_t Br = params.block_rows;
@@ -8949,10 +8934,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
const bool mmq = device->integer_dot_product && device->subgroup_clustered &&
(kv_type == GGML_TYPE_Q4_0 || kv_type == GGML_TYPE_Q4_1 ||
kv_type == GGML_TYPE_Q5_0 || kv_type == GGML_TYPE_Q5_1 ||
kv_type == GGML_TYPE_Q8_0 || kv_type == GGML_TYPE_IQ4_NL);
const bool mmq = ggml_vk_fa_scalar_uses_mmq(device, k_type);
// tmpsh is overestimated slightly
const uint32_t tmpsh = wg_size * sizeof(float);
@@ -8969,17 +8951,10 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
// kvsh uses D = HSV (K goes through kblocksh instead)
kvsh = params.shmem_staging ? Bc * (hsv / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
// block_a_cache size depends on quant type
uint32_t block_a_size;
switch (kv_type) {
case GGML_TYPE_Q4_0: block_a_size = 4 * sizeof(uint32_t) + float_type_size; break;
case GGML_TYPE_Q4_1: block_a_size = 4 * sizeof(uint32_t) + 2 * float_type_size; break;
case GGML_TYPE_Q5_0: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + float_type_size; break;
case GGML_TYPE_Q5_1: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + 2 * float_type_size; break;
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL: block_a_size = 8 * sizeof(int32_t) + float_type_size; break;
default: block_a_size = 0; break;
}
// The mixed MMQ shader uses a superset block_a_cache that fits every
// FA-supported quant: int32_t qs[8] + uint32_t qh + FLOAT_TYPEV2 dm.
// Single-scale types leave dm.y unused; non-Q5_* leave qh unused.
const uint32_t block_a_size = 8 * sizeof(int32_t) + sizeof(uint32_t) + 2 * float_type_size;
kblocksh_size = params.shmem_staging ? Bc * (hsk / 32) * block_a_size : block_a_size;
} else {
Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
@@ -9117,10 +9092,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, v->type, f32acc);
if (tuning_params.path != FA_COOPMAT2) {
GGML_ASSERT(k->type == v->type);
}
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
@@ -9164,7 +9135,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
{
std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type];
auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16;
auto it = pipelines.find(fa_pipeline_state);
if (it != pipelines.end()) {
pipeline = it->second;
@@ -15642,10 +15613,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
return false;
}
// mismatching K/V type is currently supported for coopmat2 only.
if (op->src[1]->type != op->src[2]->type && !coopmat2) {
return false;
}
auto fa_kv_ok = [coopmat2](ggml_type t) {
switch (t) {
case GGML_TYPE_F32:

View File

@@ -22,6 +22,7 @@
#include "types.glsl"
#include "flash_attn_base.glsl"
#include "flash_attn_dequant.glsl"
const uint32_t HSK_per_thread = HSK / D_split;
const uint32_t HSV_per_thread = HSV / D_split;
@@ -128,18 +129,20 @@ void main() {
Qf[buf_ib].qs[buf_iqs] = pack32(i8vec4(vals));
#if defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
if (buf_iqs == 0) {
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0);
}
#else // Q4_0, Q4_1, Q5_0, Q5_1
const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w;
const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8);
// Q8_0 K only needs (qd, _); the asymmetric Q4_*/Q5_* family also stores
// the row-sum scaled by qd, used in k_dot_correction.
if (FaTypeK == FA_TYPE_Q8_0) {
if (buf_iqs == 0) {
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0);
}
} else {
const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w;
const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8);
if (buf_iqs == 0) {
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd);
if (buf_iqs == 0) {
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd);
}
}
#endif
#endif
}
barrier();
@@ -177,13 +180,9 @@ void main() {
// mo_offset will point to the tile starting at row i*Br and col 0
uint32_t mo_offset = mo_stride * i;
#if BLOCK_SIZE > 1
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
#else
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
// FaBlockBytesK/V == 2 for f16, 16 for f32, ggml block byte size for quants.
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / FaBlockBytesK;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / FaBlockBytesV;
uint32_t m_offset = gqa_iq1*KV;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
@@ -257,21 +256,21 @@ void main() {
if (idx + gl_WorkGroupSize.x <= Bc * HSK / 4 || c < Bc) {
FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0);
if (!KV_bounds_check || j * Bc + c < KV) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
#else
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
#endif
if (USE_DECODE_K) {
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE_K + 4 * d;
uint ib = coord / BLOCK_SIZE_K;
uint iqs = (coord % BLOCK_SIZE_K);
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
} else {
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
}
}
kvsh[c * kvsh_stride + d] = K_Tf;
}
}
#else // MMQ
const uint ints_per_block = 8 / QUANT_R_MMQ;
const uint ints_per_block = 8u / fa_quant_r_mmq(FaTypeK);
const uint quant_iters = Bc * HSK / 32 * ints_per_block;
[[unroll]] for (uint32_t idx = 0; idx < quant_iters; idx += gl_WorkGroupSize.x) {
const uint32_t iqs = (idx + tid) % ints_per_block;
@@ -310,15 +309,13 @@ void main() {
FLOAT_TYPEV4 K_Tf;
if (SHMEM_STAGING != 0) {
K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
} else {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
} else if (USE_DECODE_K) {
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE_K;
uint iqs = (coord % BLOCK_SIZE_K);
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
#else
} else {
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
#endif
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Sf[r][c] += dot(ACC_TYPEV4(Q_cache[r]), ACC_TYPEV4(K_Tf));
@@ -335,15 +332,13 @@ void main() {
FLOAT_TYPEV4 K_Tf;
if (SHMEM_STAGING != 0) {
K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
} else {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
} else if (USE_DECODE_K) {
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE_K;
uint iqs = (coord % BLOCK_SIZE_K);
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
#else
} else {
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
#endif
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Sf[r][c] += dot(ACC_TYPEV4(Qf[tile_row(r) * qf_stride + d * D_split + d_tid]), ACC_TYPEV4(K_Tf));
@@ -366,72 +361,47 @@ void main() {
int32_t k_quants[d_per_step];
ACC_TYPEV2 k_dm;
// Q4_*/Q5_* take the block-8 fast path when one step covers a full
// block; Q8_0 always goes through the per-int get_k_qs* helpers
// (its qs is byte-packed, not nibble-packed).
const bool block8_fast = (d_per_step == 8) && (FaTypeK != FA_TYPE_Q8_0);
if (SHMEM_STAGING != 0) {
const uint k_block_idx = (d_tid * (HSK_per_thread / 4) + d_block) / 8;
const uint buf_ib = (c * cols_per_iter + col_tid) * qf_stride + k_block_idx;
#if QUANT_AUXF == 1
k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm, 0.0);
#else
k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm);
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
if (d_per_step == 8) {
if (block8_fast) {
const bool has_qh = (FaTypeK == FA_TYPE_Q5_0) || (FaTypeK == FA_TYPE_Q5_1);
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
uint vui = kblocksh[buf_ib].qs[d];
k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF;
uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF;
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
#endif
if (has_qh) {
uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF;
uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF;
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
}
}
} else
#endif
{
} else {
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
k_quants[d] = get_k_qs_shmem(buf_ib, (d_tid * (HSK_per_thread / 4) + d_block) % 8 + d);
}
}
} else {
const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d_tid * (HSK_per_thread / 4) + d_block);
const uint ib = coord / BLOCK_SIZE;
const uint iqs = (coord % BLOCK_SIZE);
const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d_tid * (HSK_per_thread / 4) + d_block);
const uint ib = coord / BLOCK_SIZE_K;
const uint iqs = (coord % BLOCK_SIZE_K);
#if QUANT_AUXF == 1
k_dm = ACC_TYPEV2(get_k_d(ib, k_offset), 0.0);
#else
k_dm = ACC_TYPEV2(get_k_dm(ib, k_offset));
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
if (d_per_step == 8) {
#if defined(DATA_A_Q5_0)
uint qh = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qh[0],
k_packed.k_data_packed16[k_offset + ib].qh[1]));
#elif defined(DATA_A_Q5_1)
uint qh = k_packed.k_data_packed16[k_offset + ib].qh;
#endif
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
#if defined(A_TYPE_PACKED32)
uint vui = k_packed32.k_data_packed32[k_offset + ib].qs[d];
#else
uint vui = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 0],
k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 1]));
#endif
k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
uint qh_lo = (qh >> (d * 4)) & 0xF;
uint qh_hi = (qh >> (d * 4 + 16)) & 0xF;
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
#endif
k_dm = ACC_TYPEV2(get_k_scale(ib, k_offset));
if (block8_fast) {
fa_k_qs_block8 blk = get_k_qs_block8(ib, k_offset);
[[unroll]] for (uint32_t d = 0; d < 8; d++) {
k_quants[d] = blk.qs[d];
}
} else
#endif
{
} else {
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
k_quants[d] = get_k_qs(ib, iqs + d * 4, k_offset);
}
@@ -516,14 +486,14 @@ void main() {
if (idx + gl_WorkGroupSize.x <= Bc * HSV / 4 || c < Bc) {
FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0);
if (!KV_bounds_check || j * Bc + c < KV) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
#else
V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
#endif
if (USE_DECODE_V) {
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE_V + 4 * d;
uint ib = coord / BLOCK_SIZE_V;
uint iqs = (coord % BLOCK_SIZE_V);
V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
} else {
V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
}
}
kvsh[c * kvsh_stride + d] = V_Tf;
@@ -547,15 +517,13 @@ void main() {
FLOAT_TYPEV4 Vf;
if (SHMEM_STAGING != 0) {
Vf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
} else {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
} else if (USE_DECODE_V) {
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE_V + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE_V;
uint iqs = (coord % BLOCK_SIZE_V);
Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
#else
} else {
Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
#endif
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] += FLOAT_TYPEV4(Pf[r] * Vf);

View File

@@ -87,176 +87,58 @@ layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];};
#define BINDING_IDX_K 0
#define BINDING_IDX_V 1
#if defined(DATA_A_F32)
layout (binding = 1) readonly buffer K_PACKED {vec4 k_data_packed[];} k_packed;
layout (binding = 2) readonly buffer V_PACKED {vec4 v_data_packed[];} v_packed;
#elif defined(A_TYPE_PACKED16)
layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 1) readonly buffer K_PACKED32 {A_TYPE_PACKED32 k_data_packed32[];} k_packed32;
layout (binding = 2) readonly buffer V_PACKED32 {A_TYPE_PACKED32 v_data_packed32[];} v_packed32;
#endif
// FaTypeK / FaTypeV spec constant values. These mirror enum ggml_type so the
// host can pass the type directly. Keep in sync with ggml.h.
#define FA_TYPE_F32 0u
#define FA_TYPE_F16 1u
#define FA_TYPE_Q4_0 2u
#define FA_TYPE_Q4_1 3u
#define FA_TYPE_Q5_0 6u
#define FA_TYPE_Q5_1 7u
#define FA_TYPE_Q8_0 8u
#define FA_TYPE_Q1_0 41u
#ifndef BLOCK_SIZE
#define BLOCK_SIZE 1
#endif
#if defined(DATA_A_F32)
#undef BLOCK_SIZE
#define BLOCK_SIZE 4
#define BLOCK_BYTE_SIZE 16
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
// iqs is currently always zero in the flash attention shaders
if (binding_idx == BINDING_IDX_K) {
return FLOAT_TYPEV4(k_packed.k_data_packed[a_offset + ib]);
} else {
return FLOAT_TYPEV4(v_packed.v_data_packed[a_offset + ib]);
// Number of matrix elements per buffer block, derived from the K/V type spec
// constant. F32 is treated as a vec4 "block" of 4 floats. F16 uses block size 1
// and bypasses the dequant path entirely. Quants follow their ggml block sizes.
uint fa_block_elems(uint ty) {
switch (ty) {
case FA_TYPE_F32: return 4u;
case FA_TYPE_F16: return 1u;
case FA_TYPE_Q4_0: return uint(QUANT_K_Q4_0);
case FA_TYPE_Q4_1: return uint(QUANT_K_Q4_1);
case FA_TYPE_Q5_0: return uint(QUANT_K_Q5_0);
case FA_TYPE_Q5_1: return uint(QUANT_K_Q5_1);
case FA_TYPE_Q8_0: return uint(QUANT_K_Q8_0);
case FA_TYPE_Q1_0: return uint(QUANT_K_Q1_0); // cm2-only, harmless elsewhere
default: return 1u;
}
}
#endif
#if defined(DATA_A_Q4_0)
#define BLOCK_BYTE_SIZE 18
#elif defined(DATA_A_Q4_1)
#define BLOCK_BYTE_SIZE 20
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
if (binding_idx == BINDING_IDX_K) {
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF);
#ifdef DATA_A_Q4_1
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * nibbles + FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].m);
#else
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f));
#endif
} else {
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF);
#ifdef DATA_A_Q4_1
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * nibbles + FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].m);
#else
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f));
#endif
// QUANT_R_MMQ for FA-eligible K types. Q4_*/Q5_* store two nibbles per byte
// (R==2); Q8_0 stores one byte per element (R==1). Used to derive the number
// of int32s per 32-element block on the MMQ K path: ints_per_block == 8 / R.
uint fa_quant_r_mmq(uint ty) {
switch (ty) {
case FA_TYPE_Q4_0: return uint(QUANT_R_Q4_0);
case FA_TYPE_Q4_1: return uint(QUANT_R_Q4_1);
case FA_TYPE_Q5_0: return uint(QUANT_R_Q5_0);
case FA_TYPE_Q5_1: return uint(QUANT_R_Q5_1);
case FA_TYPE_Q8_0: return uint(QUANT_R_Q8_0);
default: return 1u;
}
}
#endif
#if defined(DATA_A_Q5_0)
#define BLOCK_BYTE_SIZE 22
#elif defined(DATA_A_Q5_1)
#define BLOCK_BYTE_SIZE 24
#endif
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
if (binding_idx == BINDING_IDX_K) {
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
#ifdef DATA_A_Q5_1
uint qh = k_packed.k_data_packed16[a_offset + ib].qh;
#else
uint qh = uint(k_packed.k_data_packed16[a_offset + ib].qh[0]) | (uint(k_packed.k_data_packed16[a_offset + ib].qh[1]) << 16);
#endif
FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) * FLOAT_TYPE(16.0f);
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF);
#ifdef DATA_A_Q5_1
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles + hb) + FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].m);
#else
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f));
#endif
} else {
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
#ifdef DATA_A_Q5_1
uint qh = v_packed.v_data_packed16[a_offset + ib].qh;
#else
uint qh = uint(v_packed.v_data_packed16[a_offset + ib].qh[0]) | (uint(v_packed.v_data_packed16[a_offset + ib].qh[1]) << 16);
#endif
FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) * FLOAT_TYPE(16.0f);
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF);
#ifdef DATA_A_Q5_1
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles + hb) + FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].m);
#else
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f));
#endif
}
}
#endif
#if defined(DATA_A_IQ4_NL)
#define BLOCK_BYTE_SIZE 18
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
if (binding_idx == BINDING_IDX_K) {
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(
kvalues_iq4nl[vui_lo & 0xF],
kvalues_iq4nl[(vui_lo >> 8) & 0xF],
kvalues_iq4nl[vui_hi & 0xF],
kvalues_iq4nl[(vui_hi >> 8) & 0xF]);
} else {
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(
kvalues_iq4nl[vui_lo & 0xF],
kvalues_iq4nl[(vui_lo >> 8) & 0xF],
kvalues_iq4nl[vui_hi & 0xF],
kvalues_iq4nl[(vui_hi >> 8) & 0xF]);
}
}
#endif
#if defined(DATA_A_Q8_0)
#define BLOCK_BYTE_SIZE 34
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
if (binding_idx == BINDING_IDX_K) {
const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y);
} else {
const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y);
}
}
#endif
// These can't be `const` globals because GLSL forbids function calls in global
// const initializers, even when the spec constants would let the driver fold
// them. Macros expand at the use site and fold after specialization.
#define BLOCK_SIZE_K fa_block_elems(FaTypeK)
#define BLOCK_SIZE_V fa_block_elems(FaTypeV)
// F16 reads f16 elements directly from the binding; everything else routes
// through dequantize4 / the MMQ helpers to unpack from the packed block layout.
#define USE_DECODE_K (FaTypeK != FA_TYPE_F16)
#define USE_DECODE_V (FaTypeV != FA_TYPE_F16)
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))

View File

@@ -14,6 +14,7 @@
#include "types.glsl"
#include "flash_attn_base.glsl"
#include "flash_attn_dequant.glsl"
// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
const uint32_t MatBr = 16;
@@ -127,13 +128,9 @@ void main() {
// mo_offset will point to the tile starting at row i*Br and col 0
uint32_t mo_offset = mo_stride * i;
#if BLOCK_SIZE > 1
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
#else
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
// FaBlockBytesK/V == 2 for f16 (sizeof f16) and == 16 for f32 (vec4) and == ggml block size for quants.
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / FaBlockBytesK;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / FaBlockBytesV;
uint32_t m_offset = gqa_iq1*KV;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
@@ -227,14 +224,14 @@ void main() {
if (idx + gl_WorkGroupSize.x <= Bc * HSK_pad / 4 || c < Bc) {
f16vec4 K_Tf = f16vec4(0);
if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
#else
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
#endif
if (USE_DECODE_K) {
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE_K + 4 * d;
uint ib = coord / BLOCK_SIZE_K;
uint iqs = (coord % BLOCK_SIZE_K);
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
} else {
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
}
}
kvsh[c * kvsh_stride + d] = K_Tf;
@@ -256,47 +253,40 @@ void main() {
// staged through a Bc * MatBr size staging buffer.
// If K is not type f16, then it is always staged for dequantization.
if (SHMEM_STAGING == 0) {
#if BLOCK_SIZE == 1
if (KV_bounds_check || d * 16 + 16 > HSK) {
#endif
barrier();
[[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) {
uint32_t col_vec = (idx + tid) % (MatBr / 4);
uint32_t row = (idx + tid) / (MatBr / 4);
if (idx + tid < Bc * MatBr / 4) {
f16vec4 K_Tf = f16vec4(0);
if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
#else
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]);
#endif
// For quants we always need to dequant into kvsh; for f16 we can load
// directly from global memory when alignment / bounds allow it.
const bool stage_k = USE_DECODE_K || KV_bounds_check || d * 16 + 16 > HSK;
if (stage_k) {
barrier();
[[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) {
uint32_t col_vec = (idx + tid) % (MatBr / 4);
uint32_t row = (idx + tid) / (MatBr / 4);
if (idx + tid < Bc * MatBr / 4) {
f16vec4 K_Tf = f16vec4(0);
if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) {
if (USE_DECODE_K) {
uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE_K + d * 16 + col_vec * 4;
uint ib = coord / BLOCK_SIZE_K;
uint iqs = (coord % BLOCK_SIZE_K);
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
} else {
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]);
}
}
kvsh[row * kvsh_stride + col_vec] = K_Tf;
}
kvsh[row * kvsh_stride + col_vec] = K_Tf;
}
barrier();
}
barrier();
#if BLOCK_SIZE == 1
}
#endif
#if BLOCK_SIZE == 1
if (KV_bounds_check || d * 16 + 16 > HSK)
#endif
{
if (stage_k) {
uint coord = (gl_SubgroupID * MatBc) * kvsh_stride;
coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
}
#if BLOCK_SIZE == 1
else {
} else {
const uint coord = k_offset / 4 + (j * Bc + gl_SubgroupID * MatBc) * k_stride / 4 + d * 16 / 4;
coopMatLoad(KMat, data_kv4, coord, k_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
}
#endif
} else {
uint coord = (gl_SubgroupID * MatBc) * kvsh_stride + d * 16 / 4;
coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
@@ -397,14 +387,14 @@ void main() {
if (idx + gl_WorkGroupSize.x <= Bc * HSV_pad / 4 || c < Bc) {
f16vec4 V_Tf = f16vec4(0);
if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
#else
V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
#endif
if (USE_DECODE_V) {
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE_V + 4 * d;
uint ib = coord / BLOCK_SIZE_V;
uint iqs = (coord % BLOCK_SIZE_V);
V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
} else {
V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
}
}
kvsh[c * kvsh_stride + d] = V_Tf;
@@ -431,36 +421,33 @@ void main() {
// staged through a Bc * MatBr size staging buffer.
// If V is not type f16, then it is always staged for dequantization.
if (SHMEM_STAGING == 0) {
#if BLOCK_SIZE == 1
// For f16, only preload if not aligned
if (KV_bounds_check) {
#endif
[[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) {
const uint idx = i * gl_WorkGroupSize.x + tid;
const uint row = idx / v_cols;
const uint col = idx % v_cols;
// For quants we always preload via kvsh. For f16 we only preload when
// alignment / bounds force it (otherwise we coopMatLoad direct from data_vv4).
const bool stage_v = USE_DECODE_V || KV_bounds_check;
if (stage_v) {
[[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) {
const uint idx = i * gl_WorkGroupSize.x + tid;
const uint row = idx / v_cols;
const uint col = idx % v_cols;
const uint v_row = j * Bc + row;
const uint v_col = hsv_tile * MatBc * row_split + col * 4;
const uint v_row = j * Bc + row;
const uint v_col = hsv_tile * MatBc * row_split + col * 4;
const uint coord = v_row * v_stride * BLOCK_SIZE + v_col;
const uint ib = coord / BLOCK_SIZE;
const uint iqs = coord % BLOCK_SIZE;
const uint coord = v_row * v_stride * BLOCK_SIZE_V + v_col;
const uint ib = coord / BLOCK_SIZE_V;
const uint iqs = coord % BLOCK_SIZE_V;
if (!KV_bounds_check || (v_row < KV && v_col < HSV)) {
#if BLOCK_SIZE > 1
kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
#else
kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
#endif
} else {
kvsh[row * vsh_stride + col] = f16vec4(0.0f);
if (!KV_bounds_check || (v_row < KV && v_col < HSV)) {
if (USE_DECODE_V) {
kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
} else {
kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
}
} else {
kvsh[row * vsh_stride + col] = f16vec4(0.0f);
}
}
}
#if BLOCK_SIZE == 1
}
#endif
}
barrier();
@@ -471,15 +458,12 @@ void main() {
coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor);
if (SHMEM_STAGING == 0) {
#if BLOCK_SIZE == 1
if (!KV_bounds_check) {
if (!USE_DECODE_V && !KV_bounds_check) {
// F16 values can be loaded directly from global memory
const uint v_tile_row = j * Bc + bc_chunk * MatBc;
const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4;
coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
} else
#endif
{
} else {
const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4);
coopMatLoad(QMat, kvsh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor);
}

View File

@@ -28,43 +28,28 @@ layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_
uint8_t raw[FaBlockBytesV];
};
uint fa_block_elems(uint ty) {
switch (ty) {
case 0u: return 4u; // GGML_TYPE_F32: vec4 block (matches decodeBufF32 / dequantFuncF32)
case 1u: return 1u; // GGML_TYPE_F16
case 2u: return uint(QUANT_K_Q4_0);
case 3u: return uint(QUANT_K_Q4_1);
case 6u: return uint(QUANT_K_Q5_0);
case 7u: return uint(QUANT_K_Q5_1);
case 8u: return uint(QUANT_K_Q8_0);
case 41u: return uint(QUANT_K_Q1_0);
default:
return 1u;
}
}
float16_t faDecodeK(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) {
switch (FaTypeK) {
case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock);
case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
case FA_TYPE_F32: return dequantFuncF32 (decodeBufF32 (bl_in), blockCoords, coordInBlock);
case FA_TYPE_Q4_0: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
case FA_TYPE_Q4_1: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
case FA_TYPE_Q5_0: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
case FA_TYPE_Q5_1: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
case FA_TYPE_Q8_0: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
case FA_TYPE_Q1_0: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
default: return float16_t(0);
}
}
float16_t faDecodeV(const decodeBufFA_V bl_in, const uint blockCoords[2], const uint coordInBlock[2]) {
switch (FaTypeV) {
case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock);
case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
case FA_TYPE_F32: return dequantFuncF32 (decodeBufF32 (bl_in), blockCoords, coordInBlock);
case FA_TYPE_Q4_0: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
case FA_TYPE_Q4_1: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
case FA_TYPE_Q5_0: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
case FA_TYPE_Q5_1: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
case FA_TYPE_Q8_0: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
case FA_TYPE_Q1_0: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
default: return float16_t(0);
}
}

View File

@@ -0,0 +1,123 @@
// Asymmetric K/V flash attention: aliased SSBO views of bindings 1 (K) and 2 (V)
// covering every supported FA element type, plus an uber dequantize4() that
// switches on FaTypeK / FaTypeV. After spec-constant specialization the driver
// folds away every path except the one matching the K/V type for this pipeline.
//
// Included by flash_attn.comp and flash_attn_cm1.comp. Not included by
// flash_attn_cm2.comp, which has its own buffer_reference-based decode path.
//
// We use macros (rather than per-quant decode functions taking a struct) on
// purpose: the FA shaders don't enable GL_EXT_shader_explicit_arithmetic_types_float16
// when FLOAT16 isn't defined, which makes float16-containing struct values
// illegal to return from / pass to functions. Macros expand inline where the
// float16 stays in storage and is converted to FLOAT_TYPE at use.
// F32 is fed as a vec4 "block" (4 floats), matching what dequant_funcs_cm2.glsl
// does for F32 in the cm2 shader. FaBlockBytesK/V == 16 for F32.
layout (binding = 1) readonly buffer K_PACKED_F32 { vec4 data[]; } k_packed_f32;
layout (binding = 2) readonly buffer V_PACKED_F32 { vec4 data[]; } v_packed_f32;
layout (binding = 1) readonly buffer K_PACKED_Q4_0 { block_q4_0_packed16 data[]; } k_packed_q4_0;
layout (binding = 2) readonly buffer V_PACKED_Q4_0 { block_q4_0_packed16 data[]; } v_packed_q4_0;
layout (binding = 1) readonly buffer K_PACKED_Q4_1 { block_q4_1_packed16 data[]; } k_packed_q4_1;
layout (binding = 2) readonly buffer V_PACKED_Q4_1 { block_q4_1_packed16 data[]; } v_packed_q4_1;
layout (binding = 1) readonly buffer K_PACKED_Q5_0 { block_q5_0_packed16 data[]; } k_packed_q5_0;
layout (binding = 2) readonly buffer V_PACKED_Q5_0 { block_q5_0_packed16 data[]; } v_packed_q5_0;
layout (binding = 1) readonly buffer K_PACKED_Q5_1 { block_q5_1_packed16 data[]; } k_packed_q5_1;
layout (binding = 2) readonly buffer V_PACKED_Q5_1 { block_q5_1_packed16 data[]; } v_packed_q5_1;
layout (binding = 1) readonly buffer K_PACKED_Q8_0 { block_q8_0_packed16 data[]; } k_packed_q8_0;
layout (binding = 2) readonly buffer V_PACKED_Q8_0 { block_q8_0_packed16 data[]; } v_packed_q8_0;
// Q4_1 and Q5_1 packed32 views: aliased to the same memory as the packed16
// views, used by the MMQ K-side hot path for fast 4-uint loads.
layout (binding = 1) readonly buffer K_PACKED_Q4_1_P32 { block_q4_1_packed32 data[]; } k_packed_q4_1_p32;
layout (binding = 1) readonly buffer K_PACKED_Q5_1_P32 { block_q5_1_packed32 data[]; } k_packed_q5_1_p32;
// Per-quant decode bodies are expanded once for the K view set and once for
// the V view set. The macros take the buffer name as a parameter.
#define FA_DEQUANT4_F32(BUF) \
return FLOAT_TYPEV4(BUF.data[a_offset + ib]);
#define FA_DEQUANT4_Q4_0(BUF) { \
uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \
uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \
uint shift = (iqs & 0x10) >> 2; \
vui_lo >>= shift; \
vui_hi >>= shift; \
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \
vui_hi & 0xF, (vui_hi >> 8) & 0xF); \
return FLOAT_TYPE(BUF.data[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f)); \
}
#define FA_DEQUANT4_Q4_1(BUF) { \
uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \
uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \
uint shift = (iqs & 0x10) >> 2; \
vui_lo >>= shift; \
vui_hi >>= shift; \
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \
vui_hi & 0xF, (vui_hi >> 8) & 0xF); \
return FLOAT_TYPE(BUF.data[a_offset + ib].d) * nibbles \
+ FLOAT_TYPE(BUF.data[a_offset + ib].m); \
}
#define FA_DEQUANT4_Q5_0(BUF) { \
uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \
uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \
uint shift = (iqs & 0x10) >> 2; \
vui_lo >>= shift; \
vui_hi >>= shift; \
uint qh = uint(BUF.data[a_offset + ib].qh[0]) \
| (uint(BUF.data[a_offset + ib].qh[1]) << 16); \
FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, \
(qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) \
* FLOAT_TYPE(16.0f); \
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \
vui_hi & 0xF, (vui_hi >> 8) & 0xF); \
return FLOAT_TYPE(BUF.data[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f)); \
}
#define FA_DEQUANT4_Q5_1(BUF) { \
uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \
uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \
uint shift = (iqs & 0x10) >> 2; \
vui_lo >>= shift; \
vui_hi >>= shift; \
uint qh = BUF.data[a_offset + ib].qh; \
FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, \
(qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) \
* FLOAT_TYPE(16.0f); \
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \
vui_hi & 0xF, (vui_hi >> 8) & 0xF); \
return FLOAT_TYPE(BUF.data[a_offset + ib].d) * (nibbles + hb) \
+ FLOAT_TYPE(BUF.data[a_offset + ib].m); \
}
#define FA_DEQUANT4_Q8_0(BUF) { \
const i8vec2 v0 = unpack8(int32_t(BUF.data[a_offset + ib].qs[iqs / 2 ])).xy; \
const i8vec2 v1 = unpack8(int32_t(BUF.data[a_offset + ib].qs[iqs / 2 + 1])).xy; \
return FLOAT_TYPE(BUF.data[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); \
}
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
if (binding_idx == BINDING_IDX_K) {
switch (FaTypeK) {
case FA_TYPE_F32: FA_DEQUANT4_F32 (k_packed_f32)
case FA_TYPE_Q4_0: FA_DEQUANT4_Q4_0(k_packed_q4_0)
case FA_TYPE_Q4_1: FA_DEQUANT4_Q4_1(k_packed_q4_1)
case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(k_packed_q5_0)
case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(k_packed_q5_1)
case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(k_packed_q8_0)
}
} else {
switch (FaTypeV) {
case FA_TYPE_F32: FA_DEQUANT4_F32 (v_packed_f32)
case FA_TYPE_Q4_0: FA_DEQUANT4_Q4_0(v_packed_q4_0)
case FA_TYPE_Q4_1: FA_DEQUANT4_Q4_1(v_packed_q4_1)
case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(v_packed_q5_0)
case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(v_packed_q5_1)
case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(v_packed_q8_0)
}
}
return FLOAT_TYPEV4(0);
}

View File

@@ -1,149 +1,203 @@
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
// MMQ K-side helpers, asymmetric form. Each function dispatches on FaTypeK and
// reads from the matching aliased K binding declared in flash_attn_dequant.glsl.
// Spec-constant specialization folds the unused paths.
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
#ifdef DATA_A_Q4_0
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
#else
uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4];
#endif
uint shift = (iqs & 0x10) >> 2;
vui >>= shift;
return int32_t(vui & 0x0F0F0F0F);
switch (FaTypeK) {
case FA_TYPE_Q4_0: {
uint vui = pack32(u16vec2(k_packed_q4_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
k_packed_q4_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
uint shift = (iqs & 0x10) >> 2;
vui >>= shift;
return int32_t(vui & 0x0F0F0F0F);
}
case FA_TYPE_Q4_1: { // uses packed32 alias
uint vui = k_packed_q4_1_p32.data[a_offset + ib].qs[(iqs & 0xF) / 4];
uint shift = (iqs & 0x10) >> 2;
vui >>= shift;
return int32_t(vui & 0x0F0F0F0F);
}
case FA_TYPE_Q5_0: {
uint vui = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
k_packed_q5_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
uint qh = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qh[0],
k_packed_q5_0.data[a_offset + ib].qh[1]));
uint shift = (iqs & 0x10) >> 2;
vui >>= shift;
uint qh_bits = (qh >> iqs) & 0xF;
return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
}
case FA_TYPE_Q5_1: { // qs via packed32, qh via packed16
uint vui = k_packed_q5_1_p32.data[a_offset + ib].qs[(iqs & 0xF) / 4];
uint qh = k_packed_q5_1.data[a_offset + ib].qh;
uint shift = (iqs & 0x10) >> 2;
vui >>= shift;
uint qh_bits = (qh >> iqs) & 0xF;
return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
}
case FA_TYPE_Q8_0: {
return pack32(i16vec2(k_packed_q8_0.data[a_offset + ib].qs[iqs / 2],
k_packed_q8_0.data[a_offset + ib].qs[iqs / 2 + 1]));
}
default: return 0;
}
}
#endif
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
#ifdef DATA_A_Q5_0
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
uint qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qh[0],
k_packed.k_data_packed16[a_offset + ib].qh[1]));
#else
uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4];
uint qh = k_packed.k_data_packed16[a_offset + ib].qh;
#endif
uint shift = (iqs & 0x10) >> 2;
vui >>= shift;
uint qh_bits = (qh >> iqs) & 0xF;
return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
// Per-block scale/min, packed as (d, m). Single-scale types (Q4_0, Q5_0, Q8_0)
// return (d, 0) so call sites always see the same shape.
FLOAT_TYPEV2 get_k_scale(uint ib, uint a_offset) {
switch (FaTypeK) {
case FA_TYPE_Q4_0: return FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q4_0.data[a_offset + ib].d), 0.0);
case FA_TYPE_Q4_1: return FLOAT_TYPEV2(k_packed_q4_1_p32.data[a_offset + ib].dm);
case FA_TYPE_Q5_0: return FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q5_0.data[a_offset + ib].d), 0.0);
case FA_TYPE_Q5_1: return FLOAT_TYPEV2(k_packed_q5_1_p32.data[a_offset + ib].dm);
case FA_TYPE_Q8_0: return FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q8_0.data[a_offset + ib].d), 0.0);
default: return FLOAT_TYPEV2(0);
}
}
#endif
#if defined(DATA_A_Q8_0)
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
return pack32(i16vec2(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2], k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1]));
}
#endif
#if defined(DATA_A_IQ4_NL)
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
uint shift = (iqs & 0x10) >> 2;
vui >>= shift;
u8vec4 idx = unpack8(vui & 0x0F0F0F0F);
return pack32(i8vec4(kvalues_iq4nl_const[idx.x],
kvalues_iq4nl_const[idx.y],
kvalues_iq4nl_const[idx.z],
kvalues_iq4nl_const[idx.w]));
}
#endif
#if QUANT_AUXF == 1
FLOAT_TYPE get_k_d(uint ib, uint a_offset) {
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d);
}
#else
FLOAT_TYPEV2 get_k_dm(uint ib, uint a_offset) {
return FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + ib].dm);
}
#endif
void k_block_to_shmem(const uint buf_ib, const uint global_ib, const uint iqs, const uint a_offset) {
#if defined(DATA_A_Q4_0)
kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
#elif defined(DATA_A_Q4_1)
kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs];
#elif defined(DATA_A_Q5_0)
kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
if (iqs == 0) {
kblocksh[buf_ib].qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qh[0],
k_packed.k_data_packed16[a_offset + global_ib].qh[1]));
// kblocksh[].qs is int32_t for the unified MMQ struct; uint sources need
// explicit casts. The bit pattern is what we care about here -- the actual
// signed/unsigned interpretation happens downstream in the dot product.
switch (FaTypeK) {
case FA_TYPE_Q4_0: {
kblocksh[buf_ib].qs[iqs] = int32_t(pack32(u16vec2(k_packed_q4_0.data[a_offset + global_ib].qs[iqs * 2],
k_packed_q4_0.data[a_offset + global_ib].qs[iqs * 2 + 1])));
break;
}
case FA_TYPE_Q4_1: {
kblocksh[buf_ib].qs[iqs] = int32_t(k_packed_q4_1_p32.data[a_offset + global_ib].qs[iqs]);
break;
}
case FA_TYPE_Q5_0: {
kblocksh[buf_ib].qs[iqs] = int32_t(pack32(u16vec2(k_packed_q5_0.data[a_offset + global_ib].qs[iqs * 2],
k_packed_q5_0.data[a_offset + global_ib].qs[iqs * 2 + 1])));
if (iqs == 0) {
kblocksh[buf_ib].qh = pack32(u16vec2(k_packed_q5_0.data[a_offset + global_ib].qh[0],
k_packed_q5_0.data[a_offset + global_ib].qh[1]));
}
break;
}
case FA_TYPE_Q5_1: {
kblocksh[buf_ib].qs[iqs] = int32_t(k_packed_q5_1_p32.data[a_offset + global_ib].qs[iqs]);
if (iqs == 0) {
kblocksh[buf_ib].qh = k_packed_q5_1.data[a_offset + global_ib].qh;
}
break;
}
case FA_TYPE_Q8_0: {
kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed_q8_0.data[a_offset + global_ib].qs[iqs * 2],
k_packed_q8_0.data[a_offset + global_ib].qs[iqs * 2 + 1]));
break;
}
}
#elif defined(DATA_A_Q5_1)
kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs];
if (iqs == 0) {
kblocksh[buf_ib].qh = k_packed.k_data_packed16[a_offset + global_ib].qh;
}
#elif defined(DATA_A_Q8_0)
kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
#elif defined(DATA_A_IQ4_NL)
const uint qs = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
kblocksh[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_iq4nl_const[i_a0.x], kvalues_iq4nl_const[i_a0.y],
kvalues_iq4nl_const[i_a0.z], kvalues_iq4nl_const[i_a0.w]));
kblocksh[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_iq4nl_const[i_a1.x], kvalues_iq4nl_const[i_a1.y],
kvalues_iq4nl_const[i_a1.z], kvalues_iq4nl_const[i_a1.w]));
#endif
if (iqs == 0) {
#if QUANT_AUXF == 1
kblocksh[buf_ib].dm = FLOAT_TYPE(k_packed.k_data_packed16[a_offset + global_ib].d);
#else
kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + global_ib].dm);
#endif
// Q4_0/Q5_0/Q8_0 store dm.x = d; Q4_1/Q5_1 store dm = (d, m) pair.
switch (FaTypeK) {
case FA_TYPE_Q4_0: kblocksh[buf_ib].dm = FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q4_0.data[a_offset + global_ib].d), 0.0); break;
case FA_TYPE_Q4_1: kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed_q4_1_p32.data[a_offset + global_ib].dm); break;
case FA_TYPE_Q5_0: kblocksh[buf_ib].dm = FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q5_0.data[a_offset + global_ib].d), 0.0); break;
case FA_TYPE_Q5_1: kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed_q5_1_p32.data[a_offset + global_ib].dm); break;
case FA_TYPE_Q8_0: kblocksh[buf_ib].dm = FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q8_0.data[a_offset + global_ib].d), 0.0); break;
}
}
}
// d_per_step==8 hot path: read one full 32-element block worth of nibble-packed
// int32 quants. Equivalent to 8 calls to get_k_qs(ib, d*4, a_offset) but reads
// qh (Q5_*) and runs pack32 (Q4_0/Q5_0) once per block instead of per nibble
// quad. iqs is always 0 in this path (hsk4 % 8 == 0 implies block-aligned).
// Q8_0 takes the generic get_k_qs path because its qs layout (i8 pairs) doesn't
// share this nibble shape.
//
// Returned via a struct so the caller's k_quants array (sized from spec
// constants) doesn't need to match a fixed[8] out-parameter type.
struct fa_k_qs_block8 {
int32_t qs[8];
};
fa_k_qs_block8 get_k_qs_block8(uint ib, uint a_offset) {
fa_k_qs_block8 r;
uint qh = 0;
if (FaTypeK == FA_TYPE_Q5_0) {
qh = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qh[0],
k_packed_q5_0.data[a_offset + ib].qh[1]));
} else if (FaTypeK == FA_TYPE_Q5_1) {
qh = k_packed_q5_1.data[a_offset + ib].qh;
}
const bool has_qh = (FaTypeK == FA_TYPE_Q5_0) || (FaTypeK == FA_TYPE_Q5_1);
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
uint vui = 0;
switch (FaTypeK) {
case FA_TYPE_Q4_0: { // packed16
vui = pack32(u16vec2(k_packed_q4_0.data[a_offset + ib].qs[d * 2 + 0],
k_packed_q4_0.data[a_offset + ib].qs[d * 2 + 1]));
break;
}
case FA_TYPE_Q4_1: { // packed32 alias
vui = k_packed_q4_1_p32.data[a_offset + ib].qs[d];
break;
}
case FA_TYPE_Q5_0: { // packed16
vui = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qs[d * 2 + 0],
k_packed_q5_0.data[a_offset + ib].qs[d * 2 + 1]));
break;
}
case FA_TYPE_Q5_1: { // packed32 alias
vui = k_packed_q5_1_p32.data[a_offset + ib].qs[d];
break;
}
}
r.qs[d ] = int32_t( vui & 0x0F0F0F0F);
r.qs[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
if (has_qh) {
uint qh_lo = (qh >> (d * 4)) & 0xFu;
uint qh_hi = (qh >> (d * 4 + 16)) & 0xFu;
r.qs[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
r.qs[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
}
}
return r;
}
int32_t get_k_qs_shmem(const uint buf_ib, const uint pos) {
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
uint sub = pos % 4;
uint shift = ((pos % 8) >= 4) ? 4 : 0;
return int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F);
#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
uint sub = pos % 4;
uint shift = ((pos % 8) >= 4) ? 4 : 0;
int32_t result = int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F);
uint qh_bits = (kblocksh[buf_ib].qh >> (pos * 4)) & 0xF;
return result | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
#elif defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
return kblocksh[buf_ib].qs[pos];
#endif
switch (FaTypeK) {
case FA_TYPE_Q4_0:
case FA_TYPE_Q4_1: {
uint sub = pos % 4;
uint shift = ((pos % 8) >= 4) ? 4u : 0u;
return int32_t((uint(kblocksh[buf_ib].qs[sub]) >> shift) & 0x0F0F0F0Fu);
}
case FA_TYPE_Q5_0:
case FA_TYPE_Q5_1: {
uint sub = pos % 4;
uint shift = ((pos % 8) >= 4) ? 4u : 0u;
int32_t result = int32_t((uint(kblocksh[buf_ib].qs[sub]) >> shift) & 0x0F0F0F0Fu);
uint qh_bits = (kblocksh[buf_ib].qh >> (pos * 4u)) & 0xFu;
return result | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
}
case FA_TYPE_Q8_0: {
return kblocksh[buf_ib].qs[pos];
}
default: return 0;
}
}
ACC_TYPE k_dot_correction(const uint qib, const ACC_TYPEV2 k_dm) {
#if defined(DATA_A_Q4_0)
return -ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
#elif defined(DATA_A_Q5_0)
return -ACC_TYPE(16.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
#elif defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
return ACC_TYPE(Qf[qib].ds.y) * k_dm.y;
#else
return ACC_TYPE(0.0);
#endif
switch (FaTypeK) {
case FA_TYPE_Q4_0: return -ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
case FA_TYPE_Q5_0: return -ACC_TYPE(16.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
case FA_TYPE_Q4_1:
case FA_TYPE_Q5_1: return ACC_TYPE(Qf[qib].ds.y) * k_dm.y;
default: return ACC_TYPE(0.0);
}
}
void k_block_to_shmem_zero(const uint buf_ib, const uint iqs) {
kblocksh[buf_ib].qs[iqs] = 0;
#if defined(DATA_A_IQ4_NL)
kblocksh[buf_ib].qs[iqs + 4] = 0;
#endif
if (iqs == 0) {
#if QUANT_AUXF == 1
kblocksh[buf_ib].dm = FLOAT_TYPE(0.0f);
#else
kblocksh[buf_ib].dm = FLOAT_TYPEV2(0.0f);
#endif
}
}

View File

@@ -1,4 +1,13 @@
#if defined(DATA_A_Q4_0)
#if defined(FA_MMQ_MIXED)
// Mixed-K flash attention MMQ: superset cache that fits Q4_0/Q4_1/Q5_0/Q5_1/Q8_0.
// Q4_*/Q5_* only use qs[0..3] and (for Q5_*) qh. Q8_0 uses qs[0..7]. Single-scale
// types (Q4_0/Q5_0/Q8_0) leave dm.y unused.
struct block_a_cache {
int32_t qs[8];
uint32_t qh;
FLOAT_TYPEV2 dm;
};
#elif defined(DATA_A_Q4_0)
#define QUANT_R_MMQ 2
struct block_a_cache {
uint32_t qs[16/4];

View File

@@ -643,42 +643,22 @@ void process_shaders() {
if (fp16) {
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
string_to_spv("flash_attn_f32_f16_mixed", "flash_attn_cm2.comp",
string_to_spv("flash_attn_f32_f16", "flash_attn_cm2.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc);
#endif
}
for (const auto& tname : type_names) {
if (tname == "bf16") continue;
if (fp16) {
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
} else if (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "iq4_nl" || tname == "q8_0" || tname == "f32") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
}
string_to_spv("flash_attn_f32_f16", "flash_attn_cm1.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
#endif
}
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc);
} else if (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "iq4_nl" || tname == "q8_0" || tname == "f32") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (tname != "f32") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }, {"MMQ", "1"}}), fp16, false, false, f16acc, "_int8");
}
#endif
}
}
string_to_spv("flash_attn_f32_f16", "flash_attn.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
string_to_spv("flash_attn_f32_f16", "flash_attn.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"MMQ", "1"}, {"FA_MMQ_MIXED", "1"}}), fp16, false, false, f16acc, "_int8");
#endif
}
}

View File

@@ -858,6 +858,8 @@ extern "C" {
size_t n_token_capacity,
size_t * n_token_count_out);
#define LLAMA_STATE_SEQ_FLAGS_NONE 0
// for backwards-compat
#define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1

View File

@@ -2475,11 +2475,29 @@ public:
}
if (need_alloc) {
mbuf_cur = std::move(mbuf);
if (!mbuf_cur.buf || mbuf_cur.total_size != mbuf.total_size) {
mbuf_cur = std::move(mbuf);
mbuf_cur.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(mbuf_cur.ctx.get(), buft));
mbuf_cur.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(mbuf_cur.ctx.get(), buft));
LLAMA_LOG_INFO("%s: allocated '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0);
LLAMA_LOG_INFO("%s: allocated '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0);
} else {
//LLAMA_LOG_INFO("%s: reallocating tensors in '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0);
// save the old buffer and allocate the new tensors in it
auto buf = std::move(mbuf_cur.buf);
mbuf_cur = std::move(mbuf);
ggml_tallocr talloc = ggml_tallocr_new(buf.get());
for (size_t i = 0; i < mbuf_cur.org.size(); ++i) {
ggml_backend_view_init(mbuf_cur.org[i]);
ggml_tallocr_alloc(&talloc, mbuf_cur.cpy[i]);
}
mbuf_cur.buf = std::move(buf);
}
}
for (size_t i = 0; i < mbuf_cur.org.size(); ++i) {
@@ -2559,8 +2577,7 @@ public:
mbuf.org.push_back(ggml_view_1d(mbuf.ctx.get(), rinfo.tensor, n, rinfo.offset));
auto & view = mbuf.org.back();
view->buffer = rinfo.tensor->buffer;
ggml_backend_view_init(mbuf.org.back());
}
for (auto & [buft, mbuf] : mbufs_new) {

View File

@@ -195,11 +195,9 @@
| `--spec-draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 0)<br/>(env: LLAMA_ARG_SPEC_DRAFT_N_MIN) |
| `--spec-draft-p-split, --draft-p-split P` | speculative decoding split probability (default: 0.10)<br/>(env: LLAMA_ARG_SPEC_DRAFT_P_SPLIT) |
| `--spec-draft-p-min, --draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.75)<br/>(env: LLAMA_ARG_SPEC_DRAFT_P_MIN) |
| `--spec-draft-ctx-size, -cd, --ctx-size-draft N` | size of the prompt context for the draft model (default: 0, 0 = loaded from model)<br/>(env: LLAMA_ARG_SPEC_DRAFT_CTX_SIZE) |
| `--spec-draft-device, -devd, --device-draft <dev1,dev2,..>` | comma-separated list of devices to use for offloading the draft model (none = don't offload)<br/>use --list-devices to see a list of available devices |
| `--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)<br/>(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) |
| `--spec-draft-model, -md, --model-draft FNAME` | draft model for speculative decoding (default: unused)<br/>(env: LLAMA_ARG_SPEC_DRAFT_MODEL) |
| `--spec-draft-replace, --spec-replace TARGET DRAFT` | translate the string in TARGET into DRAFT if the draft model and main model are not compatible |
| `--spec-type [none\|ngram-cache\|ngram-simple\|ngram-map-k\|ngram-map-k4v\|ngram-mod]` | type of speculative decoding to use when no draft model is provided (default: none)<br/><br/>(env: LLAMA_ARG_SPEC_TYPE) |
| `--spec-ngram-mod-n-min N` | minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48) |
| `--spec-ngram-mod-n-max N` | maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64) |

View File

@@ -244,11 +244,9 @@ For the full list of features, please refer to [server's changelog](https://gith
| `--spec-draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 0)<br/>(env: LLAMA_ARG_SPEC_DRAFT_N_MIN) |
| `--spec-draft-p-split, --draft-p-split P` | speculative decoding split probability (default: 0.10)<br/>(env: LLAMA_ARG_SPEC_DRAFT_P_SPLIT) |
| `--spec-draft-p-min, --draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.75)<br/>(env: LLAMA_ARG_SPEC_DRAFT_P_MIN) |
| `--spec-draft-ctx-size, -cd, --ctx-size-draft N` | size of the prompt context for the draft model (default: 0, 0 = loaded from model)<br/>(env: LLAMA_ARG_SPEC_DRAFT_CTX_SIZE) |
| `--spec-draft-device, -devd, --device-draft <dev1,dev2,..>` | comma-separated list of devices to use for offloading the draft model (none = don't offload)<br/>use --list-devices to see a list of available devices |
| `--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)<br/>(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) |
| `--spec-draft-model, -md, --model-draft FNAME` | draft model for speculative decoding (default: unused)<br/>(env: LLAMA_ARG_SPEC_DRAFT_MODEL) |
| `--spec-draft-replace, --spec-replace TARGET DRAFT` | translate the string in TARGET into DRAFT if the draft model and main model are not compatible |
| `--spec-type [none\|ngram-cache\|ngram-simple\|ngram-map-k\|ngram-map-k4v\|ngram-mod]` | type of speculative decoding to use when no draft model is provided (default: none)<br/><br/>(env: LLAMA_ARG_SPEC_TYPE) |
| `--spec-ngram-mod-n-min N` | minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48) |
| `--spec-ngram-mod-n-max N` | maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64) |
@@ -1045,16 +1043,23 @@ If query param `?fail_on_no_slot=1` is set, this endpoint will respond with stat
This endpoint is only accessible if `--metrics` is set.
Available metrics:
- `llamacpp:prompt_tokens_total`: Number of prompt tokens processed.
- `llamacpp:tokens_predicted_total`: Number of generation tokens processed.
- `llamacpp:prompt_tokens_seconds`: Average prompt throughput in tokens/s.
- `llamacpp:predicted_tokens_seconds`: Average generation throughput in tokens/s.
- `llamacpp:kv_cache_usage_ratio`: KV-cache usage. `1` means 100 percent usage.
- `llamacpp:kv_cache_tokens`: KV-cache tokens.
- `llamacpp:requests_processing`: Number of requests processing.
- `llamacpp:requests_deferred`: Number of requests deferred.
- `llamacpp:n_tokens_max`: High watermark of the context size observed.
In *router mode* the query param `?model={model_id}` has to be set. This endpoint will respond with status code 400 `model name is missing from the request` if not set.
#### Available metrics
| Metric | Type | Description |
| ------ | ---------------------- | ----------- |
| `llamacpp:prompt_tokens_total` | Counter | Number of prompt tokens processed. |
| `llamacpp:prompt_seconds_total` | Counter | Prompt process time in seconds. |
| `llamacpp:prompt_tokens_seconds` | Gauge | Average prompt throughput in tokens/s. |
| `llamacpp:tokens_predicted_total` | Counter | Number of generation tokens processed. |
| `llamacpp:tokens_predicted_seconds_total` | Counter | Predict process time in seconds. |
| `llamacpp:predicted_tokens_seconds` | Gauge | Average generation throughput in tokens/s. |
| `llamacpp:requests_processing` | Gauge | Number of requests processing. |
| `llamacpp:requests_deferred` | Gauge | Number of requests deferred. |
| `llamacpp:n_tokens_max` | Counter | High watermark of the context size observed. |
| `llamacpp:n_decode_total` | Counter | Total Number of llama_decode() calls. |
| `llamacpp:n_busy_slots_per_decode` | Gauge | Average number of busy slots per llama_decode() call. |
### POST `/slots/{id_slot}?action=save`: Save the prompt cache of the specified slot to a file.

File diff suppressed because it is too large Load Diff

View File

@@ -76,7 +76,7 @@ json task_params::to_json(bool only_metrics) const {
{"reasoning_in_content", chat_parser_params.reasoning_in_content},
{"generation_prompt", chat_parser_params.generation_prompt},
{"samplers", samplers},
{"speculative.type", common_speculative_type_to_str(speculative.type)},
{"speculative.types", common_speculative_type_name_str(speculative.types)},
{"timings_per_token", timings_per_token},
{"post_sampling_probs", post_sampling_probs},
{"backend_sampling", sampling.backend_sampling},
@@ -133,7 +133,7 @@ json task_params::to_json(bool only_metrics) const {
{"reasoning_in_content", chat_parser_params.reasoning_in_content},
{"generation_prompt", chat_parser_params.generation_prompt},
{"samplers", samplers},
{"speculative.type", common_speculative_type_to_str(speculative.type)},
{"speculative.types", common_speculative_type_name_str(speculative.types)},
{"timings_per_token", timings_per_token},
{"post_sampling_probs", post_sampling_probs},
{"backend_sampling", sampling.backend_sampling},
@@ -296,6 +296,8 @@ task_params server_task::params_from_json_cmpl(
params.speculative = defaults.speculative;
// TODO: to keep things simple, we disable speculative parameter adjustments for now
#if 0
// TODO: for now, be able to adjust only the draft-model based speculative parameters
params.speculative.draft.n_min = json_value(data, "speculative.n_min", defaults.speculative.draft.n_min);
params.speculative.draft.n_max = json_value(data, "speculative.n_max", defaults.speculative.draft.n_max);
@@ -305,7 +307,6 @@ task_params server_task::params_from_json_cmpl(
params.speculative.draft.n_min = std::max(params.speculative.draft.n_min, 0);
params.speculative.draft.n_max = std::max(params.speculative.draft.n_max, 0);
#if 0
// for debugging and research purposes
params.speculative.type = common_speculative_type_from_name(json_value(data, "speculative.type", common_speculative_type_to_str(defaults.speculative.type)));
@@ -1981,7 +1982,7 @@ size_t server_prompt_cache::n_tokens() const {
return res;
}
server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size) {
server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size_tgt, size_t state_size_dft) {
// first check if the current state is contained fully in the cache
for (auto it = states.begin(); it != states.end(); ++it) {
const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens);
@@ -2005,11 +2006,13 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
}
}
std::vector<uint8_t> state_data;
std::vector<uint8_t> state_data_tgt;
std::vector<uint8_t> state_data_dft;
// check if we can allocate enough memory for the new state
try {
state_data.resize(state_size);
state_data_tgt.resize(state_size_tgt);
state_data_dft.resize(state_size_dft);
} catch (const std::bad_alloc & e) {
SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what());
@@ -2022,17 +2025,19 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
return nullptr;
}
auto & cur = states.emplace_back();
cur = {
states.push_back({
/*.tokens =*/ prompt.tokens.clone(),
/*.data =*/ std::move(state_data),
/*.data =*/ {
/*.main =*/ std::move(state_data_tgt),
/*.drft =*/ std::move(state_data_dft),
},
/*.checkpoints =*/ prompt.checkpoints,
};
});
return &cur;
return &states.back();
}
bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) {
bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx_tgt, llama_context * ctx_dft, int32_t id_slot) {
const int lcp_best = prompt.tokens.get_common_prefix(tokens_new);
float f_keep_best = prompt.tokens.size() > 0 ? float(lcp_best) / prompt.tokens.size() : -1.0f; // empty slot: any cache entry wins
@@ -2065,16 +2070,39 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok
if (it_best != states.end()) {
SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
const size_t size = it_best->data.size();
const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0);
if (n != size) {
SRV_WRN("failed to restore state with size %zu\n", size);
{
auto & data = it_best->data.main;
return false;
const size_t size = data.size();
const size_t n = llama_state_seq_set_data_ext(ctx_tgt, data.data(), size, id_slot, 0);
if (n != size) {
SRV_WRN("failed to restore state with size %zu\n", size);
return false;
}
data.clear();
data.shrink_to_fit();
}
it_best->data.clear();
it_best->data.shrink_to_fit();
{
auto & data = it_best->data.drft;
if (!data.empty()) {
GGML_ASSERT(ctx_dft);
const size_t size = data.size();
const size_t n = llama_state_seq_set_data_ext(ctx_dft, data.data(), size, id_slot, 0);
if (n != size) {
SRV_WRN("failed to restore state with size %zu\n", size);
return false;
}
data.clear();
data.shrink_to_fit();
}
}
prompt = std::move(*it_best);

View File

@@ -565,42 +565,29 @@ struct server_task_result_apply_lora : server_task_result {
virtual json to_json() override;
};
struct server_prompt_checkpoint {
llama_pos pos_min;
llama_pos pos_max;
int64_t n_tokens;
std::vector<uint8_t> data;
struct server_prompt_data {
std::vector<uint8_t> main;
std::vector<uint8_t> drft;
size_t size() const {
return data.size();
}
bool empty() const {
return data.empty();
}
void clear() {
pos_min = 0;
pos_max = 0;
n_tokens = 0;
data.clear();
return main.size() + drft.size();
}
};
struct server_prompt {
server_tokens tokens;
std::vector<uint8_t> data;
server_prompt_data data;
std::list<server_prompt_checkpoint> checkpoints;
std::list<common_prompt_checkpoint> checkpoints;
size_t size() const {
size_t res = data.size();
size_t res = 0;
for (const auto & checkpoint : checkpoints) {
res += checkpoint.size();
res += data.size();
for (const auto & ckpt : checkpoints) {
res += ckpt.size();
}
return res;
@@ -614,7 +601,7 @@ struct server_prompt {
return server_prompt {
tokens.clone(),
data,
checkpoints
checkpoints,
};
}
};
@@ -637,9 +624,9 @@ struct server_prompt_cache {
size_t n_tokens() const;
server_prompt * alloc(const server_prompt & prompt, size_t state_size);
server_prompt * alloc(const server_prompt & prompt, size_t state_size_main, size_t state_size_drft);
bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot);
bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx_main, llama_context * ctx_drft, int32_t id_slot);
void update();
};

View File

@@ -5,7 +5,7 @@ from utils import *
server = ServerPreset.stories15m_moe()
MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf"
MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/tiny-llamas/resolve/main/stories15M-q4_0.gguf"
def create_server():
global server