mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-04-24 20:09:44 +02:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
13d36cf891 | ||
|
|
f65bc34c68 | ||
|
|
15fa3c493b | ||
|
|
dc80c5252a | ||
|
|
e583f3b4f5 |
@@ -106,10 +106,16 @@ struct statement {
|
||||
size_t pos; // position in source, for debugging
|
||||
virtual ~statement() = default;
|
||||
virtual std::string type() const { return "Statement"; }
|
||||
|
||||
// execute_impl must be overridden by derived classes
|
||||
virtual value execute_impl(context &) { throw std::runtime_error("cannot exec " + type()); }
|
||||
virtual value execute_impl(context &) { throw_exec_error(); }
|
||||
// execute is the public method to execute a statement with error handling
|
||||
value execute(context &);
|
||||
|
||||
private:
|
||||
[[noreturn]] void throw_exec_error() const {
|
||||
throw std::runtime_error("cannot exec " + type());
|
||||
}
|
||||
};
|
||||
|
||||
// Type Checking Utilities
|
||||
@@ -143,7 +149,7 @@ struct program : public statement {
|
||||
program() = default;
|
||||
explicit program(statements && body) : body(std::move(body)) {}
|
||||
std::string type() const override { return "Program"; }
|
||||
value execute_impl(context &) override {
|
||||
[[noreturn]] value execute_impl(context &) override {
|
||||
throw std::runtime_error("Cannot execute program directly, use jinja::runtime instead");
|
||||
}
|
||||
};
|
||||
@@ -195,7 +201,7 @@ struct break_statement : public statement {
|
||||
}
|
||||
};
|
||||
|
||||
value execute_impl(context &) override {
|
||||
[[noreturn]] value execute_impl(context &) override {
|
||||
throw break_statement::signal();
|
||||
}
|
||||
};
|
||||
@@ -209,7 +215,7 @@ struct continue_statement : public statement {
|
||||
}
|
||||
};
|
||||
|
||||
value execute_impl(context &) override {
|
||||
[[noreturn]] value execute_impl(context &) override {
|
||||
throw continue_statement::signal();
|
||||
}
|
||||
};
|
||||
@@ -509,7 +515,7 @@ struct slice_expression : public expression {
|
||||
chk_type<expression>(this->step_expr);
|
||||
}
|
||||
std::string type() const override { return "SliceExpression"; }
|
||||
value execute_impl(context &) override {
|
||||
[[noreturn]] value execute_impl(context &) override {
|
||||
throw std::runtime_error("must be handled by MemberExpression");
|
||||
}
|
||||
};
|
||||
|
||||
@@ -590,6 +590,10 @@ static bool string_endswith(const std::string & str, const std::string & suffix)
|
||||
return str.compare(str.length() - suffix.length(), suffix.length(), suffix) == 0;
|
||||
}
|
||||
|
||||
[[noreturn]] static value string_join_not_implemented(const func_args &) {
|
||||
throw not_implemented_exception("String join builtin not implemented");
|
||||
}
|
||||
|
||||
const func_builtins & value_string_t::get_builtins() const {
|
||||
static const func_builtins builtins = {
|
||||
{"default", default_value},
|
||||
@@ -851,9 +855,7 @@ const func_builtins & value_string_t::get_builtins() const {
|
||||
res->val_str.mark_input_based_on(val_input->as_string());
|
||||
return res;
|
||||
}},
|
||||
{"join", [](const func_args &) -> value {
|
||||
throw not_implemented_exception("String join builtin not implemented");
|
||||
}},
|
||||
{"join", string_join_not_implemented},
|
||||
};
|
||||
return builtins;
|
||||
}
|
||||
@@ -884,6 +886,9 @@ const func_builtins & value_bool_t::get_builtins() const {
|
||||
return builtins;
|
||||
}
|
||||
|
||||
[[noreturn]] static value array_unique_not_implemented(const func_args &) {
|
||||
throw not_implemented_exception("Array unique builtin not implemented");
|
||||
}
|
||||
|
||||
const func_builtins & value_array_t::get_builtins() const {
|
||||
static const func_builtins builtins = {
|
||||
@@ -1084,13 +1089,14 @@ const func_builtins & value_array_t::get_builtins() const {
|
||||
std::reverse(arr.begin(), arr.end());
|
||||
return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(arr)) : mk_val<value_array>(std::move(arr));
|
||||
}},
|
||||
{"unique", [](const func_args &) -> value {
|
||||
throw not_implemented_exception("Array unique builtin not implemented");
|
||||
}},
|
||||
{"unique", array_unique_not_implemented},
|
||||
};
|
||||
return builtins;
|
||||
}
|
||||
|
||||
[[noreturn]] static value object_join_not_implemented(const func_args &) {
|
||||
throw not_implemented_exception("object join not implemented");
|
||||
}
|
||||
|
||||
const func_builtins & value_object_t::get_builtins() const {
|
||||
if (!has_builtins) {
|
||||
@@ -1183,9 +1189,7 @@ const func_builtins & value_object_t::get_builtins() const {
|
||||
});
|
||||
return result;
|
||||
}},
|
||||
{"join", [](const func_args &) -> value {
|
||||
throw not_implemented_exception("object join not implemented");
|
||||
}},
|
||||
{"join", object_join_not_implemented},
|
||||
};
|
||||
return builtins;
|
||||
}
|
||||
|
||||
@@ -129,27 +129,25 @@ struct value_t {
|
||||
// Note: only for debugging and error reporting purposes
|
||||
virtual std::string type() const { return ""; }
|
||||
|
||||
virtual int64_t as_int() const { throw std::runtime_error(type() + " is not an int value"); }
|
||||
virtual double as_float() const { throw std::runtime_error(type() + " is not a float value"); }
|
||||
virtual string as_string() const { throw std::runtime_error(type() + " is not a string value"); }
|
||||
virtual bool as_bool() const { throw std::runtime_error(type() + " is not a bool value"); }
|
||||
virtual const std::vector<value> & as_array() const { throw std::runtime_error(type() + " is not an array value"); }
|
||||
virtual const std::vector<std::pair<value, value>> & as_ordered_object() const { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual value invoke(const func_args &) const { throw std::runtime_error(type() + " is not a function value"); }
|
||||
virtual int64_t as_int() const { throw_type_error("is not an int value"); }
|
||||
virtual double as_float() const { throw_type_error("is not a float value"); }
|
||||
virtual string as_string() const { throw_type_error("is not a string value"); }
|
||||
virtual bool as_bool() const { throw_type_error("is not a bool value"); }
|
||||
virtual const std::vector<value> & as_array() const { throw_type_error("is not an array value"); }
|
||||
virtual const std::vector<std::pair<value, value>> & as_ordered_object() const { throw_type_error("is not an object value"); }
|
||||
virtual value invoke(const func_args &) const { throw_type_error("is not a function value"); }
|
||||
virtual bool is_none() const { return false; }
|
||||
virtual bool is_undefined() const { return false; }
|
||||
virtual const func_builtins & get_builtins() const {
|
||||
throw std::runtime_error("No builtins available for type " + type());
|
||||
}
|
||||
virtual const func_builtins & get_builtins() const { throw_type_error("has no builtins"); }
|
||||
|
||||
virtual bool has_key(const value &) { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual void insert(const value & /* key */, const value & /* val */) { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual value & at(const value & /* key */, value & /* default_val */) { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual value & at(const value & /* key */) { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual value & at(const std::string & /* key */, value & /* default_val */) { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual value & at(const std::string & /* key */) { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual value & at(int64_t /* idx */, value & /* default_val */) { throw std::runtime_error(type() + " is not an array value"); }
|
||||
virtual value & at(int64_t /* idx */) { throw std::runtime_error(type() + " is not an array value"); }
|
||||
virtual bool has_key(const value &) { throw_type_error("is not an object value"); }
|
||||
virtual void insert(const value & /* key */, const value & /* val */) { throw_type_error("is not an object value"); }
|
||||
virtual value & at(const value & /* key */, value & /* default_val */) { throw_type_error("is not an object value"); }
|
||||
virtual value & at(const value & /* key */) { throw_type_error("is not an object value"); }
|
||||
virtual value & at(const std::string & /* key */, value & /* default_val */) { throw_type_error("is not an object value"); }
|
||||
virtual value & at(const std::string & /* key */) { throw_type_error("is not an object value"); }
|
||||
virtual value & at(int64_t /* idx */, value & /* default_val */) { throw_type_error("is not an array value"); }
|
||||
virtual value & at(int64_t /* idx */) { throw_type_error("is not an array value"); }
|
||||
|
||||
virtual bool is_numeric() const { return false; }
|
||||
virtual bool is_hashable() const { return false; }
|
||||
@@ -163,6 +161,11 @@ struct value_t {
|
||||
// Note: only for debugging purposes
|
||||
virtual std::string as_repr() const { return as_string().str(); }
|
||||
|
||||
private:
|
||||
[[noreturn]] void throw_type_error(const char* expected) const {
|
||||
throw std::runtime_error(type() + " " + expected);
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual bool equivalent(const value_t &) const = 0;
|
||||
virtual bool nonequal(const value_t & other) const { return !equivalent(other); }
|
||||
|
||||
@@ -8,7 +8,7 @@ CatalogFile = libggml-htp.cat
|
||||
PnpLockDown = 1
|
||||
|
||||
[DestinationDirs]
|
||||
Drivers_Dir = 6
|
||||
Drivers_Dir = 13
|
||||
|
||||
[SourceDisksNames]
|
||||
1 = %DiskId%
|
||||
|
||||
@@ -814,7 +814,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) {
|
||||
}
|
||||
|
||||
// print MTL GPU family:
|
||||
GGML_LOG_INFO("%s: GPU name: %s\n", __func__, dev->props.name);
|
||||
GGML_LOG_INFO("%s: GPU name: %s (%s)\n", __func__, dev->props.name, dev->props.desc);
|
||||
|
||||
// determine max supported GPU family
|
||||
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
||||
|
||||
@@ -436,19 +436,27 @@ struct ggml_webgpu_unary_pipeline_key_hash {
|
||||
|
||||
/** FlashAttention */
|
||||
|
||||
enum ggml_webgpu_flash_attn_path : uint32_t {
|
||||
GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX = 0u,
|
||||
GGML_WEBGPU_FLASH_ATTN_PATH_TILE = 1u,
|
||||
GGML_WEBGPU_FLASH_ATTN_PATH_VEC = 2u,
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_pipeline_key {
|
||||
ggml_type kv_type;
|
||||
uint32_t head_dim_qk;
|
||||
uint32_t head_dim_v;
|
||||
bool kv_direct;
|
||||
bool kv_overlap;
|
||||
bool has_mask;
|
||||
bool has_sinks;
|
||||
bool uses_logit_softcap;
|
||||
uint32_t path;
|
||||
|
||||
bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
|
||||
return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
|
||||
kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&
|
||||
uses_logit_softcap == other.uses_logit_softcap;
|
||||
kv_direct == other.kv_direct && kv_overlap == other.kv_overlap && has_mask == other.has_mask &&
|
||||
has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap && path == other.path;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -459,39 +467,70 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash {
|
||||
ggml_webgpu_hash_combine(seed, key.head_dim_qk);
|
||||
ggml_webgpu_hash_combine(seed, key.head_dim_v);
|
||||
ggml_webgpu_hash_combine(seed, key.kv_direct);
|
||||
ggml_webgpu_hash_combine(seed, key.kv_overlap);
|
||||
ggml_webgpu_hash_combine(seed, key.has_mask);
|
||||
ggml_webgpu_hash_combine(seed, key.has_sinks);
|
||||
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
|
||||
ggml_webgpu_hash_combine(seed, key.path);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_decisions {
|
||||
uint32_t q_tile = 0;
|
||||
uint32_t kv_tile = 0;
|
||||
uint32_t wg_size = 0;
|
||||
uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX;
|
||||
uint32_t q_tile = 0;
|
||||
uint32_t kv_tile = 0;
|
||||
uint32_t wg_size = 0;
|
||||
bool kv_direct = false;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_vec_decisions {
|
||||
uint32_t kv_tile = 0;
|
||||
uint32_t wg_size = 0;
|
||||
};
|
||||
inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u;
|
||||
inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE = 4u;
|
||||
|
||||
inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) {
|
||||
if (key.path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC || key.kv_type != GGML_TYPE_F16 ||
|
||||
key.head_dim_qk != key.head_dim_v) {
|
||||
return 1u;
|
||||
}
|
||||
|
||||
switch (key.head_dim_qk) {
|
||||
case 64:
|
||||
case 192:
|
||||
case 576:
|
||||
return 2u;
|
||||
case 96:
|
||||
return 4u;
|
||||
default:
|
||||
return 1u;
|
||||
}
|
||||
}
|
||||
|
||||
inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key(
|
||||
const ggml_webgpu_shader_lib_context & context) {
|
||||
const ggml_webgpu_shader_lib_context & context,
|
||||
uint32_t path) {
|
||||
const bool has_mask = context.src3 != nullptr;
|
||||
const bool has_sinks = context.src4 != nullptr;
|
||||
const bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) &&
|
||||
(context.src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
|
||||
bool kv_direct = false;
|
||||
if (path != GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
|
||||
uint32_t kv_direct_align = GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH;
|
||||
if (path == GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) {
|
||||
kv_direct_align = context.sg_mat_k;
|
||||
}
|
||||
kv_direct = (context.src1->type == GGML_TYPE_F16) &&
|
||||
(context.src0->ne[0] % std::max(1u, kv_direct_align) == 0) &&
|
||||
(context.src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
|
||||
}
|
||||
|
||||
ggml_webgpu_flash_attn_pipeline_key key = {};
|
||||
key.kv_type = context.src1->type;
|
||||
key.head_dim_qk = (uint32_t) context.src0->ne[0];
|
||||
key.head_dim_v = (uint32_t) context.src2->ne[0];
|
||||
key.kv_direct = kv_direct;
|
||||
key.kv_overlap = context.src_overlap;
|
||||
key.has_mask = has_mask;
|
||||
key.has_sinks = has_sinks;
|
||||
key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f;
|
||||
key.path = path;
|
||||
return key;
|
||||
}
|
||||
|
||||
@@ -554,8 +593,16 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
|
||||
|
||||
inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_context & context,
|
||||
const ggml_webgpu_flash_attn_pipeline_key & key) {
|
||||
const size_t limit_bytes = context.wg_mem_limit_bytes;
|
||||
const size_t q_tile = context.sg_mat_m;
|
||||
const size_t limit_bytes = context.wg_mem_limit_bytes;
|
||||
uint32_t q_tile = context.sg_mat_m;
|
||||
uint32_t kv_granularity = context.sg_mat_n;
|
||||
if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
|
||||
q_tile = GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE;
|
||||
kv_granularity = std::max(1u, context.max_subgroup_size);
|
||||
} else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
q_tile = 1u;
|
||||
kv_granularity = 8u;
|
||||
}
|
||||
const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
|
||||
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
|
||||
size_t bytes_per_kv = 0;
|
||||
@@ -568,23 +615,90 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_
|
||||
bytes_per_kv += q_tile;
|
||||
bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
|
||||
const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
|
||||
return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n;
|
||||
return (max_kv_tile / kv_granularity) * kv_granularity;
|
||||
}
|
||||
|
||||
inline uint32_t ggml_webgpu_flash_attn_vec_get_kv_tile(const ggml_webgpu_shader_lib_context & context) {
|
||||
const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context);
|
||||
const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key);
|
||||
uint32_t kv_tile = std::max(context.sg_mat_n, std::min(32u, min_kv_tile));
|
||||
kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n;
|
||||
inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
|
||||
const ggml_webgpu_shader_lib_context & context,
|
||||
size_t storage_offset_alignment) {
|
||||
ggml_webgpu_flash_attn_decisions decisions = {};
|
||||
const size_t alignment = std::max<size_t>(1u, storage_offset_alignment);
|
||||
const auto * K = context.src1;
|
||||
const auto * V = context.src2;
|
||||
GGML_ASSERT(K != nullptr);
|
||||
GGML_ASSERT(V != nullptr);
|
||||
|
||||
if (key.kv_direct) {
|
||||
kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
|
||||
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
|
||||
kv_tile -= context.sg_mat_n;
|
||||
const auto flash_attn_tensor_offset = [](const ggml_tensor * tensor) -> size_t {
|
||||
constexpr uintptr_t ptr_base_addr = 0x1000u;
|
||||
const ggml_tensor * base = tensor->view_src != nullptr ? tensor->view_src : tensor;
|
||||
return reinterpret_cast<uintptr_t>(base->data) - ptr_base_addr + tensor->view_offs;
|
||||
};
|
||||
|
||||
const uint32_t k_offset_elems =
|
||||
(uint32_t) ((flash_attn_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type));
|
||||
const uint32_t v_offset_elems =
|
||||
(uint32_t) ((flash_attn_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type));
|
||||
const bool f16_vec4_aligned = (k_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u) &&
|
||||
(v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u);
|
||||
const bool kv_vec_type_supported =
|
||||
K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
|
||||
const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) &&
|
||||
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
|
||||
kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) &&
|
||||
(context.src2->type == K->type);
|
||||
const bool use_tile = context.supports_subgroups && !context.supports_subgroup_matrix && K->type == GGML_TYPE_F16 &&
|
||||
V->type == GGML_TYPE_F16 && f16_vec4_aligned &&
|
||||
(context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
|
||||
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && !use_vec;
|
||||
|
||||
decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC :
|
||||
use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE :
|
||||
GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX;
|
||||
|
||||
const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path);
|
||||
decisions.kv_direct = key.kv_direct;
|
||||
|
||||
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key);
|
||||
decisions.q_tile = 1u;
|
||||
decisions.kv_tile = std::max(8u, std::min(32u, min_kv_tile));
|
||||
decisions.kv_tile = (decisions.kv_tile / 8u) * 8u;
|
||||
decisions.wg_size = std::max(1u, std::min<uint32_t>(32u, context.max_subgroup_size));
|
||||
if (decisions.kv_direct) {
|
||||
decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
|
||||
while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) {
|
||||
decisions.kv_tile -= 8u;
|
||||
}
|
||||
}
|
||||
return decisions;
|
||||
}
|
||||
|
||||
decisions.q_tile =
|
||||
decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE : context.sg_mat_m;
|
||||
decisions.kv_tile = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ?
|
||||
std::min(64u, ggml_webgpu_flash_attn_max_kv_tile(context, key)) :
|
||||
std::min(ggml_webgpu_flash_attn_max_kv_tile(context, key),
|
||||
context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
|
||||
decisions.wg_size = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ?
|
||||
GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE :
|
||||
std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
|
||||
|
||||
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
|
||||
const uint32_t tile_kv_granularity = std::max(1u, context.max_subgroup_size);
|
||||
decisions.kv_tile =
|
||||
std::max(tile_kv_granularity, (decisions.kv_tile / tile_kv_granularity) * tile_kv_granularity);
|
||||
}
|
||||
|
||||
if (decisions.kv_direct) {
|
||||
GGML_ASSERT(decisions.kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
|
||||
while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) {
|
||||
decisions.kv_tile -= decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ?
|
||||
std::max(1u, context.max_subgroup_size) :
|
||||
context.sg_mat_n;
|
||||
}
|
||||
}
|
||||
|
||||
return kv_tile;
|
||||
return decisions;
|
||||
}
|
||||
|
||||
/** Matrix Multiplication **/
|
||||
@@ -821,8 +935,6 @@ class ggml_webgpu_shader_lib {
|
||||
repeat_pipelines; // type
|
||||
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
|
||||
flash_attn_pipelines;
|
||||
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
|
||||
flash_attn_vec_pipelines;
|
||||
std::unordered_map<ggml_webgpu_flash_attn_vec_reduce_pipeline_key,
|
||||
webgpu_pipeline,
|
||||
ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash>
|
||||
@@ -2044,14 +2156,19 @@ class ggml_webgpu_shader_lib {
|
||||
return repeat_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context);
|
||||
auto it = flash_attn_pipelines.find(key);
|
||||
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context,
|
||||
size_t storage_offset_alignment) {
|
||||
const ggml_webgpu_flash_attn_decisions decisions =
|
||||
ggml_webgpu_flash_attn_get_decisions(context, storage_offset_alignment);
|
||||
ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path);
|
||||
auto it = flash_attn_pipelines.find(key);
|
||||
if (it != flash_attn_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "flash_attn";
|
||||
std::string variant = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC ? "flash_attn_vec" :
|
||||
decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? "flash_attn_tile" :
|
||||
"flash_attn";
|
||||
|
||||
switch (key.kv_type) {
|
||||
case GGML_TYPE_F32:
|
||||
@@ -2073,7 +2190,12 @@ class ggml_webgpu_shader_lib {
|
||||
|
||||
if (key.has_mask) {
|
||||
defines.push_back("MASK");
|
||||
variant += "_mask";
|
||||
if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
defines.push_back("BLK");
|
||||
variant += "_mask_blk";
|
||||
} else {
|
||||
variant += "_mask";
|
||||
}
|
||||
}
|
||||
if (key.has_sinks) {
|
||||
defines.push_back("SINKS");
|
||||
@@ -2087,6 +2209,10 @@ class ggml_webgpu_shader_lib {
|
||||
defines.push_back("KV_DIRECT");
|
||||
variant += "_kvdirect";
|
||||
}
|
||||
if (key.kv_overlap) {
|
||||
defines.push_back("KV_OVERLAP");
|
||||
variant += "_kv_overlap";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
|
||||
variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
|
||||
@@ -2094,129 +2220,37 @@ class ggml_webgpu_shader_lib {
|
||||
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
|
||||
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
|
||||
|
||||
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
|
||||
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
|
||||
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
|
||||
|
||||
auto decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>();
|
||||
decisions->q_tile = context.sg_mat_m;
|
||||
|
||||
const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key);
|
||||
uint32_t kv_tile = std::min(min_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
|
||||
|
||||
if (key.kv_direct) {
|
||||
kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
|
||||
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
|
||||
kv_tile -= context.sg_mat_n;
|
||||
}
|
||||
const char * shader_src = wgsl_flash_attn;
|
||||
if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
defines.push_back("KV_GRANULARITY=8");
|
||||
defines.push_back(std::string("VEC_NE=") + std::to_string(ggml_webgpu_flash_attn_pick_vec_ne(key)) + "u");
|
||||
shader_src = wgsl_flash_attn_vec_split;
|
||||
} else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
|
||||
shader_src = wgsl_flash_attn_tile;
|
||||
defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size));
|
||||
defines.push_back("KV_STAGE_STRIDE=" + std::to_string(std::max(key.head_dim_qk, key.head_dim_v)));
|
||||
variant += "_tile";
|
||||
} else {
|
||||
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
|
||||
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
|
||||
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
|
||||
}
|
||||
|
||||
decisions->kv_tile = kv_tile;
|
||||
decisions->wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
|
||||
|
||||
defines.push_back(std::string("Q_TILE=") + std::to_string(decisions->q_tile));
|
||||
defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile));
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size));
|
||||
auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>(decisions);
|
||||
defines.push_back(std::string("Q_TILE=") + std::to_string(decisions.q_tile));
|
||||
defines.push_back(std::string("KV_TILE=") + std::to_string(decisions.kv_tile));
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions.wg_size));
|
||||
|
||||
webgpu_pipeline pipeline =
|
||||
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn, defines), variant);
|
||||
pipeline.context = decisions;
|
||||
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant);
|
||||
pipeline.context = pipeline_decisions;
|
||||
flash_attn_pipelines[key] = pipeline;
|
||||
return flash_attn_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context);
|
||||
auto it = flash_attn_vec_pipelines.find(key);
|
||||
if (it != flash_attn_vec_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "flash_attn_vec";
|
||||
|
||||
switch (key.kv_type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("KV_F32");
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("KV_F16");
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
defines.push_back("KV_Q4_0");
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
defines.push_back("KV_Q8_0");
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported KV type for flash attention shader");
|
||||
}
|
||||
variant += std::string("_") + ggml_type_name(key.kv_type);
|
||||
|
||||
if (key.has_mask) {
|
||||
defines.push_back("MASK");
|
||||
defines.push_back("BLK");
|
||||
variant += "_mask_blk";
|
||||
}
|
||||
if (key.has_sinks) {
|
||||
defines.push_back("SINKS");
|
||||
variant += "_sinks";
|
||||
}
|
||||
if (key.uses_logit_softcap) {
|
||||
defines.push_back("LOGIT_SOFTCAP");
|
||||
variant += "_lgsc";
|
||||
}
|
||||
if (key.kv_direct) {
|
||||
defines.push_back("KV_DIRECT");
|
||||
variant += "_kvdirect";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
|
||||
variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
|
||||
|
||||
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
|
||||
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
|
||||
|
||||
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
|
||||
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
|
||||
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
|
||||
defines.push_back("Q_TILE=1");
|
||||
|
||||
auto decisions = std::make_shared<ggml_webgpu_flash_attn_vec_decisions>();
|
||||
decisions->kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(context);
|
||||
decisions->wg_size = std::max(1u, std::min<uint32_t>(32u, context.max_subgroup_size));
|
||||
uint32_t vec_ne = 1u;
|
||||
|
||||
// Keep conservative defaults unless this is the f16 vec-split shape family.
|
||||
if (key.kv_type == GGML_TYPE_F16 && key.head_dim_qk == key.head_dim_v) {
|
||||
switch (key.head_dim_qk) {
|
||||
case 64:
|
||||
case 192:
|
||||
case 576:
|
||||
vec_ne = 2u;
|
||||
break;
|
||||
case 96:
|
||||
vec_ne = 4u;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile));
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size));
|
||||
defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u");
|
||||
|
||||
webgpu_pipeline pipeline =
|
||||
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_split, defines), variant);
|
||||
pipeline.context = decisions;
|
||||
flash_attn_vec_pipelines[key] = pipeline;
|
||||
return flash_attn_vec_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context, uint32_t kv_tile) {
|
||||
ggml_webgpu_flash_attn_blk_pipeline_key key = {};
|
||||
key.kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(context);
|
||||
key.kv_tile = kv_tile;
|
||||
auto it = flash_attn_blk_pipelines.find(key);
|
||||
if (it != flash_attn_blk_pipelines.end()) {
|
||||
return it->second;
|
||||
|
||||
@@ -389,23 +389,6 @@ static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_t
|
||||
return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
|
||||
}
|
||||
|
||||
static bool ggml_webgpu_flash_attn_use_vec(webgpu_global_context & global_ctx,
|
||||
const ggml_tensor * Q,
|
||||
const ggml_tensor * K,
|
||||
const ggml_tensor * V) {
|
||||
const size_t alignment = global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
|
||||
const uint32_t k_offset_elems =
|
||||
(uint32_t) ((ggml_webgpu_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type));
|
||||
const uint32_t v_offset_elems =
|
||||
(uint32_t) ((ggml_webgpu_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type));
|
||||
const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u);
|
||||
const bool kv_vec_type_supported =
|
||||
K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
|
||||
|
||||
return (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported &&
|
||||
(K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type);
|
||||
}
|
||||
|
||||
static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) {
|
||||
size_t offset = ggml_webgpu_tensor_offset(t);
|
||||
return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
|
||||
@@ -1567,7 +1550,6 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx,
|
||||
return ggml_backend_webgpu_build_multi(ctx, dispatches);
|
||||
}
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
ggml_tensor * Q,
|
||||
ggml_tensor * K,
|
||||
@@ -1585,13 +1567,29 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
|
||||
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
const int has_mask = (mask != nullptr);
|
||||
const int has_sinks = (sinks != nullptr);
|
||||
const int has_mask = (mask != nullptr);
|
||||
const int has_sinks = (sinks != nullptr);
|
||||
const bool kv_overlap = ggml_webgpu_tensor_overlap(K, V) && K->type == V->type;
|
||||
|
||||
uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type));
|
||||
uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type));
|
||||
size_t kv_bind_offset = 0;
|
||||
size_t kv_bind_size = 0;
|
||||
if (kv_overlap) {
|
||||
const size_t k_bind_offset = ggml_webgpu_tensor_align_offset(ctx, K);
|
||||
const size_t v_bind_offset = ggml_webgpu_tensor_align_offset(ctx, V);
|
||||
const size_t k_bind_end = k_bind_offset + ggml_webgpu_tensor_binding_size(ctx, K);
|
||||
const size_t v_bind_end = v_bind_offset + ggml_webgpu_tensor_binding_size(ctx, V);
|
||||
kv_bind_offset = std::min(k_bind_offset, v_bind_offset);
|
||||
kv_bind_size = std::max(k_bind_end, v_bind_end) - kv_bind_offset;
|
||||
offset_k = (uint32_t) ((ggml_webgpu_tensor_offset(K) - kv_bind_offset) / ggml_type_size(K->type));
|
||||
offset_v = (uint32_t) ((ggml_webgpu_tensor_offset(V) - kv_bind_offset) / ggml_type_size(V->type));
|
||||
}
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)),
|
||||
offset_k,
|
||||
offset_v,
|
||||
has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0,
|
||||
has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0,
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
@@ -1619,10 +1617,15 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
};
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, Q),
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K),
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V),
|
||||
};
|
||||
uint32_t binding_index = 3;
|
||||
if (kv_overlap) {
|
||||
entries.push_back(
|
||||
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), kv_bind_offset, kv_bind_size));
|
||||
} else {
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V));
|
||||
}
|
||||
uint32_t binding_index = kv_overlap ? 2u : 3u;
|
||||
if (has_mask) {
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, mask));
|
||||
}
|
||||
@@ -1638,25 +1641,25 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
shader_lib_ctx.src3 = mask;
|
||||
shader_lib_ctx.src4 = sinks;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.src_overlap = kv_overlap;
|
||||
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
|
||||
shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
|
||||
shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
|
||||
shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
|
||||
shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;
|
||||
const bool use_vec = ggml_webgpu_flash_attn_use_vec(ctx->global_ctx, Q, K, V);
|
||||
webgpu_pipeline pipeline = use_vec ? ctx->shader_lib->get_flash_attn_vec_pipeline(shader_lib_ctx) :
|
||||
ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx);
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(
|
||||
shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||
auto * decisions = static_cast<ggml_webgpu_flash_attn_decisions *>(pipeline.context.get());
|
||||
|
||||
if (!use_vec) {
|
||||
auto * decisions = static_cast<ggml_webgpu_flash_attn_decisions *>(pipeline.context.get());
|
||||
if (decisions->path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile);
|
||||
uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_flash_attn_vec_decisions *>(pipeline.context.get());
|
||||
|
||||
wgpu::Buffer blk_buf = {};
|
||||
uint64_t blk_size_bytes = 0;
|
||||
uint32_t blk_nblk0 = 0;
|
||||
@@ -1695,10 +1698,12 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
tmp_bind_size = tmp_size_bytes;
|
||||
scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes);
|
||||
} else {
|
||||
// nwg==1 writes final dst directly in vec-split; keep tmp binding valid without extra allocation.
|
||||
// nwg==1 writes final dst directly in vec-split; bind tmp to a tiny non-overlapping scratch region.
|
||||
tmp_size_bytes = WEBGPU_STORAGE_BUF_BINDING_MULT;
|
||||
tmp_buf = ggml_webgpu_tensor_buf(dst);
|
||||
tmp_bind_offset = ggml_webgpu_tensor_align_offset(ctx, dst);
|
||||
tmp_bind_size = ggml_webgpu_tensor_binding_size(ctx, dst);
|
||||
tmp_bind_offset = scratch_offset;
|
||||
tmp_bind_size = tmp_size_bytes;
|
||||
scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes);
|
||||
}
|
||||
|
||||
webgpu_pipeline blk_pipeline;
|
||||
@@ -1713,7 +1718,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count;
|
||||
blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
const ggml_webgpu_shader_lib_context blk_shader_ctx = shader_lib_ctx;
|
||||
blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx);
|
||||
blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx, decisions->kv_tile);
|
||||
|
||||
blk_params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask
|
||||
@@ -1745,12 +1750,19 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
std::vector<wgpu::BindGroupEntry> split_entries = {
|
||||
ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(Q), ggml_webgpu_tensor_align_offset(ctx, Q),
|
||||
ggml_webgpu_tensor_binding_size(ctx, Q)),
|
||||
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), ggml_webgpu_tensor_align_offset(ctx, K),
|
||||
ggml_webgpu_tensor_binding_size(ctx, K)),
|
||||
ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(V), ggml_webgpu_tensor_align_offset(ctx, V),
|
||||
ggml_webgpu_tensor_binding_size(ctx, V)),
|
||||
};
|
||||
uint32_t split_binding_index = 3;
|
||||
if (kv_overlap) {
|
||||
split_entries.push_back(
|
||||
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), kv_bind_offset, kv_bind_size));
|
||||
} else {
|
||||
split_entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K),
|
||||
ggml_webgpu_tensor_align_offset(ctx, K),
|
||||
ggml_webgpu_tensor_binding_size(ctx, K)));
|
||||
split_entries.push_back(ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(V),
|
||||
ggml_webgpu_tensor_align_offset(ctx, V),
|
||||
ggml_webgpu_tensor_binding_size(ctx, V)));
|
||||
}
|
||||
uint32_t split_binding_index = kv_overlap ? 2u : 3u;
|
||||
if (has_mask) {
|
||||
split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(mask),
|
||||
ggml_webgpu_tensor_align_offset(ctx, mask),
|
||||
@@ -1820,7 +1832,6 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
|
||||
return ggml_backend_webgpu_build_multi(ctx, dispatches);
|
||||
}
|
||||
#endif // __EMSCRIPTEN__
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
bool is_unary = dst->op == GGML_OP_UNARY;
|
||||
@@ -2710,11 +2721,7 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode(webgpu_context ctx,
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
return ggml_webgpu_mul_mat_id(ctx, src0, src1, src2, node);
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
#ifndef __EMSCRIPTEN__
|
||||
return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node);
|
||||
#else
|
||||
return std::nullopt;
|
||||
#endif
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
@@ -3257,13 +3264,19 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
|
||||
ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.wg_mem_limit_bytes =
|
||||
ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||
shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups;
|
||||
shader_lib_ctx.supports_subgroup_matrix =
|
||||
ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix;
|
||||
shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m;
|
||||
shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n;
|
||||
shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k;
|
||||
shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size;
|
||||
|
||||
if (ggml_webgpu_flash_attn_use_vec(ctx->webgpu_global_ctx, Q, K, V)) {
|
||||
const uint32_t kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(shader_lib_ctx);
|
||||
const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions(
|
||||
shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||
|
||||
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
const uint32_t kv_tile = decisions.kv_tile;
|
||||
|
||||
const uint32_t vec_nwg_cap = std::max(
|
||||
1u, std::min<uint32_t>(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size));
|
||||
@@ -3283,6 +3296,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
|
||||
const size_t tmp_size_bytes = ROUNDUP_POW2(
|
||||
(tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
res += tmp_size_bytes + align;
|
||||
} else {
|
||||
res += WEBGPU_STORAGE_BUF_BINDING_MULT + align;
|
||||
}
|
||||
if (mask != nullptr) {
|
||||
const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile);
|
||||
@@ -3431,12 +3446,12 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
ctx->webgpu_global_ctx->capabilities.supports_subgroups =
|
||||
ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups);
|
||||
|
||||
bool valid_subgroup_matrix_config = false;
|
||||
#ifndef __EMSCRIPTEN__
|
||||
// Accept f16 subgroup matrix configurations (square or non-square).
|
||||
// NVIDIA GPUs typically report square configs (e.g. 16x16x16),
|
||||
// while Intel Xe2 GPUs report non-square configs (e.g. 8x16x16).
|
||||
// The shaders are already parameterized to handle any M/N/K dimensions.
|
||||
bool valid_subgroup_matrix_config = false;
|
||||
if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
|
||||
for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
|
||||
const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
|
||||
@@ -3450,8 +3465,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
}
|
||||
}
|
||||
}
|
||||
ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config;
|
||||
#endif
|
||||
ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config;
|
||||
|
||||
// For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
|
||||
// Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
|
||||
@@ -3499,12 +3514,12 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
// Enable Dawn-specific toggles to increase native performance
|
||||
// TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
|
||||
// only for native performance?
|
||||
const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init",
|
||||
"disable_polyfills_on_integer_div_and_mod" };
|
||||
const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
|
||||
const char * const deviceEnabledToggles[] = { "disable_robustness", "disable_workgroup_init",
|
||||
"disable_polyfills_on_integer_div_and_mod" };
|
||||
const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
|
||||
wgpu::DawnTogglesDescriptor deviceTogglesDesc;
|
||||
deviceTogglesDesc.enabledToggles = deviceEnabledToggles;
|
||||
deviceTogglesDesc.enabledToggleCount = 4;
|
||||
deviceTogglesDesc.enabledToggleCount = 3;
|
||||
deviceTogglesDesc.disabledToggles = deviceDisabledToggles;
|
||||
deviceTogglesDesc.disabledToggleCount = 1;
|
||||
|
||||
@@ -3782,33 +3797,63 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
#ifndef __EMSCRIPTEN__
|
||||
if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
|
||||
break;
|
||||
}
|
||||
// Head dimensions must be divisible by subgroup matrix dimensions
|
||||
if (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k != 0 ||
|
||||
src2->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_n != 0) {
|
||||
break;
|
||||
}
|
||||
// Head dimensions must fit in workgroup memory with minimum tile sizes
|
||||
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||
const bool has_mask = op->src[3] != nullptr;
|
||||
const bool kv_direct = src1->type == GGML_TYPE_F16 &&
|
||||
(src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
|
||||
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
|
||||
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
|
||||
ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n,
|
||||
(uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct);
|
||||
if (min_bytes > limit_bytes) {
|
||||
break;
|
||||
}
|
||||
|
||||
supports_op = src0->type == GGML_TYPE_F32 &&
|
||||
(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 ||
|
||||
src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) &&
|
||||
src2->type == src1->type && op->type == GGML_TYPE_F32;
|
||||
#endif
|
||||
if (!supports_op) {
|
||||
break;
|
||||
}
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = src0;
|
||||
shader_lib_ctx.src1 = src1;
|
||||
shader_lib_ctx.src2 = src2;
|
||||
shader_lib_ctx.src3 = op->src[3];
|
||||
shader_lib_ctx.src4 = op->src[4];
|
||||
shader_lib_ctx.dst = const_cast<ggml_tensor *>(op);
|
||||
shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups;
|
||||
shader_lib_ctx.supports_subgroup_matrix = ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix;
|
||||
shader_lib_ctx.wg_mem_limit_bytes =
|
||||
ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||
shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m;
|
||||
shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n;
|
||||
shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k;
|
||||
shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size;
|
||||
|
||||
const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions(
|
||||
shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||
const size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||
const bool has_mask = op->src[3] != nullptr;
|
||||
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
const size_t min_bytes =
|
||||
ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0],
|
||||
(uint32_t) src2->ne[0], has_mask, decisions.kv_direct);
|
||||
if (min_bytes > limit_bytes) {
|
||||
supports_op = false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
|
||||
const size_t min_bytes =
|
||||
ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0],
|
||||
(uint32_t) src2->ne[0], has_mask, decisions.kv_direct);
|
||||
if (min_bytes > limit_bytes) {
|
||||
supports_op = false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
|
||||
supports_op = false;
|
||||
break;
|
||||
}
|
||||
const size_t min_bytes =
|
||||
ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0],
|
||||
(uint32_t) src2->ne[0], has_mask, decisions.kv_direct);
|
||||
if (min_bytes > limit_bytes) {
|
||||
supports_op = false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case GGML_OP_RMS_NORM:
|
||||
|
||||
@@ -138,25 +138,54 @@ struct Params {
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
|
||||
#ifdef KV_OVERLAP
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
|
||||
#define V K
|
||||
#else
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
|
||||
@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>;
|
||||
#endif
|
||||
|
||||
#if defined(MASK) && defined(SINKS)
|
||||
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||
@group(0) @binding(4) var<storage, read_write> sinks: array<f32>;
|
||||
#define DST_BINDING 5
|
||||
#define PARAMS_BINDING 6
|
||||
#elif defined(MASK)
|
||||
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||
#define DST_BINDING 4
|
||||
#define PARAMS_BINDING 5
|
||||
#elif defined(SINKS)
|
||||
#ifdef KV_OVERLAP
|
||||
@group(0) @binding(2) var<storage, read_write> mask: array<f16>;
|
||||
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
|
||||
#define DST_BINDING 4
|
||||
#define PARAMS_BINDING 5
|
||||
#else
|
||||
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||
@group(0) @binding(4) var<storage, read_write> sinks: array<f32>;
|
||||
#define DST_BINDING 5
|
||||
#define PARAMS_BINDING 6
|
||||
#endif
|
||||
#elif defined(MASK)
|
||||
#ifdef KV_OVERLAP
|
||||
@group(0) @binding(2) var<storage, read_write> mask: array<f16>;
|
||||
#define DST_BINDING 3
|
||||
#define PARAMS_BINDING 4
|
||||
#else
|
||||
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||
#define DST_BINDING 4
|
||||
#define PARAMS_BINDING 5
|
||||
#endif
|
||||
#elif defined(SINKS)
|
||||
#ifdef KV_OVERLAP
|
||||
@group(0) @binding(2) var<storage, read_write> sinks: array<f32>;
|
||||
#define DST_BINDING 3
|
||||
#define PARAMS_BINDING 4
|
||||
#else
|
||||
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
|
||||
#define DST_BINDING 4
|
||||
#define PARAMS_BINDING 5
|
||||
#endif
|
||||
#else
|
||||
#ifdef KV_OVERLAP
|
||||
#define DST_BINDING 2
|
||||
#define PARAMS_BINDING 3
|
||||
#else
|
||||
#define DST_BINDING 3
|
||||
#define PARAMS_BINDING 4
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
|
||||
|
||||
330
ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl
Normal file
330
ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl
Normal file
@@ -0,0 +1,330 @@
|
||||
enable f16;
|
||||
enable subgroups;
|
||||
|
||||
#define HEAD_DIM_QK 64
|
||||
#define HEAD_DIM_V 64
|
||||
#define KV_STAGE_STRIDE 64
|
||||
#define Q_TILE 4
|
||||
#define KV_TILE 64
|
||||
#define WG_SIZE 128
|
||||
|
||||
struct Params {
|
||||
offset_q: u32,
|
||||
offset_k: u32,
|
||||
offset_v: u32,
|
||||
offset_mask: u32,
|
||||
offset_sinks: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
n_heads: u32,
|
||||
seq_len_q: u32,
|
||||
seq_len_kv: u32,
|
||||
|
||||
stride_q1: u32,
|
||||
stride_q2: u32,
|
||||
stride_q3: u32,
|
||||
stride_k1: u32,
|
||||
stride_k2: u32,
|
||||
stride_k3: u32,
|
||||
stride_v1: u32,
|
||||
stride_v2: u32,
|
||||
stride_v3: u32,
|
||||
stride_mask3: u32,
|
||||
|
||||
q_per_kv: u32,
|
||||
|
||||
scale: f32,
|
||||
max_bias: f32,
|
||||
logit_softcap: f32,
|
||||
n_head_log2: f32,
|
||||
m0: f32,
|
||||
m1: f32,
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
|
||||
#ifdef KV_OVERLAP
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<vec4<f16>>;
|
||||
#define V K
|
||||
#else
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<vec4<f16>>;
|
||||
@group(0) @binding(2) var<storage, read_write> V: array<vec4<f16>>;
|
||||
#endif
|
||||
|
||||
#if defined(MASK) && defined(SINKS)
|
||||
#ifdef KV_OVERLAP
|
||||
@group(0) @binding(2) var<storage, read_write> mask: array<f16>;
|
||||
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
|
||||
#define DST_BINDING 4
|
||||
#define PARAMS_BINDING 5
|
||||
#else
|
||||
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||
@group(0) @binding(4) var<storage, read_write> sinks: array<f32>;
|
||||
#define DST_BINDING 5
|
||||
#define PARAMS_BINDING 6
|
||||
#endif
|
||||
#elif defined(MASK)
|
||||
#ifdef KV_OVERLAP
|
||||
@group(0) @binding(2) var<storage, read_write> mask: array<f16>;
|
||||
#define DST_BINDING 3
|
||||
#define PARAMS_BINDING 4
|
||||
#else
|
||||
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||
#define DST_BINDING 4
|
||||
#define PARAMS_BINDING 5
|
||||
#endif
|
||||
#elif defined(SINKS)
|
||||
#ifdef KV_OVERLAP
|
||||
@group(0) @binding(2) var<storage, read_write> sinks: array<f32>;
|
||||
#define DST_BINDING 3
|
||||
#define PARAMS_BINDING 4
|
||||
#else
|
||||
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
|
||||
#define DST_BINDING 4
|
||||
#define PARAMS_BINDING 5
|
||||
#endif
|
||||
#else
|
||||
#ifdef KV_OVERLAP
|
||||
#define DST_BINDING 2
|
||||
#define PARAMS_BINDING 3
|
||||
#else
|
||||
#define DST_BINDING 3
|
||||
#define PARAMS_BINDING 4
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
|
||||
@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;
|
||||
|
||||
const FLOAT_MIN: f32 = -1.0e9;
|
||||
const Q_CHUNKS: u32 = HEAD_DIM_QK / 4u;
|
||||
const V_CHUNKS: u32 = HEAD_DIM_V / 4u;
|
||||
const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE;
|
||||
const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE;
|
||||
|
||||
var<workgroup> q_shmem: array<f16, Q_TILE * HEAD_DIM_QK>;
|
||||
var<workgroup> kv_shmem: array<f16, KV_TILE * KV_STAGE_STRIDE>;
|
||||
var<workgroup> p_shmem: array<f32, Q_TILE * KV_TILE>;
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(subgroup_id) subgroup_id: u32,
|
||||
@builtin(subgroup_size) subgroup_size: u32,
|
||||
@builtin(num_subgroups) num_subgroups: u32,
|
||||
@builtin(subgroup_invocation_id) sg_inv_id: u32) {
|
||||
if (subgroup_size == 0u || num_subgroups < Q_TILE) {
|
||||
return;
|
||||
}
|
||||
|
||||
let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE;
|
||||
let wg_per_batch = wg_per_head * params.n_heads;
|
||||
|
||||
let dst2_stride = HEAD_DIM_V * params.n_heads;
|
||||
let dst3_stride = dst2_stride * params.seq_len_q;
|
||||
|
||||
let batch_idx = wg_id.x / wg_per_batch;
|
||||
let q_batch_offset = params.offset_q + batch_idx * params.stride_q3;
|
||||
let k_batch_offset = params.offset_k + batch_idx * params.stride_k3;
|
||||
let v_batch_offset = params.offset_v + batch_idx * params.stride_v3;
|
||||
let dst_batch_offset = params.offset_dst + batch_idx * dst3_stride;
|
||||
let wg_in_batch = wg_id.x % wg_per_batch;
|
||||
|
||||
let head_idx = wg_in_batch / wg_per_head;
|
||||
let q_head_offset = q_batch_offset + head_idx * params.stride_q2;
|
||||
let k_head_idx = head_idx / params.q_per_kv;
|
||||
let v_head_offset = v_batch_offset + k_head_idx * params.stride_v2;
|
||||
let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2;
|
||||
|
||||
let wg_in_head = wg_in_batch % wg_per_head;
|
||||
let q_row_start = wg_in_head * Q_TILE;
|
||||
let global_q_row = q_row_start + subgroup_id;
|
||||
let row_active = subgroup_id < Q_TILE && global_q_row < params.seq_len_q;
|
||||
|
||||
#ifdef MASK
|
||||
let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv;
|
||||
#endif
|
||||
|
||||
let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * HEAD_DIM_V;
|
||||
|
||||
let head = f32(head_idx);
|
||||
let slope = select(1.0,
|
||||
select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0),
|
||||
pow(params.m0, head + 1.0),
|
||||
head < params.n_head_log2),
|
||||
params.max_bias > 0.0);
|
||||
|
||||
for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
|
||||
let q_tile_row = elem_idx / HEAD_DIM_QK;
|
||||
let q_col = elem_idx % HEAD_DIM_QK;
|
||||
let head_q_row = q_row_start + q_tile_row;
|
||||
let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1;
|
||||
q_shmem[elem_idx] = f16(select(
|
||||
0.0,
|
||||
Q[global_q_row_offset + q_col] * params.scale,
|
||||
head_q_row < params.seq_len_q));
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
var row_max = FLOAT_MIN;
|
||||
var exp_sum = 0.0;
|
||||
var out_regs: array<vec4<f32>, OUT_REGS_PER_LANE>;
|
||||
for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) {
|
||||
out_regs[reg_idx] = vec4<f32>(0.0);
|
||||
}
|
||||
|
||||
let q_base = subgroup_id * HEAD_DIM_QK;
|
||||
let subgroup_p_offset = subgroup_id * KV_TILE;
|
||||
|
||||
for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) {
|
||||
let kv_count = min(KV_TILE, params.seq_len_kv - kv_tile);
|
||||
let score_slots = min(SCORE_REGS_PER_LANE, (kv_count + subgroup_size - 1u) / subgroup_size);
|
||||
let out_slots = min(OUT_REGS_PER_LANE, (V_CHUNKS + subgroup_size - 1u) / subgroup_size);
|
||||
var local_scores: array<f32, SCORE_REGS_PER_LANE>;
|
||||
for (var slot = 0u; slot < SCORE_REGS_PER_LANE; slot += 1u) {
|
||||
local_scores[slot] = FLOAT_MIN;
|
||||
}
|
||||
|
||||
for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * Q_CHUNKS; vec_idx_local += WG_SIZE) {
|
||||
let kv_local = vec_idx_local / Q_CHUNKS;
|
||||
let chunk = vec_idx_local % Q_CHUNKS;
|
||||
let global_k_row = kv_tile + kv_local;
|
||||
let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u;
|
||||
let k4 = K[k_vec_index];
|
||||
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
|
||||
kv_shmem[kv_off + 0u] = k4.x;
|
||||
kv_shmem[kv_off + 1u] = k4.y;
|
||||
kv_shmem[kv_off + 2u] = k4.z;
|
||||
kv_shmem[kv_off + 3u] = k4.w;
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
var local_max = FLOAT_MIN;
|
||||
if (row_active) {
|
||||
for (var slot = 0u; slot < score_slots; slot += 1u) {
|
||||
let kv_local = sg_inv_id + slot * subgroup_size;
|
||||
if (kv_local >= kv_count) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let global_k_row = kv_tile + kv_local;
|
||||
var dot_val = 0.0;
|
||||
for (var chunk = 0u; chunk < Q_CHUNKS; chunk += 1u) {
|
||||
let q_off = q_base + chunk * 4u;
|
||||
let qv = vec4<f32>(
|
||||
f32(q_shmem[q_off + 0u]),
|
||||
f32(q_shmem[q_off + 1u]),
|
||||
f32(q_shmem[q_off + 2u]),
|
||||
f32(q_shmem[q_off + 3u]));
|
||||
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
|
||||
let kv = vec4<f32>(
|
||||
f32(kv_shmem[kv_off + 0u]),
|
||||
f32(kv_shmem[kv_off + 1u]),
|
||||
f32(kv_shmem[kv_off + 2u]),
|
||||
f32(kv_shmem[kv_off + 3u]));
|
||||
dot_val += dot(qv, kv);
|
||||
}
|
||||
#ifdef LOGIT_SOFTCAP
|
||||
dot_val = params.logit_softcap * tanh(dot_val);
|
||||
#endif
|
||||
#ifdef MASK
|
||||
let mask_idx = mask_global_offset + subgroup_id * params.seq_len_kv + global_k_row;
|
||||
dot_val += slope * f32(mask[mask_idx]);
|
||||
#endif
|
||||
local_scores[slot] = dot_val;
|
||||
local_max = max(local_max, dot_val);
|
||||
}
|
||||
}
|
||||
|
||||
let tile_max = subgroupMax(local_max);
|
||||
let new_max = max(row_max, tile_max);
|
||||
let cur_exp = exp(row_max - new_max);
|
||||
exp_sum *= cur_exp;
|
||||
for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) {
|
||||
out_regs[reg_idx] *= cur_exp;
|
||||
}
|
||||
|
||||
var local_sum = 0.0;
|
||||
for (var slot = 0u; slot < score_slots; slot += 1u) {
|
||||
let kv_local = sg_inv_id + slot * subgroup_size;
|
||||
if (row_active && kv_local < kv_count) {
|
||||
let p = exp(local_scores[slot] - new_max);
|
||||
p_shmem[subgroup_p_offset + kv_local] = p;
|
||||
local_sum += p;
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * V_CHUNKS; vec_idx_local += WG_SIZE) {
|
||||
let kv_local = vec_idx_local / V_CHUNKS;
|
||||
let chunk = vec_idx_local % V_CHUNKS;
|
||||
let global_v_row = kv_tile + kv_local;
|
||||
let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u;
|
||||
let v4 = V[v_vec_index];
|
||||
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
|
||||
kv_shmem[kv_off + 0u] = v4.x;
|
||||
kv_shmem[kv_off + 1u] = v4.y;
|
||||
kv_shmem[kv_off + 2u] = v4.z;
|
||||
kv_shmem[kv_off + 3u] = v4.w;
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
let tile_sum = subgroupAdd(local_sum);
|
||||
exp_sum += tile_sum;
|
||||
row_max = new_max;
|
||||
|
||||
if (row_active) {
|
||||
for (var reg_idx = 0u; reg_idx < out_slots; reg_idx += 1u) {
|
||||
let chunk = sg_inv_id + reg_idx * subgroup_size;
|
||||
if (chunk >= V_CHUNKS) {
|
||||
continue;
|
||||
}
|
||||
|
||||
var acc = out_regs[reg_idx];
|
||||
for (var kv_local = 0u; kv_local < kv_count; kv_local += 1u) {
|
||||
let p = p_shmem[subgroup_p_offset + kv_local];
|
||||
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
|
||||
let v4 = vec4<f32>(
|
||||
f32(kv_shmem[kv_off + 0u]),
|
||||
f32(kv_shmem[kv_off + 1u]),
|
||||
f32(kv_shmem[kv_off + 2u]),
|
||||
f32(kv_shmem[kv_off + 3u]));
|
||||
acc += p * v4;
|
||||
}
|
||||
out_regs[reg_idx] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
#ifdef SINKS
|
||||
if (row_active) {
|
||||
let sink_score = sinks[params.offset_sinks + head_idx];
|
||||
let sink_max = max(row_max, sink_score);
|
||||
let sink_scale = exp(row_max - sink_max);
|
||||
for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) {
|
||||
out_regs[reg_idx] *= sink_scale;
|
||||
}
|
||||
exp_sum = exp_sum * sink_scale + exp(sink_score - sink_max);
|
||||
row_max = sink_max;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (row_active) {
|
||||
let inv_exp_sum = select(0.0, 1.0 / exp_sum, exp_sum != 0.0);
|
||||
let row_base = dst_global_offset + subgroup_id * dst2_stride;
|
||||
let out_slots = min(OUT_REGS_PER_LANE, (V_CHUNKS + subgroup_size - 1u) / subgroup_size);
|
||||
for (var reg_idx = 0u; reg_idx < out_slots; reg_idx += 1u) {
|
||||
let chunk = sg_inv_id + reg_idx * subgroup_size;
|
||||
if (chunk >= V_CHUNKS) {
|
||||
continue;
|
||||
}
|
||||
let dst_vec_index = (row_base + chunk * 4u) >> 2u;
|
||||
dst[dst_vec_index] = out_regs[reg_idx] * inv_exp_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -15,7 +15,7 @@ struct Params {
|
||||
nblk1: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read> mask: array<f16>;
|
||||
@group(0) @binding(0) var<storage, read_write> mask: array<f16>;
|
||||
@group(0) @binding(1) var<storage, read_write> blk: array<u32>;
|
||||
@group(0) @binding(2) var<uniform> params: Params;
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
diagnostic(off, chromium.subgroup_matrix_uniformity);
|
||||
diagnostic(off, subgroup_uniformity);
|
||||
enable f16;
|
||||
enable subgroups;
|
||||
enable chromium_experimental_subgroup_matrix;
|
||||
|
||||
#ifdef KV_F32
|
||||
#define KV_TYPE f32
|
||||
@@ -13,19 +11,14 @@ enable chromium_experimental_subgroup_matrix;
|
||||
#define HEAD_DIM_QK 64
|
||||
#define HEAD_DIM_V 64
|
||||
|
||||
|
||||
#define SG_MAT_M 8
|
||||
#define SG_MAT_N 8
|
||||
#define SG_MAT_K 8
|
||||
|
||||
#define Q_TILE SG_MAT_M
|
||||
#define KV_GRANULARITY 8
|
||||
#define KV_TILE 16
|
||||
#define WG_SIZE 64
|
||||
#ifndef VEC_NE
|
||||
#define VEC_NE 4u
|
||||
#endif
|
||||
|
||||
#define KV_BLOCKS (KV_TILE / SG_MAT_N)
|
||||
#define KV_BLOCKS (KV_TILE / KV_GRANULARITY)
|
||||
|
||||
#define BLOCK_SIZE 32
|
||||
#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE)
|
||||
@@ -97,6 +90,14 @@ struct Params {
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
|
||||
#ifdef KV_OVERLAP
|
||||
#if defined(KV_Q4_0) || defined(KV_Q8_0)
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
|
||||
#else
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>;
|
||||
#endif
|
||||
#define V K
|
||||
#else
|
||||
#if defined(KV_Q4_0) || defined(KV_Q8_0)
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
|
||||
#else
|
||||
@@ -107,7 +108,22 @@ struct Params {
|
||||
#else
|
||||
@group(0) @binding(2) var<storage, read_write> V: array<vec4<KV_TYPE>>;
|
||||
#endif
|
||||
#endif
|
||||
#if defined(MASK) && defined(SINKS)
|
||||
#ifdef KV_OVERLAP
|
||||
@group(0) @binding(2) var<storage, read_write> mask: array<f16>;
|
||||
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
|
||||
#ifdef BLK
|
||||
#define BLK_BINDING 4
|
||||
#define TMP_BINDING 5
|
||||
#define DST_BINDING 6
|
||||
#define PARAMS_BINDING 7
|
||||
#else
|
||||
#define TMP_BINDING 4
|
||||
#define DST_BINDING 5
|
||||
#define PARAMS_BINDING 6
|
||||
#endif
|
||||
#else
|
||||
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||
@group(0) @binding(4) var<storage, read_write> sinks: array<f32>;
|
||||
#ifdef BLK
|
||||
@@ -120,7 +136,21 @@ struct Params {
|
||||
#define DST_BINDING 6
|
||||
#define PARAMS_BINDING 7
|
||||
#endif
|
||||
#endif
|
||||
#elif defined(MASK)
|
||||
#ifdef KV_OVERLAP
|
||||
@group(0) @binding(2) var<storage, read_write> mask: array<f16>;
|
||||
#ifdef BLK
|
||||
#define BLK_BINDING 3
|
||||
#define TMP_BINDING 4
|
||||
#define DST_BINDING 5
|
||||
#define PARAMS_BINDING 6
|
||||
#else
|
||||
#define TMP_BINDING 3
|
||||
#define DST_BINDING 4
|
||||
#define PARAMS_BINDING 5
|
||||
#endif
|
||||
#else
|
||||
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||
#ifdef BLK
|
||||
#define BLK_BINDING 4
|
||||
@@ -132,16 +162,30 @@ struct Params {
|
||||
#define DST_BINDING 5
|
||||
#define PARAMS_BINDING 6
|
||||
#endif
|
||||
#endif
|
||||
#elif defined(SINKS)
|
||||
#ifdef KV_OVERLAP
|
||||
@group(0) @binding(2) var<storage, read_write> sinks: array<f32>;
|
||||
#define TMP_BINDING 3
|
||||
#define DST_BINDING 4
|
||||
#define PARAMS_BINDING 5
|
||||
#else
|
||||
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
|
||||
#define TMP_BINDING 4
|
||||
#define DST_BINDING 5
|
||||
#define PARAMS_BINDING 6
|
||||
#endif
|
||||
#else
|
||||
#ifdef KV_OVERLAP
|
||||
#define TMP_BINDING 2
|
||||
#define DST_BINDING 3
|
||||
#define PARAMS_BINDING 4
|
||||
#else
|
||||
#define TMP_BINDING 3
|
||||
#define DST_BINDING 4
|
||||
#define PARAMS_BINDING 5
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef BLK
|
||||
@group(0) @binding(BLK_BINDING) var<storage, read_write> blk: array<u32>;
|
||||
@@ -153,7 +197,7 @@ struct Params {
|
||||
// Just a very small float value.
|
||||
const FLOAT_MIN: f32 = -1.0e9;
|
||||
|
||||
var<workgroup> q_shmem: array<f16, Q_TILE * HEAD_DIM_QK>;
|
||||
var<workgroup> q_shmem: array<f16, HEAD_DIM_QK>;
|
||||
|
||||
#ifndef KV_DIRECT
|
||||
const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V);
|
||||
@@ -161,31 +205,27 @@ const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V);
|
||||
var<workgroup> kv_shmem: array<f16, kv_shmem_size>;
|
||||
#endif
|
||||
|
||||
var<workgroup> o_shmem: array<f16, Q_TILE * HEAD_DIM_V>;
|
||||
var<workgroup> o_shmem: array<f16, HEAD_DIM_V>;
|
||||
|
||||
#ifdef MASK
|
||||
// storage for mask values
|
||||
var<workgroup> mask_shmem: array<f16, Q_TILE * KV_TILE>;
|
||||
var<workgroup> mask_shmem: array<f16, KV_TILE>;
|
||||
#endif
|
||||
|
||||
// note that we reuse the same storage for both since we only need one at a time
|
||||
var<workgroup> inter_shmem: array<f16, Q_TILE * KV_TILE>;
|
||||
var<workgroup> inter_shmem: array<f16, KV_TILE>;
|
||||
|
||||
// Storage for row max and exp sum during online softmax
|
||||
var<workgroup> row_max_shmem: array<f32, Q_TILE>;
|
||||
var<workgroup> exp_sum_shmem: array<f32, Q_TILE>;
|
||||
var<workgroup> blk_state_wg: u32;
|
||||
|
||||
fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 {
|
||||
fn calc_softmax_term(kv_idx: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 {
|
||||
var v = select(FLOAT_MIN,
|
||||
f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale,
|
||||
f32(inter_shmem[kv_idx]) * params.scale,
|
||||
kv_idx < KV_TILE);
|
||||
#ifdef LOGIT_SOFTCAP
|
||||
v = params.logit_softcap * tanh(v);
|
||||
#endif
|
||||
#ifdef MASK
|
||||
if (apply_mask) {
|
||||
var mask_val = select(0.0,f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE);
|
||||
var mask_val = select(0.0, f32(mask_shmem[kv_idx]), kv_idx < KV_TILE);
|
||||
v += select(mask_val, slope * mask_val, has_bias);
|
||||
}
|
||||
#endif
|
||||
@@ -199,19 +239,17 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(subgroup_size) subgroup_size: u32,
|
||||
@builtin(num_subgroups) num_subgroups: u32,
|
||||
@builtin(subgroup_invocation_id) sg_inv_id: u32) {
|
||||
// Vec path processes exactly one query row per workgroup, so subgroup 0 can
|
||||
// keep the running softmax state in private storage.
|
||||
var row_max = FLOAT_MIN;
|
||||
var exp_sum = 0.0;
|
||||
|
||||
// initialize row max for online softmax
|
||||
for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) {
|
||||
row_max_shmem[i] = FLOAT_MIN;
|
||||
exp_sum_shmem[i] = 0.0;
|
||||
}
|
||||
|
||||
for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) {
|
||||
for (var i = local_id.x; i < HEAD_DIM_V; i += WG_SIZE) {
|
||||
o_shmem[i] = 0.0;
|
||||
}
|
||||
|
||||
// workgroups per head/batch
|
||||
let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE;
|
||||
let wg_per_head = params.seq_len_q;
|
||||
let wg_per_batch = wg_per_head * params.n_heads;
|
||||
|
||||
let dst2_stride = HEAD_DIM_V * params.n_heads;
|
||||
@@ -235,9 +273,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2;
|
||||
let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2;
|
||||
|
||||
// starting Q row for this workgroup
|
||||
// Vec path handles one Q row per workgroup.
|
||||
let wg_in_head = wg_in_batch % wg_per_head;
|
||||
let q_row_start = wg_in_head * Q_TILE;
|
||||
let q_row_start = wg_in_head;
|
||||
|
||||
#ifdef MASK
|
||||
// mask offset
|
||||
@@ -248,21 +286,18 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
let has_bias = params.max_bias > 0.0;
|
||||
let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), has_bias);
|
||||
|
||||
// load q tile into shared memory
|
||||
for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
|
||||
let q_row = elem_idx / HEAD_DIM_QK;
|
||||
let q_col = elem_idx % HEAD_DIM_QK;
|
||||
let head_q_row = q_row_start + q_row;
|
||||
let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1;
|
||||
// load the single Q row into shared memory
|
||||
for (var elem_idx = local_id.x; elem_idx < HEAD_DIM_QK; elem_idx += WG_SIZE) {
|
||||
let global_q_row_offset = q_head_offset + q_row_start * params.stride_q1;
|
||||
q_shmem[elem_idx] = f16(select(
|
||||
0.0,
|
||||
Q[global_q_row_offset + q_col],
|
||||
head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK));
|
||||
Q[global_q_row_offset + elem_idx],
|
||||
q_row_start < params.seq_len_q));
|
||||
}
|
||||
|
||||
for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) {
|
||||
#ifdef BLK
|
||||
let q_blk = q_row_start / Q_TILE;
|
||||
let q_blk = q_row_start;
|
||||
let kv_blk = kv_tile / KV_TILE;
|
||||
let blk_batch = select(0u, batch_idx, params.stride_mask3 > 0u);
|
||||
let blk_idx = params.blk_base + (blk_batch * params.blk_nblk1 + q_blk) * params.blk_nblk0 + kv_blk;
|
||||
@@ -270,13 +305,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
#else
|
||||
let blk_state_local = 1u;
|
||||
#endif
|
||||
if (local_id.x == 0u) {
|
||||
blk_state_wg = blk_state_local;
|
||||
}
|
||||
workgroupBarrier();
|
||||
let blk_state = blk_state_wg;
|
||||
let blk_state = blk_state_local;
|
||||
let skip_tile = blk_state == 0u;
|
||||
for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
|
||||
for (var elem_idx = local_id.x; elem_idx < KV_TILE; elem_idx += WG_SIZE) {
|
||||
inter_shmem[elem_idx] = f16(0.0);
|
||||
}
|
||||
|
||||
@@ -360,20 +391,14 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
let num_of_threads = subgroup_size / VEC_NE;
|
||||
let tx = sg_inv_id % num_of_threads;
|
||||
let ty = sg_inv_id / num_of_threads;
|
||||
for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) {
|
||||
let global_q_row = q_row_start + q_tile_row;
|
||||
if (global_q_row >= params.seq_len_q) {
|
||||
continue;
|
||||
}
|
||||
let local_q_row_offset = q_tile_row * HEAD_DIM_QK;
|
||||
|
||||
if (subgroup_id == 0u && q_row_start < params.seq_len_q) {
|
||||
for (var kv_base : u32 = 0u; kv_base < KV_TILE; kv_base += VEC_NE) {
|
||||
let kv_idx = kv_base + ty;
|
||||
var partial_sum: f32 = 0.0;
|
||||
let kv_valid = kv_idx < KV_TILE && (kv_tile + kv_idx) < params.seq_len_kv;
|
||||
if (kv_valid) {
|
||||
for (var i = tx; i < (HEAD_DIM_QK / 4u); i += num_of_threads) {
|
||||
let q_off = local_q_row_offset + i * 4u;
|
||||
let q_off = i * 4u;
|
||||
|
||||
let qv = vec4<f32>(
|
||||
f32(q_shmem[q_off + 0u]),
|
||||
@@ -410,8 +435,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
|
||||
let sum_bcast = subgroupShuffle(sum, num_of_threads * ty);
|
||||
if (tx == 0u && kv_valid) {
|
||||
let dst_idx = q_tile_row * KV_TILE + kv_idx;
|
||||
inter_shmem[dst_idx] = f16(sum_bcast);
|
||||
inter_shmem[kv_idx] = f16(sum_bcast);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -422,13 +446,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
let apply_mask = !skip_tile && (blk_state != 2u);
|
||||
if (apply_mask) {
|
||||
// load mask tile into shared memory for this KV block
|
||||
for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
|
||||
let mask_row = elem_idx / KV_TILE;
|
||||
let mask_col = elem_idx % KV_TILE;
|
||||
let global_q_row = q_row_start + mask_row;
|
||||
let global_k_col = kv_tile + mask_col;
|
||||
let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv;
|
||||
let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col;
|
||||
for (var elem_idx = local_id.x; elem_idx < KV_TILE; elem_idx += WG_SIZE) {
|
||||
let global_k_col = kv_tile + elem_idx;
|
||||
let mask_in_bounds = q_row_start < params.seq_len_q && global_k_col < params.seq_len_kv;
|
||||
let mask_idx = mask_global_offset + global_k_col;
|
||||
mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds);
|
||||
}
|
||||
}
|
||||
@@ -439,50 +460,40 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
workgroupBarrier();
|
||||
|
||||
// online softmax
|
||||
if (!skip_tile) {
|
||||
for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) {
|
||||
let global_q_row = q_row_start + q_tile_row;
|
||||
if (global_q_row >= params.seq_len_q) {
|
||||
break;
|
||||
}
|
||||
if (!skip_tile && subgroup_id == 0u && q_row_start < params.seq_len_q) {
|
||||
var prev_max = row_max;
|
||||
var final_max = prev_max;
|
||||
// pass 1: compute final max across the full KV tile in chunks
|
||||
for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
|
||||
let kv_idx = kv_offset + sg_inv_id;
|
||||
let kv_valid = kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE;
|
||||
let softmax_term = select(FLOAT_MIN,
|
||||
calc_softmax_term(kv_idx, slope, has_bias, apply_mask),
|
||||
kv_valid);
|
||||
final_max = subgroupMax(max(final_max, softmax_term));
|
||||
}
|
||||
|
||||
var prev_max = row_max_shmem[q_tile_row];
|
||||
var final_max = prev_max;
|
||||
// pass 1: compute final max across the full KV tile in chunks
|
||||
for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
|
||||
let kv_idx = kv_offset + sg_inv_id;
|
||||
let kv_valid = kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE;
|
||||
let softmax_term = select(FLOAT_MIN,
|
||||
calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask),
|
||||
kv_valid);
|
||||
final_max = subgroupMax(max(final_max, softmax_term));
|
||||
var total_exp_term: f32 = 0.0;
|
||||
// pass 2: compute exp sum and write P using final_max
|
||||
for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
|
||||
let kv_idx = kv_offset + sg_inv_id;
|
||||
let softmax_term = calc_softmax_term(kv_idx, slope, has_bias, apply_mask);
|
||||
let cur_p = select(0.0,
|
||||
exp(softmax_term - final_max),
|
||||
kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE);
|
||||
total_exp_term += subgroupAdd(cur_p);
|
||||
if (kv_idx < KV_TILE) {
|
||||
inter_shmem[kv_idx] = f16(cur_p);
|
||||
}
|
||||
}
|
||||
|
||||
var total_exp_term: f32 = 0.0;
|
||||
// pass 2: compute exp sum and write P using final_max
|
||||
for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
|
||||
let kv_idx = kv_offset + sg_inv_id;
|
||||
let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask);
|
||||
let cur_p = select(0.0,
|
||||
exp(softmax_term - final_max),
|
||||
kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE);
|
||||
total_exp_term += subgroupAdd(cur_p);
|
||||
if (kv_idx < KV_TILE) {
|
||||
inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p);
|
||||
}
|
||||
}
|
||||
let cur_exp = exp(prev_max - final_max);
|
||||
|
||||
let cur_exp = exp(prev_max - final_max);
|
||||
row_max = final_max;
|
||||
exp_sum = exp_sum * cur_exp + total_exp_term;
|
||||
|
||||
if (sg_inv_id == 0) {
|
||||
row_max_shmem[q_tile_row] = final_max;
|
||||
exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term;
|
||||
}
|
||||
|
||||
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
|
||||
let idx = q_tile_row * HEAD_DIM_V + elem_idx;
|
||||
o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp);
|
||||
}
|
||||
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
|
||||
o_shmem[elem_idx] = f16(f32(o_shmem[elem_idx]) * cur_exp);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -562,15 +573,13 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
workgroupBarrier();
|
||||
|
||||
if (!skip_tile) {
|
||||
// we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem
|
||||
// we have P (KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem
|
||||
// we want to compute O += P * V across the full KV tile
|
||||
let ne_threads : u32 = VEC_NE;
|
||||
let nl_threads = max(1u, subgroup_size / ne_threads);
|
||||
let tx_pv = sg_inv_id % nl_threads;
|
||||
let ty_pv = sg_inv_id / nl_threads;
|
||||
for (var q_tile_row = subgroup_id;
|
||||
q_tile_row < Q_TILE;
|
||||
q_tile_row += num_subgroups) {
|
||||
if (subgroup_id == 0u && q_row_start < params.seq_len_q) {
|
||||
for (var vec_col = tx_pv; vec_col < (HEAD_DIM_V / 4u); vec_col += nl_threads) {
|
||||
var lo = vec4<f32>(0.0, 0.0, 0.0, 0.0);
|
||||
for (var cc = 0u; cc < KV_TILE / ne_threads; cc += 1u) {
|
||||
@@ -580,7 +589,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
continue;
|
||||
}
|
||||
|
||||
let p = f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]);
|
||||
let p = f32(inter_shmem[kv_idx]);
|
||||
#ifdef KV_DIRECT
|
||||
let v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u;
|
||||
let v4 = vec4<f32>(V[v_idx >> 2u]);
|
||||
@@ -621,11 +630,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
|
||||
if (ty_pv == 0u) {
|
||||
let elem_base = vec_col * 4u;
|
||||
let o_base_idx = q_tile_row * HEAD_DIM_V + elem_base;
|
||||
o_shmem[o_base_idx + 0u] = f16(f32(o_shmem[o_base_idx + 0u]) + lo_x);
|
||||
o_shmem[o_base_idx + 1u] = f16(f32(o_shmem[o_base_idx + 1u]) + lo_y);
|
||||
o_shmem[o_base_idx + 2u] = f16(f32(o_shmem[o_base_idx + 2u]) + lo_z);
|
||||
o_shmem[o_base_idx + 3u] = f16(f32(o_shmem[o_base_idx + 3u]) + lo_w);
|
||||
o_shmem[elem_base + 0u] = f16(f32(o_shmem[elem_base + 0u]) + lo_x);
|
||||
o_shmem[elem_base + 1u] = f16(f32(o_shmem[elem_base + 1u]) + lo_y);
|
||||
o_shmem[elem_base + 2u] = f16(f32(o_shmem[elem_base + 2u]) + lo_z);
|
||||
o_shmem[elem_base + 3u] = f16(f32(o_shmem[elem_base + 3u]) + lo_w);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -637,70 +645,46 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
|
||||
#ifdef SINKS
|
||||
// Sinks are global terms and must be applied exactly once across split workgroups.
|
||||
if (iwg == 0u) {
|
||||
for (var q_tile_row = subgroup_id;
|
||||
q_tile_row < Q_TILE;
|
||||
q_tile_row += num_subgroups) {
|
||||
let global_q_row = q_row_start + q_tile_row;
|
||||
if (global_q_row >= params.seq_len_q) {
|
||||
break;
|
||||
}
|
||||
if (iwg == 0u && subgroup_id == 0u && q_row_start < params.seq_len_q) {
|
||||
var prev_max = row_max;
|
||||
|
||||
var prev_max = row_max_shmem[q_tile_row];
|
||||
// for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum
|
||||
let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0u);
|
||||
let new_max = subgroupMax(max(prev_max, sink_val));
|
||||
let max_exp = exp(prev_max - new_max);
|
||||
let sink_exp = exp(sink_val - new_max);
|
||||
|
||||
// for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum
|
||||
let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0);
|
||||
let new_max = subgroupMax(max(prev_max, sink_val));
|
||||
let max_exp = exp(prev_max - new_max);
|
||||
let sink_exp = exp(sink_val - new_max);
|
||||
let sink_exp_sum = subgroupAdd(sink_exp);
|
||||
|
||||
let sink_exp_sum = subgroupAdd(sink_exp);
|
||||
row_max = new_max;
|
||||
exp_sum = exp_sum * max_exp + sink_exp_sum;
|
||||
|
||||
if (sg_inv_id == 0) {
|
||||
row_max_shmem[q_tile_row] = new_max;
|
||||
exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum;
|
||||
}
|
||||
|
||||
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
|
||||
let idx = q_tile_row * HEAD_DIM_V + elem_idx;
|
||||
o_shmem[idx] = f16(f32(o_shmem[idx]) * max_exp);
|
||||
}
|
||||
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
|
||||
o_shmem[elem_idx] = f16(f32(o_shmem[elem_idx]) * max_exp);
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
workgroupBarrier();
|
||||
#endif
|
||||
let rows_per_batch = params.n_heads * params.seq_len_q;
|
||||
for (var q_tile_row = subgroup_id;
|
||||
q_tile_row < Q_TILE;
|
||||
q_tile_row += num_subgroups) {
|
||||
|
||||
let global_q_row = q_row_start + q_tile_row;
|
||||
if (global_q_row >= params.seq_len_q) { break; }
|
||||
|
||||
if (subgroup_id == 0u && q_row_start < params.seq_len_q) {
|
||||
if (params.nwg == 1u) {
|
||||
let exp_sum = exp_sum_shmem[q_tile_row];
|
||||
let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0);
|
||||
let row_base: u32 =
|
||||
params.offset_dst + batch_idx * dst3_stride + global_q_row * dst2_stride + head_idx * HEAD_DIM_V;
|
||||
let row_base: u32 = params.offset_dst + batch_idx * dst3_stride + q_row_start * dst2_stride +
|
||||
head_idx * HEAD_DIM_V;
|
||||
|
||||
for (var elem_base = sg_inv_id * 4u; elem_base < HEAD_DIM_V; elem_base += subgroup_size * 4u) {
|
||||
let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);
|
||||
let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);
|
||||
let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);
|
||||
let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);
|
||||
|
||||
let v = vec4<f32>(
|
||||
f32(o_shmem[i0]) * scale,
|
||||
f32(o_shmem[i1]) * scale,
|
||||
f32(o_shmem[i2]) * scale,
|
||||
f32(o_shmem[i3]) * scale
|
||||
f32(o_shmem[elem_base + 0u]) * scale,
|
||||
f32(o_shmem[elem_base + 1u]) * scale,
|
||||
f32(o_shmem[elem_base + 2u]) * scale,
|
||||
f32(o_shmem[elem_base + 3u]) * scale
|
||||
);
|
||||
|
||||
let dst_vec_index: u32 = (row_base + elem_base) >> 2u;
|
||||
dst[dst_vec_index] = v;
|
||||
}
|
||||
} else {
|
||||
let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + global_q_row;
|
||||
let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + q_row_start;
|
||||
let tmp_row_data_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg) + iwg * HEAD_DIM_V;
|
||||
let tmp_row_stats_base = params.tmp_stats_base + rid * (2u * params.nwg) + 2u * iwg;
|
||||
|
||||
@@ -708,21 +692,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
elem_base < HEAD_DIM_V;
|
||||
elem_base += subgroup_size * 4u) {
|
||||
|
||||
let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);
|
||||
let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);
|
||||
let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);
|
||||
let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);
|
||||
|
||||
let tbase = tmp_row_data_base + elem_base;
|
||||
tmp[tbase + 0u] = f32(o_shmem[i0]);
|
||||
tmp[tbase + 1u] = f32(o_shmem[i1]);
|
||||
tmp[tbase + 2u] = f32(o_shmem[i2]);
|
||||
tmp[tbase + 3u] = f32(o_shmem[i3]);
|
||||
tmp[tbase + 0u] = f32(o_shmem[elem_base + 0u]);
|
||||
tmp[tbase + 1u] = f32(o_shmem[elem_base + 1u]);
|
||||
tmp[tbase + 2u] = f32(o_shmem[elem_base + 2u]);
|
||||
tmp[tbase + 3u] = f32(o_shmem[elem_base + 3u]);
|
||||
}
|
||||
|
||||
if (sg_inv_id == 0u) {
|
||||
tmp[tmp_row_stats_base + 0u] = exp_sum_shmem[q_tile_row];
|
||||
tmp[tmp_row_stats_base + 1u] = row_max_shmem[q_tile_row];
|
||||
tmp[tmp_row_stats_base + 0u] = exp_sum;
|
||||
tmp[tmp_row_stats_base + 1u] = row_max;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7656,7 +7656,7 @@ size_t ggml_quantize_chunk(
|
||||
int64_t nrows,
|
||||
int64_t n_per_row,
|
||||
const float * imatrix) {
|
||||
const int64_t n = (int64_t) nrows * n_per_row;
|
||||
const int64_t n = nrows * n_per_row;
|
||||
|
||||
if (ggml_quantize_requires_imatrix(type)) {
|
||||
GGML_ASSERT(imatrix != NULL);
|
||||
@@ -7673,21 +7673,21 @@ size_t ggml_quantize_chunk(
|
||||
size_t result = 0;
|
||||
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q1_0: result = quantize_q1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q4_0: result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q4_1: result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_MXFP4: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_NVFP4: result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q5_K: result = quantize_q5_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_TQ1_0: result = quantize_tq1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_TQ2_0: result = quantize_tq2_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q1_0: result = quantize_q1_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q4_0: result = quantize_q4_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q4_1: result = quantize_q4_1 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q5_0: result = quantize_q5_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q5_1: result = quantize_q5_1 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q8_0: result = quantize_q8_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_MXFP4: result = quantize_mxfp4 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_NVFP4: result = quantize_nvfp4 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q2_K: result = quantize_q2_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q3_K: result = quantize_q3_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q4_K: result = quantize_q4_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q5_K: result = quantize_q5_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q6_K: result = quantize_q6_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_TQ1_0: result = quantize_tq1_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_TQ2_0: result = quantize_tq2_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
@@ -7752,9 +7752,9 @@ struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads) {
|
||||
}
|
||||
|
||||
bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1) {
|
||||
if (p0->n_threads != p1->n_threads ) return false;
|
||||
if (p0->prio != p1->prio ) return false;
|
||||
if (p0->poll != p1->poll ) return false;
|
||||
if (p0->strict_cpu != p1->strict_cpu ) return false;
|
||||
if (p0->n_threads != p1->n_threads ) return false;
|
||||
if (p0->prio != p1->prio ) return false;
|
||||
if (p0->poll != p1->poll ) return false;
|
||||
if (p0->strict_cpu != p1->strict_cpu ) return false;
|
||||
return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user