mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-27 21:31:14 +02:00
Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
328874d054 | ||
|
|
c1f1e28d29 | ||
|
|
5a4126adc1 | ||
|
|
a4d2d4ae41 | ||
|
|
d161ea7071 | ||
|
|
45158f460e | ||
|
|
22307b3e8b | ||
|
|
ce5890b5f7 | ||
|
|
b251f74f49 | ||
|
|
fa97041524 | ||
|
|
ae251b5ff2 | ||
|
|
66efd13375 |
@@ -467,7 +467,14 @@ class ModelBase:
|
||||
elif quant_method == "compressed-tensors":
|
||||
quant_format = quant_config["format"]
|
||||
groups = quant_config["config_groups"]
|
||||
if len(groups) > 1:
|
||||
nvfp4_compressed_tensors = (
|
||||
quant_format == "nvfp4-pack-quantized"
|
||||
or quant_format == "mixed-precision"
|
||||
and bool(groups)
|
||||
and all(g.get("format") == "nvfp4-pack-quantized" for g in groups.values() if isinstance(g, dict))
|
||||
)
|
||||
|
||||
if len(groups) > 1 and not nvfp4_compressed_tensors:
|
||||
raise NotImplementedError("Can't handle multiple config groups for compressed-tensors yet")
|
||||
weight_config = tuple(groups.values())[0]["weights"]
|
||||
|
||||
@@ -505,6 +512,9 @@ class ModelBase:
|
||||
tensors_to_remove += [base_name + n for n in ("_packed", "_shape", "_scale")]
|
||||
if (base_name + "_zero_point") in self.model_tensors:
|
||||
tensors_to_remove.append(base_name + "_zero_point")
|
||||
elif nvfp4_compressed_tensors:
|
||||
# Don't error from compressed-tensors, we'll handle them in _generate_nvfp4_tensors
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError(f"Quant format {quant_format!r} for method {quant_method!r} is not yet supported")
|
||||
elif quant_method == "modelopt":
|
||||
@@ -746,10 +756,13 @@ class ModelBase:
|
||||
del experts, merged
|
||||
|
||||
def prepare_tensors(self):
|
||||
# detect NVFP4 quantization (ModelOpt format)
|
||||
quant_algo = (self.hparams.get("quantization_config") or {}).get("quant_algo")
|
||||
quant_method = (self.hparams.get("quantization_config") or {}).get("quant_method")
|
||||
quant_layers = (self.hparams.get("quantization_config") or {}).get("quantized_layers") or {}
|
||||
# detect NVFP4 quantization (ModelOpt and Compressed-tensors formats)
|
||||
quantization_config = self.hparams.get("quantization_config") or {}
|
||||
quant_algo = quantization_config.get("quant_algo")
|
||||
quant_method = quantization_config.get("quant_method")
|
||||
quant_format = quantization_config.get("format")
|
||||
quant_groups = quantization_config.get("config_groups") or {}
|
||||
quant_layers = quantization_config.get("quantized_layers") or {}
|
||||
quant_config_file = self.dir_model / "hf_quant_config.json"
|
||||
|
||||
if (not quant_algo or not quant_layers) and quant_config_file.is_file():
|
||||
@@ -760,13 +773,25 @@ class ModelBase:
|
||||
producer_name = (producer.get("name") or "").lower()
|
||||
if quant_method is None:
|
||||
self.hparams.setdefault("quantization_config", {})["quant_method"] = producer_name
|
||||
quant_method = producer_name
|
||||
quant_algo = quant_config.get("quant_algo", quant_algo)
|
||||
quant_method = quant_config.get("quant_method", quant_method)
|
||||
quant_format = quant_config.get("format", quant_format)
|
||||
quant_groups = quant_config.get("config_groups", quant_groups) or {}
|
||||
quant_layers = quant_config.get("quantized_layers", quant_layers) or {}
|
||||
|
||||
# Some models use per-tensor quant_algo (e.g. "MIXED_PRECISION" with
|
||||
# per-layer NVFP4/FP8) instead of a single global "NVFP4" value.
|
||||
nvfp4_compressed_tensors = quant_method == "compressed-tensors" and (
|
||||
quant_format == "nvfp4-pack-quantized"
|
||||
or quant_format == "mixed-precision"
|
||||
and bool(quant_groups)
|
||||
and all(g.get("format") == "nvfp4-pack-quantized" for g in quant_groups.values() if isinstance(g, dict))
|
||||
)
|
||||
if quant_algo != "NVFP4":
|
||||
if any(v.get("quant_algo") == "NVFP4" for v in quant_layers.values() if isinstance(v, dict)):
|
||||
if nvfp4_compressed_tensors:
|
||||
quant_algo = "NVFP4"
|
||||
elif any(v.get("quant_algo") == "NVFP4" for v in quant_layers.values() if isinstance(v, dict)):
|
||||
quant_algo = "NVFP4"
|
||||
|
||||
self._is_nvfp4 = quant_algo == "NVFP4"
|
||||
@@ -776,6 +801,28 @@ class ModelBase:
|
||||
# This must run before dequant_model so NVFP4 tensors are removed
|
||||
# from model_tensors, leaving only non-NVFP4 (e.g. FP8) for dequant.
|
||||
if self._is_nvfp4:
|
||||
if nvfp4_compressed_tensors:
|
||||
# Convert compressed-tensors 'global' scales into the reciprocal
|
||||
def inverse_scale(gen):
|
||||
def load():
|
||||
scale = LazyTorchTensor.to_eager(gen()).float()
|
||||
return 1.0 / scale
|
||||
return load
|
||||
|
||||
# Change the compressed-tensors names to the ModelOpt names for handling consistently later
|
||||
for name in list(self.model_tensors.keys()):
|
||||
if name.endswith(".weight_packed"):
|
||||
weight_name = name.removesuffix("_packed")
|
||||
if weight_name not in self.model_tensors:
|
||||
self.model_tensors[weight_name] = self.model_tensors.pop(name)
|
||||
elif name.endswith(".weight_global_scale"):
|
||||
scale2_name = name.replace(".weight_global_scale", ".weight_scale_2")
|
||||
if scale2_name not in self.model_tensors:
|
||||
self.model_tensors[scale2_name] = inverse_scale(self.model_tensors.pop(name))
|
||||
elif name.endswith(".input_global_scale"):
|
||||
input_scale_name = name.replace(".input_global_scale", ".input_scale")
|
||||
if input_scale_name not in self.model_tensors:
|
||||
self.model_tensors[input_scale_name] = inverse_scale(self.model_tensors.pop(name))
|
||||
self._generate_nvfp4_tensors()
|
||||
|
||||
self.dequant_model()
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Iterable, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
@@ -549,6 +548,7 @@ class _Qwen35MtpMixin:
|
||||
tensor_map: gguf.TensorNameMap
|
||||
no_mtp: bool
|
||||
mtp_only: bool
|
||||
_original_block_count: int | None = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -557,22 +557,44 @@ class _Qwen35MtpMixin:
|
||||
self.block_count += self.hparams.get("mtp_num_hidden_layers", 0)
|
||||
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||
|
||||
def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Callable[[], Tensor]]:
|
||||
hparams = {**self.hparams, **self.hparams.get("text_config", {})}
|
||||
key = next((k for k in ["n_layers", "num_hidden_layers", "n_layer", "num_layers"] if k in hparams), None)
|
||||
type(self)._original_block_count = hparams.get(key)
|
||||
return super().index_tensors(remote_hf_model_id=remote_hf_model_id) # ty: ignore[unresolved-attribute]
|
||||
|
||||
@classmethod
|
||||
def filter_tensors(cls, item):
|
||||
name, _ = item
|
||||
assert cls._original_block_count is not None
|
||||
# TODO: change TextModel to super()
|
||||
if (titem := TextModel.filter_tensors(item)) is None:
|
||||
return None
|
||||
name, gen = titem
|
||||
if name.startswith("model.mtp."):
|
||||
name = name.replace("model.", "", 1)
|
||||
if name.startswith("mtp."):
|
||||
if cls.no_mtp:
|
||||
return None
|
||||
return item
|
||||
if cls.mtp_only:
|
||||
canonical = name.replace("language_model.", "")
|
||||
keep = canonical in (
|
||||
remapper = {
|
||||
"fc": "eh_proj",
|
||||
"pre_fc_norm_embedding": "enorm",
|
||||
"pre_fc_norm_hidden": "hnorm",
|
||||
"norm": "shared_head.norm",
|
||||
}
|
||||
parts = name.split(".", 3)
|
||||
if len(parts) == 4 and parts[1] == "layers" and parts[2].isdecimal():
|
||||
mtp_idx = int(parts[2])
|
||||
name = f"model.layers.{cls._original_block_count + mtp_idx}.{parts[3]}"
|
||||
elif len(parts) == 3 and parts[1] in remapper:
|
||||
name = f"model.layers.{cls._original_block_count}.{remapper[parts[1]]}.{parts[2]}"
|
||||
elif cls.mtp_only:
|
||||
keep = name in (
|
||||
"model.embed_tokens.weight", "model.norm.weight", "lm_head.weight",
|
||||
"embed_tokens.weight", "norm.weight",
|
||||
)
|
||||
if not keep:
|
||||
return None
|
||||
return super().filter_tensors(item) # ty: ignore[unresolved-attribute]
|
||||
return name, gen
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters() # ty: ignore[unresolved-attribute]
|
||||
@@ -594,29 +616,6 @@ class _Qwen35MtpMixin:
|
||||
self.metadata.version, size_label=None, output_type=output_type, model_type=None) # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
|
||||
self.fname_out = self.fname_out.parent / f"mtp-{fname_default}.gguf"
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if name.startswith("mtp."):
|
||||
n_layer = self.hparams["num_hidden_layers"]
|
||||
if name.find("layers.") != -1:
|
||||
assert bid is not None
|
||||
name = name.replace(f"mtp.layers.{bid}", f"model.layers.{bid + n_layer}")
|
||||
bid = bid + n_layer
|
||||
else:
|
||||
remapper = {
|
||||
"mtp.fc": "model.layers.{bid}.eh_proj",
|
||||
"mtp.pre_fc_norm_embedding": "model.layers.{bid}.enorm",
|
||||
"mtp.pre_fc_norm_hidden": "model.layers.{bid}.hnorm",
|
||||
"mtp.norm": "model.layers.{bid}.shared_head.norm",
|
||||
}
|
||||
stem = Path(name).stem
|
||||
suffix = Path(name).suffix
|
||||
tmpl = remapper[stem] + suffix
|
||||
for b in range(n_layer, self.block_count):
|
||||
yield from super().modify_tensors(data_torch, tmpl.format(bid=b), b) # ty: ignore[unresolved-attribute]
|
||||
return
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid) # ty: ignore[unresolved-attribute]
|
||||
|
||||
|
||||
@ModelBase.register("Qwen3_5ForConditionalGeneration", "Qwen3_5ForCausalLM")
|
||||
class Qwen3_5TextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReorderBase):
|
||||
|
||||
@@ -4,7 +4,7 @@ project("ggml" C CXX ASM)
|
||||
|
||||
### GGML Version
|
||||
set(GGML_VERSION_MAJOR 0)
|
||||
set(GGML_VERSION_MINOR 12)
|
||||
set(GGML_VERSION_MINOR 13)
|
||||
set(GGML_VERSION_PATCH 0)
|
||||
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||
|
||||
|
||||
@@ -1189,8 +1189,8 @@ extern "C" {
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// a - x
|
||||
// b - dy
|
||||
// a - dy
|
||||
// b - x
|
||||
GGML_API struct ggml_tensor * ggml_silu_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
|
||||
@@ -76,10 +76,16 @@ extern "C" {
|
||||
struct ggml_context ** ctx;
|
||||
};
|
||||
|
||||
// callback to simulate or wrap a FILE pointer - read up to `len` bytes at `offset` into `output` and return the number of bytes read
|
||||
typedef size_t (*gguf_reader_callback_t)(void * userdata, void * output, uint64_t offset, size_t len);
|
||||
|
||||
GGML_API struct gguf_context * gguf_init_empty(void);
|
||||
GGML_API struct gguf_context * gguf_init_from_file_ptr(FILE * file, struct gguf_init_params params);
|
||||
GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params);
|
||||
//GGML_API struct gguf_context * gguf_init_from_buffer(..);
|
||||
GGML_API struct gguf_context * gguf_init_from_buffer(const void * data, size_t size, struct gguf_init_params params);
|
||||
|
||||
// max_chunk_read is the maximum number of bytes that the GGUF code will read at once from the callback, a value of 0 means no limit
|
||||
GGML_API struct gguf_context * gguf_init_from_callback(gguf_reader_callback_t callback, void * userdata, size_t max_chunk_read, uint64_t max_expected_size, struct gguf_init_params params);
|
||||
|
||||
GGML_API void gguf_free(struct gguf_context * ctx);
|
||||
|
||||
@@ -87,7 +93,7 @@ extern "C" {
|
||||
|
||||
GGML_API uint32_t gguf_get_version (const struct gguf_context * ctx);
|
||||
GGML_API size_t gguf_get_alignment (const struct gguf_context * ctx);
|
||||
GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx);
|
||||
GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx); // padded to gguf_get_alignment if and only if the gguf_context contains at least one tensor
|
||||
|
||||
GGML_API int64_t gguf_get_n_kv(const struct gguf_context * ctx);
|
||||
GGML_API int64_t gguf_find_key(const struct gguf_context * ctx, const char * key); // returns -1 if key is not found
|
||||
|
||||
@@ -150,7 +150,7 @@ static void ggml_dyn_tallocr_insert_block(struct tallocr_chunk * chunk, size_t o
|
||||
|
||||
static void ggml_dyn_tallocr_remove_block(struct tallocr_chunk * chunk, int idx) {
|
||||
// shift all elements after idx by 1 to the left, overwriting the element at idx
|
||||
for (int i = idx; i < chunk->n_free_blocks; i++) {
|
||||
for (int i = idx; i < chunk->n_free_blocks - 1; i++) {
|
||||
chunk->free_blocks[i] = chunk->free_blocks[i+1];
|
||||
}
|
||||
chunk->n_free_blocks--;
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include <cstring>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
@@ -392,64 +393,100 @@ static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(
|
||||
// meta backend buffer
|
||||
//
|
||||
|
||||
// Container to hold the tensor slices per simple ggml backend buffer.
|
||||
struct ggml_backend_meta_simple_tensor_container {
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::map<const ggml_tensor *, std::vector<ggml_tensor *>> simple_tensors;
|
||||
|
||||
ggml_backend_meta_simple_tensor_container(const ggml_init_params & params, const int n_simple) {
|
||||
ctxs.reserve(n_simple);
|
||||
for (int i = 0; i < n_simple; i++) {
|
||||
ctxs.emplace_back(ggml_init(params));
|
||||
}
|
||||
}
|
||||
ggml_backend_meta_simple_tensor_container() {}
|
||||
};
|
||||
|
||||
struct ggml_backend_meta_buffer_context {
|
||||
// FIXME
|
||||
// Most tensors can simply be stored statically in their own buffer.
|
||||
// Externally created views however also need a mapping to simple tensors but they use the buffer of the view source.
|
||||
// If external views are simply using that buffer they will slowly deplete its memory.
|
||||
// Current solution: rotating set of 2 "compute" containers to hold external views, works correctly for llama.cpp.
|
||||
// Long-term: tie the lifetime of external views to the meta backend executing the graph instead,
|
||||
// currently not possible due to graph-external operations in the backend scheduler.
|
||||
ggml_backend_meta_simple_tensor_container stc_static;
|
||||
ggml_backend_meta_simple_tensor_container stc_compute[2];
|
||||
int stc_compute_index = 0;
|
||||
int stc_compute_index_next = 0;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
// FIXME
|
||||
// The size of the split state cache is unbounded and can theoretically grow infinitely large.
|
||||
// However, it is also expensive to build and clearing it on every rebuild in ggml_backend_meta_graph_compute is too expensive.
|
||||
static constexpr size_t nbtc = GGML_TENSOR_SIZE - sizeof(ggml_tensor::padding);
|
||||
|
||||
std::map<std::pair<const ggml_tensor *, bool>, std::pair<ggml_backend_meta_split_state, char[nbtc]>> split_state_cache;
|
||||
std::map< const ggml_tensor *, std::vector<ggml_tensor *>> simple_tensors;
|
||||
|
||||
struct buffer_config {
|
||||
ggml_context * ctx;
|
||||
ggml_backend_buffer_t buf;
|
||||
|
||||
buffer_config(ggml_context * ctx, ggml_backend_buffer_t buf) : ctx(ctx), buf(buf) {}
|
||||
};
|
||||
std::vector<buffer_config> buf_configs;
|
||||
|
||||
int debug;
|
||||
|
||||
ggml_backend_meta_buffer_context() {
|
||||
ggml_backend_meta_buffer_context(
|
||||
ggml_backend_meta_simple_tensor_container & stc_static,
|
||||
ggml_backend_meta_simple_tensor_container & stc_compute_0,
|
||||
ggml_backend_meta_simple_tensor_container & stc_compute_1,
|
||||
const std::vector<ggml_backend_buffer_t> & bufs)
|
||||
: stc_static(std::move(stc_static)), stc_compute{std::move(stc_compute_0), std::move(stc_compute_1)} {
|
||||
this->bufs.reserve(bufs.size());
|
||||
for (ggml_backend_buffer_t buf : bufs) {
|
||||
this->bufs.emplace_back(buf);
|
||||
}
|
||||
const char * GGML_META_DEBUG = getenv("GGML_META_DEBUG");
|
||||
debug = GGML_META_DEBUG ? atoi(GGML_META_DEBUG) : 0;
|
||||
}
|
||||
|
||||
ggml_backend_meta_simple_tensor_container & get_simple_tensor_container(const ggml_tensor * tensor) {
|
||||
if (stc_static.simple_tensors.find(tensor) != stc_static.simple_tensors.end()) {
|
||||
return stc_static;
|
||||
}
|
||||
return stc_compute[stc_compute_index];
|
||||
}
|
||||
};
|
||||
|
||||
static void ggml_backend_meta_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_meta(buffer));
|
||||
ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context;
|
||||
for (auto & [ctx, buf] : buf_ctx->buf_configs) {
|
||||
ggml_backend_buffer_free(buf);
|
||||
ggml_free(ctx);
|
||||
}
|
||||
delete buf_ctx;
|
||||
}
|
||||
|
||||
static size_t ggml_backend_meta_buffer_n_bufs(ggml_backend_buffer_t meta_buf) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_meta(meta_buf));
|
||||
ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) meta_buf->context;
|
||||
return buf_ctx->buf_configs.size();
|
||||
return buf_ctx->bufs.size();
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_t ggml_backend_meta_buffer_simple_buffer(ggml_backend_buffer_t meta_buf, size_t index) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_meta(meta_buf));
|
||||
ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) meta_buf->context;
|
||||
GGML_ASSERT(index < buf_ctx->buf_configs.size());
|
||||
return buf_ctx->buf_configs[index].buf;
|
||||
GGML_ASSERT(index < buf_ctx->bufs.size());
|
||||
return buf_ctx->bufs[index].get();
|
||||
}
|
||||
|
||||
static struct ggml_tensor * ggml_backend_meta_buffer_simple_tensor(const struct ggml_tensor * tensor, size_t index) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer));
|
||||
ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context;
|
||||
GGML_ASSERT(index < buf_ctx->buf_configs.size());
|
||||
GGML_ASSERT(index < buf_ctx->bufs.size());
|
||||
|
||||
auto it = buf_ctx->simple_tensors.find(tensor);
|
||||
if (it == buf_ctx->simple_tensors.end()) {
|
||||
ggml_backend_meta_simple_tensor_container & stc = buf_ctx->get_simple_tensor_container(tensor);
|
||||
auto it = stc.simple_tensors.find(tensor);
|
||||
if (it == stc.simple_tensors.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return it->second[index];
|
||||
}
|
||||
|
||||
static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync) {
|
||||
static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync);
|
||||
|
||||
static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(
|
||||
ggml_backend_meta_simple_tensor_container & stc, const struct ggml_tensor * tensor, bool assume_sync) {
|
||||
const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer);
|
||||
ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context;
|
||||
|
||||
@@ -785,7 +822,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(co
|
||||
src_ss[i] = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1};
|
||||
continue;
|
||||
}
|
||||
src_ss[i] = ggml_backend_meta_get_split_state(tensor->src[i], /*assume_sync =*/ true);
|
||||
src_ss[i] = ggml_backend_meta_get_split_state(stc, tensor->src[i], /*assume_sync =*/ true);
|
||||
GGML_ASSERT(src_ss[i].axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN);
|
||||
}
|
||||
|
||||
@@ -1079,17 +1116,23 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(co
|
||||
return ret;
|
||||
}
|
||||
|
||||
static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer));
|
||||
ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context;
|
||||
return ggml_backend_meta_get_split_state(buf_ctx->get_simple_tensor_container(tensor), tensor, assume_sync);
|
||||
}
|
||||
|
||||
static void * ggml_backend_meta_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||
GGML_UNUSED(buffer);
|
||||
return (void *) 0x1000000000000000; // FIXME
|
||||
}
|
||||
|
||||
static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_meta(buffer));
|
||||
ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context;
|
||||
const size_t n_simple_bufs = ggml_backend_meta_buffer_n_bufs(buffer);
|
||||
static enum ggml_status ggml_backend_meta_buffer_init_tensor_impl(ggml_backend_meta_simple_tensor_container & stc, ggml_tensor * tensor) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer));
|
||||
ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context;
|
||||
const size_t n_simple_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer);
|
||||
|
||||
const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ true);
|
||||
const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(stc, tensor, /*assume_sync =*/ true);
|
||||
GGML_ASSERT(ggml_nelements(tensor) == 0 || split_state.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN);
|
||||
GGML_ASSERT(split_state.n_segments <= 16);
|
||||
|
||||
@@ -1104,8 +1147,8 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer
|
||||
std::vector<ggml_tensor *> simple_tensors;
|
||||
simple_tensors.reserve(n_simple_bufs);
|
||||
for (size_t j = 0; j < n_simple_bufs; j++) {
|
||||
ggml_context * simple_ctx = buf_ctx->buf_configs[j].ctx;
|
||||
ggml_backend_buffer_t simple_buf = buf_ctx->buf_configs[j].buf;
|
||||
ggml_context * simple_ctx = stc.ctxs[j].get();
|
||||
ggml_backend_buffer_t simple_buf = buf_ctx->bufs[j].get();
|
||||
|
||||
if (split_dim >= 0 && split_dim < GGML_MAX_DIMS) {
|
||||
// TODO: the following assert fails for llama-parallel even though the results are correct:
|
||||
@@ -1158,7 +1201,7 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer
|
||||
t_ij->data = (char *) t_ij->view_src->data + t_ij->view_offs;
|
||||
} else if (simple_buf != nullptr) {
|
||||
t_ij->data = (char *) ggml_backend_buffer_get_base(simple_buf)
|
||||
+ size_t(tensor->data) - size_t(ggml_backend_buffer_get_base(buffer));
|
||||
+ size_t(tensor->data) - size_t(ggml_backend_buffer_get_base(tensor->buffer));
|
||||
}
|
||||
t_ij->extra = tensor->extra;
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
@@ -1194,11 +1237,18 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer
|
||||
}
|
||||
}
|
||||
|
||||
buf_ctx->simple_tensors[tensor] = simple_tensors;
|
||||
stc.simple_tensors[tensor] = simple_tensors;
|
||||
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_meta(buffer));
|
||||
ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context;
|
||||
buf_ctx->stc_compute_index = buf_ctx->stc_compute_index_next;
|
||||
return ggml_backend_meta_buffer_init_tensor_impl(buf_ctx->get_simple_tensor_container(tensor), tensor);
|
||||
}
|
||||
|
||||
static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(buffer);
|
||||
GGML_ASSERT(ggml_is_contiguous(tensor));
|
||||
@@ -1413,8 +1463,9 @@ static void ggml_backend_meta_buffer_clear(ggml_backend_buffer_t buffer, uint8_t
|
||||
}
|
||||
|
||||
static void ggml_backend_meta_buffer_reset(ggml_backend_buffer_t buffer) {
|
||||
const size_t n_buffers = ggml_backend_meta_buffer_n_bufs(buffer);
|
||||
for (size_t i = 0; i < n_buffers; i++) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_meta(buffer));
|
||||
ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context;
|
||||
for (size_t i = 0; i < buf_ctx->bufs.size(); i++) {
|
||||
ggml_backend_buffer_reset(ggml_backend_meta_buffer_simple_buffer(buffer, i));
|
||||
}
|
||||
}
|
||||
@@ -1440,21 +1491,24 @@ bool ggml_backend_buffer_is_meta(ggml_backend_buffer_t buf) {
|
||||
static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||
const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
|
||||
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ 1024*1024*1024, // FIXME
|
||||
const ggml_init_params params = {
|
||||
/*.mem_size =*/ 1024*1024*ggml_tensor_overhead(), // FIXME
|
||||
/*.mem_buffer =*/ nullptr,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
ggml_backend_meta_simple_tensor_container stc_static;
|
||||
ggml_backend_meta_simple_tensor_container stc_compute_0(params, n_simple_bufts);
|
||||
ggml_backend_meta_simple_tensor_container stc_compute_1(params, n_simple_bufts);
|
||||
|
||||
ggml_backend_meta_buffer_context * buf_ctx = new ggml_backend_meta_buffer_context();
|
||||
size_t max_size = 0;
|
||||
buf_ctx->buf_configs.reserve(n_simple_bufts);
|
||||
std::vector<ggml_backend_buffer_t> bufs;
|
||||
bufs.reserve(n_simple_bufts);
|
||||
for (size_t i = 0; i < n_simple_bufts; i++) {
|
||||
ggml_backend_buffer_t simple_buf = ggml_backend_buft_alloc_buffer(ggml_backend_meta_buft_simple_buft(buft, i), size);
|
||||
GGML_ASSERT(simple_buf != nullptr);
|
||||
max_size = std::max(max_size, ggml_backend_buffer_get_size(simple_buf));
|
||||
buf_ctx->buf_configs.emplace_back(ggml_init(params), simple_buf);
|
||||
bufs.push_back(ggml_backend_buft_alloc_buffer(ggml_backend_meta_buft_simple_buft(buft, i), size));
|
||||
GGML_ASSERT(bufs.back() != nullptr);
|
||||
max_size = std::max(max_size, ggml_backend_buffer_get_size(bufs.back()));
|
||||
}
|
||||
ggml_backend_meta_buffer_context * buf_ctx = new ggml_backend_meta_buffer_context(stc_static, stc_compute_0, stc_compute_1, bufs);
|
||||
|
||||
return ggml_backend_buffer_init(buft, ggml_backend_meta_buffer_iface, buf_ctx, max_size);
|
||||
}
|
||||
@@ -1462,26 +1516,32 @@ static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_bac
|
||||
struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
|
||||
const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
|
||||
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ 1024*1024*1024, // FIXME
|
||||
constexpr size_t compute_headroom = 16; // Maximum number of views per statically allocated tensor that can be created between evals.
|
||||
const ggml_init_params params_static = {
|
||||
/*.mem_size =*/ ggml_get_mem_size(ctx),
|
||||
/*.mem_buffer =*/ nullptr,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
const ggml_init_params params_compute = {
|
||||
/*.mem_size =*/ compute_headroom*ggml_get_mem_size(ctx),
|
||||
/*.mem_buffer =*/ nullptr,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
ggml_backend_meta_simple_tensor_container stc_static (params_static, n_simple_bufts);
|
||||
ggml_backend_meta_simple_tensor_container stc_compute_0(params_compute, n_simple_bufts);
|
||||
ggml_backend_meta_simple_tensor_container stc_compute_1(params_compute, n_simple_bufts);
|
||||
|
||||
ggml_backend_meta_buffer_context * meta_buf_ctx = new ggml_backend_meta_buffer_context();
|
||||
meta_buf_ctx->buf_configs.reserve(n_simple_bufts);
|
||||
for (size_t i = 0; i < n_simple_bufts; i++) {
|
||||
meta_buf_ctx->buf_configs.emplace_back(ggml_init(params), nullptr);
|
||||
}
|
||||
std::vector<ggml_backend_buffer_t> bufs(n_simple_bufts, nullptr);
|
||||
ggml_backend_meta_buffer_context * meta_buf_ctx = new ggml_backend_meta_buffer_context(stc_static, stc_compute_0, stc_compute_1, bufs);
|
||||
|
||||
ggml_backend_buffer_t meta_buf = ggml_backend_buffer_init(buft, ggml_backend_meta_buffer_iface, meta_buf_ctx, 0);
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
|
||||
t->buffer = meta_buf;
|
||||
ggml_backend_meta_buffer_init_tensor(meta_buf, t);
|
||||
ggml_backend_meta_buffer_init_tensor_impl(meta_buf_ctx->stc_static, t);
|
||||
t->data = (void *) 0x2000000000000000; // FIXME
|
||||
}
|
||||
for (size_t i = 0; i < n_simple_bufts; i++) {
|
||||
ggml_context * ctx = meta_buf_ctx->buf_configs[i].ctx;
|
||||
ggml_context * ctx = meta_buf_ctx->stc_static.ctxs[i].get();
|
||||
ggml_backend_buffer_type_t simple_buft = ggml_backend_meta_buft_simple_buft(buft, i);
|
||||
|
||||
// If a ggml_context only has zero-sized tensors, ggml_backend_alloc_ctx_tensors_from_buft returns NULL.
|
||||
@@ -1494,15 +1554,15 @@ struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struc
|
||||
}
|
||||
}
|
||||
if (any_nonzero_slice) {
|
||||
meta_buf_ctx->buf_configs[i].buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, simple_buft);
|
||||
meta_buf_ctx->bufs[i].reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx, simple_buft));
|
||||
} else {
|
||||
meta_buf_ctx->buf_configs[i].buf = ggml_backend_buft_alloc_buffer(simple_buft, 0);
|
||||
meta_buf_ctx->bufs[i].reset(ggml_backend_buft_alloc_buffer(simple_buft, 0));
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
|
||||
t->buffer = meta_buf_ctx->buf_configs[i].buf;
|
||||
t->buffer = meta_buf_ctx->bufs[i].get();
|
||||
}
|
||||
}
|
||||
GGML_ASSERT(meta_buf_ctx->buf_configs[i].buf != nullptr);
|
||||
meta_buf->size = std::max(meta_buf->size, ggml_backend_buffer_get_size(meta_buf_ctx->buf_configs[i].buf));
|
||||
GGML_ASSERT(meta_buf_ctx->bufs[i]);
|
||||
meta_buf->size = std::max(meta_buf->size, ggml_backend_buffer_get_size(meta_buf_ctx->bufs[i].get()));
|
||||
}
|
||||
return meta_buf;
|
||||
}
|
||||
@@ -1724,6 +1784,26 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend,
|
||||
}
|
||||
|
||||
if (needs_rebuild) {
|
||||
std::set<ggml_backend_buffer_t> used_buffers;
|
||||
for (int i = 0; i < cgraph->n_leafs; i++) {
|
||||
if (ggml_backend_buffer_is_meta(cgraph->leafs[i]->buffer)) {
|
||||
used_buffers.emplace(cgraph->leafs[i]->buffer);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
if (ggml_backend_buffer_is_meta(cgraph->nodes[i]->buffer)) {
|
||||
used_buffers.emplace(cgraph->nodes[i]->buffer);
|
||||
}
|
||||
}
|
||||
for (ggml_backend_buffer_t buf : used_buffers) {
|
||||
ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buf->context;
|
||||
buf_ctx->stc_compute_index_next = buf_ctx->stc_compute_index ^ 1;
|
||||
ggml_backend_meta_simple_tensor_container & stc = buf_ctx->stc_compute[buf_ctx->stc_compute_index_next];
|
||||
for (ggml_context_ptr & ctx : stc.ctxs) {
|
||||
ggml_reset(ctx.get());
|
||||
}
|
||||
stc.simple_tensors.clear();
|
||||
}
|
||||
size_t n_subgraphs = 0;
|
||||
size_t max_tmp_size = 0;
|
||||
|
||||
@@ -1909,7 +1989,7 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend,
|
||||
const size_t mem_per_device_graphs_main = backend_ctx->max_subgraphs*ggml_graph_overhead_custom(backend_ctx->max_nnodes, cgraph->grads);
|
||||
const size_t mem_per_device_graphs_aux = n_cgraphs_per_device*backend_ctx->max_subgraphs*ggml_graph_overhead_custom(1, cgraph->grads);
|
||||
const size_t mem_per_device_nodes_aux = n_nodes_per_device*backend_ctx->max_subgraphs*ggml_tensor_overhead();
|
||||
ggml_init_params params = {
|
||||
const ggml_init_params params = {
|
||||
/*.mem_size =*/ n_backends * (mem_per_device_graphs_main + mem_per_device_graphs_aux + mem_per_device_nodes_aux),
|
||||
/*.mem_buffer =*/ nullptr,
|
||||
/*.no_alloc =*/ true,
|
||||
|
||||
108
ggml/src/ggml-cuda/fwht.cu
Normal file
108
ggml/src/ggml-cuda/fwht.cu
Normal file
@@ -0,0 +1,108 @@
|
||||
#include "common.cuh"
|
||||
#include "fwht.cuh"
|
||||
|
||||
template <int N>
|
||||
__launch_bounds__(4*ggml_cuda_get_physical_warp_size(), 1)
|
||||
__global__ void fwht_cuda(const float * src, float * dst, const int64_t n_rows, const float scale) {
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
|
||||
const int64_t r = (int64_t) blockIdx.x * blockDim.y + threadIdx.y;
|
||||
|
||||
if (r >= n_rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
src += r * N;
|
||||
dst += r * N;
|
||||
|
||||
static constexpr int el_w = N / warp_size;
|
||||
float reg[el_w];
|
||||
const int lane = threadIdx.x;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < el_w; ++i) {
|
||||
reg[i] = src[i * warp_size + lane] * scale;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int h = 1; h < warp_size; h *= 2) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < el_w; j++) {
|
||||
const float val = reg[j];
|
||||
const float val2 = __shfl_xor_sync(0xFFFFFFFF, val, h, warp_size);
|
||||
|
||||
reg[j] = (lane & h) == 0 ? val + val2 : val2 - val;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int h = warp_size; h < N; h *= 2) {
|
||||
const int step = h / warp_size;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < el_w; j += 2 * step) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < step; k++) {
|
||||
const float x = reg[j + k];
|
||||
const float y = reg[j + k + step];
|
||||
|
||||
reg[j + k] = x + y;
|
||||
reg[j + k + step] = x - y;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < el_w; ++i) {
|
||||
dst[i * warp_size + lane] = reg[i];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_are_same_shape(src, dst));
|
||||
GGML_ASSERT(ggml_is_contiguous(src));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
const int n = src->ne[0];
|
||||
const int64_t rows = ggml_nrows(src);
|
||||
|
||||
const float * src_d = (const float *) src->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
||||
GGML_ASSERT(n % warp_size == 0);
|
||||
const int rows_per_block = 4;
|
||||
|
||||
const int64_t num_blocks = (rows + rows_per_block - 1) / rows_per_block;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
dim3 grid_dims(num_blocks, 1, 1);
|
||||
dim3 block_dims(warp_size, rows_per_block, 1);
|
||||
const ggml_cuda_kernel_launch_params launch_params =
|
||||
ggml_cuda_kernel_launch_params(grid_dims, block_dims, 0, stream);
|
||||
|
||||
const float scale = 1 / sqrtf(n);
|
||||
|
||||
switch (n) {
|
||||
case 64:
|
||||
{
|
||||
ggml_cuda_kernel_launch(fwht_cuda<64>, launch_params, src_d, dst_d, rows, scale);
|
||||
break;
|
||||
}
|
||||
case 128:
|
||||
{
|
||||
ggml_cuda_kernel_launch(fwht_cuda<128>, launch_params, src_d, dst_d, rows, scale);
|
||||
break;
|
||||
}
|
||||
case 256:
|
||||
{
|
||||
ggml_cuda_kernel_launch(fwht_cuda<256>, launch_params, src_d, dst_d, rows, scale);
|
||||
break;
|
||||
}
|
||||
case 512:
|
||||
{
|
||||
ggml_cuda_kernel_launch(fwht_cuda<512>, launch_params, src_d, dst_d, rows, scale);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
3
ggml/src/ggml-cuda/fwht.cuh
Normal file
3
ggml/src/ggml-cuda/fwht.cuh
Normal file
@@ -0,0 +1,3 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst);
|
||||
@@ -24,6 +24,7 @@
|
||||
#include "ggml-cuda/diagmask.cuh"
|
||||
#include "ggml-cuda/diag.cuh"
|
||||
#include "ggml-cuda/fattn.cuh"
|
||||
#include "ggml-cuda/fwht.cuh"
|
||||
#include "ggml-cuda/getrows.cuh"
|
||||
#include "ggml-cuda/im2col.cuh"
|
||||
#include "ggml-cuda/mmf.cuh"
|
||||
@@ -2594,6 +2595,13 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
||||
bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
|
||||
bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
|
||||
|
||||
const int32_t hint = ggml_get_op_params_i32(dst, 1);
|
||||
if (hint == GGML_HINT_SRC0_IS_HADAMARD) {
|
||||
GGML_ASSERT(!split);
|
||||
ggml_cuda_op_fwht(ctx, src1, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!split && use_mul_mat_vec_f) {
|
||||
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
|
||||
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
|
||||
|
||||
@@ -228,9 +228,18 @@ struct gguf_context {
|
||||
};
|
||||
|
||||
struct gguf_reader {
|
||||
gguf_reader(FILE * file) : file(file) {
|
||||
// read the remaining bytes once and update on each read
|
||||
nbytes_remain = file_remain(file);
|
||||
gguf_reader(
|
||||
gguf_reader_callback_t callback,
|
||||
void * userdata,
|
||||
size_t max_chunk_read,
|
||||
uint64_t data_offset = 0,
|
||||
uint64_t nbytes_remain = 0)
|
||||
: callback(callback),
|
||||
userdata(userdata),
|
||||
max_chunk_read(max_chunk_read),
|
||||
data_offset(data_offset),
|
||||
nbytes_remain(nbytes_remain) {
|
||||
GGML_ASSERT(max_chunk_read > 0);
|
||||
}
|
||||
|
||||
// helper for remaining bytes in a file
|
||||
@@ -257,12 +266,10 @@ struct gguf_reader {
|
||||
template <typename T>
|
||||
bool read(T & dst) const {
|
||||
const size_t size = sizeof(dst);
|
||||
if (nbytes_remain < size) {
|
||||
if (size > nbytes_remain) {
|
||||
return false;
|
||||
}
|
||||
const size_t nread = fread(&dst, 1, size, file);
|
||||
nbytes_remain -= nread;
|
||||
return nread == size;
|
||||
return read_raw(&dst, size) == size;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -344,24 +351,71 @@ struct gguf_reader {
|
||||
return false;
|
||||
}
|
||||
dst.resize(static_cast<size_t>(size));
|
||||
const size_t nread = fread(dst.data(), 1, size, file);
|
||||
nbytes_remain -= nread;
|
||||
return nread == size;
|
||||
return read_raw(dst.data(), static_cast<size_t>(size)) == size;
|
||||
}
|
||||
|
||||
bool read(void * dst, const size_t size) const {
|
||||
if (size > nbytes_remain) {
|
||||
return false;
|
||||
}
|
||||
const size_t nread = fread(dst, 1, size, file);
|
||||
nbytes_remain -= nread;
|
||||
return nread == size;
|
||||
return read_raw(dst, size) == size;
|
||||
}
|
||||
|
||||
uint64_t tell() const {
|
||||
return data_offset;
|
||||
}
|
||||
|
||||
bool seek(uint64_t absolute_offset) const {
|
||||
const uint64_t end_offset = uint64_t(data_offset) + nbytes_remain;
|
||||
if (absolute_offset > end_offset) {
|
||||
return false;
|
||||
}
|
||||
|
||||
data_offset = absolute_offset;
|
||||
nbytes_remain = end_offset - absolute_offset;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
FILE * file;
|
||||
size_t read_raw(void * dst, size_t size) const {
|
||||
if (callback == nullptr || size == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
mutable uint64_t nbytes_remain;
|
||||
uint8_t * data = static_cast<uint8_t *>(dst);
|
||||
size_t total_nread = 0;
|
||||
bool reached_eof = false;
|
||||
|
||||
while (total_nread < size) {
|
||||
const size_t chunk_size = std::min(max_chunk_read, size - total_nread);
|
||||
if (data_offset + total_nread < data_offset) {
|
||||
break;
|
||||
}
|
||||
const size_t nread = callback(userdata, static_cast<void *>(data + total_nread), data_offset + total_nread, chunk_size);
|
||||
total_nread += nread;
|
||||
if (nread != chunk_size) {
|
||||
reached_eof = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
data_offset += total_nread;
|
||||
GGML_ASSERT(total_nread <= nbytes_remain);
|
||||
nbytes_remain -= total_nread;
|
||||
|
||||
if (reached_eof) {
|
||||
nbytes_remain = 0;
|
||||
}
|
||||
|
||||
return total_nread;
|
||||
}
|
||||
|
||||
gguf_reader_callback_t callback = nullptr;
|
||||
void * userdata = nullptr;
|
||||
size_t max_chunk_read = 0;
|
||||
mutable uint64_t data_offset = 0;
|
||||
mutable uint64_t nbytes_remain = 0;
|
||||
};
|
||||
|
||||
struct gguf_context * gguf_init_empty(void) {
|
||||
@@ -394,12 +448,7 @@ bool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vector<struct
|
||||
return true;
|
||||
}
|
||||
|
||||
struct gguf_context * gguf_init_from_file_ptr(FILE * file, struct gguf_init_params params) {
|
||||
if (!file) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const struct gguf_reader gr(file);
|
||||
static struct gguf_context * gguf_init_from_reader(const struct gguf_reader & gr, struct gguf_init_params params) {
|
||||
struct gguf_context * ctx = new gguf_context;
|
||||
|
||||
bool ok = true;
|
||||
@@ -700,14 +749,14 @@ struct gguf_context * gguf_init_from_file_ptr(FILE * file, struct gguf_init_para
|
||||
GGML_ASSERT(int64_t(ctx->info.size()) == n_tensors);
|
||||
|
||||
// we require the data section to be aligned, so take into account any padding
|
||||
if (gguf_fseek(file, GGML_PAD(gguf_ftell(file), ctx->alignment), SEEK_SET) != 0) {
|
||||
if (n_tensors > 0 && !gr.seek(GGML_PAD(gr.tell(), ctx->alignment))) {
|
||||
GGML_LOG_ERROR("%s: failed to seek to beginning of data section\n", __func__);
|
||||
gguf_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// store the current file offset - this is where the data section starts
|
||||
ctx->offset = gguf_ftell(file);
|
||||
ctx->offset = gr.tell();
|
||||
|
||||
// compute the total size of the data section, taking into account the alignment
|
||||
{
|
||||
@@ -844,6 +893,89 @@ struct gguf_context * gguf_init_from_file_ptr(FILE * file, struct gguf_init_para
|
||||
return ctx;
|
||||
}
|
||||
|
||||
struct gguf_context * gguf_init_from_callback(gguf_reader_callback_t callback, void * userdata, size_t max_chunk_read, uint64_t max_expected_size, struct gguf_init_params params) {
|
||||
if (callback == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const struct gguf_reader gr(callback, userdata, max_chunk_read == 0 ? SIZE_MAX : max_chunk_read, 0, max_expected_size);
|
||||
return gguf_init_from_reader(gr, params);
|
||||
}
|
||||
|
||||
struct gguf_file_reader {
|
||||
FILE * file;
|
||||
uint64_t offset;
|
||||
};
|
||||
|
||||
static size_t gguf_file_reader_callback(void * userdata, void * output, uint64_t offset, size_t len) {
|
||||
GGML_ASSERT(len > 0);
|
||||
|
||||
gguf_file_reader & reader = *static_cast<gguf_file_reader *>(userdata);
|
||||
|
||||
if (reader.offset != offset) {
|
||||
if (offset > INT64_MAX || gguf_fseek(reader.file, static_cast<int64_t>(offset), SEEK_SET) != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
reader.offset = offset;
|
||||
}
|
||||
|
||||
const size_t nread = fread(static_cast<uint8_t *>(output), 1, len, reader.file);
|
||||
reader.offset += nread;
|
||||
return nread;
|
||||
}
|
||||
|
||||
struct gguf_context * gguf_init_from_file_ptr(FILE * file, struct gguf_init_params params) {
|
||||
if (!file) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const int64_t cur = gguf_ftell(file);
|
||||
if (cur < 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
gguf_file_reader reader = {
|
||||
/*.file = */ file,
|
||||
/*.offset = */ static_cast<uint64_t>(cur),
|
||||
};
|
||||
const struct gguf_reader gr(gguf_file_reader_callback, &reader, SIZE_MAX, reader.offset, gguf_reader::file_remain(file));
|
||||
return gguf_init_from_reader(gr, params);
|
||||
}
|
||||
|
||||
struct gguf_buffer_reader {
|
||||
const uint8_t * data;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
static size_t gguf_buffer_reader_callback(void * userdata, void * output, uint64_t offset, size_t len) {
|
||||
GGML_ASSERT(len > 0);
|
||||
|
||||
const gguf_buffer_reader & reader = *static_cast<gguf_buffer_reader *>(userdata);
|
||||
|
||||
if (offset > reader.size || len > reader.size - offset) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const size_t data_offset = static_cast<size_t>(offset);
|
||||
const size_t nread = std::min(len, reader.size - data_offset);
|
||||
memcpy(static_cast<uint8_t *>(output), reader.data + data_offset, nread);
|
||||
return nread;
|
||||
}
|
||||
|
||||
struct gguf_context * gguf_init_from_buffer(const void * data, size_t size, struct gguf_init_params params) {
|
||||
if (data == nullptr || size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
gguf_buffer_reader reader = {
|
||||
/*.data = */ static_cast<const uint8_t *>(data),
|
||||
/*.size = */ size,
|
||||
};
|
||||
const struct gguf_reader gr(gguf_buffer_reader_callback, &reader, SIZE_MAX, 0, size);
|
||||
return gguf_init_from_reader(gr, params);
|
||||
}
|
||||
|
||||
struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) {
|
||||
FILE * file = ggml_fopen(fname, "rb");
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
0ce7ad348a3151e1da9f65d962044546bcaad421
|
||||
e705c5fed490514458bdd2eaddc43bd098fcce9b
|
||||
|
||||
@@ -767,8 +767,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
// Nemotron 3 Super
|
||||
{LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
// latent projections feed ggml_mul_mat, the buft probe must use MUL_MAT to keep them on GPU
|
||||
{LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
};
|
||||
|
||||
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
|
||||
|
||||
@@ -8308,6 +8308,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 64, 1, 64));
|
||||
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 256, 1, 256));
|
||||
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 128, 32, 128));
|
||||
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 128, 4, 128, {2, 3}));
|
||||
|
||||
#if 0
|
||||
// > 4GB A matrix. Too slow to be enabled by default.
|
||||
|
||||
@@ -162,6 +162,42 @@ static void helper_write(FILE * file, const void * data, const size_t nbytes) {
|
||||
GGML_ASSERT(fwrite(data, 1, nbytes, file) == nbytes);
|
||||
}
|
||||
|
||||
static std::vector<uint8_t> read_file_to_buffer(FILE * file) {
|
||||
GGML_ASSERT(file != nullptr);
|
||||
GGML_ASSERT(fseek(file, 0, SEEK_END) == 0);
|
||||
|
||||
const long size = ftell(file);
|
||||
GGML_ASSERT(size >= 0);
|
||||
|
||||
rewind(file);
|
||||
|
||||
std::vector<uint8_t> data(static_cast<size_t>(size));
|
||||
GGML_ASSERT(fread(data.data(), 1, data.size(), file) == data.size());
|
||||
|
||||
rewind(file);
|
||||
return data;
|
||||
}
|
||||
|
||||
struct callback_reader_data {
|
||||
const uint8_t * data;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
static size_t read_buffer_callback(void * userdata, void * output, uint64_t offset, size_t len) {
|
||||
GGML_ASSERT(len > 0);
|
||||
|
||||
const callback_reader_data & reader = *static_cast<callback_reader_data *>(userdata);
|
||||
|
||||
if (offset > reader.size || len > reader.size - offset) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const size_t data_offset = static_cast<size_t>(offset);
|
||||
const size_t nread = std::min(len, reader.size - data_offset);
|
||||
memcpy(static_cast<uint8_t *>(output), reader.data + data_offset, nread);
|
||||
return nread;
|
||||
}
|
||||
|
||||
static FILE * get_handcrafted_file(const unsigned int seed, const enum handcrafted_file_type hft, const int extra_bytes = 0) {
|
||||
FILE * file = tmpfile();
|
||||
|
||||
@@ -1095,10 +1131,29 @@ static bool same_tensor_data(const struct ggml_context * orig, const struct ggml
|
||||
return ok;
|
||||
}
|
||||
|
||||
static std::pair<int, int> test_roundtrip(ggml_backend_dev_t dev, const unsigned int seed, const bool only_meta) {
|
||||
enum roundtrip_read_mode {
|
||||
ROUNDTRIP_READ_MODE_FILE,
|
||||
ROUNDTRIP_READ_MODE_BUFFER,
|
||||
ROUNDTRIP_READ_MODE_CALLBACK,
|
||||
};
|
||||
|
||||
static const char * roundtrip_read_mode_name(const roundtrip_read_mode mode) {
|
||||
switch (mode) {
|
||||
case ROUNDTRIP_READ_MODE_FILE: return "file";
|
||||
case ROUNDTRIP_READ_MODE_BUFFER: return "buffer";
|
||||
case ROUNDTRIP_READ_MODE_CALLBACK: return "callback";
|
||||
}
|
||||
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
static std::pair<int, int> test_roundtrip(
|
||||
ggml_backend_dev_t dev, const unsigned int seed, const bool only_meta,
|
||||
const roundtrip_read_mode read_mode) {
|
||||
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
|
||||
printf("%s: device=%s, backend=%s, only_meta=%s\n",
|
||||
__func__, ggml_backend_dev_description(dev), ggml_backend_name(backend), only_meta ? "yes" : "no");
|
||||
printf("%s: device=%s, backend=%s, only_meta=%s, read_mode=%s\n",
|
||||
__func__, ggml_backend_dev_description(dev), ggml_backend_name(backend),
|
||||
only_meta ? "yes" : "no", roundtrip_read_mode_name(read_mode));
|
||||
|
||||
int npass = 0;
|
||||
int ntest = 0;
|
||||
@@ -1133,7 +1188,22 @@ static std::pair<int, int> test_roundtrip(ggml_backend_dev_t dev, const unsigned
|
||||
/*no_alloc =*/ false,
|
||||
/*ctx =*/ only_meta ? nullptr : &ctx_1,
|
||||
};
|
||||
struct gguf_context * gguf_ctx_1 = gguf_init_from_file_ptr(file, gguf_params);
|
||||
struct gguf_context * gguf_ctx_1 = nullptr;
|
||||
const std::vector<uint8_t> data = read_mode == ROUNDTRIP_READ_MODE_FILE
|
||||
? std::vector<uint8_t>()
|
||||
: read_file_to_buffer(file);
|
||||
|
||||
if (read_mode == ROUNDTRIP_READ_MODE_BUFFER) {
|
||||
gguf_ctx_1 = gguf_init_from_buffer(data.data(), data.size(), gguf_params);
|
||||
} else if (read_mode == ROUNDTRIP_READ_MODE_CALLBACK) {
|
||||
callback_reader_data reader = {
|
||||
/*.data = */ data.data(),
|
||||
/*.size = */ data.size(),
|
||||
};
|
||||
gguf_ctx_1 = gguf_init_from_callback(read_buffer_callback, &reader, 4096, 4ull << 30 /* 4GB */, gguf_params);
|
||||
} else {
|
||||
gguf_ctx_1 = gguf_init_from_file_ptr(file, gguf_params);
|
||||
}
|
||||
|
||||
printf("%s: same_version: ", __func__);
|
||||
if (gguf_get_version(gguf_ctx_0) == gguf_get_version(gguf_ctx_1)) {
|
||||
@@ -1343,7 +1413,17 @@ int main(int argc, char ** argv) {
|
||||
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
||||
|
||||
for (bool only_meta : {true, false}) {
|
||||
std::pair<int, int> result = test_roundtrip(dev, seed, only_meta);
|
||||
std::pair<int, int> result = test_roundtrip(dev, seed, only_meta, ROUNDTRIP_READ_MODE_FILE);
|
||||
npass += result.first;
|
||||
ntest += result.second;
|
||||
}
|
||||
{
|
||||
std::pair<int, int> result = test_roundtrip(dev, seed, /*only_meta=*/false, ROUNDTRIP_READ_MODE_BUFFER);
|
||||
npass += result.first;
|
||||
ntest += result.second;
|
||||
}
|
||||
{
|
||||
std::pair<int, int> result = test_roundtrip(dev, seed, /*only_meta=*/false, ROUNDTRIP_READ_MODE_CALLBACK);
|
||||
npass += result.first;
|
||||
ntest += result.second;
|
||||
}
|
||||
|
||||
@@ -16,3 +16,12 @@ export enum AgenticSectionType {
|
||||
REASONING = 'reasoning',
|
||||
REASONING_PENDING = 'reasoning_pending'
|
||||
}
|
||||
|
||||
/**
|
||||
* How a Continue click on an assistant message resumes generation.
|
||||
*/
|
||||
export enum ContinueIntentKind {
|
||||
APPEND_TEXT = 'append_text',
|
||||
RERUN_TURN = 'rerun_turn',
|
||||
NEXT_TURN = 'next_turn'
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ export {
|
||||
AttachmentItemVisibleWhen
|
||||
} from './attachment.enums';
|
||||
|
||||
export { AgenticSectionType, ToolCallType } from './agentic.enums';
|
||||
export { AgenticSectionType, ContinueIntentKind, ToolCallType } from './agentic.enums';
|
||||
|
||||
export {
|
||||
ChatMessageStatsView,
|
||||
|
||||
@@ -33,6 +33,7 @@ import {
|
||||
isAbortError,
|
||||
generateConversationTitle
|
||||
} from '$lib/utils';
|
||||
import { classifyContinueIntent } from '$lib/utils/agentic';
|
||||
import {
|
||||
MAX_INACTIVE_CONVERSATION_STATES,
|
||||
INACTIVE_CONVERSATION_STATE_MAX_AGE_MS,
|
||||
@@ -51,7 +52,7 @@ import type {
|
||||
DatabaseMessage,
|
||||
DatabaseMessageExtra
|
||||
} from '$lib/types';
|
||||
import { ErrorDialogType, MessageRole, MessageType } from '$lib/enums';
|
||||
import { ContinueIntentKind, ErrorDialogType, MessageRole, MessageType } from '$lib/enums';
|
||||
|
||||
interface ConversationStateEntry {
|
||||
lastAccessed: number;
|
||||
@@ -1259,6 +1260,57 @@ class ChatStore {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Open a fresh assistant turn anchored at the last tool result of a resolved
|
||||
* agentic round and let streamChatCompletion route through runAgenticFlow.
|
||||
* Used by continueAssistantMessage when classifyContinueIntent returns
|
||||
* next_turn, meaning the target assistant already has its tool_calls paired
|
||||
* with trailing tool results and the next thing to generate is a brand new
|
||||
* turn rather than a token level continuation.
|
||||
*/
|
||||
private async continueAsNextAgenticTurn(anchorIndex: number): Promise<void> {
|
||||
const activeConv = conversationsStore.activeConversation;
|
||||
if (!activeConv) return;
|
||||
const anchor = conversationsStore.activeMessages[anchorIndex];
|
||||
if (!anchor) return;
|
||||
this.cancelPreEncode();
|
||||
this.setChatLoading(activeConv.id, true);
|
||||
this.clearChatStreaming(activeConv.id);
|
||||
try {
|
||||
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
|
||||
const anchorMessage = findMessageById(allMessages, anchor.id);
|
||||
if (!anchorMessage) {
|
||||
this.setChatLoading(activeConv.id, false);
|
||||
return;
|
||||
}
|
||||
const newAssistantMessage = await DatabaseService.createMessageBranch(
|
||||
{
|
||||
convId: activeConv.id,
|
||||
type: MessageType.TEXT,
|
||||
timestamp: Date.now(),
|
||||
role: MessageRole.ASSISTANT,
|
||||
content: '',
|
||||
toolCalls: '',
|
||||
children: [],
|
||||
model: null
|
||||
},
|
||||
anchorMessage.id
|
||||
);
|
||||
await conversationsStore.updateCurrentNode(newAssistantMessage.id);
|
||||
conversationsStore.updateConversationTimestamp();
|
||||
await conversationsStore.refreshActiveMessages();
|
||||
const conversationPath = filterByLeafNodeId(
|
||||
allMessages,
|
||||
anchorMessage.id,
|
||||
false
|
||||
) as DatabaseMessage[];
|
||||
await this.streamChatCompletion(conversationPath, newAssistantMessage);
|
||||
} catch (error) {
|
||||
if (!isAbortError(error)) console.error('Failed to continue agentic turn:', error);
|
||||
this.setChatLoading(activeConv.id, false);
|
||||
}
|
||||
}
|
||||
|
||||
async continueAssistantMessage(messageId: string): Promise<void> {
|
||||
const activeConv = conversationsStore.activeConversation;
|
||||
if (!activeConv || this.isChatLoadingInternal(activeConv.id)) return;
|
||||
@@ -1268,6 +1320,18 @@ class ChatStore {
|
||||
|
||||
const { message: msg, index: idx } = result;
|
||||
|
||||
// Decide which resume path applies. tool_calls without tool results can
|
||||
// not be resumed mid sequence by continue_final_message, branch instead.
|
||||
// tool_calls already paired with tool results need a fresh next turn,
|
||||
// not a token level continuation of the target assistant.
|
||||
const intent = classifyContinueIntent(conversationsStore.activeMessages, idx);
|
||||
if (intent.kind === ContinueIntentKind.RERUN_TURN) {
|
||||
return this.regenerateMessageWithBranching(messageId);
|
||||
}
|
||||
if (intent.kind === ContinueIntentKind.NEXT_TURN) {
|
||||
return this.continueAsNextAgenticTurn(intent.truncateAfter);
|
||||
}
|
||||
|
||||
try {
|
||||
this.showErrorDialog(null);
|
||||
this.setChatLoading(activeConv.id, true);
|
||||
@@ -1283,15 +1347,11 @@ class ChatStore {
|
||||
|
||||
const originalContent = dbMessage.content;
|
||||
const originalReasoning = dbMessage.reasoningContent || '';
|
||||
const conversationContext = conversationsStore.activeMessages.slice(0, idx);
|
||||
const contextWithContinue = [
|
||||
...conversationContext,
|
||||
{
|
||||
role: MessageRole.ASSISTANT as const,
|
||||
content: originalContent,
|
||||
reasoning_content: originalReasoning || undefined
|
||||
}
|
||||
];
|
||||
// Hand the persisted DatabaseMessage straight to sendMessage so its
|
||||
// internal converter preserves tool_calls and extras when present.
|
||||
// Reconstructing a bare {role, content} here would drop those fields
|
||||
// and break continue_final_message for messages with tool calls.
|
||||
const contextWithContinue = conversationsStore.activeMessages.slice(0, idx + 1);
|
||||
|
||||
let appendedContent = '';
|
||||
let appendedReasoning = '';
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { AgenticSectionType, MessageRole } from '$lib/enums';
|
||||
import { AgenticSectionType, ContinueIntentKind, MessageRole } from '$lib/enums';
|
||||
import { ATTACHMENT_SAVED_REGEX, NEWLINE_SEPARATOR } from '$lib/constants';
|
||||
import type { ApiChatCompletionToolCall } from '$lib/types/api';
|
||||
import type {
|
||||
@@ -225,3 +225,62 @@ export function hasAgenticContent(
|
||||
|
||||
return toolMessages.length > 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Classification of how a Continue click on an assistant message should resume
|
||||
* generation. The caller dispatches the resume path based on this value.
|
||||
*
|
||||
* append_text -> the target is a plain text turn, resume with
|
||||
* continue_final_message and rehydrate the persisted
|
||||
* tool_calls and attachments through the regular DB to API
|
||||
* message converter.
|
||||
* rerun_turn -> the target carries tool_calls that were never resolved by
|
||||
* tool result messages. The agentic stream was cut mid turn,
|
||||
* so we drop the target and rerun the loop from the previous
|
||||
* history. truncateAfter is the last kept index, inclusive.
|
||||
* next_turn -> the target's tool_calls were already resolved by trailing
|
||||
* tool results. Hand the history up to and including the
|
||||
* last consecutive tool result back to the agentic loop so it
|
||||
* starts the next turn naturally. truncateAfter points at
|
||||
* that last tool result.
|
||||
*/
|
||||
export type ContinueIntent =
|
||||
| { kind: ContinueIntentKind.APPEND_TEXT }
|
||||
| { kind: ContinueIntentKind.RERUN_TURN; truncateAfter: number }
|
||||
| { kind: ContinueIntentKind.NEXT_TURN; truncateAfter: number };
|
||||
|
||||
/**
|
||||
* Decide how a Continue click on messages[idx] should resume generation.
|
||||
* Pure function over the persisted history snapshot.
|
||||
*/
|
||||
export function classifyContinueIntent(messages: DatabaseMessage[], idx: number): ContinueIntent {
|
||||
const target = messages[idx];
|
||||
|
||||
// Defensive default: callers already filter by role, stay deterministic.
|
||||
if (!target || target.role !== MessageRole.ASSISTANT) {
|
||||
return { kind: ContinueIntentKind.APPEND_TEXT };
|
||||
}
|
||||
|
||||
const hasToolCalls = parseToolCalls(target.toolCalls).length > 0;
|
||||
if (!hasToolCalls) {
|
||||
return { kind: ContinueIntentKind.APPEND_TEXT };
|
||||
}
|
||||
|
||||
// Walk consecutive trailing tool results. The agentic loop only emits tool
|
||||
// messages directly after the assistant turn that owns them, so the first
|
||||
// non tool message marks the boundary.
|
||||
let lastTrailingTool = idx;
|
||||
for (let i = idx + 1; i < messages.length; i++) {
|
||||
if (messages[i].role === MessageRole.TOOL) {
|
||||
lastTrailingTool = i;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (lastTrailingTool > idx) {
|
||||
return { kind: ContinueIntentKind.NEXT_TURN, truncateAfter: lastTrailingTool };
|
||||
}
|
||||
|
||||
return { kind: ContinueIntentKind.RERUN_TURN, truncateAfter: idx - 1 };
|
||||
}
|
||||
|
||||
166
tools/ui/tests/unit/continue-intent.test.ts
Normal file
166
tools/ui/tests/unit/continue-intent.test.ts
Normal file
@@ -0,0 +1,166 @@
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { classifyContinueIntent } from '$lib/utils/agentic';
|
||||
import { ContinueIntentKind, MessageRole, MessageType } from '$lib/enums';
|
||||
import type { DatabaseMessage } from '$lib/types/database';
|
||||
|
||||
/**
|
||||
* Tests for the Continue button intent classifier.
|
||||
*
|
||||
* The classifier walks the persisted message history to decide which of three
|
||||
* resume paths a Continue click should take:
|
||||
*
|
||||
* A. append_text -> plain text assistant turn, resume with
|
||||
* continue_final_message.
|
||||
* B. rerun_turn -> assistant turn with tool_calls but no tool results yet,
|
||||
* the stream was cut mid turn and the tool_calls are
|
||||
* unrecoverable as a token level continuation. Drop the
|
||||
* target and rerun from the previous history.
|
||||
* C. next_turn -> assistant turn with tool_calls that were already
|
||||
* resolved by trailing tool results. Hand the history
|
||||
* back to the agentic loop so it starts the next turn.
|
||||
*/
|
||||
|
||||
let nextId = 0;
|
||||
function makeMsg(role: MessageRole, opts: Partial<DatabaseMessage> = {}): DatabaseMessage {
|
||||
nextId++;
|
||||
return {
|
||||
id: `msg-${nextId}`,
|
||||
convId: 'conv-1',
|
||||
type: MessageType.TEXT,
|
||||
timestamp: nextId,
|
||||
role,
|
||||
content: '',
|
||||
parent: null,
|
||||
children: [],
|
||||
...opts
|
||||
};
|
||||
}
|
||||
|
||||
function toolCall(id: string, name: string, args: string = '{}'): string {
|
||||
return JSON.stringify([{ id, type: 'function', function: { name, arguments: args } }]);
|
||||
}
|
||||
|
||||
describe('classifyContinueIntent', () => {
|
||||
it('returns append_text for a plain text assistant turn at the tail', () => {
|
||||
const messages = [
|
||||
makeMsg(MessageRole.USER, { content: 'hello' }),
|
||||
makeMsg(MessageRole.ASSISTANT, { content: 'hi there' })
|
||||
];
|
||||
|
||||
const intent = classifyContinueIntent(messages, 1);
|
||||
|
||||
expect(intent).toEqual({ kind: ContinueIntentKind.APPEND_TEXT });
|
||||
});
|
||||
|
||||
it('returns append_text for a plain text assistant turn in the middle', () => {
|
||||
const messages = [
|
||||
makeMsg(MessageRole.USER, { content: 'q1' }),
|
||||
makeMsg(MessageRole.ASSISTANT, { content: 'a1' }),
|
||||
makeMsg(MessageRole.USER, { content: 'q2' }),
|
||||
makeMsg(MessageRole.ASSISTANT, { content: 'a2' })
|
||||
];
|
||||
|
||||
expect(classifyContinueIntent(messages, 1)).toEqual({ kind: ContinueIntentKind.APPEND_TEXT });
|
||||
});
|
||||
|
||||
it('returns rerun_turn when the assistant has tool_calls without results', () => {
|
||||
const messages = [
|
||||
makeMsg(MessageRole.USER, { content: 'list files' }),
|
||||
makeMsg(MessageRole.ASSISTANT, {
|
||||
content: '',
|
||||
toolCalls: toolCall('call_1', 'bash_tool', '{"command":"ls"}')
|
||||
})
|
||||
];
|
||||
|
||||
const intent = classifyContinueIntent(messages, 1);
|
||||
|
||||
expect(intent).toEqual({ kind: ContinueIntentKind.RERUN_TURN, truncateAfter: 0 });
|
||||
});
|
||||
|
||||
it('returns next_turn when trailing tool results resolve the tool_calls', () => {
|
||||
const messages = [
|
||||
makeMsg(MessageRole.USER, { content: 'list files' }),
|
||||
makeMsg(MessageRole.ASSISTANT, {
|
||||
content: '',
|
||||
toolCalls: toolCall('call_1', 'bash_tool')
|
||||
}),
|
||||
makeMsg(MessageRole.TOOL, { content: 'file1\nfile2', toolCallId: 'call_1' })
|
||||
];
|
||||
|
||||
const intent = classifyContinueIntent(messages, 1);
|
||||
|
||||
expect(intent).toEqual({ kind: ContinueIntentKind.NEXT_TURN, truncateAfter: 2 });
|
||||
});
|
||||
|
||||
it('next_turn keeps all consecutive trailing tool results, not just one', () => {
|
||||
const messages = [
|
||||
makeMsg(MessageRole.USER, { content: 'do many things' }),
|
||||
makeMsg(MessageRole.ASSISTANT, {
|
||||
content: '',
|
||||
toolCalls: JSON.stringify([
|
||||
{ id: 'call_1', type: 'function', function: { name: 'a', arguments: '{}' } },
|
||||
{ id: 'call_2', type: 'function', function: { name: 'b', arguments: '{}' } }
|
||||
])
|
||||
}),
|
||||
makeMsg(MessageRole.TOOL, { content: 'r1', toolCallId: 'call_1' }),
|
||||
makeMsg(MessageRole.TOOL, { content: 'r2', toolCallId: 'call_2' })
|
||||
];
|
||||
|
||||
const intent = classifyContinueIntent(messages, 1);
|
||||
|
||||
expect(intent).toEqual({ kind: ContinueIntentKind.NEXT_TURN, truncateAfter: 3 });
|
||||
});
|
||||
|
||||
it('next_turn stops at the first non tool message after the target', () => {
|
||||
const messages = [
|
||||
makeMsg(MessageRole.USER, { content: 'go' }),
|
||||
makeMsg(MessageRole.ASSISTANT, {
|
||||
content: '',
|
||||
toolCalls: toolCall('call_1', 'a')
|
||||
}),
|
||||
makeMsg(MessageRole.TOOL, { content: 'r1', toolCallId: 'call_1' }),
|
||||
makeMsg(MessageRole.USER, { content: 'wait' }),
|
||||
makeMsg(MessageRole.TOOL, { content: 'late', toolCallId: 'call_1' })
|
||||
];
|
||||
|
||||
const intent = classifyContinueIntent(messages, 1);
|
||||
|
||||
// truncateAfter must point at the contiguous tool block, not jump over
|
||||
// the user message to grab the dangling late tool.
|
||||
expect(intent).toEqual({ kind: ContinueIntentKind.NEXT_TURN, truncateAfter: 2 });
|
||||
});
|
||||
|
||||
it('returns append_text when toolCalls is set but parses to empty array', () => {
|
||||
const messages = [
|
||||
makeMsg(MessageRole.USER, { content: 'q' }),
|
||||
makeMsg(MessageRole.ASSISTANT, { content: 'a', toolCalls: '[]' })
|
||||
];
|
||||
|
||||
expect(classifyContinueIntent(messages, 1)).toEqual({ kind: ContinueIntentKind.APPEND_TEXT });
|
||||
});
|
||||
|
||||
it('returns append_text when toolCalls is malformed JSON', () => {
|
||||
const messages = [
|
||||
makeMsg(MessageRole.USER, { content: 'q' }),
|
||||
makeMsg(MessageRole.ASSISTANT, { content: 'a', toolCalls: '{not json' })
|
||||
];
|
||||
|
||||
expect(classifyContinueIntent(messages, 1)).toEqual({ kind: ContinueIntentKind.APPEND_TEXT });
|
||||
});
|
||||
|
||||
it('returns append_text defensively when idx points at a non assistant message', () => {
|
||||
const messages = [
|
||||
makeMsg(MessageRole.USER, { content: 'q' }),
|
||||
makeMsg(MessageRole.ASSISTANT, { content: 'a' })
|
||||
];
|
||||
|
||||
expect(classifyContinueIntent(messages, 0)).toEqual({ kind: ContinueIntentKind.APPEND_TEXT });
|
||||
});
|
||||
|
||||
it('returns append_text defensively when idx is out of bounds', () => {
|
||||
const messages = [makeMsg(MessageRole.ASSISTANT, { content: 'a' })];
|
||||
|
||||
expect(classifyContinueIntent(messages, 5)).toEqual({ kind: ContinueIntentKind.APPEND_TEXT });
|
||||
expect(classifyContinueIntent([], 0)).toEqual({ kind: ContinueIntentKind.APPEND_TEXT });
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user